6 #include <boost/test/unit_test.hpp> 18 explicit Convolution2dFixture(
const std::string& dataLayout,
const std::string& paddingType)
19 : Convolution2dFixture(dataLayout, paddingType, 1)
24 explicit Convolution2dFixture(
const std::string& dataLayout,
const std::string& paddingType,
25 int stride,
int dilation = 0)
27 std::string strideString (
" i: 1 \n" 29 if (dataLayout ==
"NHWC")
31 strideString.append(
" i: " + std::to_string(stride) +
" \n" 36 strideString.append(
" i: 1 \n" 37 " i: " + std::to_string(stride) +
" \n");
40 std::string dilationString;
41 if (dataLayout ==
"NHWC")
43 dilationString.append(
" i: 1 \n" 44 " i: " + std::to_string(dilation) +
" \n" 45 " i: " + std::to_string(dilation) +
" \n" 50 dilationString.append(
" i: 1 \n" 52 " i: " + std::to_string(dilation) +
" \n" 53 " i: " + std::to_string(dilation) +
" \n");
57 " name: \"graphInput\" \n" 58 " op: \"Placeholder\" \n" 74 " name: \"Const_1\" \n" 101 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n" 107 " name: \"potato\" \n" 109 " input: \"graphInput\" \n" 110 " input: \"Const_1\" \n" 118 " key: \"data_format\" \n" 126 " key: \"padding\" \n" 134 " key: \"strides\" \n" 146 " key: \"dilations\" \n" 156 " key: \"use_cudnn_on_gpu\" \n" 164 ARMNN_ASSERT_MSG(stride == 1 || stride == 2,
"Add support for strides other than 1 or 2.");
165 std::array<unsigned int, 4> dims;
166 if (dataLayout ==
"NHWC")
168 dims = { 1u, (stride == 2 ? 3u : 2u), 3u, 1u };
172 dims = { 1u, 1u, (stride == 2 ? 3u : 2u), 3u };
179 struct Convolution2dNhwcSameFixture : Convolution2dFixture
181 Convolution2dNhwcSameFixture() : Convolution2dFixture(
"NHWC",
"SAME", 1){}
185 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
188 struct Convolution2dNchwSameFixture : Convolution2dFixture
190 Convolution2dNchwSameFixture() : Convolution2dFixture(
"NCHW",
"SAME", 1){}
194 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
198 struct Convolution2dNhwcValidFixture : Convolution2dFixture
200 Convolution2dNhwcValidFixture() : Convolution2dFixture(
"NHWC",
"VALID", 1){}
204 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
207 struct Convolution2dNchwValidFixture : Convolution2dFixture
209 Convolution2dNchwValidFixture() : Convolution2dFixture(
"NCHW",
"VALID", 1){}
213 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
217 struct Convolution2dStride2NhwcSameFixture : Convolution2dFixture
219 Convolution2dStride2NhwcSameFixture() : Convolution2dFixture(
"NHWC",
"SAME", 2){}
223 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
226 struct Convolution2dStride2NchwSameFixture : Convolution2dFixture
228 Convolution2dStride2NchwSameFixture() : Convolution2dFixture(
"NCHW",
"SAME", 2){}
232 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
236 struct Convolution2dStride2NhwcValidFixture : Convolution2dFixture
238 Convolution2dStride2NhwcValidFixture() : Convolution2dFixture(
"NHWC",
"VALID", 2){}
242 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
245 struct Convolution2dStride2NchwValidFixture : Convolution2dFixture
247 Convolution2dStride2NchwValidFixture() : Convolution2dFixture(
"NCHW",
"VALID", 2){}
251 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
255 struct Convolution2dDilation1NhwcFixture : Convolution2dFixture
257 Convolution2dDilation1NhwcFixture() : Convolution2dFixture(
"NHWC",
"SAME", 1, 1){}
261 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
264 struct Convolution2dDilation1NchwFixture : Convolution2dFixture
266 Convolution2dDilation1NchwFixture() : Convolution2dFixture(
"NCHW",
"SAME", 1, 1){}
270 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
275 explicit Convolution2dDilationFixture(
const std::string& dataLayout,
const std::string& paddingType)
276 : Convolution2dDilationFixture(dataLayout, paddingType, 1)
279 explicit Convolution2dDilationFixture(
const std::string& dataLayout,
const std::string& paddingType,
280 int stride,
int dilation = 0)
282 std::string strideString;
283 if (dataLayout ==
"NHWC")
285 strideString.append(
" i: 1 \n" 286 " i: " + std::to_string(stride) +
" \n" 287 " i: " + std::to_string(stride) +
" \n" 292 strideString.append(
" i: 1 \n" 294 " i: " + std::to_string(stride) +
" \n" 295 " i: " + std::to_string(stride) +
" \n");
298 std::string dilationString;
299 if (dataLayout ==
"NHWC")
301 dilationString.append(
" i: 1 \n" 302 " i: " + std::to_string(dilation) +
" \n" 303 " i: " + std::to_string(dilation) +
" \n" 308 dilationString.append(
" i: 1 \n" 310 " i: " + std::to_string(dilation) +
" \n" 311 " i: " + std::to_string(dilation) +
" \n");
314 m_Prototext =
"node { \n" 315 " name: \"graphInput\" \n" 316 " op: \"Placeholder\" \n" 332 " name: \"Const_1\" \n" 344 " dtype: DT_FLOAT \n" 359 " tensor_content: \"\\001\\000\\000?\\000\\000\\000?\\001\\000\\000?\" \n" 365 " name: \"potato\" \n" 367 " input: \"graphInput\" \n" 368 " input: \"Const_1\" \n" 376 " key: \"data_format\" \n" 379 m_Prototext.append(dataLayout);
380 m_Prototext.append(
"\"\n" 384 " key: \"padding\" \n" 387 m_Prototext.append(paddingType);
388 m_Prototext.append(
"\"\n" 392 " key: \"strides\" \n" 395 m_Prototext.append(strideString);
397 m_Prototext.append(
" } \n" 403 m_Prototext.append(
" attr { \n" 404 " key: \"dilations\" \n" 407 m_Prototext.append(dilationString);
409 m_Prototext.append(
" } \n" 413 m_Prototext.append(
" attr { \n" 414 " key: \"use_cudnn_on_gpu\" \n" 422 std::array<unsigned int, 4> dims = { 1u, 1u, 6u, 6u };;
424 SetupSingleInputSingleOutput(
armnn::TensorShape(4, dims.data()),
"graphInput",
"potato");
428 struct Convolution2dDilation2NchwValidFixture : Convolution2dDilationFixture
430 Convolution2dDilation2NchwValidFixture() : Convolution2dDilationFixture(
"NCHW",
"VALID", 1, 2){}
434 RunTest<4>({1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
435 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
436 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
437 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
438 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
439 7.0, 8.0, 9.0, 10.0, 11.0, 12.0},
440 {1.5f, 3.0f, 4.5f, 6.0f, 7.5f, 9.0f, 10.5f, 12.f, 13.5f, 15.0f, 16.5f, 18.0f});
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcSame, Convolution2dNhwcSameFixture)
#define ARMNN_ASSERT_MSG(COND, MSG)
BOOST_AUTO_TEST_SUITE_END()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.