ArmNN
 20.02
Tensor.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnn/Tensor.hpp"
7 #include "armnn/Utils.hpp"
8 #include "armnn/Exceptions.hpp"
9 #include "armnn/TypesUtils.hpp"
10 
11 #include <boost/assert.hpp>
12 #include <boost/numeric/conversion/cast.hpp>
13 
14 #include <sstream>
15 
16 namespace armnn
17 {
18 
19 // ---
20 // --- TensorShape
21 // ---
22 
24  : m_NumDimensions(0)
25 {
26 }
27 
28 TensorShape::TensorShape(unsigned int numDimensions)
29  : m_NumDimensions(numDimensions)
30 {
31  if (numDimensions < 1)
32  {
33  throw InvalidArgumentException("Tensor numDimensions must be greater than 0");
34  }
35 
36  if (numDimensions > MaxNumOfTensorDimensions)
37  {
38  throw InvalidArgumentException("Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions");
39  }
40 
41  std::fill(m_Dimensions.begin(), m_Dimensions.begin() + m_NumDimensions, 0);
42 }
43 
44 TensorShape::TensorShape(const unsigned int numDimensions, const unsigned int* const dimensionSizes)
45  : m_NumDimensions(numDimensions)
46 {
47  if (numDimensions < 1)
48  {
49  throw InvalidArgumentException("Tensor numDimensions must be greater than 0");
50  }
51 
52  if (numDimensions > MaxNumOfTensorDimensions)
53  {
54  throw InvalidArgumentException("Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions");
55  }
56 
57  if (dimensionSizes == nullptr)
58  {
59  throw InvalidArgumentException("Tensor dimensionSizes must not be NULL");
60  }
61 
62  std::copy(dimensionSizes, dimensionSizes + numDimensions, m_Dimensions.begin());
63 }
64 
65 TensorShape::TensorShape(std::initializer_list<unsigned int> dimensionSizeList)
66  : TensorShape(boost::numeric_cast<unsigned int>(dimensionSizeList.size()), dimensionSizeList.begin())
67 {
68 }
69 
71  : m_NumDimensions(other.m_NumDimensions)
72 {
73  std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
74 }
75 
77 {
78  m_NumDimensions = other.m_NumDimensions;
79  std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
80  return *this;
81 }
82 
83 unsigned int TensorShape::operator[](unsigned int i) const
84 {
85  CheckDimensionIndex(i);
86  return m_Dimensions.at(i);
87 }
88 
89 unsigned int& TensorShape::operator[](unsigned int i)
90 {
91  CheckDimensionIndex(i);
92  return m_Dimensions.at(i);
93 }
94 
95 bool TensorShape::operator==(const TensorShape& other) const
96 {
97  return ((m_NumDimensions == other.m_NumDimensions) &&
98  std::equal(m_Dimensions.cbegin(), m_Dimensions.cbegin() + m_NumDimensions, other.m_Dimensions.cbegin()));
99 }
100 
101 bool TensorShape::operator!=(const TensorShape& other) const
102 {
103  return !(*this == other);
104 }
105 
106 unsigned int TensorShape::GetNumElements() const
107 {
108  if (m_NumDimensions == 0)
109  {
110  return 0;
111  }
112 
113  unsigned int count = 1;
114  for (unsigned int i = 0; i < m_NumDimensions; i++)
115  {
116  count *= m_Dimensions[i];
117  }
118 
119  return count;
120 }
121 
122 void TensorShape::CheckDimensionIndex(unsigned int i) const
123 {
124  if (i >= m_NumDimensions)
125  {
126  std::stringstream errorMessage;
127  errorMessage << "Invalid dimension index: " << i << " (number of dimensions is " << m_NumDimensions << ")";
128  throw InvalidArgumentException(errorMessage.str(), CHECK_LOCATION());
129  }
130 }
131 
132 // ---
133 // --- TensorInfo
134 // ---
135 
137 : m_DataType(DataType::Float32)
138 {
139 }
140 
142  DataType dataType,
143  float quantizationScale,
144  int32_t quantizationOffset)
145  : m_Shape(shape)
146  , m_DataType(dataType)
147 {
148  SetQuantizationScale(quantizationScale);
149  SetQuantizationOffset(quantizationOffset);
150 }
151 
152 TensorInfo::TensorInfo(unsigned int numDimensions,
153  const unsigned int* dimensionSizes,
154  DataType dataType,
155  float quantizationScale,
156  int32_t quantizationOffset)
157  : m_Shape(numDimensions, dimensionSizes)
158  , m_DataType(dataType)
159 {
160  SetQuantizationScale(quantizationScale);
161  SetQuantizationOffset(quantizationOffset);
162 }
163 
165  DataType dataType,
166  const std::vector<float>& quantizationScales,
167  unsigned int quantizationDim)
168  : m_Shape(shape)
169  , m_DataType(dataType)
170 {
171  SetQuantizationScales(quantizationScales);
172  SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim));
173 }
174 
175 TensorInfo::TensorInfo(unsigned int numDimensions,
176  const unsigned int* dimensionSizes,
177  DataType dataType,
178  const std::vector<float>& quantizationScales,
179  unsigned int quantizationDim)
180  : m_Shape(numDimensions, dimensionSizes)
181  , m_DataType(dataType)
182 {
183  SetQuantizationScales(quantizationScales);
184  SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim));
185 }
186 
188 : m_Shape(other.m_Shape)
189 , m_DataType(other.m_DataType)
190 , m_Quantization(other.m_Quantization)
191 {}
192 
194 {
195  m_Shape = other.m_Shape;
196  m_DataType = other.m_DataType;
197  m_Quantization = other.m_Quantization;
198  return *this;
199 }
200 
201 bool TensorInfo::operator==(const TensorInfo& other) const
202 {
203  return ((m_Shape == other.m_Shape) &&
204  (m_DataType == other.m_DataType) &&
205  (m_Quantization == other.m_Quantization));
206 }
207 
208 bool TensorInfo::operator!=(const TensorInfo& other) const
209 {
210  return !(*this == other);
211 }
212 
213 unsigned int TensorInfo::GetNumBytes() const
214 {
215  return GetDataTypeSize(m_DataType) * GetNumElements();
216 }
217 
218 bool TensorInfo::IsTypeSpaceMatch(const TensorInfo& other) const
219 {
220  bool match = true;
221 
222  match &= m_DataType == other.m_DataType;
223 
225  {
226  match &= GetQuantizationScale() == other.GetQuantizationScale() &&
228  }
229  return match;
230 }
231 
233 {
234  return HasMultipleQuantizationScales() || m_Quantization.m_QuantizationDim.has_value();
235 }
236 
237 std::vector<float> TensorInfo::GetQuantizationScales() const
238 {
239  return m_Quantization.m_Scales;
240 }
241 
242 void TensorInfo::SetQuantizationScales(const std::vector<float>& scales)
243 {
244  m_Quantization.m_Scales = scales;
245 }
246 
248 {
249  if (m_Quantization.m_Scales.empty())
250  {
251  // NOTE: old default for backward compatibility
252  return 1.0f;
253  }
254 
255  BOOST_ASSERT(!HasMultipleQuantizationScales());
256  return m_Quantization.m_Scales[0];
257 }
258 
260 {
261  m_Quantization.m_Scales = { scale };
262 }
263 
265 {
266  if (!m_Quantization.m_Offset.has_value())
267  {
268  // NOTE: old default for backward compatibility
269  return 0;
270  }
271 
272  return m_Quantization.m_Offset.value();
273 }
274 
276 {
277  m_Quantization.m_Offset = MakeOptional<int32_t>(offset);
278 }
279 
281 {
282  return m_Quantization.m_QuantizationDim;
283 }
284 
286 {
287  m_Quantization.m_QuantizationDim = quantizationDim;
288 }
289 
291 {
292  return IsQuantizedType(m_DataType);
293 }
294 
295 // ---
296 // --- BaseTensor
297 // ---
298 
299 template<typename MemoryType>
301  : m_MemoryArea(nullptr)
302 {
303 }
304 
305 template<typename MemoryType>
306 BaseTensor<MemoryType>::BaseTensor(const TensorInfo& info, MemoryType memoryArea)
307  : m_MemoryArea(memoryArea)
308  , m_Info(info)
309 {
310 }
311 
312 template<typename MemoryType>
314  : m_MemoryArea(other.m_MemoryArea)
315  , m_Info(other.GetInfo())
316 {
317 }
318 
319 template<typename MemoryType>
321 {
322  m_Info = other.m_Info;
323  m_MemoryArea = other.m_MemoryArea;
324  return *this;
325 }
326 
327 // Explicit instantiations.
328 template class BaseTensor<const void*>;
329 template class BaseTensor<void*>;
330 
331 } // namespace armnn
unsigned int GetNumElements() const
Definition: Tensor.cpp:106
bool operator!=(const TensorShape &other) const
Definition: Tensor.cpp:101
unsigned int operator[](unsigned int i) const
Definition: Tensor.cpp:83
bool IsTypeSpaceMatch(const TensorInfo &other) const
Check that the types are the same and, if quantize, that the quantization parameters are the same...
Definition: Tensor.cpp:218
TensorShape & operator=(const TensorShape &other)
Definition: Tensor.cpp:76
constexpr bool IsQuantizedType()
Definition: TypesUtils.hpp:236
bool HasPerAxisQuantization() const
Definition: Tensor.cpp:232
Optional< unsigned int > GetQuantizationDim() const
Definition: Tensor.cpp:280
unsigned int GetNumBytes() const
Definition: Tensor.cpp:213
Copyright (c) 2020 ARM Limited.
std::vector< float > GetQuantizationScales() const
Definition: Tensor.cpp:237
bool HasMultipleQuantizationScales() const
Definition: Tensor.hpp:98
bool operator==(const TensorShape &other) const
Definition: Tensor.cpp:95
TensorShape()
Empty (invalid) constructor.
Definition: Tensor.cpp:23
TensorInfo()
Empty (invalid) constructor.
Definition: Tensor.cpp:136
DataType
Definition: Types.hpp:32
int32_t GetQuantizationOffset() const
Definition: Tensor.cpp:264
float GetQuantizationScale() const
Definition: Tensor.cpp:247
void SetQuantizationScale(float scale)
Definition: Tensor.cpp:259
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:33
MemoryType m_MemoryArea
Definition: Tensor.hpp:184
const TensorInfo & GetInfo() const
Definition: Tensor.hpp:167
#define CHECK_LOCATION()
Definition: Exceptions.hpp:192
TensorInfo & operator=(const TensorInfo &other)
Definition: Tensor.cpp:193
void SetQuantizationDim(const Optional< unsigned int > &quantizationDim)
Definition: Tensor.cpp:285
bool operator==(const TensorInfo &other) const
Definition: Tensor.cpp:201
bool operator!=(const TensorInfo &other) const
Definition: Tensor.cpp:208
BaseTensor()
Empty (invalid) constructor.
Definition: Tensor.cpp:300
void SetQuantizationOffset(int32_t offset)
Definition: Tensor.cpp:275
void SetQuantizationScales(const std::vector< float > &scales)
Definition: Tensor.cpp:242
bool IsQuantized() const
Definition: Tensor.cpp:290
constexpr unsigned int MaxNumOfTensorDimensions
Definition: Types.hpp:18
unsigned int GetNumElements() const
Definition: Tensor.hpp:93
constexpr unsigned int GetDataTypeSize(DataType dataType)
Definition: TypesUtils.hpp:115