ArmNN
 20.02
DeserializeComparison.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 #include "../Deserializer.hpp"
8 
9 #include <QuantizeHelper.hpp>
10 #include <ResolveType.hpp>
11 
12 #include <boost/test/unit_test.hpp>
13 
14 #include <string>
15 
16 BOOST_AUTO_TEST_SUITE(Deserializer)
17 
18 #define DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
19 struct Simple##operation##dataType##Fixture : public SimpleComparisonFixture \
20 { \
21  Simple##operation##dataType##Fixture() \
22  : SimpleComparisonFixture(#dataType, #operation) {} \
23 };
24 
25 #define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType) \
26 DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
27 BOOST_FIXTURE_TEST_CASE(operation##dataType, Simple##operation##dataType##Fixture) \
28 { \
29  using T = armnn::ResolveType<armnn::DataType::dataType>; \
30  constexpr float qScale = 1.f; \
31  constexpr int32_t qOffset = 0; \
32  RunTest<4, armnn::DataType::dataType, armnn::DataType::Boolean>( \
33  0, \
34  {{ "InputLayer0", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData0, qScale, qOffset) }, \
35  { "InputLayer1", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData1, qScale, qOffset) }}, \
36  {{ "OutputLayer", s_TestData.m_Output##operation }}); \
37 }
38 
39 struct ComparisonFixture : public ParserFlatbuffersSerializeFixture
40 {
41  explicit ComparisonFixture(const std::string& inputShape0,
42  const std::string& inputShape1,
43  const std::string& outputShape,
44  const std::string& inputDataType,
45  const std::string& comparisonOperation)
46  {
47  m_JsonString = R"(
48  {
49  inputIds: [0, 1],
50  outputIds: [3],
51  layers: [
52  {
53  layer_type: "InputLayer",
54  layer: {
55  base: {
56  layerBindingId: 0,
57  base: {
58  index: 0,
59  layerName: "InputLayer0",
60  layerType: "Input",
61  inputSlots: [{
62  index: 0,
63  connection: { sourceLayerIndex:0, outputSlotIndex:0 },
64  }],
65  outputSlots: [{
66  index: 0,
67  tensorInfo: {
68  dimensions: )" + inputShape0 + R"(,
69  dataType: )" + inputDataType + R"(
70  },
71  }],
72  },
73  }
74  },
75  },
76  {
77  layer_type: "InputLayer",
78  layer: {
79  base: {
80  layerBindingId: 1,
81  base: {
82  index:1,
83  layerName: "InputLayer1",
84  layerType: "Input",
85  inputSlots: [{
86  index: 0,
87  connection: { sourceLayerIndex:0, outputSlotIndex:0 },
88  }],
89  outputSlots: [{
90  index: 0,
91  tensorInfo: {
92  dimensions: )" + inputShape1 + R"(,
93  dataType: )" + inputDataType + R"(
94  },
95  }],
96  },
97  }
98  },
99  },
100  {
101  layer_type: "ComparisonLayer",
102  layer: {
103  base: {
104  index:2,
105  layerName: "ComparisonLayer",
106  layerType: "Comparison",
107  inputSlots: [{
108  index: 0,
109  connection: { sourceLayerIndex:0, outputSlotIndex:0 },
110  },
111  {
112  index: 1,
113  connection: { sourceLayerIndex:1, outputSlotIndex:0 },
114  }],
115  outputSlots: [{
116  index: 0,
117  tensorInfo: {
118  dimensions: )" + outputShape + R"(,
119  dataType: Boolean
120  },
121  }],
122  },
123  descriptor: {
124  operation: )" + comparisonOperation + R"(
125  }
126  },
127  },
128  {
129  layer_type: "OutputLayer",
130  layer: {
131  base:{
132  layerBindingId: 0,
133  base: {
134  index: 3,
135  layerName: "OutputLayer",
136  layerType: "Output",
137  inputSlots: [{
138  index: 0,
139  connection: { sourceLayerIndex:2, outputSlotIndex:0 },
140  }],
141  outputSlots: [{
142  index: 0,
143  tensorInfo: {
144  dimensions: )" + outputShape + R"(,
145  dataType: Boolean
146  },
147  }],
148  }
149  }
150  },
151  }
152  ]
153  }
154  )";
155  Setup();
156  }
157 };
158 
159 struct SimpleComparisonTestData
160 {
161  SimpleComparisonTestData()
162  {
163  m_InputData0 =
164  {
165  1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
166  3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
167  };
168 
169  m_InputData1 =
170  {
171  1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f,
172  5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f
173  };
174 
175  m_OutputEqual =
176  {
177  1, 1, 1, 1, 0, 0, 0, 0,
178  0, 0, 0, 0, 1, 1, 1, 1
179  };
180 
181  m_OutputGreater =
182  {
183  0, 0, 0, 0, 1, 1, 1, 1,
184  0, 0, 0, 0, 0, 0, 0, 0
185  };
186 
187  m_OutputGreaterOrEqual =
188  {
189  1, 1, 1, 1, 1, 1, 1, 1,
190  0, 0, 0, 0, 1, 1, 1, 1
191  };
192 
193  m_OutputLess =
194  {
195  0, 0, 0, 0, 0, 0, 0, 0,
196  1, 1, 1, 1, 0, 0, 0, 0
197  };
198 
199  m_OutputLessOrEqual =
200  {
201  1, 1, 1, 1, 0, 0, 0, 0,
202  1, 1, 1, 1, 1, 1, 1, 1
203  };
204 
205  m_OutputNotEqual =
206  {
207  0, 0, 0, 0, 1, 1, 1, 1,
208  1, 1, 1, 1, 0, 0, 0, 0
209  };
210  }
211 
212  std::vector<float> m_InputData0;
213  std::vector<float> m_InputData1;
214 
215  std::vector<uint8_t> m_OutputEqual;
216  std::vector<uint8_t> m_OutputGreater;
217  std::vector<uint8_t> m_OutputGreaterOrEqual;
218  std::vector<uint8_t> m_OutputLess;
219  std::vector<uint8_t> m_OutputLessOrEqual;
220  std::vector<uint8_t> m_OutputNotEqual;
221 };
222 
223 struct SimpleComparisonFixture : public ComparisonFixture
224 {
225  SimpleComparisonFixture(const std::string& inputDataType,
226  const std::string& comparisonOperation)
227  : ComparisonFixture("[ 2, 2, 2, 2 ]", // inputShape0
228  "[ 2, 2, 2, 2 ]", // inputShape1
229  "[ 2, 2, 2, 2 ]", // outputShape,
230  inputDataType,
231  comparisonOperation) {}
232 
233  static SimpleComparisonTestData s_TestData;
234 };
235 
236 SimpleComparisonTestData SimpleComparisonFixture::s_TestData;
237 
240 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, Float32)
244 
245 
247 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, QuantisedAsymm8)
248 DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, QuantisedAsymm8)
249 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QuantisedAsymm8)
251 DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, QuantisedAsymm8)
252 DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, QuantisedAsymm8)
254 
257 DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QAsymmU8)
261 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
Definition: Deprecated.hpp:33
#define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType)
#define ARMNN_NO_DEPRECATE_WARN_END
Definition: Deprecated.hpp:34
BOOST_AUTO_TEST_SUITE_END()