ArmNN
 21.11
Reshape.cpp File Reference

Go to the source code of this file.

Functions

 TEST_SUITE ("OnnxParser_Reshape")
 

Function Documentation

◆ TEST_SUITE()

TEST_SUITE ( "OnnxParser_Reshape"  )

Definition at line 10 of file Reshape.cpp.

References armnnUtils::ConstructTensorShapeString(), and TEST_CASE_FIXTURE().

11 {
12 struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
14  ReshapeMainFixture(const std::string& dataType)
15  {
16  m_Prototext = R"(
17  ir_version: 3
18  producer_name: "CNTK"
19  producer_version: "2.5.1"
20  domain: "ai.cntk"
21  model_version: 1
22  graph {
23  name: "CNTKGraph"
24  input {
25  name: "Input"
26  type {
27  tensor_type {
28  elem_type: )" + dataType + R"(
29  shape {
30  dim {
31  dim_value: 4
32  }
33  }
34  }
35  }
36  }
37  input {
38  name: "Shape"
39  type {
40  tensor_type {
41  elem_type: 7
42  shape {
43  dim {
44  dim_value: 2
45  }
46  }
47  }
48  }
49  }
50  node {
51  input: "Input"
52  input: "Shape"
53  output: "Output"
54  name: "reshape"
55  op_type: "Reshape"
56 
57  }
58  initializer {
59  dims: 2
60  data_type: 7
61  int64_data: 2
62  int64_data: 2
63  name: "Shape"
64  }
65  output {
66  name: "Output"
67  type {
68  tensor_type {
69  elem_type: 1
70  shape {
71  dim {
72  dim_value: 2
73  }
74  dim {
75  dim_value: 2
76  }
77  }
78  }
79  }
80  }
81  }
82  opset_import {
83  version: 7
84  })";
85  }
86 };
87 
88 struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
89 {
90  ReshapeRank4Fixture(const std::string& dataType)
91  {
92  m_Prototext = R"(
93  ir_version: 3
94  producer_name: "CNTK"
95  producer_version: "2.5.1"
96  domain: "ai.cntk"
97  model_version: 1
98  graph {
99  name: "CNTKGraph"
100  input {
101  name: "Input"
102  type {
103  tensor_type {
104  elem_type: )" + dataType + R"(
105  shape {
106  dim {
107  dim_value: 2
108  }
109  dim {
110  dim_value: 2
111  }
112  dim {
113  dim_value: 3
114  }
115  dim {
116  dim_value: 3
117  }
118  }
119  }
120  }
121  }
122  input {
123  name: "Shape"
124  type {
125  tensor_type {
126  elem_type: 7
127  shape {
128  dim {
129  dim_value: 2
130  }
131  }
132  }
133  }
134  }
135  node {
136  input: "Input"
137  input: "Shape"
138  output: "Output"
139  name: "reshape"
140  op_type: "Reshape"
141 
142  }
143  initializer {
144  dims: 2
145  data_type: 7
146  int64_data: 2
147  int64_data: 2
148  name: "Shape"
149  }
150  output {
151  name: "Output"
152  type {
153  tensor_type {
154  elem_type: 1
155  shape {
156  dim {
157  dim_value: 6
158  }
159  dim {
160  dim_value: 6
161  }
162  }
163  }
164  }
165  }
166  }
167  opset_import {
168  version: 7
169  })";
170  }
171 };
172 
173 struct ReshapeValidFixture : ReshapeMainFixture
174 {
175  ReshapeValidFixture() : ReshapeMainFixture("1") {
176  Setup();
177  }
178 };
179 
180 struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
181 {
182  ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
183  Setup();
184  }
185 };
186 
187 struct ReshapeInvalidFixture : ReshapeMainFixture
188 {
189  ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
190 };
191 
192 TEST_CASE_FIXTURE(ReshapeValidFixture, "ValidReshapeTest")
193 {
194  RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
195 }
196 
197 TEST_CASE_FIXTURE(ReshapeValidRank4Fixture, "ValidRank4ReshapeTest")
198 {
199  RunTest<2>(
200  {{"Input",
201  {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
202  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
203  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
204  {{"Output",
205  {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
206  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
207  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
208 }
209 
210 TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape")
211 {
212  CHECK_THROWS_AS(Setup(), armnn::ParseException);
213 }
214 
215 struct ReshapeNegativeReshapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
216 {
217  ReshapeNegativeReshapeFixture(const std::vector<int>& inputShape,
218  const std::vector<int>& shapeInputShape,
219  const std::vector<int>& outputShape,
220  const std::string& shape)
221  {
222  m_Prototext = R"(
223  ir_version: 3
224  producer_name: "onnx-example"
225  graph {
226  name: "ReshapeGrapn"
227  input {
228  name: "Input"
229  type {
230  tensor_type {
231  elem_type: 1
232  shape {
233  )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
234  }
235  }
236  }
237  }
238  input {
239  name: "Shape"
240  type {
241  tensor_type {
242  elem_type: 7
243  shape {
244  )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
245  }
246  }
247  }
248  }
249  node {
250  input: "Input"
251  input: "Shape"
252  output: "Output"
253  name: "reshape"
254  op_type: "Reshape"
255  }
256  initializer {
257  dims: 2
258  data_type: 7
259  )" + shape + R"(
260  name: "Shape"
261  }
262  output {
263  name: "Output"
264  type {
265  tensor_type {
266  elem_type: 1
267  shape {
268  )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
269  }
270  }
271  }
272  }
273  }
274  opset_import {
275  version: 7
276  })";
277  }
278 };
279 
280 struct ReshapeNegativeReshape1DFixture : ReshapeNegativeReshapeFixture
281 {
282  ReshapeNegativeReshape1DFixture() : ReshapeNegativeReshapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }, "int64_data: -1")
283  {
284  Setup();
285  }
286 };
287 
288 struct ReshapeNegativeReshape2DFixture : ReshapeNegativeReshapeFixture
289 {
290  ReshapeNegativeReshape2DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
291  { 2 },
292  { 2, 6 },
293  "int64_data: -1 int64_data: 6")
294  {
295  Setup();
296  }
297 };
298 
299 struct ReshapeNegativeReshape3DFixture : ReshapeNegativeReshapeFixture
300 {
301  ReshapeNegativeReshape3DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
302  { 3 },
303  { 3, 1, 4 },
304  "int64_data: 3 int64_data: -1 int64_data: 4")
305  {
306  Setup();
307  }
308 };
309 
310 struct ReshapeNegativeReshape4DFixture : ReshapeNegativeReshapeFixture
311 {
312  ReshapeNegativeReshape4DFixture() : ReshapeNegativeReshapeFixture(
313  { 2, 3, 1, 2 },
314  { 4 },
315  { 3, 1, 2, 2 },
316  "int64_data: 3 int64_data: 1 int64_data: 2 int64_data: -1")
317  {
318  Setup();
319  }
320 };
321 
322 TEST_CASE_FIXTURE(ReshapeNegativeReshape1DFixture, "ReshapeNegativeReshape1DTest")
323 {
324  RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
325  {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
326 }
327 
328 TEST_CASE_FIXTURE(ReshapeNegativeReshape2DFixture, "ReshapeNegativeReshape2DTest")
329 {
330  RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
331  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
332  {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
333  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
334 }
335 
336 TEST_CASE_FIXTURE(ReshapeNegativeReshape3DFixture, "ReshapeNegativeReshape3DTest")
337 {
338  RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
339  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
340  {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
341  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
342 }
343 
344 TEST_CASE_FIXTURE(ReshapeNegativeReshape4DFixture, "ReshapeNegativeReshape4DTest")
345 {
346  RunTest<4, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
347  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
348  {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
349  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
350 }
351 
352 struct ReshapeNonConstShapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
353 {
354  ReshapeNonConstShapeFixture(const std::vector<int>& inputShape,
355  const std::vector<int>& shapeInputShape,
356  const std::vector<int>& outputShape)
357  {
358  m_Prototext = R"(
359  ir_version: 3
360  producer_name: "onnx-example"
361  graph {
362  name: "ReshapeGrapn"
363  input {
364  name: "Input"
365  type {
366  tensor_type {
367  elem_type: 1
368  shape {
369  )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
370  }
371  }
372  }
373  }
374  input {
375  name: "Shape"
376  type {
377  tensor_type {
378  elem_type: 7
379  shape {
380  )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
381  }
382  }
383  }
384  }
385  node {
386  input: "Input"
387  input: "Shape"
388  output: "Output"
389  name: "reshape"
390  op_type: "Reshape"
391  }
392  output {
393  name: "Output"
394  type {
395  tensor_type {
396  elem_type: 1
397  shape {
398  )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
399  }
400  }
401  }
402  }
403  }
404  opset_import {
405  version: 7
406  })";
407  }
408 };
409 
410 struct ReshapeNonConst1DShapeFixture : ReshapeNonConstShapeFixture
411 {
412  ReshapeNonConst1DShapeFixture() : ReshapeNonConstShapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 })
413  {
414  Setup();
415  }
416 };
417 
418 struct ReshapeNonConst2DShapeFixture : ReshapeNonConstShapeFixture
419 {
420  ReshapeNonConst2DShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 2 }, { 2, 12 })
421  {
422  Setup();
423  }
424 };
425 
426 struct ReshapeInvalidNonConstShapeFixture : ReshapeNonConstShapeFixture
427 {
428  ReshapeInvalidNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 3 }, { 2, 3, 4 })
429  {
430  }
431 };
432 
433 struct ReshapeInvalidDimNonConstShapeFixture : ReshapeNonConstShapeFixture
434 {
435  ReshapeInvalidDimNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 1, 2 }, { 2, 3, 4 })
436  {
437  }
438 };
439 
440 TEST_CASE_FIXTURE(ReshapeNonConst1DShapeFixture, "ReshapeNonConst1DShapeTest")
441 {
442  RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
443  {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
444 }
445 
446 TEST_CASE_FIXTURE(ReshapeNonConst2DShapeFixture, "ReshapeNonConst2DShapeTest")
447 {
448  RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
449  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
450  13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
451  19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}},
452  {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
453  7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
454  13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
455  19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}});
456 }
457 
458 TEST_CASE_FIXTURE(ReshapeInvalidNonConstShapeFixture, "ReshapeInvalidNonConstShapeTest")
459 {
460  CHECK_THROWS_AS(Setup(), armnn::ParseException);
461 }
462 
463 TEST_CASE_FIXTURE(ReshapeInvalidDimNonConstShapeFixture, "ReshapeInvalidDimNonConstShapeTest")
464 {
465  CHECK_THROWS_AS(Setup(), armnn::ParseException);
466 }
467 
468 }
std::string ConstructTensorShapeString(const std::vector< int > &shape)
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")