ArmNN
 21.02
Addition.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11 
12 struct AddMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
14  AddMainFixture(const std::string& dataType)
15  {
16  m_Prototext = R"(
17  ir_version: 3
18  producer_name: "CNTK"
19  producer_version: "2.5.1"
20  domain: "ai.cntk"
21  model_version: 1
22  graph {
23  name: "CNTKGraph"
24  input {
25  name: "Input0"
26  type {
27  tensor_type {
28  elem_type: )" + dataType + R"(
29  shape {
30  dim {
31  dim_value: 1
32  }
33  dim {
34  dim_value: 1
35  }
36  dim {
37  dim_value: 2
38  }
39  dim {
40  dim_value: 2
41  }
42  }
43  }
44  }
45  }
46  input {
47  name: "Input1"
48  type {
49  tensor_type {
50  elem_type: )" + dataType + R"(
51  shape {
52  dim {
53  dim_value: 1
54  }
55  dim {
56  dim_value: 1
57  }
58  dim {
59  dim_value: 2
60  }
61  dim {
62  dim_value: 2
63  }
64  }
65  }
66  }
67  }
68  node {
69  input: "Input0"
70  input: "Input1"
71  output: "Output"
72  name: "addition"
73  op_type: "Add"
74  doc_string: ""
75  domain: ""
76  }
77  output {
78  name: "Output"
79  type {
80  tensor_type {
81  elem_type: 1
82  shape {
83  dim {
84  dim_value: 1
85  }
86  dim {
87  dim_value: 1
88  }
89  dim {
90  dim_value: 2
91  }
92  dim {
93  dim_value: 2
94  }
95  }
96  }
97  }
98  }
99  }
100  opset_import {
101  version: 7
102  })";
103  }
104 };
105 
106 struct AddValidFixture : AddMainFixture
107 {
108  AddValidFixture() : AddMainFixture("1") {
109  Setup();
110  }
111 };
112 
113 struct AddInvalidFixture : AddMainFixture
114 {
115  AddInvalidFixture() : AddMainFixture("6") { }
116 };
117 
118 struct AddValidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
119 {
120  AddValidBroadcastFixture() {
121 
122  m_Prototext = R"(
123  ir_version: 3
124  producer_name: "CNTK"
125  producer_version: "2.5.1"
126  domain: "ai.cntk"
127  model_version: 1
128  graph {
129  name: "CNTKGraph"
130  input {
131  name: "Input0"
132  type {
133  tensor_type {
134  elem_type: 1
135  shape {
136  dim {
137  dim_value: 1
138  }
139  dim {
140  dim_value: 1
141  }
142  dim {
143  dim_value: 1
144  }
145  dim {
146  dim_value: 4
147  }
148  }
149  }
150  }
151  }
152  input {
153  name: "Input1"
154  type {
155  tensor_type {
156  elem_type: 1
157  shape {
158  dim {
159  dim_value: 4
160  }
161  }
162  }
163  }
164  }
165  node {
166  input: "Input0"
167  input: "Input1"
168  output: "Output"
169  name: "addition"
170  op_type: "Add"
171  doc_string: ""
172  domain: ""
173  }
174  output {
175  name: "Output"
176  type {
177  tensor_type {
178  elem_type: 1
179  shape {
180  dim {
181  dim_value: 1
182  }
183  dim {
184  dim_value: 1
185  }
186  dim {
187  dim_value: 1
188  }
189  dim {
190  dim_value: 4
191  }
192  }
193  }
194  }
195  }
196  }
197  opset_import {
198  version: 7
199  })";
200  Setup();
201  }
202 };
203 
204 struct AddInvalidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
205 {
206  AddInvalidBroadcastFixture() {
207 
208  m_Prototext = R"(
209  ir_version: 3
210  producer_name: "CNTK"
211  producer_version: "2.5.1"
212  domain: "ai.cntk"
213  model_version: 1
214  graph {
215  name: "CNTKGraph"
216  input {
217  name: "Input0"
218  type {
219  tensor_type {
220  elem_type: 1
221  shape {
222  dim {
223  dim_value: 1
224  }
225  dim {
226  dim_value: 1
227  }
228  dim {
229  dim_value: 1
230  }
231  dim {
232  dim_value: 3
233  }
234  }
235  }
236  }
237  }
238  input {
239  name: "Input1"
240  type {
241  tensor_type {
242  elem_type: 1
243  shape {
244  dim {
245  dim_value: 4
246  }
247  }
248  }
249  }
250  }
251  node {
252  input: "Input0"
253  input: "Input1"
254  output: "Output"
255  name: "addition"
256  op_type: "Add"
257  doc_string: ""
258  domain: ""
259  }
260  output {
261  name: "Output"
262  type {
263  tensor_type {
264  elem_type: 1
265  shape {
266  dim {
267  dim_value: 1
268  }
269  dim {
270  dim_value: 1
271  }
272  dim {
273  dim_value: 1
274  }
275  dim {
276  dim_value: 4
277  }
278  }
279  }
280  }
281  }
282  }
283  opset_import {
284  version: 7
285  })";
286  }
287 };
288 
289 struct AddScalarFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
290 {
291  AddScalarFixture(const std::string& dataType)
292  {
293  m_Prototext = R"(
294  ir_version: 3
295  producer_name: "CNTK"
296  producer_version: "2.5.1"
297  domain: "ai.cntk"
298  model_version: 1
299  graph {
300  name: "CNTKGraph"
301  input {
302  name: "Input0"
303  type {
304  tensor_type {
305  elem_type: )" + dataType + R"(
306  shape {
307  dim {
308  dim_value: 1
309  }
310  dim {
311  dim_value: 1
312  }
313  dim {
314  dim_value: 2
315  }
316  dim {
317  dim_value: 2
318  }
319  }
320  }
321  }
322  }
323  input {
324  name: "Input1"
325  type {
326  tensor_type {
327  elem_type: )" + dataType + R"(
328  shape {
329  dim {
330  dim_value: 1
331  }
332  }
333  }
334  }
335  }
336  node {
337  input: "Input0"
338  input: "Input1"
339  output: "Output"
340  name: "addition"
341  op_type: "Add"
342  doc_string: ""
343  domain: ""
344  }
345  output {
346  name: "Output"
347  type {
348  tensor_type {
349  elem_type: 1
350  shape {
351  dim {
352  dim_value: 1
353  }
354  dim {
355  dim_value: 1
356  }
357  dim {
358  dim_value: 2
359  }
360  dim {
361  dim_value: 2
362  }
363  }
364  }
365  }
366  }
367  }
368  opset_import {
369  version: 7
370  })";
371  }
372 };
373 
374 struct AddValidScalarFixture : AddScalarFixture
375 {
376  AddValidScalarFixture() : AddScalarFixture("1") {
377  Setup();
378  }
379 };
380 
381 struct AddInvalidScalarFixture : AddScalarFixture
382 {
383  AddInvalidScalarFixture() : AddScalarFixture("6") { }
384 };
385 
386 BOOST_FIXTURE_TEST_CASE(ValidAddTest, AddValidFixture)
387 {
388  RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
389  {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
390 }
391 
392 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeAdd, AddInvalidFixture)
393 {
394  BOOST_CHECK_THROW(Setup(), armnn::ParseException);
395 }
396 
397 BOOST_FIXTURE_TEST_CASE(InvalidBroadcastAdd, AddInvalidBroadcastFixture)
398 {
399  BOOST_CHECK_THROW(Setup(), armnn::ParseException);
400 }
401 
402 BOOST_FIXTURE_TEST_CASE(ValidBroadcastAdd, AddValidBroadcastFixture)
403 {
404  RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
405  {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
406 }
407 
408 BOOST_FIXTURE_TEST_CASE(ValidAddScalarTest, AddValidScalarFixture)
409 {
410  RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
411  {"Input1", {-8.0f}}}, {{"Output", {-7.0, -6.0, -11.0, -12.0}}});
412 }
413 
414 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeAddScalar, AddInvalidScalarFixture)
415 {
416  BOOST_CHECK_THROW(Setup(), armnn::ParseException);
417 }
418 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ValidAddTest, AddValidFixture)
Definition: Addition.cpp:386
BOOST_AUTO_TEST_SUITE_END()