ArmNN
 22.08
Flatten.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
8 
9 TEST_SUITE("OnnxParser_Flatter")
10 {
11 struct FlattenMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12 {
13  FlattenMainFixture(const std::string& dataType)
14  {
15  m_Prototext = R"(
16  ir_version: 3
17  producer_name: "CNTK"
18  producer_version: "2.5.1"
19  domain: "ai.cntk"
20  model_version: 1
21  graph {
22  name: "CNTKGraph"
23  input {
24  name: "Input"
25  type {
26  tensor_type {
27  elem_type: )" + dataType + R"(
28  shape {
29  dim {
30  dim_value: 2
31  }
32  dim {
33  dim_value: 2
34  }
35  dim {
36  dim_value: 3
37  }
38  dim {
39  dim_value: 3
40  }
41  }
42  }
43  }
44  }
45  node {
46  input: "Input"
47  output: "Output"
48  name: "flatten"
49  op_type: "Flatten"
50  attribute {
51  name: "axis"
52  i: 2
53  type: INT
54  }
55  }
56  output {
57  name: "Output"
58  type {
59  tensor_type {
60  elem_type: 1
61  shape {
62  dim {
63  dim_value: 4
64  }
65  dim {
66  dim_value: 9
67  }
68  }
69  }
70  }
71  }
72  }
73  opset_import {
74  version: 7
75  })";
76  }
77 };
78 
79 struct FlattenDefaultAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
80 {
81  FlattenDefaultAxisFixture(const std::string& dataType)
82  {
83  m_Prototext = R"(
84  ir_version: 3
85  producer_name: "CNTK"
86  producer_version: "2.5.1"
87  domain: "ai.cntk"
88  model_version: 1
89  graph {
90  name: "CNTKGraph"
91  input {
92  name: "Input"
93  type {
94  tensor_type {
95  elem_type: )" + dataType + R"(
96  shape {
97  dim {
98  dim_value: 2
99  }
100  dim {
101  dim_value: 2
102  }
103  dim {
104  dim_value: 3
105  }
106  dim {
107  dim_value: 3
108  }
109  }
110  }
111  }
112  }
113  node {
114  input: "Input"
115  output: "Output"
116  name: "flatten"
117  op_type: "Flatten"
118  }
119  output {
120  name: "Output"
121  type {
122  tensor_type {
123  elem_type: 1
124  shape {
125  dim {
126  dim_value: 2
127  }
128  dim {
129  dim_value: 18
130  }
131  }
132  }
133  }
134  }
135  }
136  opset_import {
137  version: 7
138  })";
139  }
140 };
141 
142 struct FlattenAxisZeroFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
143 {
144  FlattenAxisZeroFixture(const std::string& dataType)
145  {
146  m_Prototext = R"(
147  ir_version: 3
148  producer_name: "CNTK"
149  producer_version: "2.5.1"
150  domain: "ai.cntk"
151  model_version: 1
152  graph {
153  name: "CNTKGraph"
154  input {
155  name: "Input"
156  type {
157  tensor_type {
158  elem_type: )" + dataType + R"(
159  shape {
160  dim {
161  dim_value: 2
162  }
163  dim {
164  dim_value: 2
165  }
166  dim {
167  dim_value: 3
168  }
169  dim {
170  dim_value: 3
171  }
172  }
173  }
174  }
175  }
176  node {
177  input: "Input"
178  output: "Output"
179  name: "flatten"
180  op_type: "Flatten"
181  attribute {
182  name: "axis"
183  i: 0
184  type: INT
185  }
186  }
187  output {
188  name: "Output"
189  type {
190  tensor_type {
191  elem_type: 1
192  shape {
193  dim {
194  dim_value: 1
195  }
196  dim {
197  dim_value: 36
198  }
199  }
200  }
201  }
202  }
203  }
204  opset_import {
205  version: 7
206  })";
207  }
208 };
209 
210 struct FlattenNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
211 {
212  FlattenNegativeAxisFixture(const std::string& dataType)
213  {
214  m_Prototext = R"(
215  ir_version: 3
216  producer_name: "CNTK"
217  producer_version: "2.5.1"
218  domain: "ai.cntk"
219  model_version: 1
220  graph {
221  name: "CNTKGraph"
222  input {
223  name: "Input"
224  type {
225  tensor_type {
226  elem_type: )" + dataType + R"(
227  shape {
228  dim {
229  dim_value: 2
230  }
231  dim {
232  dim_value: 2
233  }
234  dim {
235  dim_value: 3
236  }
237  dim {
238  dim_value: 3
239  }
240  }
241  }
242  }
243  }
244  node {
245  input: "Input"
246  output: "Output"
247  name: "flatten"
248  op_type: "Flatten"
249  attribute {
250  name: "axis"
251  i: -1
252  type: INT
253  }
254  }
255  output {
256  name: "Output"
257  type {
258  tensor_type {
259  elem_type: 1
260  shape {
261  dim {
262  dim_value: 12
263  }
264  dim {
265  dim_value: 3
266  }
267  }
268  }
269  }
270  }
271  }
272  opset_import {
273  version: 7
274  })";
275  }
276 };
277 
278 struct FlattenInvalidNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
279 {
280  FlattenInvalidNegativeAxisFixture(const std::string& dataType)
281  {
282  m_Prototext = R"(
283  ir_version: 3
284  producer_name: "CNTK"
285  producer_version: "2.5.1"
286  domain: "ai.cntk"
287  model_version: 1
288  graph {
289  name: "CNTKGraph"
290  input {
291  name: "Input"
292  type {
293  tensor_type {
294  elem_type: )" + dataType + R"(
295  shape {
296  dim {
297  dim_value: 2
298  }
299  dim {
300  dim_value: 2
301  }
302  dim {
303  dim_value: 3
304  }
305  dim {
306  dim_value: 3
307  }
308  }
309  }
310  }
311  }
312  node {
313  input: "Input"
314  output: "Output"
315  name: "flatten"
316  op_type: "Flatten"
317  attribute {
318  name: "axis"
319  i: -5
320  type: INT
321  }
322  }
323  output {
324  name: "Output"
325  type {
326  tensor_type {
327  elem_type: 1
328  shape {
329  dim {
330  dim_value: 12
331  }
332  dim {
333  dim_value: 3
334  }
335  }
336  }
337  }
338  }
339  }
340  opset_import {
341  version: 7
342  })";
343  }
344 };
345 
346 struct FlattenValidFixture : FlattenMainFixture
347 {
348  FlattenValidFixture() : FlattenMainFixture("1") {
349  Setup();
350  }
351 };
352 
353 struct FlattenDefaultValidFixture : FlattenDefaultAxisFixture
354 {
355  FlattenDefaultValidFixture() : FlattenDefaultAxisFixture("1") {
356  Setup();
357  }
358 };
359 
360 struct FlattenAxisZeroValidFixture : FlattenAxisZeroFixture
361 {
362  FlattenAxisZeroValidFixture() : FlattenAxisZeroFixture("1") {
363  Setup();
364  }
365 };
366 
367 struct FlattenNegativeAxisValidFixture : FlattenNegativeAxisFixture
368 {
369  FlattenNegativeAxisValidFixture() : FlattenNegativeAxisFixture("1") {
370  Setup();
371  }
372 };
373 
374 struct FlattenInvalidFixture : FlattenMainFixture
375 {
376  FlattenInvalidFixture() : FlattenMainFixture("10") { }
377 };
378 
379 struct FlattenInvalidAxisFixture : FlattenInvalidNegativeAxisFixture
380 {
381  FlattenInvalidAxisFixture() : FlattenInvalidNegativeAxisFixture("1") { }
382 };
383 
384 TEST_CASE_FIXTURE(FlattenValidFixture, "ValidFlattenTest")
385 {
386  RunTest<2>({{"Input",
387  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
388  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
389  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
390  {{"Output",
391  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
392  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
393  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
394 }
395 
396 TEST_CASE_FIXTURE(FlattenDefaultValidFixture, "ValidFlattenDefaultTest")
397 {
398  RunTest<2>({{"Input",
399  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
400  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
401  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
402  {{"Output",
403  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
404  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
405  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
406 }
407 
408 TEST_CASE_FIXTURE(FlattenAxisZeroValidFixture, "ValidFlattenAxisZeroTest")
409 {
410  RunTest<2>({{"Input",
411  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
412  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
413  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
414  {{"Output",
415  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
416  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
417  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
418 }
419 
420 TEST_CASE_FIXTURE(FlattenNegativeAxisValidFixture, "ValidFlattenNegativeAxisTest")
421 {
422  RunTest<2>({{"Input",
423  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
424  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
425  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
426  {{"Output",
427  { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
428  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
429  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
430 }
431 
432 TEST_CASE_FIXTURE(FlattenInvalidFixture, "IncorrectDataTypeFlatten")
433 {
434  CHECK_THROWS_AS(Setup(), armnn::ParseException);
435 }
436 
437 TEST_CASE_FIXTURE(FlattenInvalidAxisFixture, "IncorrectAxisFlatten")
438 {
439  CHECK_THROWS_AS(Setup(), armnn::ParseException);
440 }
441 
442 }
TEST_SUITE("OnnxParser_Flatter")
Definition: Flatten.cpp:9
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")