aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/TypesUtils.hpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /include/armnn/TypesUtils.hpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'include/armnn/TypesUtils.hpp')
-rw-r--r--include/armnn/TypesUtils.hpp133
1 files changed, 83 insertions, 50 deletions
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index c63b653ae3..3077ce111f 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -10,6 +10,7 @@
#include <ostream>
#include <boost/assert.hpp>
#include <boost/numeric/conversion/cast.hpp>
+#include <set>
namespace armnn
{
@@ -89,8 +90,9 @@ constexpr unsigned int GetDataTypeSize(DataType dataType)
{
switch (dataType)
{
- case DataType::Signed32:
- case DataType::Float32: return 4U;
+ case DataType::Float16: return 2U;
+ case DataType::Float32:
+ case DataType::Signed32: return 4U;
case DataType::QuantisedAsymm8: return 1U;
default: return 0U;
}
@@ -107,17 +109,17 @@ constexpr bool StrEqual(const char* strA, const char (&strB)[N])
return isEqual;
}
-constexpr Compute ParseComputeDevice(const char* str)
+constexpr armnn::Compute ParseComputeDevice(const char* str)
{
- if (StrEqual(str, "CpuAcc"))
+ if (armnn::StrEqual(str, "CpuAcc"))
{
return armnn::Compute::CpuAcc;
}
- else if (StrEqual(str, "CpuRef"))
+ else if (armnn::StrEqual(str, "CpuRef"))
{
return armnn::Compute::CpuRef;
}
- else if (StrEqual(str, "GpuAcc"))
+ else if (armnn::StrEqual(str, "GpuAcc"))
{
return armnn::Compute::GpuAcc;
}
@@ -131,59 +133,60 @@ constexpr const char* GetDataTypeName(DataType dataType)
{
switch (dataType)
{
- case DataType::Float32: return "Float32";
+ case DataType::Float16: return "Float16";
+ case DataType::Float32: return "Float32";
case DataType::QuantisedAsymm8: return "Unsigned8";
- case DataType::Signed32: return "Signed32";
- default: return "Unknown";
+ case DataType::Signed32: return "Signed32";
+
+ default:
+ return "Unknown";
}
}
-template <typename T>
-constexpr DataType GetDataType();
-
-template <>
-constexpr DataType GetDataType<float>()
-{
- return DataType::Float32;
-}
-template <>
-constexpr DataType GetDataType<uint8_t>()
-{
- return DataType::QuantisedAsymm8;
-}
+template<typename T>
+struct IsHalfType
+ : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
+{};
-template <>
-constexpr DataType GetDataType<int32_t>()
-{
- return DataType::Signed32;
-}
+template<typename T, typename U=T>
+struct GetDataTypeImpl;
template<typename T>
-constexpr bool IsQuantizedType()
+struct GetDataTypeImpl<T, typename std::enable_if_t<IsHalfType<T>::value, T>>
{
- return std::is_integral<T>::value;
-}
-
+ static constexpr DataType Value = DataType::Float16;
+};
-template<DataType DT>
-struct ResolveTypeImpl;
+template<>
+struct GetDataTypeImpl<float>
+{
+ static constexpr DataType Value = DataType::Float32;
+};
template<>
-struct ResolveTypeImpl<DataType::QuantisedAsymm8>
+struct GetDataTypeImpl<uint8_t>
{
- using Type = uint8_t;
+ static constexpr DataType Value = DataType::QuantisedAsymm8;
};
template<>
-struct ResolveTypeImpl<DataType::Float32>
+struct GetDataTypeImpl<int32_t>
{
- using Type = float;
+ static constexpr DataType Value = DataType::Signed32;
};
-template<DataType DT>
-using ResolveType = typename ResolveTypeImpl<DT>::Type;
+template <typename T>
+constexpr DataType GetDataType()
+{
+ return GetDataTypeImpl<T>::Value;
+}
+template<typename T>
+constexpr bool IsQuantizedType()
+{
+ return std::is_integral<T>::value;
+}
inline std::ostream& operator<<(std::ostream& os, Status stat)
{
@@ -191,7 +194,23 @@ inline std::ostream& operator<<(std::ostream& os, Status stat)
return os;
}
-inline std::ostream& operator<<(std::ostream& os, Compute compute)
+inline std::ostream& operator<<(std::ostream& os, const std::vector<Compute>& compute)
+{
+ for (const Compute& comp : compute) {
+ os << GetComputeDeviceAsCString(comp) << " ";
+ }
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const std::set<Compute>& compute)
+{
+ for (const Compute& comp : compute) {
+ os << GetComputeDeviceAsCString(comp) << " ";
+ }
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const Compute& compute)
{
os << GetComputeDeviceAsCString(compute);
return os;
@@ -212,11 +231,11 @@ inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & s
return os;
}
-/// Quantize a floating point data type into an 8-bit data type
-/// @param value The value to quantize
-/// @param scale The scale (must be non-zero)
-/// @param offset The offset
-/// @return The quantized value calculated as round(value/scale)+offset
+/// Quantize a floating point data type into an 8-bit data type.
+/// @param value - The value to quantize.
+/// @param scale - The scale (must be non-zero).
+/// @param offset - The offset.
+/// @return - The quantized value calculated as round(value/scale)+offset.
///
template<typename QuantizedType>
inline QuantizedType Quantize(float value, float scale, int32_t offset)
@@ -234,11 +253,11 @@ inline QuantizedType Quantize(float value, float scale, int32_t offset)
return quantizedBits;
}
-/// Dequantize an 8-bit data type into a floating point data type
-/// @param value The value to dequantize
-/// @param scale The scale (must be non-zero)
-/// @param offset The offset
-/// @return The dequantized value calculated as (value-offset)*scale
+/// Dequantize an 8-bit data type into a floating point data type.
+/// @param value - The value to dequantize.
+/// @param scale - The scale (must be non-zero).
+/// @param offset - The offset.
+/// @return - The dequantized value calculated as (value-offset)*scale.
///
template <typename QuantizedType>
inline float Dequantize(QuantizedType value, float scale, int32_t offset)
@@ -249,4 +268,18 @@ inline float Dequantize(QuantizedType value, float scale, int32_t offset)
return dequantized;
}
+template <typename DataType>
+void VerifyTensorInfoDataType(const armnn::TensorInfo & info)
+{
+ auto expectedType = armnn::GetDataType<DataType>();
+ if (info.GetDataType() != expectedType)
+ {
+ std::stringstream ss;
+ ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
+ << " for tensor:" << info.GetShape()
+ << ". The type expected to be: " << armnn::GetDataTypeName(expectedType);
+ throw armnn::Exception(ss.str());
+ }
+}
+
} //namespace armnn