ArmNN
 21.11
Pooling.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
8 
9 TEST_SUITE("OnnxParser_Pooling")
10 {
11 struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12 {
13  PoolingMainFixture(const std::string& dataType, const std::string& op)
14  {
15  m_Prototext = R"(
16  ir_version: 3
17  producer_name: "CNTK"
18  producer_version: "2.5.1"
19  domain: "ai.cntk"
20  model_version: 1
21  graph {
22  name: "CNTKGraph"
23  input {
24  name: "Input"
25  type {
26  tensor_type {
27  elem_type: )" + dataType + R"(
28  shape {
29  dim {
30  dim_value: 1
31  }
32  dim {
33  dim_value: 1
34  }
35  dim {
36  dim_value: 2
37  }
38  dim {
39  dim_value: 2
40  }
41  }
42  }
43  }
44  }
45  node {
46  input: "Input"
47  output: "Output"
48  name: "Pooling"
49  op_type: )" + op + R"(
50  attribute {
51  name: "kernel_shape"
52  ints: 2
53  ints: 2
54  type: INTS
55  }
56  attribute {
57  name: "strides"
58  ints: 1
59  ints: 1
60  type: INTS
61  }
62  attribute {
63  name: "pads"
64  ints: 0
65  ints: 0
66  ints: 0
67  ints: 0
68  type: INTS
69  }
70  }
71  output {
72  name: "Output"
73  type {
74  tensor_type {
75  elem_type: 1
76  shape {
77  dim {
78  dim_value: 1
79  }
80  dim {
81  dim_value: 1
82  }
83  dim {
84  dim_value: 1
85  }
86  dim {
87  dim_value: 1
88  }
89  }
90  }
91  }
92  }
93  }
94  opset_import {
95  version: 7
96  })";
97  }
98 };
99 
100 struct MaxPoolValidFixture : PoolingMainFixture
101 {
102  MaxPoolValidFixture() : PoolingMainFixture("1", "\"MaxPool\"") {
103  Setup();
104  }
105 };
106 
107 struct MaxPoolInvalidFixture : PoolingMainFixture
108 {
109  MaxPoolInvalidFixture() : PoolingMainFixture("10", "\"MaxPool\"") { }
110 };
111 
112 TEST_CASE_FIXTURE(MaxPoolValidFixture, "ValidMaxPoolTest")
113 {
114  RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}});
115 }
116 
117 struct AvgPoolValidFixture : PoolingMainFixture
118 {
119  AvgPoolValidFixture() : PoolingMainFixture("1", "\"AveragePool\"") {
120  Setup();
121  }
122 };
123 
124 struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
125 {
126  PoolingWithPadFixture()
127  {
128  m_Prototext = R"(
129  ir_version: 3
130  producer_name: "CNTK"
131  producer_version: "2.5.1"
132  domain: "ai.cntk"
133  model_version: 1
134  graph {
135  name: "CNTKGraph"
136  input {
137  name: "Input"
138  type {
139  tensor_type {
140  elem_type: 1
141  shape {
142  dim {
143  dim_value: 1
144  }
145  dim {
146  dim_value: 1
147  }
148  dim {
149  dim_value: 2
150  }
151  dim {
152  dim_value: 2
153  }
154  }
155  }
156  }
157  }
158  node {
159  input: "Input"
160  output: "Output"
161  name: "Pooling"
162  op_type: "AveragePool"
163  attribute {
164  name: "kernel_shape"
165  ints: 4
166  ints: 4
167  type: INTS
168  }
169  attribute {
170  name: "strides"
171  ints: 1
172  ints: 1
173  type: INTS
174  }
175  attribute {
176  name: "pads"
177  ints: 1
178  ints: 1
179  ints: 1
180  ints: 1
181  type: INTS
182  }
183  attribute {
184  name: "count_include_pad"
185  i: 1
186  type: INT
187  }
188  }
189  output {
190  name: "Output"
191  type {
192  tensor_type {
193  elem_type: 1
194  shape {
195  dim {
196  dim_value: 1
197  }
198  dim {
199  dim_value: 1
200  }
201  dim {
202  dim_value: 1
203  }
204  dim {
205  dim_value: 1
206  }
207  }
208  }
209  }
210  }
211  }
212  opset_import {
213  version: 7
214  })";
215  Setup();
216  }
217 };
218 
219 TEST_CASE_FIXTURE(AvgPoolValidFixture, "AveragePoolValid")
220 {
221  RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}});
222 }
223 
224 TEST_CASE_FIXTURE(PoolingWithPadFixture, "ValidAvgWithPadTest")
225 {
226  RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}});
227 }
228 
229 struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
230 {
231  GlobalAvgFixture()
232  {
233  m_Prototext = R"(
234  ir_version: 3
235  producer_name: "CNTK"
236  producer_version: "2.5.1"
237  domain: "ai.cntk"
238  model_version: 1
239  graph {
240  name: "CNTKGraph"
241  input {
242  name: "Input"
243  type {
244  tensor_type {
245  elem_type: 1
246  shape {
247  dim {
248  dim_value: 1
249  }
250  dim {
251  dim_value: 2
252  }
253  dim {
254  dim_value: 2
255  }
256  dim {
257  dim_value: 2
258  }
259  }
260  }
261  }
262  }
263  node {
264  input: "Input"
265  output: "Output"
266  name: "Pooling"
267  op_type: "GlobalAveragePool"
268  }
269  output {
270  name: "Output"
271  type {
272  tensor_type {
273  elem_type: 1
274  shape {
275  dim {
276  dim_value: 1
277  }
278  dim {
279  dim_value: 2
280  }
281  dim {
282  dim_value: 1
283  }
284  dim {
285  dim_value: 1
286  }
287  }
288  }
289  }
290  }
291  }
292  opset_import {
293  version: 7
294  })";
295  Setup();
296  }
297 };
298 
299 TEST_CASE_FIXTURE(GlobalAvgFixture, "GlobalAvgTest")
300 {
301  RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}});
302 }
303 
304 TEST_CASE_FIXTURE(MaxPoolInvalidFixture, "IncorrectDataTypeMaxPool")
305 {
306  CHECK_THROWS_AS(Setup(), armnn::ParseException);
307 }
308 
309 }
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")
TEST_SUITE("OnnxParser_Pooling")
Definition: Pooling.cpp:9