ArmNN
 21.08
DeserializeComparison.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
8 
9 #include <QuantizeHelper.hpp>
10 #include <ResolveType.hpp>
11 
12 #include <string>
13 
14 TEST_SUITE("Deserializer_Comparison")
15 {
16 #define DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
17 struct Simple##operation##dataType##Fixture : public SimpleComparisonFixture \
18 { \
19  Simple##operation##dataType##Fixture() \
20  : SimpleComparisonFixture(#dataType, #operation) {} \
21 };
22 
23 #define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType) \
24 DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
25 TEST_CASE_FIXTURE(Simple##operation##dataType##Fixture, #operation#dataType) \
26 { \
27  using T = armnn::ResolveType<armnn::DataType::dataType>; \
28  constexpr float qScale = 1.f; \
29  constexpr int32_t qOffset = 0; \
30  RunTest<4, armnn::DataType::dataType, armnn::DataType::Boolean>( \
31  0, \
32  {{ "InputLayer0", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData0, qScale, qOffset) }, \
33  { "InputLayer1", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData1, qScale, qOffset) }}, \
34  {{ "OutputLayer", s_TestData.m_Output##operation }}); \
35 }
36 
37 struct ComparisonFixture : public ParserFlatbuffersSerializeFixture
38 {
39  explicit ComparisonFixture(const std::string& inputShape0,
40  const std::string& inputShape1,
41  const std::string& outputShape,
42  const std::string& inputDataType,
43  const std::string& comparisonOperation)
44  {
45  m_JsonString = R"(
46  {
47  inputIds: [0, 1],
48  outputIds: [3],
49  layers: [
50  {
51  layer_type: "InputLayer",
52  layer: {
53  base: {
54  layerBindingId: 0,
55  base: {
56  index: 0,
57  layerName: "InputLayer0",
58  layerType: "Input",
59  inputSlots: [{
60  index: 0,
61  connection: { sourceLayerIndex:0, outputSlotIndex:0 },
62  }],
63  outputSlots: [{
64  index: 0,
65  tensorInfo: {
66  dimensions: )" + inputShape0 + R"(,
67  dataType: )" + inputDataType + R"(
68  },
69  }],
70  },
71  }
72  },
73  },
74  {
75  layer_type: "InputLayer",
76  layer: {
77  base: {
78  layerBindingId: 1,
79  base: {
80  index:1,
81  layerName: "InputLayer1",
82  layerType: "Input",
83  inputSlots: [{
84  index: 0,
85  connection: { sourceLayerIndex:0, outputSlotIndex:0 },
86  }],
87  outputSlots: [{
88  index: 0,
89  tensorInfo: {
90  dimensions: )" + inputShape1 + R"(,
91  dataType: )" + inputDataType + R"(
92  },
93  }],
94  },
95  }
96  },
97  },
98  {
99  layer_type: "ComparisonLayer",
100  layer: {
101  base: {
102  index:2,
103  layerName: "ComparisonLayer",
104  layerType: "Comparison",
105  inputSlots: [{
106  index: 0,
107  connection: { sourceLayerIndex:0, outputSlotIndex:0 },
108  },
109  {
110  index: 1,
111  connection: { sourceLayerIndex:1, outputSlotIndex:0 },
112  }],
113  outputSlots: [{
114  index: 0,
115  tensorInfo: {
116  dimensions: )" + outputShape + R"(,
117  dataType: Boolean
118  },
119  }],
120  },
121  descriptor: {
122  operation: )" + comparisonOperation + R"(
123  }
124  },
125  },
126  {
127  layer_type: "OutputLayer",
128  layer: {
129  base:{
130  layerBindingId: 0,
131  base: {
132  index: 3,
133  layerName: "OutputLayer",
134  layerType: "Output",
135  inputSlots: [{
136  index: 0,
137  connection: { sourceLayerIndex:2, outputSlotIndex:0 },
138  }],
139  outputSlots: [{
140  index: 0,
141  tensorInfo: {
142  dimensions: )" + outputShape + R"(,
143  dataType: Boolean
144  },
145  }],
146  }
147  }
148  },
149  }
150  ]
151  }
152  )";
153  Setup();
154  }
155 };
156 
157 struct SimpleComparisonTestData
158 {
159  SimpleComparisonTestData()
160  {
161  m_InputData0 =
162  {
163  1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
164  3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
165  };
166 
167  m_InputData1 =
168  {
169  1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f,
170  5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f
171  };
172 
173  m_OutputEqual =
174  {
175  1, 1, 1, 1, 0, 0, 0, 0,
176  0, 0, 0, 0, 1, 1, 1, 1
177  };
178 
179  m_OutputGreater =
180  {
181  0, 0, 0, 0, 1, 1, 1, 1,
182  0, 0, 0, 0, 0, 0, 0, 0
183  };
184 
185  m_OutputGreaterOrEqual =
186  {
187  1, 1, 1, 1, 1, 1, 1, 1,
188  0, 0, 0, 0, 1, 1, 1, 1
189  };
190 
191  m_OutputLess =
192  {
193  0, 0, 0, 0, 0, 0, 0, 0,
194  1, 1, 1, 1, 0, 0, 0, 0
195  };
196 
197  m_OutputLessOrEqual =
198  {
199  1, 1, 1, 1, 0, 0, 0, 0,
200  1, 1, 1, 1, 1, 1, 1, 1
201  };
202 
203  m_OutputNotEqual =
204  {
205  0, 0, 0, 0, 1, 1, 1, 1,
206  1, 1, 1, 1, 0, 0, 0, 0
207  };
208  }
209 
210  std::vector<float> m_InputData0;
211  std::vector<float> m_InputData1;
212 
213  std::vector<uint8_t> m_OutputEqual;
214  std::vector<uint8_t> m_OutputGreater;
215  std::vector<uint8_t> m_OutputGreaterOrEqual;
216  std::vector<uint8_t> m_OutputLess;
217  std::vector<uint8_t> m_OutputLessOrEqual;
218  std::vector<uint8_t> m_OutputNotEqual;
219 };
220 
221 struct SimpleComparisonFixture : public ComparisonFixture
222 {
223  SimpleComparisonFixture(const std::string& inputDataType,
224  const std::string& comparisonOperation)
225  : ComparisonFixture("[ 2, 2, 2, 2 ]", // inputShape0
226  "[ 2, 2, 2, 2 ]", // inputShape1
227  "[ 2, 2, 2, 2 ]", // outputShape,
228  inputDataType,
229  comparisonOperation) {}
230 
231  static SimpleComparisonTestData s_TestData;
232 };
233 
234 SimpleComparisonTestData SimpleComparisonFixture::s_TestData;
235 
237 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, Float32)
238 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, Float32)
240 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, Float32)
241 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, Float32)
242 
243 
245 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, QuantisedAsymm8)
246 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, QuantisedAsymm8)
247 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QuantisedAsymm8)
248 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less, QuantisedAsymm8)
249 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, QuantisedAsymm8)
250 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, QuantisedAsymm8)
252 
254 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, QAsymmU8)
255 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QAsymmU8)
257 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, QAsymmU8)
258 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, QAsymmU8)
259 
260 }
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
Definition: Deprecated.hpp:33
#define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType)
#define ARMNN_NO_DEPRECATE_WARN_END
Definition: Deprecated.hpp:34
TEST_SUITE("Deserializer_Comparison")