// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include #include "armnnTfParser/ITfParser.hpp" #include "ParserPrototxtFixture.hpp" BOOST_AUTO_TEST_SUITE(TensorflowParser) struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture { explicit Pooling2dFixture(const char* poolingtype, std::string dataLayout, std::string paddingOption) { m_Prototext = "node {\n" " name: \"Placeholder\"\n" " op: \"Placeholder\"\n" " attr {\n" " key: \"dtype\"\n" " value {\n" " type: DT_FLOAT\n" " }\n" " }\n" " attr {\n" " key: \"value\"\n" " value {\n" " tensor {\n" " dtype: DT_FLOAT\n" " tensor_shape {\n" " }\n" " }\n" " }\n" " }\n" " }\n" "node {\n" " name: \""; m_Prototext.append(poolingtype); m_Prototext.append("\"\n" " op: \""); m_Prototext.append(poolingtype); m_Prototext.append("\"\n" " input: \"Placeholder\"\n" " attr {\n" " key: \"T\"\n" " value {\n" " type: DT_FLOAT\n" " }\n" " }\n" " attr {\n" " key: \"data_format\"\n" " value {\n" " s: \""); m_Prototext.append(dataLayout); m_Prototext.append("\"\n" " }\n" " }\n" " attr {\n" " key: \"ksize\"\n" " value {\n" " list {\n" " i: 1\n"); if(dataLayout == "NHWC") { m_Prototext.append(" i: 2\n" " i: 2\n" " i: 1\n"); } else { m_Prototext.append(" i: 1\n" " i: 2\n" " i: 2\n"); } m_Prototext.append( " }\n" " }\n" " }\n" " attr {\n" " key: \"padding\"\n" " value {\n" " s: \""); m_Prototext.append(paddingOption); m_Prototext.append( "\"\n" " }\n" " }\n" " attr {\n" " key: \"strides\"\n" " value {\n" " list {\n" " i: 1\n" " i: 1\n" " i: 1\n" " i: 1\n" " }\n" " }\n" " }\n" "}\n"); if(dataLayout == "NHWC") { SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype); } else { SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "Placeholder", poolingtype); } } }; struct MaxPoolFixtureNhwcValid : Pooling2dFixture { MaxPoolFixtureNhwcValid() : Pooling2dFixture("MaxPool", "NHWC", "VALID") {} }; BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcValid, MaxPoolFixtureNhwcValid) { RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); } struct MaxPoolFixtureNchwValid : Pooling2dFixture { MaxPoolFixtureNchwValid() : Pooling2dFixture("MaxPool", "NCHW", "VALID") {} }; BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwValid, MaxPoolFixtureNchwValid) { RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); } struct MaxPoolFixtureNhwcSame : Pooling2dFixture { MaxPoolFixtureNhwcSame() : Pooling2dFixture("MaxPool", "NHWC", "SAME") {} }; BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcSame, MaxPoolFixtureNhwcSame) { RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f}); } struct MaxPoolFixtureNchwSame : Pooling2dFixture { MaxPoolFixtureNchwSame() : Pooling2dFixture("MaxPool", "NCHW", "SAME") {} }; BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwSame, MaxPoolFixtureNchwSame) { RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f}); } struct AvgPoolFixtureNhwcValid : Pooling2dFixture { AvgPoolFixtureNhwcValid() : Pooling2dFixture("AvgPool", "NHWC", "VALID") {} }; BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcValid, AvgPoolFixtureNhwcValid) { RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); } struct AvgPoolFixtureNchwValid : Pooling2dFixture { AvgPoolFixtureNchwValid() : Pooling2dFixture("AvgPool", "NCHW", "VALID") {} }; BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwValid, AvgPoolFixtureNchwValid) { RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); } struct AvgPoolFixtureNhwcSame : Pooling2dFixture { AvgPoolFixtureNhwcSame() : Pooling2dFixture("AvgPool", "NHWC", "SAME") {} }; BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcSame, AvgPoolFixtureNhwcSame) { RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f}); } struct AvgPoolFixtureNchwSame : Pooling2dFixture { AvgPoolFixtureNchwSame() : Pooling2dFixture("AvgPool", "NCHW", "SAME") {} }; BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwSame, AvgPoolFixtureNchwSame) { RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f}); } BOOST_AUTO_TEST_SUITE_END()