ArmNN
 21.02
Pooling.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14  explicit Pooling2dFixture(const char* poolingtype, std::string dataLayout, std::string paddingOption)
15  {
16  m_Prototext = "node {\n"
17  " name: \"Placeholder\"\n"
18  " op: \"Placeholder\"\n"
19  " attr {\n"
20  " key: \"dtype\"\n"
21  " value {\n"
22  " type: DT_FLOAT\n"
23  " }\n"
24  " }\n"
25  " attr {\n"
26  " key: \"value\"\n"
27  " value {\n"
28  " tensor {\n"
29  " dtype: DT_FLOAT\n"
30  " tensor_shape {\n"
31  " }\n"
32  " }\n"
33  " }\n"
34  " }\n"
35  " }\n"
36  "node {\n"
37  " name: \"";
38  m_Prototext.append(poolingtype);
39  m_Prototext.append("\"\n"
40  " op: \"");
41  m_Prototext.append(poolingtype);
42  m_Prototext.append("\"\n"
43  " input: \"Placeholder\"\n"
44  " attr {\n"
45  " key: \"T\"\n"
46  " value {\n"
47  " type: DT_FLOAT\n"
48  " }\n"
49  " }\n"
50  " attr {\n"
51  " key: \"data_format\"\n"
52  " value {\n"
53  " s: \"");
54  m_Prototext.append(dataLayout);
55  m_Prototext.append("\"\n"
56  " }\n"
57  " }\n"
58  " attr {\n"
59  " key: \"ksize\"\n"
60  " value {\n"
61  " list {\n"
62 
63  " i: 1\n");
64  if(dataLayout == "NHWC")
65  {
66  m_Prototext.append(" i: 2\n"
67  " i: 2\n"
68  " i: 1\n");
69  }
70  else
71  {
72  m_Prototext.append(" i: 1\n"
73  " i: 2\n"
74  " i: 2\n");
75  }
76  m_Prototext.append(
77  " }\n"
78  " }\n"
79  " }\n"
80  " attr {\n"
81  " key: \"padding\"\n"
82  " value {\n"
83  " s: \"");
84  m_Prototext.append(paddingOption);
85  m_Prototext.append(
86  "\"\n"
87  " }\n"
88  " }\n"
89  " attr {\n"
90  " key: \"strides\"\n"
91  " value {\n"
92  " list {\n"
93  " i: 1\n"
94  " i: 1\n"
95  " i: 1\n"
96  " i: 1\n"
97  " }\n"
98  " }\n"
99  " }\n"
100  "}\n");
101 
102  if(dataLayout == "NHWC")
103  {
104  SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype);
105  }
106  else
107  {
108  SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "Placeholder", poolingtype);
109  }
110  }
111 };
112 
113 
114 struct MaxPoolFixtureNhwcValid : Pooling2dFixture
115 {
116  MaxPoolFixtureNhwcValid() : Pooling2dFixture("MaxPool", "NHWC", "VALID") {}
117 };
118 BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcValid, MaxPoolFixtureNhwcValid)
119 {
120  RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f});
121 }
122 
123 struct MaxPoolFixtureNchwValid : Pooling2dFixture
124 {
125  MaxPoolFixtureNchwValid() : Pooling2dFixture("MaxPool", "NCHW", "VALID") {}
126 };
127 BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwValid, MaxPoolFixtureNchwValid)
128 {
129  RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f});
130 }
131 
132 struct MaxPoolFixtureNhwcSame : Pooling2dFixture
133 {
134  MaxPoolFixtureNhwcSame() : Pooling2dFixture("MaxPool", "NHWC", "SAME") {}
135 };
136 BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcSame, MaxPoolFixtureNhwcSame)
137 {
138  RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f});
139 }
140 
141 struct MaxPoolFixtureNchwSame : Pooling2dFixture
142 {
143  MaxPoolFixtureNchwSame() : Pooling2dFixture("MaxPool", "NCHW", "SAME") {}
144 };
145 BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwSame, MaxPoolFixtureNchwSame)
146 {
147  RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f});
148 }
149 
150 struct AvgPoolFixtureNhwcValid : Pooling2dFixture
151 {
152  AvgPoolFixtureNhwcValid() : Pooling2dFixture("AvgPool", "NHWC", "VALID") {}
153 };
154 BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcValid, AvgPoolFixtureNhwcValid)
155 {
156  RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f});
157 }
158 
159 struct AvgPoolFixtureNchwValid : Pooling2dFixture
160 {
161  AvgPoolFixtureNchwValid() : Pooling2dFixture("AvgPool", "NCHW", "VALID") {}
162 };
163 BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwValid, AvgPoolFixtureNchwValid)
164 {
165  RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f});
166 }
167 
168 struct AvgPoolFixtureNhwcSame : Pooling2dFixture
169 {
170  AvgPoolFixtureNhwcSame() : Pooling2dFixture("AvgPool", "NHWC", "SAME") {}
171 };
172 BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcSame, AvgPoolFixtureNhwcSame)
173 {
174  RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f});
175 }
176 
177 struct AvgPoolFixtureNchwSame : Pooling2dFixture
178 {
179  AvgPoolFixtureNchwSame() : Pooling2dFixture("AvgPool", "NCHW", "SAME") {}
180 };
181 BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwSame, AvgPoolFixtureNchwSame)
182 {
183  RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f});
184 }
185 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ValidMaxPoolTest, MaxPoolValidFixture)
Definition: Pooling.cpp:113
BOOST_AUTO_TEST_SUITE_END()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.