diff options
Diffstat (limited to 'include/armnn/TypesUtils.hpp')
-rw-r--r-- | include/armnn/TypesUtils.hpp | 133 |
1 files changed, 83 insertions, 50 deletions
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp index c63b653ae3..3077ce111f 100644 --- a/include/armnn/TypesUtils.hpp +++ b/include/armnn/TypesUtils.hpp @@ -10,6 +10,7 @@ #include <ostream> #include <boost/assert.hpp> #include <boost/numeric/conversion/cast.hpp> +#include <set> namespace armnn { @@ -89,8 +90,9 @@ constexpr unsigned int GetDataTypeSize(DataType dataType) { switch (dataType) { - case DataType::Signed32: - case DataType::Float32: return 4U; + case DataType::Float16: return 2U; + case DataType::Float32: + case DataType::Signed32: return 4U; case DataType::QuantisedAsymm8: return 1U; default: return 0U; } @@ -107,17 +109,17 @@ constexpr bool StrEqual(const char* strA, const char (&strB)[N]) return isEqual; } -constexpr Compute ParseComputeDevice(const char* str) +constexpr armnn::Compute ParseComputeDevice(const char* str) { - if (StrEqual(str, "CpuAcc")) + if (armnn::StrEqual(str, "CpuAcc")) { return armnn::Compute::CpuAcc; } - else if (StrEqual(str, "CpuRef")) + else if (armnn::StrEqual(str, "CpuRef")) { return armnn::Compute::CpuRef; } - else if (StrEqual(str, "GpuAcc")) + else if (armnn::StrEqual(str, "GpuAcc")) { return armnn::Compute::GpuAcc; } @@ -131,59 +133,60 @@ constexpr const char* GetDataTypeName(DataType dataType) { switch (dataType) { - case DataType::Float32: return "Float32"; + case DataType::Float16: return "Float16"; + case DataType::Float32: return "Float32"; case DataType::QuantisedAsymm8: return "Unsigned8"; - case DataType::Signed32: return "Signed32"; - default: return "Unknown"; + case DataType::Signed32: return "Signed32"; + + default: + return "Unknown"; } } -template <typename T> -constexpr DataType GetDataType(); - -template <> -constexpr DataType GetDataType<float>() -{ - return DataType::Float32; -} -template <> -constexpr DataType GetDataType<uint8_t>() -{ - return DataType::QuantisedAsymm8; -} +template<typename T> +struct IsHalfType + : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2> +{}; -template <> -constexpr DataType GetDataType<int32_t>() -{ - return DataType::Signed32; -} +template<typename T, typename U=T> +struct GetDataTypeImpl; template<typename T> -constexpr bool IsQuantizedType() +struct GetDataTypeImpl<T, typename std::enable_if_t<IsHalfType<T>::value, T>> { - return std::is_integral<T>::value; -} - + static constexpr DataType Value = DataType::Float16; +}; -template<DataType DT> -struct ResolveTypeImpl; +template<> +struct GetDataTypeImpl<float> +{ + static constexpr DataType Value = DataType::Float32; +}; template<> -struct ResolveTypeImpl<DataType::QuantisedAsymm8> +struct GetDataTypeImpl<uint8_t> { - using Type = uint8_t; + static constexpr DataType Value = DataType::QuantisedAsymm8; }; template<> -struct ResolveTypeImpl<DataType::Float32> +struct GetDataTypeImpl<int32_t> { - using Type = float; + static constexpr DataType Value = DataType::Signed32; }; -template<DataType DT> -using ResolveType = typename ResolveTypeImpl<DT>::Type; +template <typename T> +constexpr DataType GetDataType() +{ + return GetDataTypeImpl<T>::Value; +} +template<typename T> +constexpr bool IsQuantizedType() +{ + return std::is_integral<T>::value; +} inline std::ostream& operator<<(std::ostream& os, Status stat) { @@ -191,7 +194,23 @@ inline std::ostream& operator<<(std::ostream& os, Status stat) return os; } -inline std::ostream& operator<<(std::ostream& os, Compute compute) +inline std::ostream& operator<<(std::ostream& os, const std::vector<Compute>& compute) +{ + for (const Compute& comp : compute) { + os << GetComputeDeviceAsCString(comp) << " "; + } + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const std::set<Compute>& compute) +{ + for (const Compute& comp : compute) { + os << GetComputeDeviceAsCString(comp) << " "; + } + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const Compute& compute) { os << GetComputeDeviceAsCString(compute); return os; @@ -212,11 +231,11 @@ inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & s return os; } -/// Quantize a floating point data type into an 8-bit data type -/// @param value The value to quantize -/// @param scale The scale (must be non-zero) -/// @param offset The offset -/// @return The quantized value calculated as round(value/scale)+offset +/// Quantize a floating point data type into an 8-bit data type. +/// @param value - The value to quantize. +/// @param scale - The scale (must be non-zero). +/// @param offset - The offset. +/// @return - The quantized value calculated as round(value/scale)+offset. /// template<typename QuantizedType> inline QuantizedType Quantize(float value, float scale, int32_t offset) @@ -234,11 +253,11 @@ inline QuantizedType Quantize(float value, float scale, int32_t offset) return quantizedBits; } -/// Dequantize an 8-bit data type into a floating point data type -/// @param value The value to dequantize -/// @param scale The scale (must be non-zero) -/// @param offset The offset -/// @return The dequantized value calculated as (value-offset)*scale +/// Dequantize an 8-bit data type into a floating point data type. +/// @param value - The value to dequantize. +/// @param scale - The scale (must be non-zero). +/// @param offset - The offset. +/// @return - The dequantized value calculated as (value-offset)*scale. /// template <typename QuantizedType> inline float Dequantize(QuantizedType value, float scale, int32_t offset) @@ -249,4 +268,18 @@ inline float Dequantize(QuantizedType value, float scale, int32_t offset) return dequantized; } +template <typename DataType> +void VerifyTensorInfoDataType(const armnn::TensorInfo & info) +{ + auto expectedType = armnn::GetDataType<DataType>(); + if (info.GetDataType() != expectedType) + { + std::stringstream ss; + ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType()) + << " for tensor:" << info.GetShape() + << ". The type expected to be: " << armnn::GetDataTypeName(expectedType); + throw armnn::Exception(ss.str()); + } +} + } //namespace armnn |