diff options
Diffstat (limited to 'arm_compute/core/utils/DataTypeUtils.h')
-rw-r--r-- | arm_compute/core/utils/DataTypeUtils.h | 549 |
1 files changed, 549 insertions, 0 deletions
diff --git a/arm_compute/core/utils/DataTypeUtils.h b/arm_compute/core/utils/DataTypeUtils.h new file mode 100644 index 0000000000..6fabb19b64 --- /dev/null +++ b/arm_compute/core/utils/DataTypeUtils.h @@ -0,0 +1,549 @@ +/* + * Copyright (c) 2016-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H +#define ACL_ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H + +#include "arm_compute/core/PixelValue.h" +#include "arm_compute/core/Types.h" + +namespace arm_compute +{ +/** The size in bytes of the data type + * + * @param[in] data_type Input data type + * + * @return The size in bytes of the data type + */ +inline size_t data_size_from_type(DataType data_type) +{ + switch (data_type) + { + case DataType::U8: + case DataType::S8: + case DataType::QSYMM8: + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + case DataType::QSYMM8_PER_CHANNEL: + return 1; + case DataType::U16: + case DataType::S16: + case DataType::QSYMM16: + case DataType::QASYMM16: + case DataType::BFLOAT16: + case DataType::F16: + return 2; + case DataType::F32: + case DataType::U32: + case DataType::S32: + return 4; + case DataType::F64: + case DataType::U64: + case DataType::S64: + return 8; + case DataType::SIZET: + return sizeof(size_t); + default: + ARM_COMPUTE_ERROR("Invalid data type"); + return 0; + } +} + +/** The size in bytes of the data type + * + * @param[in] dt Input data type + * + * @return The size in bytes of the data type + */ +inline size_t element_size_from_data_type(DataType dt) +{ + switch (dt) + { + case DataType::S8: + case DataType::U8: + case DataType::QSYMM8: + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + case DataType::QSYMM8_PER_CHANNEL: + return 1; + case DataType::U16: + case DataType::S16: + case DataType::QSYMM16: + case DataType::QASYMM16: + case DataType::BFLOAT16: + case DataType::F16: + return 2; + case DataType::U32: + case DataType::S32: + case DataType::F32: + return 4; + case DataType::U64: + case DataType::S64: + return 8; + default: + ARM_COMPUTE_ERROR("Undefined element size for given data type"); + return 0; + } +} + +/** Return the data type used by a given single-planar pixel format + * + * @param[in] format Input format + * + * @return The size in bytes of the pixel format + */ +inline DataType data_type_from_format(Format format) +{ + switch (format) + { + case Format::U8: + case Format::UV88: + case Format::RGB888: + case Format::RGBA8888: + case Format::YUYV422: + case Format::UYVY422: + return DataType::U8; + case Format::U16: + return DataType::U16; + case Format::S16: + return DataType::S16; + case Format::U32: + return DataType::U32; + case Format::S32: + return DataType::S32; + case Format::BFLOAT16: + return DataType::BFLOAT16; + case Format::F16: + return DataType::F16; + case Format::F32: + return DataType::F32; + //Doesn't make sense for planar formats: + case Format::NV12: + case Format::NV21: + case Format::IYUV: + case Format::YUV444: + default: + ARM_COMPUTE_ERROR("Not supported data_type for given format"); + return DataType::UNKNOWN; + } +} + +/** Return the promoted data type of a given data type. + * + * @note If promoted data type is not supported an error will be thrown + * + * @param[in] dt Data type to get the promoted type of. + * + * @return Promoted data type + */ +inline DataType get_promoted_data_type(DataType dt) +{ + switch (dt) + { + case DataType::U8: + return DataType::U16; + case DataType::S8: + return DataType::S16; + case DataType::U16: + return DataType::U32; + case DataType::S16: + return DataType::S32; + case DataType::QSYMM8: + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + case DataType::QSYMM8_PER_CHANNEL: + case DataType::QSYMM16: + case DataType::QASYMM16: + case DataType::BFLOAT16: + case DataType::F16: + case DataType::U32: + case DataType::S32: + case DataType::F32: + ARM_COMPUTE_ERROR("Unsupported data type promotions!"); + default: + ARM_COMPUTE_ERROR("Undefined data type!"); + } + return DataType::UNKNOWN; +} + +/** Compute the mininum and maximum values a data type can take + * + * @param[in] dt Data type to get the min/max bounds of + * + * @return A tuple (min,max) with the minimum and maximum values respectively wrapped in PixelValue. + */ +inline std::tuple<PixelValue, PixelValue> get_min_max(DataType dt) +{ + PixelValue min{}; + PixelValue max{}; + switch (dt) + { + case DataType::U8: + case DataType::QASYMM8: + { + min = PixelValue(static_cast<int32_t>(std::numeric_limits<uint8_t>::lowest())); + max = PixelValue(static_cast<int32_t>(std::numeric_limits<uint8_t>::max())); + break; + } + case DataType::S8: + case DataType::QSYMM8: + case DataType::QASYMM8_SIGNED: + case DataType::QSYMM8_PER_CHANNEL: + { + min = PixelValue(static_cast<int32_t>(std::numeric_limits<int8_t>::lowest())); + max = PixelValue(static_cast<int32_t>(std::numeric_limits<int8_t>::max())); + break; + } + case DataType::U16: + case DataType::QASYMM16: + { + min = PixelValue(static_cast<int32_t>(std::numeric_limits<uint16_t>::lowest())); + max = PixelValue(static_cast<int32_t>(std::numeric_limits<uint16_t>::max())); + break; + } + case DataType::S16: + case DataType::QSYMM16: + { + min = PixelValue(static_cast<int32_t>(std::numeric_limits<int16_t>::lowest())); + max = PixelValue(static_cast<int32_t>(std::numeric_limits<int16_t>::max())); + break; + } + case DataType::U32: + { + min = PixelValue(std::numeric_limits<uint32_t>::lowest()); + max = PixelValue(std::numeric_limits<uint32_t>::max()); + break; + } + case DataType::S32: + { + min = PixelValue(std::numeric_limits<int32_t>::lowest()); + max = PixelValue(std::numeric_limits<int32_t>::max()); + break; + } + case DataType::BFLOAT16: + { + min = PixelValue(bfloat16::lowest()); + max = PixelValue(bfloat16::max()); + break; + } + case DataType::F16: + { + min = PixelValue(std::numeric_limits<half>::lowest()); + max = PixelValue(std::numeric_limits<half>::max()); + break; + } + case DataType::F32: + { + min = PixelValue(std::numeric_limits<float>::lowest()); + max = PixelValue(std::numeric_limits<float>::max()); + break; + } + default: + ARM_COMPUTE_ERROR("Undefined data type!"); + } + return std::make_tuple(min, max); +} + +/** Convert a data type identity into a string. + * + * @param[in] dt @ref DataType to be translated to string. + * + * @return The string describing the data type. + */ +const std::string &string_from_data_type(DataType dt); + +/** Convert a string to DataType + * + * @param[in] name The name of the data type + * + * @return DataType + */ +DataType data_type_from_name(const std::string &name); + +/** Input Stream operator for @ref DataType + * + * @param[in] stream Stream to parse + * @param[out] data_type Output data type + * + * @return Updated stream + */ +inline ::std::istream &operator>>(::std::istream &stream, DataType &data_type) +{ + std::string value; + stream >> value; + data_type = data_type_from_name(value); + return stream; +} + +/** Check if a given data type is of floating point type + * + * @param[in] dt Input data type. + * + * @return True if data type is of floating point type, else false. + */ +inline bool is_data_type_float(DataType dt) +{ + switch (dt) + { + case DataType::F16: + case DataType::F32: + return true; + default: + return false; + } +} + +/** Check if a given data type is of quantized type + * + * @note Quantized is considered a super-set of fixed-point and asymmetric data types. + * + * @param[in] dt Input data type. + * + * @return True if data type is of quantized type, else false. + */ +inline bool is_data_type_quantized(DataType dt) +{ + switch (dt) + { + case DataType::QSYMM8: + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + case DataType::QSYMM8_PER_CHANNEL: + case DataType::QSYMM16: + case DataType::QASYMM16: + return true; + default: + return false; + } +} + +/** Check if a given data type is of asymmetric quantized type + * + * @param[in] dt Input data type. + * + * @return True if data type is of asymmetric quantized type, else false. + */ +inline bool is_data_type_quantized_asymmetric(DataType dt) +{ + switch (dt) + { + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + case DataType::QASYMM16: + return true; + default: + return false; + } +} + +/** Check if a given data type is of asymmetric quantized signed type + * + * @param[in] dt Input data type. + * + * @return True if data type is of asymmetric quantized signed type, else false. + */ +inline bool is_data_type_quantized_asymmetric_signed(DataType dt) +{ + switch (dt) + { + case DataType::QASYMM8_SIGNED: + return true; + default: + return false; + } +} + +/** Check if a given data type is of 8-bit asymmetric quantized signed type + * + * @param[in] dt Input data type. + * + * @return True if data type is of 8-bit asymmetric quantized signed type, else false. + */ +inline bool is_data_type_quantized_asymmetric_char(DataType dt) +{ + switch (dt) + { + case DataType::QASYMM8_SIGNED: + case DataType::QASYMM8: + return true; + default: + return false; + } +} + +/** Check if a given data type is of symmetric quantized type + * + * @param[in] dt Input data type. + * + * @return True if data type is of symmetric quantized type, else false. + */ +inline bool is_data_type_quantized_symmetric(DataType dt) +{ + switch (dt) + { + case DataType::QSYMM8: + case DataType::QSYMM8_PER_CHANNEL: + case DataType::QSYMM16: + return true; + default: + return false; + } +} + +/** Check if a given data type is of per channel type + * + * @param[in] dt Input data type. + * + * @return True if data type is of per channel type, else false. + */ +inline bool is_data_type_quantized_per_channel(DataType dt) +{ + switch (dt) + { + case DataType::QSYMM8_PER_CHANNEL: + return true; + default: + return false; + } +} + +/** Returns true if the value can be represented by the given data type + * + * @param[in] val value to be checked + * @param[in] dt data type that is checked + * @param[in] qinfo (Optional) quantization info if the data type is QASYMM8 + * + * @return true if the data type can hold the value. + */ +template <typename T> +bool check_value_range(T val, DataType dt, QuantizationInfo qinfo = QuantizationInfo()) +{ + switch (dt) + { + case DataType::U8: + { + const auto val_u8 = static_cast<uint8_t>(val); + return ((val_u8 == val) && val >= std::numeric_limits<uint8_t>::lowest() && + val <= std::numeric_limits<uint8_t>::max()); + } + case DataType::QASYMM8: + { + double min = static_cast<double>(dequantize_qasymm8(0, qinfo)); + double max = static_cast<double>(dequantize_qasymm8(std::numeric_limits<uint8_t>::max(), qinfo)); + return ((double)val >= min && (double)val <= max); + } + case DataType::S8: + { + const auto val_s8 = static_cast<int8_t>(val); + return ((val_s8 == val) && val >= std::numeric_limits<int8_t>::lowest() && + val <= std::numeric_limits<int8_t>::max()); + } + case DataType::U16: + { + const auto val_u16 = static_cast<uint16_t>(val); + return ((val_u16 == val) && val >= std::numeric_limits<uint16_t>::lowest() && + val <= std::numeric_limits<uint16_t>::max()); + } + case DataType::S16: + { + const auto val_s16 = static_cast<int16_t>(val); + return ((val_s16 == val) && val >= std::numeric_limits<int16_t>::lowest() && + val <= std::numeric_limits<int16_t>::max()); + } + case DataType::U32: + { + const auto val_d64 = static_cast<double>(val); + const auto val_u32 = static_cast<uint32_t>(val); + return ((val_u32 == val_d64) && val_d64 >= std::numeric_limits<uint32_t>::lowest() && + val_d64 <= std::numeric_limits<uint32_t>::max()); + } + case DataType::S32: + { + const auto val_d64 = static_cast<double>(val); + const auto val_s32 = static_cast<int32_t>(val); + return ((val_s32 == val_d64) && val_d64 >= std::numeric_limits<int32_t>::lowest() && + val_d64 <= std::numeric_limits<int32_t>::max()); + } + case DataType::BFLOAT16: + return (val >= bfloat16::lowest() && val <= bfloat16::max()); + case DataType::F16: + return (val >= std::numeric_limits<half>::lowest() && val <= std::numeric_limits<half>::max()); + case DataType::F32: + return (val >= std::numeric_limits<float>::lowest() && val <= std::numeric_limits<float>::max()); + default: + ARM_COMPUTE_ERROR("Data type not supported"); + return false; + } +} + +/** Returns the suffix string of CPU kernel implementation names based on the given data type + * + * @param[in] data_type The data type the CPU kernel implemetation uses + * + * @return the suffix string of CPU kernel implementations + */ +inline std::string cpu_impl_dt(const DataType &data_type) +{ + std::string ret = ""; + + switch (data_type) + { + case DataType::F32: + ret = "fp32"; + break; + case DataType::F16: + ret = "fp16"; + break; + case DataType::U8: + ret = "u8"; + break; + case DataType::S16: + ret = "s16"; + break; + case DataType::S32: + ret = "s32"; + break; + case DataType::QASYMM8: + ret = "qu8"; + break; + case DataType::QASYMM8_SIGNED: + ret = "qs8"; + break; + case DataType::QSYMM16: + ret = "qs16"; + break; + case DataType::QSYMM8_PER_CHANNEL: + ret = "qp8"; + break; + case DataType::BFLOAT16: + ret = "bf16"; + break; + default: + ARM_COMPUTE_ERROR("Unsupported."); + } + + return ret; +} + +} // namespace arm_compute +#endif // ACL_ARM_COMPUTE_CORE_UTILS_DATATYPEUTILS_H |