ArmNN
 21.08
Reshape.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
8 
9 TEST_SUITE("OnnxParser_Reshape")
10 {
11 struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12 {
13  ReshapeMainFixture(const std::string& dataType)
14  {
15  m_Prototext = R"(
16  ir_version: 3
17  producer_name: "CNTK"
18  producer_version: "2.5.1"
19  domain: "ai.cntk"
20  model_version: 1
21  graph {
22  name: "CNTKGraph"
23  input {
24  name: "Input"
25  type {
26  tensor_type {
27  elem_type: )" + dataType + R"(
28  shape {
29  dim {
30  dim_value: 4
31  }
32  }
33  }
34  }
35  }
36  input {
37  name: "Shape"
38  type {
39  tensor_type {
40  elem_type: 7
41  shape {
42  dim {
43  dim_value: 2
44  }
45  }
46  }
47  }
48  }
49  node {
50  input: "Input"
51  input: "Shape"
52  output: "Output"
53  name: "reshape"
54  op_type: "Reshape"
55 
56  }
57  initializer {
58  dims: 2
59  data_type: 7
60  int64_data: 2
61  int64_data: 2
62  name: "Shape"
63  }
64  output {
65  name: "Output"
66  type {
67  tensor_type {
68  elem_type: 1
69  shape {
70  dim {
71  dim_value: 2
72  }
73  dim {
74  dim_value: 2
75  }
76  }
77  }
78  }
79  }
80  }
81  opset_import {
82  version: 7
83  })";
84  }
85 };
86 
87 struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
88 {
89  ReshapeRank4Fixture(const std::string& dataType)
90  {
91  m_Prototext = R"(
92  ir_version: 3
93  producer_name: "CNTK"
94  producer_version: "2.5.1"
95  domain: "ai.cntk"
96  model_version: 1
97  graph {
98  name: "CNTKGraph"
99  input {
100  name: "Input"
101  type {
102  tensor_type {
103  elem_type: )" + dataType + R"(
104  shape {
105  dim {
106  dim_value: 2
107  }
108  dim {
109  dim_value: 2
110  }
111  dim {
112  dim_value: 3
113  }
114  dim {
115  dim_value: 3
116  }
117  }
118  }
119  }
120  }
121  input {
122  name: "Shape"
123  type {
124  tensor_type {
125  elem_type: 7
126  shape {
127  dim {
128  dim_value: 2
129  }
130  }
131  }
132  }
133  }
134  node {
135  input: "Input"
136  input: "Shape"
137  output: "Output"
138  name: "reshape"
139  op_type: "Reshape"
140 
141  }
142  initializer {
143  dims: 2
144  data_type: 7
145  int64_data: 2
146  int64_data: 2
147  name: "Shape"
148  }
149  output {
150  name: "Output"
151  type {
152  tensor_type {
153  elem_type: 1
154  shape {
155  dim {
156  dim_value: 6
157  }
158  dim {
159  dim_value: 6
160  }
161  }
162  }
163  }
164  }
165  }
166  opset_import {
167  version: 7
168  })";
169  }
170 };
171 
172 struct ReshapeValidFixture : ReshapeMainFixture
173 {
174  ReshapeValidFixture() : ReshapeMainFixture("1") {
175  Setup();
176  }
177 };
178 
179 struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
180 {
181  ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
182  Setup();
183  }
184 };
185 
186 struct ReshapeInvalidFixture : ReshapeMainFixture
187 {
188  ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
189 };
190 
191 TEST_CASE_FIXTURE(ReshapeValidFixture, "ValidReshapeTest")
192 {
193  RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
194 }
195 
196 TEST_CASE_FIXTURE(ReshapeValidRank4Fixture, "ValidRank4ReshapeTest")
197 {
198  RunTest<2>(
199  {{"Input",
200  {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
201  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
202  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
203  {{"Output",
204  {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
205  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
206  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
207 }
208 
209 TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape")
210 {
211  CHECK_THROWS_AS(Setup(), armnn::ParseException);
212 }
213 
214 }
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")
TEST_SUITE("OnnxParser_Reshape")
Definition: Reshape.cpp:9