ArmNN
 22.05.01
TensorTest.cpp File Reference
#include <armnn/Tensor.hpp>
#include <armnn/utility/IgnoreUnused.hpp>
#include <doctest/doctest.h>

Go to the source code of this file.

Functions

 TEST_SUITE ("Tensor")
 

Function Documentation

◆ TEST_SUITE()

TEST_SUITE ( "Tensor"  )

Definition at line 13 of file TensorTest.cpp.

14 {
15 struct TensorInfoFixture
16 {
17  TensorInfoFixture()
18  {
19  unsigned int sizes[] = {6,7,8,9};
20  m_TensorInfo = TensorInfo(4, sizes, DataType::Float32);
21  }
22  ~TensorInfoFixture() {};
23 
24  TensorInfo m_TensorInfo;
25 };
26 
27 TEST_CASE_FIXTURE(TensorInfoFixture, "ConstructShapeUsingListInitialization")
28 {
29  TensorShape listInitializedShape{ 6, 7, 8, 9 };
30  CHECK(listInitializedShape == m_TensorInfo.GetShape());
31 }
32 
33 TEST_CASE_FIXTURE(TensorInfoFixture, "ConstructTensorInfo")
34 {
35  CHECK(m_TensorInfo.GetNumDimensions() == 4);
36  CHECK(m_TensorInfo.GetShape()[0] == 6); // <= Outer most
37  CHECK(m_TensorInfo.GetShape()[1] == 7);
38  CHECK(m_TensorInfo.GetShape()[2] == 8);
39  CHECK(m_TensorInfo.GetShape()[3] == 9); // <= Inner most
40 }
41 
42 TEST_CASE_FIXTURE(TensorInfoFixture, "CopyConstructTensorInfo")
43 {
44  TensorInfo copyConstructed(m_TensorInfo);
45  CHECK(copyConstructed.GetNumDimensions() == 4);
46  CHECK(copyConstructed.GetShape()[0] == 6);
47  CHECK(copyConstructed.GetShape()[1] == 7);
48  CHECK(copyConstructed.GetShape()[2] == 8);
49  CHECK(copyConstructed.GetShape()[3] == 9);
50 }
51 
52 TEST_CASE_FIXTURE(TensorInfoFixture, "TensorInfoEquality")
53 {
54  TensorInfo copyConstructed(m_TensorInfo);
55  CHECK(copyConstructed == m_TensorInfo);
56 }
57 
58 TEST_CASE_FIXTURE(TensorInfoFixture, "TensorInfoInequality")
59 {
60  TensorInfo other;
61  unsigned int sizes[] = {2,3,4,5};
62  other = TensorInfo(4, sizes, DataType::Float32);
63 
64  CHECK(other != m_TensorInfo);
65 }
66 
67 TEST_CASE_FIXTURE(TensorInfoFixture, "TensorInfoAssignmentOperator")
68 {
69  TensorInfo copy;
70  copy = m_TensorInfo;
71  CHECK(copy == m_TensorInfo);
72 }
73 
74 TEST_CASE("CopyNoQuantizationTensorInfo")
75 {
76  TensorInfo infoA;
77  infoA.SetShape({ 5, 6, 7, 8 });
78  infoA.SetDataType(DataType::QAsymmU8);
79 
80  TensorInfo infoB;
81  infoB.SetShape({ 5, 6, 7, 8 });
82  infoB.SetDataType(DataType::QAsymmU8);
83  infoB.SetQuantizationScale(10.0f);
84  infoB.SetQuantizationOffset(5);
86 
87  CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
88  CHECK((infoA.GetDataType() == DataType::QAsymmU8));
89  CHECK(infoA.GetQuantizationScale() == 1);
90  CHECK(infoA.GetQuantizationOffset() == 0);
91  CHECK(!infoA.GetQuantizationDim().has_value());
92 
93  CHECK(infoA != infoB);
94  infoA = infoB;
95  CHECK(infoA == infoB);
96 
97  CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
98  CHECK((infoA.GetDataType() == DataType::QAsymmU8));
99  CHECK(infoA.GetQuantizationScale() == 10.0f);
100  CHECK(infoA.GetQuantizationOffset() == 5);
101  CHECK(infoA.GetQuantizationDim().value() == 1);
102 }
103 
104 TEST_CASE("CopyDifferentQuantizationTensorInfo")
105 {
106  TensorInfo infoA;
107  infoA.SetShape({ 5, 6, 7, 8 });
108  infoA.SetDataType(DataType::QAsymmU8);
109  infoA.SetQuantizationScale(10.0f);
110  infoA.SetQuantizationOffset(5);
112 
113  TensorInfo infoB;
114  infoB.SetShape({ 5, 6, 7, 8 });
115  infoB.SetDataType(DataType::QAsymmU8);
116  infoB.SetQuantizationScale(11.0f);
117  infoB.SetQuantizationOffset(6);
119 
120  CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
121  CHECK((infoA.GetDataType() == DataType::QAsymmU8));
122  CHECK(infoA.GetQuantizationScale() == 10.0f);
123  CHECK(infoA.GetQuantizationOffset() == 5);
124  CHECK(infoA.GetQuantizationDim().value() == 1);
125 
126  CHECK(infoA != infoB);
127  infoA = infoB;
128  CHECK(infoA == infoB);
129 
130  CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
131  CHECK((infoA.GetDataType() == DataType::QAsymmU8));
132  CHECK(infoA.GetQuantizationScale() == 11.0f);
133  CHECK(infoA.GetQuantizationOffset() == 6);
134  CHECK(infoA.GetQuantizationDim().value() == 2);
135 }
136 
137 void CheckTensor(const ConstTensor& t)
138 {
139  t.GetInfo();
140 }
141 
142 TEST_CASE("TensorVsConstTensor")
143 {
144  int mutableDatum = 2;
145  const int immutableDatum = 3;
146 
147  armnn::Tensor uninitializedTensor;
148  uninitializedTensor.GetInfo().SetConstant(true);
149  armnn::ConstTensor uninitializedTensor2;
150 
151  uninitializedTensor2 = uninitializedTensor;
152 
153  armnn::TensorInfo emptyTensorInfo;
154  emptyTensorInfo.SetConstant(true);
155  armnn::Tensor t(emptyTensorInfo, &mutableDatum);
156  armnn::ConstTensor ct(emptyTensorInfo, &immutableDatum);
157 
158  // Checks that both Tensor and ConstTensor can be passed as a ConstTensor.
159  CheckTensor(t);
160  CheckTensor(ct);
161 }
162 
163 TEST_CASE("ConstTensor_EmptyConstructorTensorInfoSet")
164 {
166  CHECK(t.GetInfo().IsConstant() == true);
167 }
168 
169 TEST_CASE("ConstTensor_TensorInfoNotConstantError")
170 {
171  armnn::TensorInfo tensorInfo ({ 1 }, armnn::DataType::Float32);
172  std::vector<float> tensorData = { 1.0f };
173  try
174  {
175  armnn::ConstTensor ct(tensorInfo, tensorData);
176  FAIL("InvalidArgumentException should have been thrown");
177  }
178  catch(const InvalidArgumentException& exc)
179  {
180  CHECK(strcmp(exc.what(), "Invalid attempt to construct ConstTensor from non-constant TensorInfo.") == 0);
181  }
182 }
183 
184 TEST_CASE("PassTensorToConstTensor_TensorInfoNotConstantError")
185 {
186  try
187  {
189  FAIL("InvalidArgumentException should have been thrown");
190  }
191  catch(const InvalidArgumentException& exc)
192  {
193  CHECK(strcmp(exc.what(), "Invalid attempt to construct ConstTensor from "
194  "Tensor due to non-constant TensorInfo") == 0);
195  }
196 }
197 
198 TEST_CASE("ModifyTensorInfo")
199 {
201  info.SetShape({ 5, 6, 7, 8 });
202  CHECK((info.GetShape() == TensorShape({ 5, 6, 7, 8 })));
203  info.SetDataType(DataType::QAsymmU8);
204  CHECK((info.GetDataType() == DataType::QAsymmU8));
205  info.SetQuantizationScale(10.0f);
206  CHECK(info.GetQuantizationScale() == 10.0f);
207  info.SetQuantizationOffset(5);
208  CHECK(info.GetQuantizationOffset() == 5);
209 }
210 
211 TEST_CASE("TensorShapeOperatorBrackets")
212 {
213  const TensorShape constShape({0,1,2,3});
214  TensorShape shape({0,1,2,3});
215 
216  // Checks version of operator[] which returns an unsigned int.
217  CHECK(shape[2] == 2);
218  shape[2] = 20;
219  CHECK(shape[2] == 20);
220 
221  // Checks the version of operator[] which returns a reference.
222  CHECK(constShape[2] == 2);
223 }
224 
225 TEST_CASE("TensorInfoPerAxisQuantization")
226 {
227  // Old constructor
228  TensorInfo tensorInfo0({ 1, 1 }, DataType::Float32, 2.0f, 1);
229  CHECK(!tensorInfo0.HasMultipleQuantizationScales());
230  CHECK(tensorInfo0.GetQuantizationScale() == 2.0f);
231  CHECK(tensorInfo0.GetQuantizationOffset() == 1);
232  CHECK(tensorInfo0.GetQuantizationScales()[0] == 2.0f);
233  CHECK(!tensorInfo0.GetQuantizationDim().has_value());
234 
235  // Set per-axis quantization scales
236  std::vector<float> perAxisScales{ 3.0f, 4.0f };
237  tensorInfo0.SetQuantizationScales(perAxisScales);
238  CHECK(tensorInfo0.HasMultipleQuantizationScales());
239  CHECK(tensorInfo0.GetQuantizationScales() == perAxisScales);
240 
241  // Set per-tensor quantization scale
242  tensorInfo0.SetQuantizationScale(5.0f);
243  CHECK(!tensorInfo0.HasMultipleQuantizationScales());
244  CHECK(tensorInfo0.GetQuantizationScales()[0] == 5.0f);
245 
246  // Set quantization offset
247  tensorInfo0.SetQuantizationDim(Optional<unsigned int>(1));
248  CHECK(tensorInfo0.GetQuantizationDim().value() == 1);
249 
250  // New constructor
251  perAxisScales = { 6.0f, 7.0f };
252  TensorInfo tensorInfo1({ 1, 1 }, DataType::Float32, perAxisScales, 1);
253  CHECK(tensorInfo1.HasMultipleQuantizationScales());
254  CHECK(tensorInfo1.GetQuantizationOffset() == 0);
255  CHECK(tensorInfo1.GetQuantizationScales() == perAxisScales);
256  CHECK(tensorInfo1.GetQuantizationDim().value() == 1);
257 }
258 
259 TEST_CASE("TensorShape_scalar")
260 {
261  float mutableDatum = 3.1416f;
262 
264  armnn::TensorInfo info ( shape, DataType::Float32 );
265  const armnn::Tensor tensor ( info, &mutableDatum );
266 
267  CHECK(armnn::Dimensionality::Scalar == shape.GetDimensionality());
268  float scalarValue = *reinterpret_cast<float*>(tensor.GetMemoryArea());
269  CHECK_MESSAGE(mutableDatum == scalarValue, "Scalar value is " << scalarValue);
270 
271  armnn::TensorShape shape_equal;
272  armnn::TensorShape shape_different;
273  shape_equal = shape;
274  CHECK(shape_equal == shape);
275  CHECK(shape_different != shape);
276  CHECK_MESSAGE(1 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
277  CHECK_MESSAGE(1 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
278  CHECK(true == shape.GetDimensionSpecificity(0));
279  CHECK(shape.AreAllDimensionsSpecified());
280  CHECK(shape.IsAtLeastOneDimensionSpecified());
281 
282  CHECK(1 == shape[0]);
283  CHECK(1 == tensor.GetShape()[0]);
284  CHECK(1 == tensor.GetInfo().GetShape()[0]);
285  CHECK_THROWS_AS( shape[1], InvalidArgumentException );
286 
287  float newMutableDatum = 42.f;
288  std::memcpy(tensor.GetMemoryArea(), &newMutableDatum, sizeof(float));
289  scalarValue = *reinterpret_cast<float*>(tensor.GetMemoryArea());
290  CHECK_MESSAGE(newMutableDatum == scalarValue, "Scalar value is " << scalarValue);
291 }
292 
293 TEST_CASE("TensorShape_DynamicTensorType1_unknownNumberDimensions")
294 {
295  float mutableDatum = 3.1416f;
296 
298  armnn::TensorInfo info ( shape, DataType::Float32 );
299  armnn::Tensor tensor ( info, &mutableDatum );
300 
301  CHECK(armnn::Dimensionality::NotSpecified == shape.GetDimensionality());
302  CHECK_THROWS_AS( shape[0], InvalidArgumentException );
303  CHECK_THROWS_AS( shape.GetNumElements(), InvalidArgumentException );
304  CHECK_THROWS_AS( shape.GetNumDimensions(), InvalidArgumentException );
305 
306  armnn::TensorShape shape_equal;
307  armnn::TensorShape shape_different;
308  shape_equal = shape;
309  CHECK(shape_equal == shape);
310  CHECK(shape_different != shape);
311 }
312 
313 TEST_CASE("TensorShape_DynamicTensorType1_unknownAllDimensionsSizes")
314 {
315  float mutableDatum = 3.1416f;
316 
317  armnn::TensorShape shape ( 3, false );
318  armnn::TensorInfo info ( shape, DataType::Float32 );
319  armnn::Tensor tensor ( info, &mutableDatum );
320 
321  CHECK(armnn::Dimensionality::Specified == shape.GetDimensionality());
322  CHECK_MESSAGE(0 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
323  CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
324  CHECK(false == shape.GetDimensionSpecificity(0));
325  CHECK(false == shape.GetDimensionSpecificity(1));
326  CHECK(false == shape.GetDimensionSpecificity(2));
327  CHECK(!shape.AreAllDimensionsSpecified());
328  CHECK(!shape.IsAtLeastOneDimensionSpecified());
329 
330  armnn::TensorShape shape_equal;
331  armnn::TensorShape shape_different;
332  shape_equal = shape;
333  CHECK(shape_equal == shape);
334  CHECK(shape_different != shape);
335 }
336 
337 TEST_CASE("TensorShape_DynamicTensorType1_unknownSomeDimensionsSizes")
338 {
339  std::vector<float> mutableDatum { 42.f, 42.f, 42.f,
340  0.0f, 0.1f, 0.2f };
341 
342  armnn::TensorShape shape ( {2, 0, 3}, {true, false, true} );
343  armnn::TensorInfo info ( shape, DataType::Float32 );
344  armnn::Tensor tensor ( info, &mutableDatum );
345 
346  CHECK(armnn::Dimensionality::Specified == shape.GetDimensionality());
347  CHECK_MESSAGE(6 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
348  CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
349  CHECK(true == shape.GetDimensionSpecificity(0));
350  CHECK(false == shape.GetDimensionSpecificity(1));
351  CHECK(true == shape.GetDimensionSpecificity(2));
352  CHECK(!shape.AreAllDimensionsSpecified());
353  CHECK(shape.IsAtLeastOneDimensionSpecified());
354 
355  CHECK_THROWS_AS(shape[1], InvalidArgumentException);
356  CHECK_THROWS_AS(tensor.GetShape()[1], InvalidArgumentException);
357  CHECK_THROWS_AS(tensor.GetInfo().GetShape()[1], InvalidArgumentException);
358 
359  CHECK(2 == shape[0]);
360  CHECK(2 == tensor.GetShape()[0]);
361  CHECK(2 == tensor.GetInfo().GetShape()[0]);
362  CHECK_THROWS_AS( shape[1], InvalidArgumentException );
363 
364  CHECK(3 == shape[2]);
365  CHECK(3 == tensor.GetShape()[2]);
366  CHECK(3 == tensor.GetInfo().GetShape()[2]);
367 
368  armnn::TensorShape shape_equal;
369  armnn::TensorShape shape_different;
370  shape_equal = shape;
371  CHECK(shape_equal == shape);
372  CHECK(shape_different != shape);
373 }
374 
375 TEST_CASE("TensorShape_DynamicTensorType1_transitionFromUnknownToKnownDimensionsSizes")
376 {
377  std::vector<float> mutableDatum { 42.f, 42.f, 42.f,
378  0.0f, 0.1f, 0.2f };
379 
381  armnn::TensorInfo info ( shape, DataType::Float32 );
382  armnn::Tensor tensor ( info, &mutableDatum );
383 
384  // Specify the number of dimensions
385  shape.SetNumDimensions(3);
386  CHECK(armnn::Dimensionality::Specified == shape.GetDimensionality());
387  CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
388  CHECK(false == shape.GetDimensionSpecificity(0));
389  CHECK(false == shape.GetDimensionSpecificity(1));
390  CHECK(false == shape.GetDimensionSpecificity(2));
391  CHECK(!shape.AreAllDimensionsSpecified());
392  CHECK(!shape.IsAtLeastOneDimensionSpecified());
393 
394  // Specify dimension 0 and 2.
395  shape.SetDimensionSize(0, 2);
396  shape.SetDimensionSize(2, 3);
397  CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
398  CHECK_MESSAGE(6 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
399  CHECK(true == shape.GetDimensionSpecificity(0));
400  CHECK(false == shape.GetDimensionSpecificity(1));
401  CHECK(true == shape.GetDimensionSpecificity(2));
402  CHECK(!shape.AreAllDimensionsSpecified());
403  CHECK(shape.IsAtLeastOneDimensionSpecified());
404 
405  info.SetShape(shape);
406  armnn::Tensor tensor2( info, &mutableDatum );
407  CHECK(2 == shape[0]);
408  CHECK(2 == tensor2.GetShape()[0]);
409  CHECK(2 == tensor2.GetInfo().GetShape()[0]);
410 
411  CHECK_THROWS_AS(shape[1], InvalidArgumentException);
412  CHECK_THROWS_AS(tensor.GetShape()[1], InvalidArgumentException);
413  CHECK_THROWS_AS(tensor.GetInfo().GetShape()[1], InvalidArgumentException);
414 
415  CHECK(3 == shape[2]);
416  CHECK(3 == tensor2.GetShape()[2]);
417  CHECK(3 == tensor2.GetInfo().GetShape()[2]);
418 
419  armnn::TensorShape shape_equal;
420  armnn::TensorShape shape_different;
421  shape_equal = shape;
422  CHECK(shape_equal == shape);
423  CHECK(shape_different != shape);
424 
425  // Specify dimension 1.
426  shape.SetDimensionSize(1, 5);
427  CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
428  CHECK_MESSAGE(30 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
429  CHECK(true == shape.GetDimensionSpecificity(0));
430  CHECK(true == shape.GetDimensionSpecificity(1));
431  CHECK(true == shape.GetDimensionSpecificity(2));
432  CHECK(shape.AreAllDimensionsSpecified());
433  CHECK(shape.IsAtLeastOneDimensionSpecified());
434 }
435 
436 TEST_CASE("Tensor_emptyConstructors")
437 {
438  auto shape = armnn::TensorShape();
439  CHECK_MESSAGE( 0 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
440  CHECK_MESSAGE( 0 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
441  CHECK( armnn::Dimensionality::Specified == shape.GetDimensionality());
442  CHECK( shape.AreAllDimensionsSpecified());
443  CHECK_THROWS_AS( shape[0], InvalidArgumentException );
444 
445  auto tensor = armnn::Tensor();
446  CHECK_MESSAGE( 0 == tensor.GetNumDimensions(), "Number of dimensions is " << tensor.GetNumDimensions());
447  CHECK_MESSAGE( 0 == tensor.GetNumElements(), "Number of elements is " << tensor.GetNumElements());
448  CHECK_MESSAGE( 0 == tensor.GetShape().GetNumDimensions(), "Number of dimensions is " <<
449  tensor.GetShape().GetNumDimensions());
450  CHECK_MESSAGE( 0 == tensor.GetShape().GetNumElements(), "Number of dimensions is " <<
451  tensor.GetShape().GetNumElements());
452  CHECK( armnn::Dimensionality::Specified == tensor.GetShape().GetDimensionality());
453  CHECK( tensor.GetShape().AreAllDimensionsSpecified());
454  CHECK_THROWS_AS( tensor.GetShape()[0], InvalidArgumentException );
455 }
456 }
bool IsConstant() const
Definition: Tensor.cpp:509
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
Optional< unsigned int > GetQuantizationDim() const
Definition: Tensor.cpp:494
virtual const char * what() const noexcept override
Definition: Exceptions.cpp:32
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:193
A tensor defined by a TensorInfo (shape and data type) and a mutable backing store.
Definition: Tensor.hpp:319
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")
int32_t GetQuantizationOffset() const
Definition: Tensor.cpp:478
float GetQuantizationScale() const
Definition: Tensor.cpp:461
DataType GetDataType() const
Definition: Tensor.hpp:198
bool has_value() const noexcept
Definition: Optional.hpp:53
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:327
void SetQuantizationScale(float scale)
Definition: Tensor.cpp:473
const TensorInfo & GetInfo() const
Definition: Tensor.hpp:295
void SetDataType(DataType type)
Definition: Tensor.hpp:199
void SetQuantizationDim(const Optional< unsigned int > &quantizationDim)
Definition: Tensor.cpp:499
void SetConstant(const bool IsConstant=true)
Marks the data corresponding to this tensor info as constant.
Definition: Tensor.cpp:514
void SetQuantizationOffset(int32_t offset)
Definition: Tensor.cpp:489
void SetQuantizationScales(const std::vector< float > &scales)
Definition: Tensor.cpp:456
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195