ArmNN
 20.05
Tensor.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnn/Tensor.hpp"
7 #include "armnn/Utils.hpp"
8 #include "armnn/Exceptions.hpp"
9 #include "armnn/TypesUtils.hpp"
10 
11 #include <armnn/utility/Assert.hpp>
12 
13 #include <boost/numeric/conversion/cast.hpp>
14 
15 #include <sstream>
16 
17 namespace armnn
18 {
19 
20 // ---
21 // --- TensorShape
22 // ---
23 
25  : m_NumDimensions(0)
26 {
27 }
28 
29 TensorShape::TensorShape(unsigned int numDimensions)
30  : m_NumDimensions(numDimensions)
31 {
32  if (numDimensions < 1)
33  {
34  throw InvalidArgumentException("Tensor numDimensions must be greater than 0");
35  }
36 
37  if (numDimensions > MaxNumOfTensorDimensions)
38  {
39  throw InvalidArgumentException("Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions");
40  }
41 
42  std::fill(m_Dimensions.begin(), m_Dimensions.begin() + m_NumDimensions, 0);
43 }
44 
45 TensorShape::TensorShape(const unsigned int numDimensions, const unsigned int* const dimensionSizes)
46  : m_NumDimensions(numDimensions)
47 {
48  if (numDimensions < 1)
49  {
50  throw InvalidArgumentException("Tensor numDimensions must be greater than 0");
51  }
52 
53  if (numDimensions > MaxNumOfTensorDimensions)
54  {
55  throw InvalidArgumentException("Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions");
56  }
57 
58  if (dimensionSizes == nullptr)
59  {
60  throw InvalidArgumentException("Tensor dimensionSizes must not be NULL");
61  }
62 
63  std::copy(dimensionSizes, dimensionSizes + numDimensions, m_Dimensions.begin());
64 }
65 
66 TensorShape::TensorShape(std::initializer_list<unsigned int> dimensionSizeList)
67  : TensorShape(boost::numeric_cast<unsigned int>(dimensionSizeList.size()), dimensionSizeList.begin())
68 {
69 }
70 
72  : m_NumDimensions(other.m_NumDimensions)
73 {
74  std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
75 }
76 
78 {
79  m_NumDimensions = other.m_NumDimensions;
80  std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
81  return *this;
82 }
83 
84 unsigned int TensorShape::operator[](unsigned int i) const
85 {
86  CheckDimensionIndex(i);
87  return m_Dimensions.at(i);
88 }
89 
90 unsigned int& TensorShape::operator[](unsigned int i)
91 {
92  CheckDimensionIndex(i);
93  return m_Dimensions.at(i);
94 }
95 
96 bool TensorShape::operator==(const TensorShape& other) const
97 {
98  return ((m_NumDimensions == other.m_NumDimensions) &&
99  std::equal(m_Dimensions.cbegin(), m_Dimensions.cbegin() + m_NumDimensions, other.m_Dimensions.cbegin()));
100 }
101 
102 bool TensorShape::operator!=(const TensorShape& other) const
103 {
104  return !(*this == other);
105 }
106 
107 unsigned int TensorShape::GetNumElements() const
108 {
109  if (m_NumDimensions == 0)
110  {
111  return 0;
112  }
113 
114  unsigned int count = 1;
115  for (unsigned int i = 0; i < m_NumDimensions; i++)
116  {
117  count *= m_Dimensions[i];
118  }
119 
120  return count;
121 }
122 
123 void TensorShape::CheckDimensionIndex(unsigned int i) const
124 {
125  if (i >= m_NumDimensions)
126  {
127  std::stringstream errorMessage;
128  errorMessage << "Invalid dimension index: " << i << " (number of dimensions is " << m_NumDimensions << ")";
129  throw InvalidArgumentException(errorMessage.str(), CHECK_LOCATION());
130  }
131 }
132 
133 // ---
134 // --- TensorInfo
135 // ---
136 
138 : m_DataType(DataType::Float32)
139 {
140 }
141 
143  DataType dataType,
144  float quantizationScale,
145  int32_t quantizationOffset)
146  : m_Shape(shape)
147  , m_DataType(dataType)
148 {
149  SetQuantizationScale(quantizationScale);
150  SetQuantizationOffset(quantizationOffset);
151 }
152 
153 TensorInfo::TensorInfo(unsigned int numDimensions,
154  const unsigned int* dimensionSizes,
155  DataType dataType,
156  float quantizationScale,
157  int32_t quantizationOffset)
158  : m_Shape(numDimensions, dimensionSizes)
159  , m_DataType(dataType)
160 {
161  SetQuantizationScale(quantizationScale);
162  SetQuantizationOffset(quantizationOffset);
163 }
164 
166  DataType dataType,
167  const std::vector<float>& quantizationScales,
168  unsigned int quantizationDim)
169  : m_Shape(shape)
170  , m_DataType(dataType)
171 {
172  SetQuantizationScales(quantizationScales);
173  SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim));
174 }
175 
176 TensorInfo::TensorInfo(unsigned int numDimensions,
177  const unsigned int* dimensionSizes,
178  DataType dataType,
179  const std::vector<float>& quantizationScales,
180  unsigned int quantizationDim)
181  : m_Shape(numDimensions, dimensionSizes)
182  , m_DataType(dataType)
183 {
184  SetQuantizationScales(quantizationScales);
185  SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim));
186 }
187 
189 : m_Shape(other.m_Shape)
190 , m_DataType(other.m_DataType)
191 , m_Quantization(other.m_Quantization)
192 {}
193 
195 {
196  m_Shape = other.m_Shape;
197  m_DataType = other.m_DataType;
198  m_Quantization = other.m_Quantization;
199  return *this;
200 }
201 
202 bool TensorInfo::operator==(const TensorInfo& other) const
203 {
204  return ((m_Shape == other.m_Shape) &&
205  (m_DataType == other.m_DataType) &&
206  (m_Quantization == other.m_Quantization));
207 }
208 
209 bool TensorInfo::operator!=(const TensorInfo& other) const
210 {
211  return !(*this == other);
212 }
213 
214 unsigned int TensorInfo::GetNumBytes() const
215 {
216  return GetDataTypeSize(m_DataType) * GetNumElements();
217 }
218 
219 bool TensorInfo::IsTypeSpaceMatch(const TensorInfo& other) const
220 {
221  bool match = true;
222 
223  match &= m_DataType == other.m_DataType;
224 
226  {
227  match &= GetQuantizationScale() == other.GetQuantizationScale() &&
229  }
230  return match;
231 }
232 
234 {
235  return HasMultipleQuantizationScales() || m_Quantization.m_QuantizationDim.has_value();
236 }
237 
238 std::vector<float> TensorInfo::GetQuantizationScales() const
239 {
240  return m_Quantization.m_Scales;
241 }
242 
243 void TensorInfo::SetQuantizationScales(const std::vector<float>& scales)
244 {
245  m_Quantization.m_Scales = scales;
246 }
247 
249 {
250  if (m_Quantization.m_Scales.empty())
251  {
252  // NOTE: old default for backward compatibility
253  return 1.0f;
254  }
255 
257  return m_Quantization.m_Scales[0];
258 }
259 
261 {
262  m_Quantization.m_Scales = { scale };
263 }
264 
266 {
267  if (!m_Quantization.m_Offset.has_value())
268  {
269  // NOTE: old default for backward compatibility
270  return 0;
271  }
272 
273  return m_Quantization.m_Offset.value();
274 }
275 
277 {
278  m_Quantization.m_Offset = MakeOptional<int32_t>(offset);
279 }
280 
282 {
283  return m_Quantization.m_QuantizationDim;
284 }
285 
287 {
288  m_Quantization.m_QuantizationDim = quantizationDim;
289 }
290 
292 {
293  return IsQuantizedType(m_DataType);
294 }
295 
296 // ---
297 // --- BaseTensor
298 // ---
299 
300 template<typename MemoryType>
302  : m_MemoryArea(nullptr)
303 {
304 }
305 
306 template<typename MemoryType>
307 BaseTensor<MemoryType>::BaseTensor(const TensorInfo& info, MemoryType memoryArea)
308  : m_MemoryArea(memoryArea)
309  , m_Info(info)
310 {
311 }
312 
313 template<typename MemoryType>
315  : m_MemoryArea(other.m_MemoryArea)
316  , m_Info(other.GetInfo())
317 {
318 }
319 
320 template<typename MemoryType>
322 {
323  m_Info = other.m_Info;
324  m_MemoryArea = other.m_MemoryArea;
325  return *this;
326 }
327 
328 // Explicit instantiations.
329 template class BaseTensor<const void*>;
330 template class BaseTensor<void*>;
331 
332 } // namespace armnn
unsigned int GetNumElements() const
Definition: Tensor.cpp:107
bool operator!=(const TensorShape &other) const
Definition: Tensor.cpp:102
unsigned int operator[](unsigned int i) const
Definition: Tensor.cpp:84
bool IsTypeSpaceMatch(const TensorInfo &other) const
Check that the types are the same and, if quantize, that the quantization parameters are the same...
Definition: Tensor.cpp:219
TensorShape & operator=(const TensorShape &other)
Definition: Tensor.cpp:77
constexpr bool IsQuantizedType()
Definition: TypesUtils.hpp:236
bool HasPerAxisQuantization() const
Definition: Tensor.cpp:233
Optional< unsigned int > GetQuantizationDim() const
Definition: Tensor.cpp:281
unsigned int GetNumBytes() const
Definition: Tensor.cpp:214
Copyright (c) 2020 ARM Limited.
std::vector< float > GetQuantizationScales() const
Definition: Tensor.cpp:238
bool HasMultipleQuantizationScales() const
Definition: Tensor.hpp:98
bool operator==(const TensorShape &other) const
Definition: Tensor.cpp:96
TensorShape()
Empty (invalid) constructor.
Definition: Tensor.cpp:24
TensorInfo()
Empty (invalid) constructor.
Definition: Tensor.cpp:137
DataType
Definition: Types.hpp:32
int32_t GetQuantizationOffset() const
Definition: Tensor.cpp:265
float GetQuantizationScale() const
Definition: Tensor.cpp:248
void SetQuantizationScale(float scale)
Definition: Tensor.cpp:260
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:33
MemoryType m_MemoryArea
Definition: Tensor.hpp:184
const TensorInfo & GetInfo() const
Definition: Tensor.hpp:167
#define CHECK_LOCATION()
Definition: Exceptions.hpp:192
TensorInfo & operator=(const TensorInfo &other)
Definition: Tensor.cpp:194
void SetQuantizationDim(const Optional< unsigned int > &quantizationDim)
Definition: Tensor.cpp:286
bool operator==(const TensorInfo &other) const
Definition: Tensor.cpp:202
bool operator!=(const TensorInfo &other) const
Definition: Tensor.cpp:209
BaseTensor()
Empty (invalid) constructor.
Definition: Tensor.cpp:301
void SetQuantizationOffset(int32_t offset)
Definition: Tensor.cpp:276
void SetQuantizationScales(const std::vector< float > &scales)
Definition: Tensor.cpp:243
bool IsQuantized() const
Definition: Tensor.cpp:291
constexpr unsigned int MaxNumOfTensorDimensions
Definition: Types.hpp:18
unsigned int GetNumElements() const
Definition: Tensor.hpp:93
constexpr unsigned int GetDataTypeSize(DataType dataType)
Definition: TypesUtils.hpp:115