/* * Copyright (c) 2016-2024 Arm Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #ifndef ACL_ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H #define ACL_ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H #include "arm_compute/core/PixelValue.h" #include "arm_compute/core/Types.h" namespace arm_compute { /** The size in bytes of the data type * * @param[in] data_type Input data type * * @return The size in bytes of the data type */ inline size_t data_size_from_type(DataType data_type) { switch (data_type) { case DataType::U8: case DataType::S8: case DataType::QSYMM8: case DataType::QASYMM8: case DataType::QASYMM8_SIGNED: case DataType::QSYMM8_PER_CHANNEL: return 1; case DataType::U16: case DataType::S16: case DataType::QSYMM16: case DataType::QASYMM16: case DataType::BFLOAT16: case DataType::F16: return 2; case DataType::F32: case DataType::U32: case DataType::S32: return 4; case DataType::F64: case DataType::U64: case DataType::S64: return 8; case DataType::SIZET: return sizeof(size_t); default: ARM_COMPUTE_ERROR("Invalid data type"); return 0; } } /** The size in bytes of the data type * * @param[in] dt Input data type * * @return The size in bytes of the data type */ inline size_t element_size_from_data_type(DataType dt) { switch (dt) { case DataType::S8: case DataType::U8: case DataType::QSYMM8: case DataType::QASYMM8: case DataType::QASYMM8_SIGNED: case DataType::QSYMM8_PER_CHANNEL: return 1; case DataType::U16: case DataType::S16: case DataType::QSYMM16: case DataType::QASYMM16: case DataType::BFLOAT16: case DataType::F16: return 2; case DataType::U32: case DataType::S32: case DataType::F32: return 4; case DataType::U64: case DataType::S64: return 8; default: ARM_COMPUTE_ERROR("Undefined element size for given data type"); return 0; } } /** Return the data type used by a given single-planar pixel format * * @param[in] format Input format * * @return The size in bytes of the pixel format */ inline DataType data_type_from_format(Format format) { switch (format) { case Format::U8: case Format::UV88: case Format::RGB888: case Format::RGBA8888: case Format::YUYV422: case Format::UYVY422: return DataType::U8; case Format::U16: return DataType::U16; case Format::S16: return DataType::S16; case Format::U32: return DataType::U32; case Format::S32: return DataType::S32; case Format::BFLOAT16: return DataType::BFLOAT16; case Format::F16: return DataType::F16; case Format::F32: return DataType::F32; //Doesn't make sense for planar formats: case Format::NV12: case Format::NV21: case Format::IYUV: case Format::YUV444: default: ARM_COMPUTE_ERROR("Not supported data_type for given format"); return DataType::UNKNOWN; } } /** Return the promoted data type of a given data type. * * @note If promoted data type is not supported an error will be thrown * * @param[in] dt Data type to get the promoted type of. * * @return Promoted data type */ inline DataType get_promoted_data_type(DataType dt) { switch (dt) { case DataType::U8: return DataType::U16; case DataType::S8: return DataType::S16; case DataType::U16: return DataType::U32; case DataType::S16: return DataType::S32; case DataType::QSYMM8: case DataType::QASYMM8: case DataType::QASYMM8_SIGNED: case DataType::QSYMM8_PER_CHANNEL: case DataType::QSYMM16: case DataType::QASYMM16: case DataType::BFLOAT16: case DataType::F16: case DataType::U32: case DataType::S32: case DataType::F32: ARM_COMPUTE_ERROR("Unsupported data type promotions!"); default: ARM_COMPUTE_ERROR("Undefined data type!"); } return DataType::UNKNOWN; } /** Compute the mininum and maximum values a data type can take * * @param[in] dt Data type to get the min/max bounds of * * @return A tuple (min,max) with the minimum and maximum values respectively wrapped in PixelValue. */ inline std::tuple get_min_max(DataType dt) { PixelValue min{}; PixelValue max{}; switch (dt) { case DataType::U8: case DataType::QASYMM8: { min = PixelValue(static_cast(std::numeric_limits::lowest())); max = PixelValue(static_cast(std::numeric_limits::max())); break; } case DataType::S8: case DataType::QSYMM8: case DataType::QASYMM8_SIGNED: case DataType::QSYMM8_PER_CHANNEL: { min = PixelValue(static_cast(std::numeric_limits::lowest())); max = PixelValue(static_cast(std::numeric_limits::max())); break; } case DataType::U16: case DataType::QASYMM16: { min = PixelValue(static_cast(std::numeric_limits::lowest())); max = PixelValue(static_cast(std::numeric_limits::max())); break; } case DataType::S16: case DataType::QSYMM16: { min = PixelValue(static_cast(std::numeric_limits::lowest())); max = PixelValue(static_cast(std::numeric_limits::max())); break; } case DataType::U32: { min = PixelValue(std::numeric_limits::lowest()); max = PixelValue(std::numeric_limits::max()); break; } case DataType::S32: { min = PixelValue(std::numeric_limits::lowest()); max = PixelValue(std::numeric_limits::max()); break; } case DataType::BFLOAT16: { min = PixelValue(bfloat16::lowest()); max = PixelValue(bfloat16::max()); break; } case DataType::F16: { min = PixelValue(std::numeric_limits::lowest()); max = PixelValue(std::numeric_limits::max()); break; } case DataType::F32: { min = PixelValue(std::numeric_limits::lowest()); max = PixelValue(std::numeric_limits::max()); break; } default: ARM_COMPUTE_ERROR("Undefined data type!"); } return std::make_tuple(min, max); } /** Convert a data type identity into a string. * * @param[in] dt @ref DataType to be translated to string. * * @return The string describing the data type. */ const std::string &string_from_data_type(DataType dt); /** Convert a string to DataType * * @param[in] name The name of the data type * * @return DataType */ DataType data_type_from_name(const std::string &name); /** Input Stream operator for @ref DataType * * @param[in] stream Stream to parse * @param[out] data_type Output data type * * @return Updated stream */ inline ::std::istream &operator>>(::std::istream &stream, DataType &data_type) { std::string value; stream >> value; data_type = data_type_from_name(value); return stream; } /** Check if a given data type is of floating point type * * @param[in] dt Input data type. * * @return True if data type is of floating point type, else false. */ inline bool is_data_type_float(DataType dt) { switch (dt) { case DataType::F16: case DataType::F32: return true; default: return false; } } /** Check if a given data type is of quantized type * * @note Quantized is considered a super-set of fixed-point and asymmetric data types. * * @param[in] dt Input data type. * * @return True if data type is of quantized type, else false. */ inline bool is_data_type_quantized(DataType dt) { switch (dt) { case DataType::QSYMM8: case DataType::QASYMM8: case DataType::QASYMM8_SIGNED: case DataType::QSYMM8_PER_CHANNEL: case DataType::QSYMM16: case DataType::QASYMM16: return true; default: return false; } } /** Check if a given data type is of asymmetric quantized type * * @param[in] dt Input data type. * * @return True if data type is of asymmetric quantized type, else false. */ inline bool is_data_type_quantized_asymmetric(DataType dt) { switch (dt) { case DataType::QASYMM8: case DataType::QASYMM8_SIGNED: case DataType::QASYMM16: return true; default: return false; } } /** Check if a given data type is of asymmetric quantized signed type * * @param[in] dt Input data type. * * @return True if data type is of asymmetric quantized signed type, else false. */ inline bool is_data_type_quantized_asymmetric_signed(DataType dt) { switch (dt) { case DataType::QASYMM8_SIGNED: return true; default: return false; } } /** Check if a given data type is of 8-bit asymmetric quantized signed type * * @param[in] dt Input data type. * * @return True if data type is of 8-bit asymmetric quantized signed type, else false. */ inline bool is_data_type_quantized_asymmetric_char(DataType dt) { switch (dt) { case DataType::QASYMM8_SIGNED: case DataType::QASYMM8: return true; default: return false; } } /** Check if a given data type is of symmetric quantized type * * @param[in] dt Input data type. * * @return True if data type is of symmetric quantized type, else false. */ inline bool is_data_type_quantized_symmetric(DataType dt) { switch (dt) { case DataType::QSYMM8: case DataType::QSYMM8_PER_CHANNEL: case DataType::QSYMM16: return true; default: return false; } } /** Check if a given data type is of per channel type * * @param[in] dt Input data type. * * @return True if data type is of per channel type, else false. */ inline bool is_data_type_quantized_per_channel(DataType dt) { switch (dt) { case DataType::QSYMM8_PER_CHANNEL: return true; default: return false; } } /** Returns true if the value can be represented by the given data type * * @param[in] val value to be checked * @param[in] dt data type that is checked * @param[in] qinfo (Optional) quantization info if the data type is QASYMM8 * * @return true if the data type can hold the value. */ template bool check_value_range(T val, DataType dt, QuantizationInfo qinfo = QuantizationInfo()) { switch (dt) { case DataType::U8: { const auto val_u8 = static_cast(val); return ((val_u8 == val) && val >= std::numeric_limits::lowest() && val <= std::numeric_limits::max()); } case DataType::QASYMM8: { double min = static_cast(dequantize_qasymm8(0, qinfo)); double max = static_cast(dequantize_qasymm8(std::numeric_limits::max(), qinfo)); return ((double)val >= min && (double)val <= max); } case DataType::S8: { const auto val_s8 = static_cast(val); return ((val_s8 == val) && val >= std::numeric_limits::lowest() && val <= std::numeric_limits::max()); } case DataType::U16: { const auto val_u16 = static_cast(val); return ((val_u16 == val) && val >= std::numeric_limits::lowest() && val <= std::numeric_limits::max()); } case DataType::S16: { const auto val_s16 = static_cast(val); return ((val_s16 == val) && val >= std::numeric_limits::lowest() && val <= std::numeric_limits::max()); } case DataType::U32: { const auto val_d64 = static_cast(val); const auto val_u32 = static_cast(val); return ((val_u32 == val_d64) && val_d64 >= std::numeric_limits::lowest() && val_d64 <= std::numeric_limits::max()); } case DataType::S32: { const auto val_d64 = static_cast(val); const auto val_s32 = static_cast(val); return ((val_s32 == val_d64) && val_d64 >= std::numeric_limits::lowest() && val_d64 <= std::numeric_limits::max()); } case DataType::BFLOAT16: return (val >= bfloat16::lowest() && val <= bfloat16::max()); case DataType::F16: return (val >= std::numeric_limits::lowest() && val <= std::numeric_limits::max()); case DataType::F32: return (val >= std::numeric_limits::lowest() && val <= std::numeric_limits::max()); default: ARM_COMPUTE_ERROR("Data type not supported"); return false; } } /** Returns the suffix string of CPU kernel implementation names based on the given data type * * @param[in] data_type The data type the CPU kernel implemetation uses * * @return the suffix string of CPU kernel implementations */ inline std::string cpu_impl_dt(const DataType &data_type) { std::string ret = ""; switch (data_type) { case DataType::F32: ret = "fp32"; break; case DataType::F16: ret = "fp16"; break; case DataType::U8: ret = "u8"; break; case DataType::S16: ret = "s16"; break; case DataType::S32: ret = "s32"; break; case DataType::QASYMM8: ret = "qu8"; break; case DataType::QASYMM8_SIGNED: ret = "qs8"; break; case DataType::QSYMM16: ret = "qs16"; break; case DataType::QSYMM8_PER_CHANNEL: ret = "qp8"; break; case DataType::BFLOAT16: ret = "bf16"; break; default: ARM_COMPUTE_ERROR("Unsupported."); } return ret; } } // namespace arm_compute #endif // ACL_ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H