ArmNN
 21.11
RefPerAxisIteratorTests.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
8 #include <fmt/format.h>
9 
10 #include <doctest/doctest.h>
11 
12 #include <chrono>
13 
14 template<typename T>
15 void CompareVector(std::vector<T> vec1, std::vector<T> vec2)
16 {
17  CHECK(vec1.size() == vec2.size());
18 
19  bool mismatch = false;
20  for (uint32_t i = 0; i < vec1.size(); ++i)
21  {
22  if (vec1[i] != vec2[i])
23  {
24  MESSAGE(fmt::format("Vector value mismatch: index={} {} != {}",
25  i,
26  vec1[i],
27  vec2[i]));
28 
29  mismatch = true;
30  }
31  }
32 
33  if (mismatch)
34  {
35  FAIL("Error in CompareVector. Vectors don't match.");
36  }
37 }
38 
39 using namespace armnn;
40 
41 // Basically a per axis decoder but without any decoding/quantization
42 class MockPerAxisIterator : public PerAxisIterator<const int8_t, Decoder<int8_t>>
43 {
44 public:
45  MockPerAxisIterator(const int8_t* data, const armnn::TensorShape& tensorShape, const unsigned int axis)
46  : PerAxisIterator(data, tensorShape, axis), m_NumElements(tensorShape.GetNumElements())
47  {}
48 
49  int8_t Get() const override
50  {
51  return *m_Iterator;
52  }
53 
54  virtual std::vector<float> DecodeTensor(const TensorShape &tensorShape,
55  bool isDepthwise = false) override
56  {
57  IgnoreUnused(tensorShape, isDepthwise);
58  return std::vector<float>{};
59  };
60 
61  // Iterates over data using operator[] and returns vector
62  std::vector<int8_t> Loop()
63  {
64  std::vector<int8_t> vec;
65  for (uint32_t i = 0; i < m_NumElements; ++i)
66  {
67  this->operator[](i);
68  vec.emplace_back(Get());
69  }
70  return vec;
71  }
72 
73  unsigned int GetAxisIndex()
74  {
75  return m_AxisIndex;
76  }
77  unsigned int m_NumElements;
78 };
79 
80 TEST_SUITE("RefPerAxisIterator")
81 {
82 // Test Loop (Equivalent to DecodeTensor) and Axis = 0
83 TEST_CASE("PerAxisIteratorTest1")
84 {
85  std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
86  TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
87 
88  // test axis=0
89  std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
90  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 0);
91  std::vector<int8_t> output = iterator.Loop();
92  CompareVector(output, expOutput);
93 
94  // Set iterator to index and check if the axis index is correct
95  iterator[5];
96  CHECK(iterator.GetAxisIndex() == 1u);
97 
98  iterator[1];
99  CHECK(iterator.GetAxisIndex() == 0u);
100 
101  iterator[10];
102  CHECK(iterator.GetAxisIndex() == 2u);
103 }
104 
105 // Test Axis = 1
106 TEST_CASE("PerAxisIteratorTest2")
107 {
108  std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
109  TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
110 
111  // test axis=1
112  std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
113  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
114  std::vector<int8_t> output = iterator.Loop();
115  CompareVector(output, expOutput);
116 
117  // Set iterator to index and check if the axis index is correct
118  iterator[5];
119  CHECK(iterator.GetAxisIndex() == 0u);
120 
121  iterator[1];
122  CHECK(iterator.GetAxisIndex() == 0u);
123 
124  iterator[10];
125  CHECK(iterator.GetAxisIndex() == 0u);
126 }
127 
128 // Test Axis = 2
129 TEST_CASE("PerAxisIteratorTest3")
130 {
131  std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
132  TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
133 
134  // test axis=2
135  std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
136  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
137  std::vector<int8_t> output = iterator.Loop();
138  CompareVector(output, expOutput);
139 
140  // Set iterator to index and check if the axis index is correct
141  iterator[5];
142  CHECK(iterator.GetAxisIndex() == 0u);
143 
144  iterator[1];
145  CHECK(iterator.GetAxisIndex() == 0u);
146 
147  iterator[10];
148  CHECK(iterator.GetAxisIndex() == 1u);
149 }
150 
151 // Test Axis = 3
152 TEST_CASE("PerAxisIteratorTest4")
153 {
154  std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
155  TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
156 
157  // test axis=3
158  std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
159  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 3);
160  std::vector<int8_t> output = iterator.Loop();
161  CompareVector(output, expOutput);
162 
163  // Set iterator to index and check if the axis index is correct
164  iterator[5];
165  CHECK(iterator.GetAxisIndex() == 1u);
166 
167  iterator[1];
168  CHECK(iterator.GetAxisIndex() == 1u);
169 
170  iterator[10];
171  CHECK(iterator.GetAxisIndex() == 0u);
172 }
173 
174 // Test Axis = 1. Different tensor shape
175 TEST_CASE("PerAxisIteratorTest5")
176 {
177  using namespace armnn;
178  std::vector<int8_t> input =
179  {
180  0, 1, 2, 3,
181  4, 5, 6, 7,
182  8, 9, 10, 11,
183  12, 13, 14, 15
184  };
185 
186  std::vector<int8_t> expOutput =
187  {
188  0, 1, 2, 3,
189  4, 5, 6, 7,
190  8, 9, 10, 11,
191  12, 13, 14, 15
192  };
193 
194  TensorInfo tensorInfo ({2,2,2,2},DataType::QSymmS8);
195  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
196  std::vector<int8_t> output = iterator.Loop();
197  CompareVector(output, expOutput);
198 
199  // Set iterator to index and check if the axis index is correct
200  iterator[5];
201  CHECK(iterator.GetAxisIndex() == 1u);
202 
203  iterator[1];
204  CHECK(iterator.GetAxisIndex() == 0u);
205 
206  iterator[10];
207  CHECK(iterator.GetAxisIndex() == 0u);
208 }
209 
210 // Test the increment and decrement operator
211 TEST_CASE("PerAxisIteratorTest7")
212 {
213  using namespace armnn;
214  std::vector<int8_t> input =
215  {
216  0, 1, 2, 3,
217  4, 5, 6, 7,
218  8, 9, 10, 11
219  };
220 
221  std::vector<int8_t> expOutput =
222  {
223  0, 1, 2, 3,
224  4, 5, 6, 7,
225  8, 9, 10, 11
226  };
227 
228  TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
229  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
230 
231  iterator += 3;
232  CHECK(iterator.Get() == expOutput[3]);
233  CHECK(iterator.GetAxisIndex() == 1u);
234 
235  iterator += 3;
236  CHECK(iterator.Get() == expOutput[6]);
237  CHECK(iterator.GetAxisIndex() == 1u);
238 
239  iterator -= 2;
240  CHECK(iterator.Get() == expOutput[4]);
241  CHECK(iterator.GetAxisIndex() == 0u);
242 
243  iterator -= 1;
244  CHECK(iterator.Get() == expOutput[3]);
245  CHECK(iterator.GetAxisIndex() == 1u);
246 }
247 
248 }
TEST_SUITE("TestConstTensorLayerVisitor")
unsigned int GetNumElements() const
Function that calculates the tensor elements by multiplying all dimension size which are Specified...
Definition: Tensor.cpp:181
PerAxisIterator for per-axis quantization.
void CompareVector(std::vector< T > vec1, std::vector< T > vec2)
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)