ArmNN
 21.02
Convolution2d.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 #include <array>
11 #include <string>
12 #include <iostream>
13 
14 BOOST_AUTO_TEST_SUITE(TensorflowParser)
15 
16 struct Convolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
17 {
18  explicit Convolution2dFixture(const std::string& dataLayout, const std::string& paddingType)
19  : Convolution2dFixture(dataLayout, paddingType, 1)
20  {}
21 
22  // Dilation: 0 - dilations attribute is not included;
23  // Dilation: >0 - dilations attribute set to [1,v,v,1], where v is the value of the dilation arg
24  explicit Convolution2dFixture(const std::string& dataLayout, const std::string& paddingType,
25  int stride, int dilation = 0)
26  {
27  std::string strideString (" i: 1 \n"
28  " i: 1 \n");
29  if (dataLayout == "NHWC")
30  {
31  strideString.append(" i: " + std::to_string(stride) + " \n"
32  " i: 1 \n");
33  }
34  else // dataLayout == "NCHW"
35  {
36  strideString.append(" i: 1 \n"
37  " i: " + std::to_string(stride) + " \n");
38  }
39 
40  std::string dilationString;
41  if (dataLayout == "NHWC")
42  {
43  dilationString.append(" i: 1 \n"
44  " i: " + std::to_string(dilation) + " \n"
45  " i: " + std::to_string(dilation) + " \n"
46  " i: 1 \n");
47  }
48  else // dataLayout == "NCHW"
49  {
50  dilationString.append(" i: 1 \n"
51  " i: 1 \n"
52  " i: " + std::to_string(dilation) + " \n"
53  " i: " + std::to_string(dilation) + " \n");
54  }
55 
56  m_Prototext = "node { \n"
57  " name: \"graphInput\" \n"
58  " op: \"Placeholder\" \n"
59  " attr { \n"
60  " key: \"dtype\" \n"
61  " value { \n"
62  " type: DT_FLOAT \n"
63  " } \n"
64  " } \n"
65  " attr { \n"
66  " key: \"shape\" \n"
67  " value { \n"
68  " shape { \n"
69  " } \n"
70  " } \n"
71  " } \n"
72  " } \n"
73  " node { \n"
74  " name: \"Const_1\" \n"
75  " op: \"Const\" \n"
76  " attr { \n"
77  " key: \"dtype\" \n"
78  " value { \n"
79  " type: DT_FLOAT \n"
80  " } \n"
81  " } \n"
82  " attr { \n"
83  " key: \"value\" \n"
84  " value { \n"
85  " tensor { \n"
86  " dtype: DT_FLOAT \n"
87  " tensor_shape { \n"
88  " dim { \n"
89  " size: 1 \n"
90  " } \n"
91  " dim { \n"
92  " size: 3 \n"
93  " } \n"
94  " dim { \n"
95  " size: 1 \n"
96  " } \n"
97  " dim { \n"
98  " size: 1 \n"
99  " } \n"
100  " } \n"
101  " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
102  " } \n"
103  " } \n"
104  " } \n"
105  "} \n"
106  "node { \n"
107  " name: \"potato\" \n"
108  " op: \"Conv2D\" \n"
109  " input: \"graphInput\" \n"
110  " input: \"Const_1\" \n"
111  " attr { \n"
112  " key: \"T\" \n"
113  " value { \n"
114  " type: DT_FLOAT \n"
115  " } \n"
116  " } \n"
117  " attr { \n"
118  " key: \"data_format\" \n"
119  " value { \n"
120  " s: \"";
121  m_Prototext.append(dataLayout);
122  m_Prototext.append("\"\n"
123  " } \n"
124  " } \n"
125  " attr { \n"
126  " key: \"padding\" \n"
127  " value { \n"
128  " s: \"");
129  m_Prototext.append(paddingType);
130  m_Prototext.append("\"\n"
131  " } \n"
132  " } \n"
133  " attr { \n"
134  " key: \"strides\" \n"
135  " value { \n"
136  " list { \n");
137  m_Prototext.append(strideString);
138 
139  m_Prototext.append(" } \n"
140  " } \n"
141  " } \n");
142 
143  if (dilation > 0)
144  {
145  m_Prototext.append(" attr { \n"
146  " key: \"dilations\" \n"
147  " value { \n"
148  " list { \n");
149  m_Prototext.append(dilationString);
150 
151  m_Prototext.append(" } \n"
152  " } \n"
153  " } \n");
154  }
155  m_Prototext.append(" attr { \n"
156  " key: \"use_cudnn_on_gpu\" \n"
157  " value { \n"
158  " b: false \n"
159  " } \n"
160  " } \n"
161  "} \n");
162 
163  // Manual height computation based on stride parameter.
164  ARMNN_ASSERT_MSG(stride == 1 || stride == 2, "Add support for strides other than 1 or 2.");
165  std::array<unsigned int, 4> dims;
166  if (dataLayout == "NHWC")
167  {
168  dims = { 1u, (stride == 2 ? 3u : 2u), 3u, 1u };
169  }
170  else // dataLayout == "NCHW"
171  {
172  dims = { 1u, 1u, (stride == 2 ? 3u : 2u), 3u };
173  }
174 
175  SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "potato");
176  }
177 };
178 
179 struct Convolution2dNhwcSameFixture : Convolution2dFixture
180 {
181  Convolution2dNhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 1){}
182 };
183 BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcSame, Convolution2dNhwcSameFixture)
184 {
185  RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
186 }
187 
188 struct Convolution2dNchwSameFixture : Convolution2dFixture
189 {
190  Convolution2dNchwSameFixture() : Convolution2dFixture("NCHW", "SAME", 1){}
191 };
192 BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwSame, Convolution2dNchwSameFixture)
193 {
194  RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
195 }
196 
197 
198 struct Convolution2dNhwcValidFixture : Convolution2dFixture
199 {
200  Convolution2dNhwcValidFixture() : Convolution2dFixture("NHWC", "VALID", 1){}
201 };
202 BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcValid, Convolution2dNhwcValidFixture)
203 {
204  RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
205 }
206 
207 struct Convolution2dNchwValidFixture : Convolution2dFixture
208 {
209  Convolution2dNchwValidFixture() : Convolution2dFixture("NCHW", "VALID", 1){}
210 };
211 BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwValid, Convolution2dNchwValidFixture)
212 {
213  RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
214 }
215 
216 
217 struct Convolution2dStride2NhwcSameFixture : Convolution2dFixture
218 {
219  Convolution2dStride2NhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 2){}
220 };
221 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcSame, Convolution2dStride2NhwcSameFixture)
222 {
223  RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
224 }
225 
226 struct Convolution2dStride2NchwSameFixture : Convolution2dFixture
227 {
228  Convolution2dStride2NchwSameFixture() : Convolution2dFixture("NCHW", "SAME", 2){}
229 };
230 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwSame, Convolution2dStride2NchwSameFixture)
231 {
232  RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
233 }
234 
235 
236 struct Convolution2dStride2NhwcValidFixture : Convolution2dFixture
237 {
238  Convolution2dStride2NhwcValidFixture() : Convolution2dFixture("NHWC", "VALID", 2){}
239 };
240 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcValid, Convolution2dStride2NhwcValidFixture)
241 {
242  RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
243 }
244 
245 struct Convolution2dStride2NchwValidFixture : Convolution2dFixture
246 {
247  Convolution2dStride2NchwValidFixture() : Convolution2dFixture("NCHW", "VALID", 2){}
248 };
249 BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwValid, Convolution2dStride2NchwValidFixture)
250 {
251  RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
252 }
253 
254 
255 struct Convolution2dDilation1NhwcFixture : Convolution2dFixture
256 {
257  Convolution2dDilation1NhwcFixture() : Convolution2dFixture("NHWC", "SAME", 1, 1){}
258 };
259 BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nhwc, Convolution2dDilation1NhwcFixture)
260 {
261  RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
262 }
263 
264 struct Convolution2dDilation1NchwFixture : Convolution2dFixture
265 {
266  Convolution2dDilation1NchwFixture() : Convolution2dFixture("NCHW", "SAME", 1, 1){}
267 };
268 BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nchw, Convolution2dDilation1NchwFixture)
269 {
270  RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
271 }
272 
273 struct Convolution2dDilationFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
274 {
275  explicit Convolution2dDilationFixture(const std::string& dataLayout, const std::string& paddingType)
276  : Convolution2dDilationFixture(dataLayout, paddingType, 1)
277  {}
278 
279  explicit Convolution2dDilationFixture(const std::string& dataLayout, const std::string& paddingType,
280  int stride, int dilation = 0)
281  {
282  std::string strideString;
283  if (dataLayout == "NHWC")
284  {
285  strideString.append(" i: 1 \n"
286  " i: " + std::to_string(stride) + " \n"
287  " i: " + std::to_string(stride) + " \n"
288  " i: 1 \n");
289  }
290  else // dataLayout == "NCHW"
291  {
292  strideString.append(" i: 1 \n"
293  " i: 1 \n"
294  " i: " + std::to_string(stride) + " \n"
295  " i: " + std::to_string(stride) + " \n");
296  }
297 
298  std::string dilationString;
299  if (dataLayout == "NHWC")
300  {
301  dilationString.append(" i: 1 \n"
302  " i: " + std::to_string(dilation) + " \n"
303  " i: " + std::to_string(dilation) + " \n"
304  " i: 1 \n");
305  }
306  else // dataLayout == "NCHW"
307  {
308  dilationString.append(" i: 1 \n"
309  " i: 1 \n"
310  " i: " + std::to_string(dilation) + " \n"
311  " i: " + std::to_string(dilation) + " \n");
312  }
313 
314  m_Prototext = "node { \n"
315  " name: \"graphInput\" \n"
316  " op: \"Placeholder\" \n"
317  " attr { \n"
318  " key: \"dtype\" \n"
319  " value { \n"
320  " type: DT_FLOAT \n"
321  " } \n"
322  " } \n"
323  " attr { \n"
324  " key: \"shape\" \n"
325  " value { \n"
326  " shape { \n"
327  " } \n"
328  " } \n"
329  " } \n"
330  " } \n"
331  " node { \n"
332  " name: \"Const_1\" \n"
333  " op: \"Const\" \n"
334  " attr { \n"
335  " key: \"dtype\" \n"
336  " value { \n"
337  " type: DT_FLOAT \n"
338  " } \n"
339  " } \n"
340  " attr { \n"
341  " key: \"value\" \n"
342  " value { \n"
343  " tensor { \n"
344  " dtype: DT_FLOAT \n"
345  " tensor_shape { \n"
346  " dim { \n"
347  " size: 3 \n"
348  " } \n"
349  " dim { \n"
350  " size: 1 \n"
351  " } \n"
352  " dim { \n"
353  " size: 1 \n"
354  " } \n"
355  " dim { \n"
356  " size: 1 \n"
357  " } \n"
358  " } \n"
359  " tensor_content: \"\\001\\000\\000?\\000\\000\\000?\\001\\000\\000?\" \n"
360  " } \n"
361  " } \n"
362  " } \n"
363  "} \n"
364  "node { \n"
365  " name: \"potato\" \n"
366  " op: \"Conv2D\" \n"
367  " input: \"graphInput\" \n"
368  " input: \"Const_1\" \n"
369  " attr { \n"
370  " key: \"T\" \n"
371  " value { \n"
372  " type: DT_FLOAT \n"
373  " } \n"
374  " } \n"
375  " attr { \n"
376  " key: \"data_format\" \n"
377  " value { \n"
378  " s: \"";
379  m_Prototext.append(dataLayout);
380  m_Prototext.append("\"\n"
381  " } \n"
382  " } \n"
383  " attr { \n"
384  " key: \"padding\" \n"
385  " value { \n"
386  " s: \"");
387  m_Prototext.append(paddingType);
388  m_Prototext.append("\"\n"
389  " } \n"
390  " } \n"
391  " attr { \n"
392  " key: \"strides\" \n"
393  " value { \n"
394  " list { \n");
395  m_Prototext.append(strideString);
396 
397  m_Prototext.append(" } \n"
398  " } \n"
399  " } \n");
400 
401  if (dilation > 0)
402  {
403  m_Prototext.append(" attr { \n"
404  " key: \"dilations\" \n"
405  " value { \n"
406  " list { \n");
407  m_Prototext.append(dilationString);
408 
409  m_Prototext.append(" } \n"
410  " } \n"
411  " } \n");
412  }
413  m_Prototext.append(" attr { \n"
414  " key: \"use_cudnn_on_gpu\" \n"
415  " value { \n"
416  " b: false \n"
417  " } \n"
418  " } \n"
419  "} \n");
420 
421  // Manual height computation based on stride parameter.
422  std::array<unsigned int, 4> dims = { 1u, 1u, 6u, 6u };;
423 
424  SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "potato");
425  }
426 };
427 
428 struct Convolution2dDilation2NchwValidFixture : Convolution2dDilationFixture
429 {
430  Convolution2dDilation2NchwValidFixture() : Convolution2dDilationFixture("NCHW", "VALID", 1, 2){}
431 };
432 BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation2NchwValid, Convolution2dDilation2NchwValidFixture)
433 {
434  RunTest<4>({1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
435  7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
436  1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
437  7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
438  1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
439  7.0, 8.0, 9.0, 10.0, 11.0, 12.0},
440  {1.5f, 3.0f, 4.5f, 6.0f, 7.5f, 9.0f, 10.5f, 12.f, 13.5f, 15.0f, 16.5f, 18.0f});
441 }
442 
443 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcSame, Convolution2dNhwcSameFixture)
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
BOOST_AUTO_TEST_SUITE_END()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.