ArmNN  NotReleased
TypesUtils.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/Tensor.hpp>
8 #include <armnn/Types.hpp>
9 
10 #include <cmath>
11 #include <ostream>
12 #include <set>
13 
14 namespace armnn
15 {
16 
17 constexpr char const* GetStatusAsCString(Status status)
18 {
19  switch (status)
20  {
21  case armnn::Status::Success: return "Status::Success";
22  case armnn::Status::Failure: return "Status::Failure";
23  default: return "Unknown";
24  }
25 }
26 
27 constexpr char const* GetActivationFunctionAsCString(ActivationFunction activation)
28 {
29  switch (activation)
30  {
31  case ActivationFunction::Sigmoid: return "Sigmoid";
32  case ActivationFunction::TanH: return "TanH";
33  case ActivationFunction::Linear: return "Linear";
34  case ActivationFunction::ReLu: return "ReLu";
35  case ActivationFunction::BoundedReLu: return "BoundedReLu";
36  case ActivationFunction::SoftReLu: return "SoftReLu";
37  case ActivationFunction::LeakyReLu: return "LeakyReLu";
38  case ActivationFunction::Abs: return "Abs";
39  case ActivationFunction::Sqrt: return "Sqrt";
40  case ActivationFunction::Square: return "Square";
41  default: return "Unknown";
42  }
43 }
44 
45 constexpr char const* GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
46 {
47  switch (function)
48  {
49  case ArgMinMaxFunction::Max: return "Max";
50  case ArgMinMaxFunction::Min: return "Min";
51  default: return "Unknown";
52  }
53 }
54 
55 constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
56 {
57  switch (operation)
58  {
59  case ComparisonOperation::Equal: return "Equal";
60  case ComparisonOperation::Greater: return "Greater";
61  case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
62  case ComparisonOperation::Less: return "Less";
63  case ComparisonOperation::LessOrEqual: return "LessOrEqual";
64  case ComparisonOperation::NotEqual: return "NotEqual";
65  default: return "Unknown";
66  }
67 }
68 
69 constexpr char const* GetUnaryOperationAsCString(UnaryOperation operation)
70 {
71  switch (operation)
72  {
73  case UnaryOperation::Abs: return "Abs";
74  case UnaryOperation::Exp: return "Exp";
75  case UnaryOperation::Sqrt: return "Sqrt";
76  case UnaryOperation::Rsqrt: return "Rsqrt";
77  case UnaryOperation::Neg: return "Neg";
78  default: return "Unknown";
79  }
80 }
81 
82 constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
83 {
84  switch (pooling)
85  {
86  case PoolingAlgorithm::Average: return "Average";
87  case PoolingAlgorithm::Max: return "Max";
88  case PoolingAlgorithm::L2: return "L2";
89  default: return "Unknown";
90  }
91 }
92 
94 {
95  switch (rounding)
96  {
97  case OutputShapeRounding::Ceiling: return "Ceiling";
98  case OutputShapeRounding::Floor: return "Floor";
99  default: return "Unknown";
100  }
101 }
102 
103 constexpr char const* GetPaddingMethodAsCString(PaddingMethod method)
104 {
105  switch (method)
106  {
107  case PaddingMethod::Exclude: return "Exclude";
108  case PaddingMethod::IgnoreValue: return "IgnoreValue";
109  default: return "Unknown";
110  }
111 }
112 
113 constexpr unsigned int GetDataTypeSize(DataType dataType)
114 {
115  switch (dataType)
116  {
117  case DataType::Float16: return 2U;
118  case DataType::Float32:
119  case DataType::Signed32: return 4U;
120  case DataType::QAsymmU8: return 1U;
121  case DataType::QAsymmS8: return 1U;
122  case DataType::QSymmS8: return 1U;
124  case DataType::QuantizedSymm8PerAxis: return 1U;
126  case DataType::QSymmS16: return 2U;
127  case DataType::Boolean: return 1U;
128  default: return 0U;
129  }
130 }
131 
132 template <unsigned N>
133 constexpr bool StrEqual(const char* strA, const char (&strB)[N])
134 {
135  bool isEqual = true;
136  for (unsigned i = 0; isEqual && (i < N); ++i)
137  {
138  isEqual = (strA[i] == strB[i]);
139  }
140  return isEqual;
141 }
142 
145 constexpr armnn::Compute ParseComputeDevice(const char* str)
146 {
147  if (armnn::StrEqual(str, "CpuAcc"))
148  {
149  return armnn::Compute::CpuAcc;
150  }
151  else if (armnn::StrEqual(str, "CpuRef"))
152  {
153  return armnn::Compute::CpuRef;
154  }
155  else if (armnn::StrEqual(str, "GpuAcc"))
156  {
157  return armnn::Compute::GpuAcc;
158  }
159  else
160  {
162  }
163 }
164 
165 constexpr const char* GetDataTypeName(DataType dataType)
166 {
167  switch (dataType)
168  {
169  case DataType::Float16: return "Float16";
170  case DataType::Float32: return "Float32";
171  case DataType::QAsymmU8: return "QAsymmU8";
172  case DataType::QAsymmS8: return "QAsymmS8";
173  case DataType::QSymmS8: return "QSymmS8";
175  case DataType::QuantizedSymm8PerAxis: return "QSymm8PerAxis";
177  case DataType::QSymmS16: return "QSymm16";
178  case DataType::Signed32: return "Signed32";
179  case DataType::Boolean: return "Boolean";
180 
181  default:
182  return "Unknown";
183  }
184 }
185 
186 constexpr const char* GetDataLayoutName(DataLayout dataLayout)
187 {
188  switch (dataLayout)
189  {
190  case DataLayout::NCHW: return "NCHW";
191  case DataLayout::NHWC: return "NHWC";
192  default: return "Unknown";
193  }
194 }
195 
197 {
198  switch (channel)
199  {
200  case NormalizationAlgorithmChannel::Across: return "Across";
201  case NormalizationAlgorithmChannel::Within: return "Within";
202  default: return "Unknown";
203  }
204 }
205 
207 {
208  switch (method)
209  {
210  case NormalizationAlgorithmMethod::LocalBrightness: return "LocalBrightness";
211  case NormalizationAlgorithmMethod::LocalContrast: return "LocalContrast";
212  default: return "Unknown";
213  }
214 }
215 
216 constexpr const char* GetResizeMethodAsCString(ResizeMethod method)
217 {
218  switch (method)
219  {
220  case ResizeMethod::Bilinear: return "Bilinear";
221  case ResizeMethod::NearestNeighbor: return "NearestNeighbour";
222  default: return "Unknown";
223  }
224 }
225 
226 template<typename T>
228  : std::integral_constant<bool, std::is_floating_point<T>::value && sizeof(T) == 2>
229 {};
230 
231 template<typename T>
232 constexpr bool IsQuantizedType()
233 {
234  return std::is_integral<T>::value;
235 }
236 
237 constexpr bool IsQuantized8BitType(DataType dataType)
238 {
240  return dataType == DataType::QAsymmU8 ||
241  dataType == DataType::QAsymmS8 ||
242  dataType == DataType::QSymmS8 ||
245 }
246 
247 constexpr bool IsQuantizedType(DataType dataType)
248 {
249  return dataType == DataType::QSymmS16 || IsQuantized8BitType(dataType);
250 }
251 
252 inline std::ostream& operator<<(std::ostream& os, Status stat)
253 {
254  os << GetStatusAsCString(stat);
255  return os;
256 }
257 
258 
259 inline std::ostream & operator<<(std::ostream & os, const armnn::TensorShape & shape)
260 {
261  os << "[";
262  for (uint32_t i=0; i<shape.GetNumDimensions(); ++i)
263  {
264  if (i!=0)
265  {
266  os << ",";
267  }
268  os << shape[i];
269  }
270  os << "]";
271  return os;
272 }
273 
280 template<typename QuantizedType>
281 QuantizedType Quantize(float value, float scale, int32_t offset);
282 
289 template <typename QuantizedType>
290 float Dequantize(QuantizedType value, float scale, int32_t offset);
291 
293 {
294  if (info.GetDataType() != dataType)
295  {
296  std::stringstream ss;
297  ss << "Unexpected datatype:" << armnn::GetDataTypeName(info.GetDataType())
298  << " for tensor:" << info.GetShape()
299  << ". The type expected to be: " << armnn::GetDataTypeName(dataType);
300  throw armnn::Exception(ss.str());
301  }
302 }
303 
304 } //namespace armnn
PaddingMethod
Definition: Types.hpp:115
constexpr bool StrEqual(const char *strA, const char(&strB)[N])
Definition: TypesUtils.hpp:133
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:43
constexpr char const * GetComparisonOperationAsCString(ComparisonOperation operation)
Definition: TypesUtils.hpp:55
Status
Definition: Types.hpp:26
constexpr char const * GetOutputShapeRoundingAsCString(OutputShapeRounding rounding)
Definition: TypesUtils.hpp:93
constexpr const char * GetResizeMethodAsCString(ResizeMethod method)
Definition: TypesUtils.hpp:216
ResizeMethod
Definition: Types.hpp:100
void VerifyTensorInfoDataType(const armnn::TensorInfo &info, armnn::DataType dataType)
Definition: TypesUtils.hpp:292
ActivationFunction
Definition: Types.hpp:54
OutputShapeRounding
Definition: Types.hpp:137
constexpr char const * GetPaddingMethodAsCString(PaddingMethod method)
Definition: TypesUtils.hpp:103
constexpr unsigned int GetDataTypeSize(DataType dataType)
Definition: TypesUtils.hpp:113
Krichevsky 2012: Local Brightness Normalization.
constexpr char const * GetUnaryOperationAsCString(UnaryOperation operation)
Definition: TypesUtils.hpp:69
Jarret 2009: Local Contrast Normalization.
constexpr char const * GetActivationFunctionAsCString(ActivationFunction activation)
Definition: TypesUtils.hpp:27
constexpr bool IsQuantizedType()
Definition: TypesUtils.hpp:232
ComparisonOperation
Definition: Types.hpp:74
PoolingAlgorithm
Definition: Types.hpp:93
The padding fields count, but are ignored.
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
Definition: Deprecated.hpp:33
The padding fields don&#39;t count and are ignored.
constexpr char const * GetArgMinMaxFunctionAsCString(ArgMinMaxFunction function)
Definition: TypesUtils.hpp:45
GPU Execution: OpenCL: ArmCompute.
NormalizationAlgorithmMethod
Definition: Types.hpp:129
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
CPU Execution: Reference C++ kernels.
constexpr const char * GetNormalizationAlgorithmChannelAsCString(NormalizationAlgorithmChannel channel)
Definition: TypesUtils.hpp:196
CPU Execution: NEON: ArmCompute.
constexpr char const * GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
Definition: TypesUtils.hpp:82
UnaryOperation
Definition: Types.hpp:84
DataLayout
Definition: Types.hpp:48
ArgMinMaxFunction
Definition: Types.hpp:68
DataType
Definition: Types.hpp:32
constexpr const char * GetNormalizationAlgorithmMethodAsCString(NormalizationAlgorithmMethod method)
Definition: TypesUtils.hpp:206
constexpr const char * GetDataLayoutName(DataLayout dataLayout)
Definition: TypesUtils.hpp:186
NormalizationAlgorithmChannel
Definition: Types.hpp:123
DataType GetDataType() const
Definition: Tensor.hpp:95
constexpr armnn::Compute ParseComputeDevice(const char *str)
Definition: TypesUtils.hpp:145
const TensorShape & GetShape() const
Definition: Tensor.hpp:88
constexpr const char * GetDataTypeName(DataType dataType)
Definition: TypesUtils.hpp:165
std::ostream & operator<<(std::ostream &os, const std::vector< Compute > &compute)
Definition: BackendId.hpp:47
constexpr char const * GetStatusAsCString(Status status)
Definition: TypesUtils.hpp:17
constexpr bool IsQuantized8BitType(DataType dataType)
Definition: TypesUtils.hpp:237
#define ARMNN_NO_DEPRECATE_WARN_END
Definition: Deprecated.hpp:34