ArmNN
 20.02
Pooling.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11 
12 struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
14  PoolingMainFixture(const std::string& dataType, const std::string& op)
15  {
16  m_Prototext = R"(
17  ir_version: 3
18  producer_name: "CNTK"
19  producer_version: "2.5.1"
20  domain: "ai.cntk"
21  model_version: 1
22  graph {
23  name: "CNTKGraph"
24  input {
25  name: "Input"
26  type {
27  tensor_type {
28  elem_type: )" + dataType + R"(
29  shape {
30  dim {
31  dim_value: 1
32  }
33  dim {
34  dim_value: 1
35  }
36  dim {
37  dim_value: 2
38  }
39  dim {
40  dim_value: 2
41  }
42  }
43  }
44  }
45  }
46  node {
47  input: "Input"
48  output: "Output"
49  name: "Pooling"
50  op_type: )" + op + R"(
51  attribute {
52  name: "kernel_shape"
53  ints: 2
54  ints: 2
55  type: INTS
56  }
57  attribute {
58  name: "strides"
59  ints: 1
60  ints: 1
61  type: INTS
62  }
63  attribute {
64  name: "pads"
65  ints: 0
66  ints: 0
67  ints: 0
68  ints: 0
69  type: INTS
70  }
71  }
72  output {
73  name: "Output"
74  type {
75  tensor_type {
76  elem_type: 1
77  shape {
78  dim {
79  dim_value: 1
80  }
81  dim {
82  dim_value: 1
83  }
84  dim {
85  dim_value: 1
86  }
87  dim {
88  dim_value: 1
89  }
90  }
91  }
92  }
93  }
94  }
95  opset_import {
96  version: 7
97  })";
98  }
99 };
100 
101 struct MaxPoolValidFixture : PoolingMainFixture
102 {
103  MaxPoolValidFixture() : PoolingMainFixture("1", "\"MaxPool\"") {
104  Setup();
105  }
106 };
107 
108 struct MaxPoolInvalidFixture : PoolingMainFixture
109 {
110  MaxPoolInvalidFixture() : PoolingMainFixture("10", "\"MaxPool\"") { }
111 };
112 
113 BOOST_FIXTURE_TEST_CASE(ValidMaxPoolTest, MaxPoolValidFixture)
114 {
115  RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}});
116 }
117 
118 struct AvgPoolValidFixture : PoolingMainFixture
119 {
120  AvgPoolValidFixture() : PoolingMainFixture("1", "\"AveragePool\"") {
121  Setup();
122  }
123 };
124 
125 struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
126 {
127  PoolingWithPadFixture()
128  {
129  m_Prototext = R"(
130  ir_version: 3
131  producer_name: "CNTK"
132  producer_version: "2.5.1"
133  domain: "ai.cntk"
134  model_version: 1
135  graph {
136  name: "CNTKGraph"
137  input {
138  name: "Input"
139  type {
140  tensor_type {
141  elem_type: 1
142  shape {
143  dim {
144  dim_value: 1
145  }
146  dim {
147  dim_value: 1
148  }
149  dim {
150  dim_value: 2
151  }
152  dim {
153  dim_value: 2
154  }
155  }
156  }
157  }
158  }
159  node {
160  input: "Input"
161  output: "Output"
162  name: "Pooling"
163  op_type: "AveragePool"
164  attribute {
165  name: "kernel_shape"
166  ints: 4
167  ints: 4
168  type: INTS
169  }
170  attribute {
171  name: "strides"
172  ints: 1
173  ints: 1
174  type: INTS
175  }
176  attribute {
177  name: "pads"
178  ints: 1
179  ints: 1
180  ints: 1
181  ints: 1
182  type: INTS
183  }
184  attribute {
185  name: "count_include_pad"
186  i: 1
187  type: INT
188  }
189  }
190  output {
191  name: "Output"
192  type {
193  tensor_type {
194  elem_type: 1
195  shape {
196  dim {
197  dim_value: 1
198  }
199  dim {
200  dim_value: 1
201  }
202  dim {
203  dim_value: 1
204  }
205  dim {
206  dim_value: 1
207  }
208  }
209  }
210  }
211  }
212  }
213  opset_import {
214  version: 7
215  })";
216  Setup();
217  }
218 };
219 
220 BOOST_FIXTURE_TEST_CASE(AveragePoolValid, AvgPoolValidFixture)
221 {
222  RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}});
223 }
224 
225 BOOST_FIXTURE_TEST_CASE(ValidAvgWithPadTest, PoolingWithPadFixture)
226 {
227  RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}});
228 }
229 
230 struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
231 {
232  GlobalAvgFixture()
233  {
234  m_Prototext = R"(
235  ir_version: 3
236  producer_name: "CNTK"
237  producer_version: "2.5.1"
238  domain: "ai.cntk"
239  model_version: 1
240  graph {
241  name: "CNTKGraph"
242  input {
243  name: "Input"
244  type {
245  tensor_type {
246  elem_type: 1
247  shape {
248  dim {
249  dim_value: 1
250  }
251  dim {
252  dim_value: 2
253  }
254  dim {
255  dim_value: 2
256  }
257  dim {
258  dim_value: 2
259  }
260  }
261  }
262  }
263  }
264  node {
265  input: "Input"
266  output: "Output"
267  name: "Pooling"
268  op_type: "GlobalAveragePool"
269  }
270  output {
271  name: "Output"
272  type {
273  tensor_type {
274  elem_type: 1
275  shape {
276  dim {
277  dim_value: 1
278  }
279  dim {
280  dim_value: 2
281  }
282  dim {
283  dim_value: 1
284  }
285  dim {
286  dim_value: 1
287  }
288  }
289  }
290  }
291  }
292  }
293  opset_import {
294  version: 7
295  })";
296  Setup();
297  }
298 };
299 
300 BOOST_FIXTURE_TEST_CASE(GlobalAvgTest, GlobalAvgFixture)
301 {
302  RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}});
303 }
304 
305 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeMaxPool, MaxPoolInvalidFixture)
306 {
307  BOOST_CHECK_THROW(Setup(), armnn::ParseException);
308 }
309 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ValidMaxPoolTest, MaxPoolValidFixture)
Definition: Pooling.cpp:113
BOOST_AUTO_TEST_SUITE_END()