ArmNN
 21.02
Reshape.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11 
12 struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
14  ReshapeMainFixture(const std::string& dataType)
15  {
16  m_Prototext = R"(
17  ir_version: 3
18  producer_name: "CNTK"
19  producer_version: "2.5.1"
20  domain: "ai.cntk"
21  model_version: 1
22  graph {
23  name: "CNTKGraph"
24  input {
25  name: "Input"
26  type {
27  tensor_type {
28  elem_type: )" + dataType + R"(
29  shape {
30  dim {
31  dim_value: 4
32  }
33  }
34  }
35  }
36  }
37  input {
38  name: "Shape"
39  type {
40  tensor_type {
41  elem_type: 7
42  shape {
43  dim {
44  dim_value: 2
45  }
46  }
47  }
48  }
49  }
50  node {
51  input: "Input"
52  input: "Shape"
53  output: "Output"
54  name: "reshape"
55  op_type: "Reshape"
56 
57  }
58  initializer {
59  dims: 2
60  data_type: 7
61  int64_data: 2
62  int64_data: 2
63  name: "Shape"
64  }
65  output {
66  name: "Output"
67  type {
68  tensor_type {
69  elem_type: 1
70  shape {
71  dim {
72  dim_value: 2
73  }
74  dim {
75  dim_value: 2
76  }
77  }
78  }
79  }
80  }
81  }
82  opset_import {
83  version: 7
84  })";
85  }
86 };
87 
88 struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
89 {
90  ReshapeRank4Fixture(const std::string& dataType)
91  {
92  m_Prototext = R"(
93  ir_version: 3
94  producer_name: "CNTK"
95  producer_version: "2.5.1"
96  domain: "ai.cntk"
97  model_version: 1
98  graph {
99  name: "CNTKGraph"
100  input {
101  name: "Input"
102  type {
103  tensor_type {
104  elem_type: )" + dataType + R"(
105  shape {
106  dim {
107  dim_value: 2
108  }
109  dim {
110  dim_value: 2
111  }
112  dim {
113  dim_value: 3
114  }
115  dim {
116  dim_value: 3
117  }
118  }
119  }
120  }
121  }
122  input {
123  name: "Shape"
124  type {
125  tensor_type {
126  elem_type: 7
127  shape {
128  dim {
129  dim_value: 2
130  }
131  }
132  }
133  }
134  }
135  node {
136  input: "Input"
137  input: "Shape"
138  output: "Output"
139  name: "reshape"
140  op_type: "Reshape"
141 
142  }
143  initializer {
144  dims: 2
145  data_type: 7
146  int64_data: 2
147  int64_data: 2
148  name: "Shape"
149  }
150  output {
151  name: "Output"
152  type {
153  tensor_type {
154  elem_type: 1
155  shape {
156  dim {
157  dim_value: 6
158  }
159  dim {
160  dim_value: 6
161  }
162  }
163  }
164  }
165  }
166  }
167  opset_import {
168  version: 7
169  })";
170  }
171 };
172 
173 struct ReshapeValidFixture : ReshapeMainFixture
174 {
175  ReshapeValidFixture() : ReshapeMainFixture("1") {
176  Setup();
177  }
178 };
179 
180 struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
181 {
182  ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
183  Setup();
184  }
185 };
186 
187 struct ReshapeInvalidFixture : ReshapeMainFixture
188 {
189  ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
190 };
191 
192 BOOST_FIXTURE_TEST_CASE(ValidReshapeTest, ReshapeValidFixture)
193 {
194  RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
195 }
196 
197 BOOST_FIXTURE_TEST_CASE(ValidRank4ReshapeTest, ReshapeValidRank4Fixture)
198 {
199  RunTest<2>(
200  {{"Input",
201  {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
202  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
203  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
204  {{"Output",
205  {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
206  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
207  1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
208 }
209 
210 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeReshape, ReshapeInvalidFixture)
211 {
212  BOOST_CHECK_THROW(Setup(), armnn::ParseException);
213 }
214 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ValidReshapeTest, ReshapeValidFixture)
Definition: Reshape.cpp:192
BOOST_AUTO_TEST_SUITE_END()