ArmNN
 20.05
DepthwiseConvolution2d.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
9 
10 #include <armnnUtils/Permute.hpp>
11 
12 #include <boost/test/unit_test.hpp>
13 
14 #include <string>
15 #include <iostream>
16 
17 using namespace armnnUtils;
18 using namespace armnn;
19 
20 BOOST_AUTO_TEST_SUITE(TensorflowParser)
21 
22 struct DepthwiseConvolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
23 {
24  explicit DepthwiseConvolution2dFixture(const std::string& dataLayout, const char* paddingType)
25  {
26  m_Prototext = "node { \n"
27  " name: \"graphInput\" \n"
28  " op: \"Placeholder\" \n"
29  " attr { \n"
30  " key: \"dtype\" \n"
31  " value { \n"
32  " type: DT_FLOAT \n"
33  " } \n"
34  " } \n"
35  " attr { \n"
36  " key: \"shape\" \n"
37  " value { \n"
38  " shape { \n"
39  " } \n"
40  " } \n"
41  " } \n"
42  " } \n"
43  " node { \n"
44  " name: \"Const_1\" \n"
45  " op: \"Const\" \n"
46  " attr { \n"
47  " key: \"dtype\" \n"
48  " value { \n"
49  " type: DT_FLOAT \n"
50  " } \n"
51  " } \n"
52  " attr { \n"
53  " key: \"value\" \n"
54  " value { \n"
55  " tensor { \n"
56  " dtype: DT_FLOAT \n"
57  " tensor_shape { \n"
58  " dim { \n"
59  " size: 1 \n"
60  " } \n"
61  " dim { \n"
62  " size: 3 \n"
63  " } \n"
64  " dim { \n"
65  " size: 3 \n"
66  " } \n"
67  " dim { \n"
68  " size: 3 \n"
69  " } \n"
70  " } \n"
71  " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
72  "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
73  "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
74  "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
75  "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
76  "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
77  "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
78  "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
79  "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
80  " } \n"
81  " } \n"
82  " } \n"
83  "} \n"
84  "node { \n"
85  " name: \"potato\" \n"
86  " op: \"DepthwiseConv2dNative\" \n"
87  " input: \"graphInput\" \n"
88  " input: \"Const_1\" \n"
89  " attr { \n"
90  " key: \"T\" \n"
91  " value { \n"
92  " type: DT_FLOAT \n"
93  " } \n"
94  " } \n"
95  " attr { \n"
96  " key: \"data_format\" \n"
97  " value { \n"
98  " s: \"";
99  m_Prototext.append(dataLayout);
100  m_Prototext.append("\"\n"
101  " } \n"
102  " } \n"
103  " attr { \n"
104  " key: \"padding\" \n"
105  " value { \n"
106  " s: \"");
107  m_Prototext.append(paddingType);
108  m_Prototext.append("\"\n"
109  " } \n"
110  " } \n"
111  " attr { \n"
112  " key: \"strides\" \n"
113  " value { \n"
114  " list { \n"
115  " i: 1 \n"
116  " i: 1 \n"
117  " i: 1 \n"
118  " i: 1 \n"
119  " } \n"
120  " } \n"
121  " } \n"
122  " attr { \n"
123  " key: \"use_cudnn_on_gpu\" \n"
124  " value { \n"
125  " b: false \n"
126  " } \n"
127  " } \n"
128  "} \n");
129 
130  if(dataLayout == "NHWC")
131  {
132  SetupSingleInputSingleOutput({ 1u, 1u, 3u, 3u }, "graphInput", "potato");
133  }
134  else
135  {
136  SetupSingleInputSingleOutput({ 1u, 3u, 1u, 3u }, "graphInput", "potato");
137  }
138  }
139 };
140 
141 struct DepthwiseConvolution2dNhwcSameFixture : DepthwiseConvolution2dFixture
142 {
143  DepthwiseConvolution2dNhwcSameFixture() : DepthwiseConvolution2dFixture("NHWC", "SAME") { }
144 };
145 
146 BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNhwcSame, DepthwiseConvolution2dNhwcSameFixture)
147 {
148  RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
149  { 2.5f, 5.f, 2.5f, 3.5f, 7.f, 3.5f, 4.5f, 9.f, 4.5f,
150  6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f,
151  5.5f, 11.f, 5.5f, 6.5f, 13.f, 6.5f, 7.5f, 15.f, 7.5f });
152 }
153 
154 struct DepthwiseConvolution2dNchwSameFixture : DepthwiseConvolution2dFixture
155 {
156  DepthwiseConvolution2dNchwSameFixture() : DepthwiseConvolution2dFixture("NCHW", "SAME") { }
157 };
158 
159 BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNchwSame, DepthwiseConvolution2dNchwSameFixture)
160 {
161  RunTest<4>({ 1, 4, 7, 2, 5, 8, 3, 6, 9 },
162  { 2.5f, 6.f, 5.5f, 5.f, 12.f, 11.f, 2.5f, 6.f, 5.5f,
163  3.5f, 7.5f, 6.5f, 7.f, 15.f, 13.f, 3.5f, 7.5f, 6.5f,
164  4.5f, 9.f, 7.5f, 9.f, 18.f, 15.f, 4.5f, 9.f, 7.5f });
165 }
166 
167 struct DepthwiseConvolution2dNhwcValidFixture : DepthwiseConvolution2dFixture
168 {
169  DepthwiseConvolution2dNhwcValidFixture() : DepthwiseConvolution2dFixture("NHWC", "VALID") { }
170 };
171 
172 BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNhwcValid, DepthwiseConvolution2dNhwcValidFixture)
173 {
174  RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // input data
175  { 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f }); // output expected data
176 }
177 
178 struct DepthwiseConvolution2dNchwValidFixture : DepthwiseConvolution2dFixture
179 {
180  DepthwiseConvolution2dNchwValidFixture() : DepthwiseConvolution2dFixture("NCHW", "VALID") { }
181 };
182 
183 BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNchwValid, DepthwiseConvolution2dNchwValidFixture)
184 {
185  RunTest<4>({ 1, 4, 7, 2, 5, 8, 3, 6, 9 },
186  { 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f });
187 }
188 
189 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
Copyright (c) 2020 ARM Limited.
BOOST_AUTO_TEST_SUITE_END()
BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DNhwcSame, DepthwiseConvolution2dNhwcSameFixture)