ArmNN
 21.08
RefPerAxisIteratorTests.cpp File Reference
#include <reference/workloads/Decoders.hpp>
#include <fmt/format.h>
#include <doctest/doctest.h>
#include <chrono>

Go to the source code of this file.

Functions

template<typename T >
void CompareVector (std::vector< T > vec1, std::vector< T > vec2)
 
 TEST_SUITE ("RefPerAxisIterator")
 

Function Documentation

◆ CompareVector()

void CompareVector ( std::vector< T >  vec1,
std::vector< T >  vec2 
)

Definition at line 15 of file RefPerAxisIteratorTests.cpp.

References TensorShape::GetNumElements(), and armnn::IgnoreUnused().

Referenced by TEST_SUITE().

16 {
17  CHECK(vec1.size() == vec2.size());
18 
19  bool mismatch = false;
20  for (uint32_t i = 0; i < vec1.size(); ++i)
21  {
22  if (vec1[i] != vec2[i])
23  {
24  MESSAGE(fmt::format("Vector value mismatch: index={} {} != {}",
25  i,
26  vec1[i],
27  vec2[i]));
28 
29  mismatch = true;
30  }
31  }
32 
33  if (mismatch)
34  {
35  FAIL("Error in CompareVector. Vectors don't match.");
36  }
37 }

◆ TEST_SUITE()

TEST_SUITE ( "RefPerAxisIterator"  )

Definition at line 80 of file RefPerAxisIteratorTests.cpp.

References CompareVector(), and armnn::QSymmS8.

81 {
82 // Test Loop (Equivalent to DecodeTensor) and Axis = 0
83 TEST_CASE("PerAxisIteratorTest1")
84 {
85  std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
86  TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
87 
88  // test axis=0
89  std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
90  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 0);
91  std::vector<int8_t> output = iterator.Loop();
92  CompareVector(output, expOutput);
93 
94  // Set iterator to index and check if the axis index is correct
95  iterator[5];
96  CHECK(iterator.GetAxisIndex() == 1u);
97 
98  iterator[1];
99  CHECK(iterator.GetAxisIndex() == 0u);
100 
101  iterator[10];
102  CHECK(iterator.GetAxisIndex() == 2u);
103 }
104 
105 // Test Axis = 1
106 TEST_CASE("PerAxisIteratorTest2")
107 {
108  std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
109  TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
110 
111  // test axis=1
112  std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
113  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
114  std::vector<int8_t> output = iterator.Loop();
115  CompareVector(output, expOutput);
116 
117  // Set iterator to index and check if the axis index is correct
118  iterator[5];
119  CHECK(iterator.GetAxisIndex() == 0u);
120 
121  iterator[1];
122  CHECK(iterator.GetAxisIndex() == 0u);
123 
124  iterator[10];
125  CHECK(iterator.GetAxisIndex() == 0u);
126 }
127 
128 // Test Axis = 2
129 TEST_CASE("PerAxisIteratorTest3")
130 {
131  std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
132  TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
133 
134  // test axis=2
135  std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
136  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
137  std::vector<int8_t> output = iterator.Loop();
138  CompareVector(output, expOutput);
139 
140  // Set iterator to index and check if the axis index is correct
141  iterator[5];
142  CHECK(iterator.GetAxisIndex() == 0u);
143 
144  iterator[1];
145  CHECK(iterator.GetAxisIndex() == 0u);
146 
147  iterator[10];
148  CHECK(iterator.GetAxisIndex() == 1u);
149 }
150 
151 // Test Axis = 3
152 TEST_CASE("PerAxisIteratorTest4")
153 {
154  std::vector<int8_t> input = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
155  TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
156 
157  // test axis=3
158  std::vector<int8_t> expOutput = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
159  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 3);
160  std::vector<int8_t> output = iterator.Loop();
161  CompareVector(output, expOutput);
162 
163  // Set iterator to index and check if the axis index is correct
164  iterator[5];
165  CHECK(iterator.GetAxisIndex() == 1u);
166 
167  iterator[1];
168  CHECK(iterator.GetAxisIndex() == 1u);
169 
170  iterator[10];
171  CHECK(iterator.GetAxisIndex() == 0u);
172 }
173 
174 // Test Axis = 1. Different tensor shape
175 TEST_CASE("PerAxisIteratorTest5")
176 {
177  using namespace armnn;
178  std::vector<int8_t> input =
179  {
180  0, 1, 2, 3,
181  4, 5, 6, 7,
182  8, 9, 10, 11,
183  12, 13, 14, 15
184  };
185 
186  std::vector<int8_t> expOutput =
187  {
188  0, 1, 2, 3,
189  4, 5, 6, 7,
190  8, 9, 10, 11,
191  12, 13, 14, 15
192  };
193 
194  TensorInfo tensorInfo ({2,2,2,2},DataType::QSymmS8);
195  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 1);
196  std::vector<int8_t> output = iterator.Loop();
197  CompareVector(output, expOutput);
198 
199  // Set iterator to index and check if the axis index is correct
200  iterator[5];
201  CHECK(iterator.GetAxisIndex() == 1u);
202 
203  iterator[1];
204  CHECK(iterator.GetAxisIndex() == 0u);
205 
206  iterator[10];
207  CHECK(iterator.GetAxisIndex() == 0u);
208 }
209 
210 // Test the increment and decrement operator
211 TEST_CASE("PerAxisIteratorTest7")
212 {
213  using namespace armnn;
214  std::vector<int8_t> input =
215  {
216  0, 1, 2, 3,
217  4, 5, 6, 7,
218  8, 9, 10, 11
219  };
220 
221  std::vector<int8_t> expOutput =
222  {
223  0, 1, 2, 3,
224  4, 5, 6, 7,
225  8, 9, 10, 11
226  };
227 
228  TensorInfo tensorInfo ({3,1,2,2},DataType::QSymmS8);
229  auto iterator = MockPerAxisIterator(input.data(), tensorInfo.GetShape(), 2);
230 
231  iterator += 3;
232  CHECK(iterator.Get() == expOutput[3]);
233  CHECK(iterator.GetAxisIndex() == 1u);
234 
235  iterator += 3;
236  CHECK(iterator.Get() == expOutput[6]);
237  CHECK(iterator.GetAxisIndex() == 1u);
238 
239  iterator -= 2;
240  CHECK(iterator.Get() == expOutput[4]);
241  CHECK(iterator.GetAxisIndex() == 0u);
242 
243  iterator -= 1;
244  CHECK(iterator.Get() == expOutput[3]);
245  CHECK(iterator.GetAxisIndex() == 1u);
246 }
247 
248 }
void CompareVector(std::vector< T > vec1, std::vector< T > vec2)
Copyright (c) 2021 ARM Limited and Contributors.