ArmNN
 21.02
Flatten.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11 
12 struct FlattenMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
14  FlattenMainFixture(const std::string& dataType)
15  {
16  m_Prototext = R"(
17  ir_version: 3
18  producer_name: "CNTK"
19  producer_version: "2.5.1"
20  domain: "ai.cntk"
21  model_version: 1
22  graph {
23  name: "CNTKGraph"
24  input {
25  name: "Input"
26  type {
27  tensor_type {
28  elem_type: )" + dataType + R"(
29  shape {
30  dim {
31  dim_value: 2
32  }
33  dim {
34  dim_value: 2
35  }
36  dim {
37  dim_value: 3
38  }
39  dim {
40  dim_value: 3
41  }
42  }
43  }
44  }
45  }
46  node {
47  input: "Input"
48  output: "Output"
49  name: "flatten"
50  op_type: "Flatten"
51  attribute {
52  name: "axis"
53  i: 2
54  type: INT
55  }
56  }
57  output {
58  name: "Output"
59  type {
60  tensor_type {
61  elem_type: 1
62  shape {
63  dim {
64  dim_value: 4
65  }
66  dim {
67  dim_value: 9
68  }
69  }
70  }
71  }
72  }
73  }
74  opset_import {
75  version: 7
76  })";
77  }
78 };
79 
80 struct FlattenDefaultAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
81 {
82  FlattenDefaultAxisFixture(const std::string& dataType)
83  {
84  m_Prototext = R"(
85  ir_version: 3
86  producer_name: "CNTK"
87  producer_version: "2.5.1"
88  domain: "ai.cntk"
89  model_version: 1
90  graph {
91  name: "CNTKGraph"
92  input {
93  name: "Input"
94  type {
95  tensor_type {
96  elem_type: )" + dataType + R"(
97  shape {
98  dim {
99  dim_value: 2
100  }
101  dim {
102  dim_value: 2
103  }
104  dim {
105  dim_value: 3
106  }
107  dim {
108  dim_value: 3
109  }
110  }
111  }
112  }
113  }
114  node {
115  input: "Input"
116  output: "Output"
117  name: "flatten"
118  op_type: "Flatten"
119  }
120  output {
121  name: "Output"
122  type {
123  tensor_type {
124  elem_type: 1
125  shape {
126  dim {
127  dim_value: 2
128  }
129  dim {
130  dim_value: 18
131  }
132  }
133  }
134  }
135  }
136  }
137  opset_import {
138  version: 7
139  })";
140  }
141 };
142 
143 struct FlattenAxisZeroFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
144 {
145  FlattenAxisZeroFixture(const std::string& dataType)
146  {
147  m_Prototext = R"(
148  ir_version: 3
149  producer_name: "CNTK"
150  producer_version: "2.5.1"
151  domain: "ai.cntk"
152  model_version: 1
153  graph {
154  name: "CNTKGraph"
155  input {
156  name: "Input"
157  type {
158  tensor_type {
159  elem_type: )" + dataType + R"(
160  shape {
161  dim {
162  dim_value: 2
163  }
164  dim {
165  dim_value: 2
166  }
167  dim {
168  dim_value: 3
169  }
170  dim {
171  dim_value: 3
172  }
173  }
174  }
175  }
176  }
177  node {
178  input: "Input"
179  output: "Output"
180  name: "flatten"
181  op_type: "Flatten"
182  attribute {
183  name: "axis"
184  i: 0
185  type: INT
186  }
187  }
188  output {
189  name: "Output"
190  type {
191  tensor_type {
192  elem_type: 1
193  shape {
194  dim {
195  dim_value: 1
196  }
197  dim {
198  dim_value: 36
199  }
200  }
201  }
202  }
203  }
204  }
205  opset_import {
206  version: 7
207  })";
208  }
209 };
210 
211 struct FlattenNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
212 {
213  FlattenNegativeAxisFixture(const std::string& dataType)
214  {
215  m_Prototext = R"(
216  ir_version: 3
217  producer_name: "CNTK"
218  producer_version: "2.5.1"
219  domain: "ai.cntk"
220  model_version: 1
221  graph {
222  name: "CNTKGraph"
223  input {
224  name: "Input"
225  type {
226  tensor_type {
227  elem_type: )" + dataType + R"(
228  shape {
229  dim {
230  dim_value: 2
231  }
232  dim {
233  dim_value: 2
234  }
235  dim {
236  dim_value: 3
237  }
238  dim {
239  dim_value: 3
240  }
241  }
242  }
243  }
244  }
245  node {
246  input: "Input"
247  output: "Output"
248  name: "flatten"
249  op_type: "Flatten"
250  attribute {
251  name: "axis"
252  i: -1
253  type: INT
254  }
255  }
256  output {
257  name: "Output"
258  type {
259  tensor_type {
260  elem_type: 1
261  shape {
262  dim {
263  dim_value: 12
264  }
265  dim {
266  dim_value: 3
267  }
268  }
269  }
270  }
271  }
272  }
273  opset_import {
274  version: 7
275  })";
276  }
277 };
278 
279 struct FlattenInvalidNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
280 {
281  FlattenInvalidNegativeAxisFixture(const std::string& dataType)
282  {
283  m_Prototext = R"(
284  ir_version: 3
285  producer_name: "CNTK"
286  producer_version: "2.5.1"
287  domain: "ai.cntk"
288  model_version: 1
289  graph {
290  name: "CNTKGraph"
291  input {
292  name: "Input"
293  type {
294  tensor_type {
295  elem_type: )" + dataType + R"(
296  shape {
297  dim {
298  dim_value: 2
299  }
300  dim {
301  dim_value: 2
302  }
303  dim {
304  dim_value: 3
305  }
306  dim {
307  dim_value: 3
308  }
309  }
310  }
311  }
312  }
313  node {
314  input: "Input"
315  output: "Output"
316  name: "flatten"
317  op_type: "Flatten"
318  attribute {
319  name: "axis"
320  i: -5
321  type: INT
322  }
323  }
324  output {
325  name: "Output"
326  type {
327  tensor_type {
328  elem_type: 1
329  shape {
330  dim {
331  dim_value: 12
332  }
333  dim {
334  dim_value: 3
335  }
336  }
337  }
338  }
339  }
340  }
341  opset_import {
342  version: 7
343  })";
344  }
345 };
346 
347 struct FlattenValidFixture : FlattenMainFixture
348 {
349  FlattenValidFixture() : FlattenMainFixture("1") {
350  Setup();
351  }
352 };
353 
354 struct FlattenDefaultValidFixture : FlattenDefaultAxisFixture
355 {
356  FlattenDefaultValidFixture() : FlattenDefaultAxisFixture("1") {
357  Setup();
358  }
359 };
360 
361 struct FlattenAxisZeroValidFixture : FlattenAxisZeroFixture
362 {
363  FlattenAxisZeroValidFixture() : FlattenAxisZeroFixture("1") {
364  Setup();
365  }
366 };
367 
368 struct FlattenNegativeAxisValidFixture : FlattenNegativeAxisFixture
369 {
370  FlattenNegativeAxisValidFixture() : FlattenNegativeAxisFixture("1") {
371  Setup();
372  }
373 };
374 
375 struct FlattenInvalidFixture : FlattenMainFixture
376 {
377  FlattenInvalidFixture() : FlattenMainFixture("10") { }
378 };
379 
380 struct FlattenInvalidAxisFixture : FlattenInvalidNegativeAxisFixture
381 {
382  FlattenInvalidAxisFixture() : FlattenInvalidNegativeAxisFixture("1") { }
383 };
384 
385 BOOST_FIXTURE_TEST_CASE(ValidFlattenTest, FlattenValidFixture)
386 {
387  RunTest<2>({{"Input",
388  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
389  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
390  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
391  {{"Output",
392  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
393  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
394  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
395 }
396 
397 BOOST_FIXTURE_TEST_CASE(ValidFlattenDefaultTest, FlattenDefaultValidFixture)
398 {
399  RunTest<2>({{"Input",
400  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
401  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
402  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
403  {{"Output",
404  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
405  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
406  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
407 }
408 
409 BOOST_FIXTURE_TEST_CASE(ValidFlattenAxisZeroTest, FlattenAxisZeroValidFixture)
410 {
411  RunTest<2>({{"Input",
412  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
413  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
414  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
415  {{"Output",
416  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
417  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
418  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
419 }
420 
421 BOOST_FIXTURE_TEST_CASE(ValidFlattenNegativeAxisTest, FlattenNegativeAxisValidFixture)
422 {
423  RunTest<2>({{"Input",
424  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
425  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
426  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
427  {{"Output",
428  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
429  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
430  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
431 }
432 
433 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeFlatten, FlattenInvalidFixture)
434 {
435  BOOST_CHECK_THROW(Setup(), armnn::ParseException);
436 }
437 
438 BOOST_FIXTURE_TEST_CASE(IncorrectAxisFlatten, FlattenInvalidAxisFixture)
439 {
440  BOOST_CHECK_THROW(Setup(), armnn::ParseException);
441 }
442 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_AUTO_TEST_SUITE_END()
BOOST_FIXTURE_TEST_CASE(ValidFlattenTest, FlattenValidFixture)
Definition: Flatten.cpp:385