ArmNN
 22.08
Addition.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
8 
9 TEST_SUITE("OnnxParser_Addition")
10 {
11 struct AddMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12 {
13  AddMainFixture(const std::string& dataType)
14  {
15  m_Prototext = R"(
16  ir_version: 3
17  producer_name: "CNTK"
18  producer_version: "2.5.1"
19  domain: "ai.cntk"
20  model_version: 1
21  graph {
22  name: "CNTKGraph"
23  input {
24  name: "Input0"
25  type {
26  tensor_type {
27  elem_type: )" + dataType + R"(
28  shape {
29  dim {
30  dim_value: 1
31  }
32  dim {
33  dim_value: 1
34  }
35  dim {
36  dim_value: 2
37  }
38  dim {
39  dim_value: 2
40  }
41  }
42  }
43  }
44  }
45  input {
46  name: "Input1"
47  type {
48  tensor_type {
49  elem_type: )" + dataType + R"(
50  shape {
51  dim {
52  dim_value: 1
53  }
54  dim {
55  dim_value: 1
56  }
57  dim {
58  dim_value: 2
59  }
60  dim {
61  dim_value: 2
62  }
63  }
64  }
65  }
66  }
67  node {
68  input: "Input0"
69  input: "Input1"
70  output: "Output"
71  name: "addition"
72  op_type: "Add"
73  doc_string: ""
74  domain: ""
75  }
76  output {
77  name: "Output"
78  type {
79  tensor_type {
80  elem_type: 1
81  shape {
82  dim {
83  dim_value: 1
84  }
85  dim {
86  dim_value: 1
87  }
88  dim {
89  dim_value: 2
90  }
91  dim {
92  dim_value: 2
93  }
94  }
95  }
96  }
97  }
98  }
99  opset_import {
100  version: 7
101  })";
102  }
103 };
104 
105 struct AddValidFixture : AddMainFixture
106 {
107  AddValidFixture() : AddMainFixture("1") {
108  Setup();
109  }
110 };
111 
112 struct AddInvalidFixture : AddMainFixture
113 {
114  AddInvalidFixture() : AddMainFixture("6") { }
115 };
116 
117 struct AddValidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
118 {
119  AddValidBroadcastFixture() {
120 
121  m_Prototext = R"(
122  ir_version: 3
123  producer_name: "CNTK"
124  producer_version: "2.5.1"
125  domain: "ai.cntk"
126  model_version: 1
127  graph {
128  name: "CNTKGraph"
129  input {
130  name: "Input0"
131  type {
132  tensor_type {
133  elem_type: 1
134  shape {
135  dim {
136  dim_value: 1
137  }
138  dim {
139  dim_value: 1
140  }
141  dim {
142  dim_value: 1
143  }
144  dim {
145  dim_value: 4
146  }
147  }
148  }
149  }
150  }
151  input {
152  name: "Input1"
153  type {
154  tensor_type {
155  elem_type: 1
156  shape {
157  dim {
158  dim_value: 4
159  }
160  }
161  }
162  }
163  }
164  node {
165  input: "Input0"
166  input: "Input1"
167  output: "Output"
168  name: "addition"
169  op_type: "Add"
170  doc_string: ""
171  domain: ""
172  }
173  output {
174  name: "Output"
175  type {
176  tensor_type {
177  elem_type: 1
178  shape {
179  dim {
180  dim_value: 1
181  }
182  dim {
183  dim_value: 1
184  }
185  dim {
186  dim_value: 1
187  }
188  dim {
189  dim_value: 4
190  }
191  }
192  }
193  }
194  }
195  }
196  opset_import {
197  version: 7
198  })";
199  Setup();
200  }
201 };
202 
203 struct AddInvalidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
204 {
205  AddInvalidBroadcastFixture() {
206 
207  m_Prototext = R"(
208  ir_version: 3
209  producer_name: "CNTK"
210  producer_version: "2.5.1"
211  domain: "ai.cntk"
212  model_version: 1
213  graph {
214  name: "CNTKGraph"
215  input {
216  name: "Input0"
217  type {
218  tensor_type {
219  elem_type: 1
220  shape {
221  dim {
222  dim_value: 1
223  }
224  dim {
225  dim_value: 1
226  }
227  dim {
228  dim_value: 1
229  }
230  dim {
231  dim_value: 3
232  }
233  }
234  }
235  }
236  }
237  input {
238  name: "Input1"
239  type {
240  tensor_type {
241  elem_type: 1
242  shape {
243  dim {
244  dim_value: 4
245  }
246  }
247  }
248  }
249  }
250  node {
251  input: "Input0"
252  input: "Input1"
253  output: "Output"
254  name: "addition"
255  op_type: "Add"
256  doc_string: ""
257  domain: ""
258  }
259  output {
260  name: "Output"
261  type {
262  tensor_type {
263  elem_type: 1
264  shape {
265  dim {
266  dim_value: 1
267  }
268  dim {
269  dim_value: 1
270  }
271  dim {
272  dim_value: 1
273  }
274  dim {
275  dim_value: 4
276  }
277  }
278  }
279  }
280  }
281  }
282  opset_import {
283  version: 7
284  })";
285  }
286 };
287 
288 struct AddScalarFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
289 {
290  AddScalarFixture(const std::string& dataType)
291  {
292  m_Prototext = R"(
293  ir_version: 3
294  producer_name: "CNTK"
295  producer_version: "2.5.1"
296  domain: "ai.cntk"
297  model_version: 1
298  graph {
299  name: "CNTKGraph"
300  input {
301  name: "Input0"
302  type {
303  tensor_type {
304  elem_type: )" + dataType + R"(
305  shape {
306  dim {
307  dim_value: 1
308  }
309  dim {
310  dim_value: 1
311  }
312  dim {
313  dim_value: 2
314  }
315  dim {
316  dim_value: 2
317  }
318  }
319  }
320  }
321  }
322  input {
323  name: "Input1"
324  type {
325  tensor_type {
326  elem_type: )" + dataType + R"(
327  shape {
328  dim {
329  dim_value: 1
330  }
331  }
332  }
333  }
334  }
335  node {
336  input: "Input0"
337  input: "Input1"
338  output: "Output"
339  name: "addition"
340  op_type: "Add"
341  doc_string: ""
342  domain: ""
343  }
344  output {
345  name: "Output"
346  type {
347  tensor_type {
348  elem_type: 1
349  shape {
350  dim {
351  dim_value: 1
352  }
353  dim {
354  dim_value: 1
355  }
356  dim {
357  dim_value: 2
358  }
359  dim {
360  dim_value: 2
361  }
362  }
363  }
364  }
365  }
366  }
367  opset_import {
368  version: 7
369  })";
370  }
371 };
372 
373 struct AddValidScalarFixture : AddScalarFixture
374 {
375  AddValidScalarFixture() : AddScalarFixture("1") {
376  Setup();
377  }
378 };
379 
380 struct AddInvalidScalarFixture : AddScalarFixture
381 {
382  AddInvalidScalarFixture() : AddScalarFixture("6") { }
383 };
384 
385 TEST_CASE_FIXTURE(AddValidFixture, "ValidAddTest")
386 {
387  RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
388  {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
389 }
390 
391 TEST_CASE_FIXTURE(AddInvalidFixture, "IncorrectDataTypeAdd")
392 {
393  CHECK_THROWS_AS(Setup(), armnn::ParseException);
394 }
395 
396 TEST_CASE_FIXTURE(AddInvalidBroadcastFixture, "InvalidBroadcastAdd")
397 {
398  CHECK_THROWS_AS(Setup(), armnn::ParseException);
399 }
400 
401 TEST_CASE_FIXTURE(AddValidBroadcastFixture, "ValidBroadcastAdd")
402 {
403  RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
404  {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
405 }
406 
407 TEST_CASE_FIXTURE(AddValidScalarFixture, "ValidAddScalarTest")
408 {
409  RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
410  {"Input1", {-8.0f}}}, {{"Output", {-7.0, -6.0, -11.0, -12.0}}});
411 }
412 
413 TEST_CASE_FIXTURE(AddInvalidScalarFixture, "IncorrectDataTypeAddScalar")
414 {
415  CHECK_THROWS_AS(Setup(), armnn::ParseException);
416 }
417 
418 }
TEST_SUITE("OnnxParser_Addition")
Definition: Addition.cpp:9
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")