From f005e313e501a28669e214fd05db04c12c8af7fc Mon Sep 17 00:00:00 2001 From: FrancisMurtagh Date: Thu, 6 Dec 2018 15:26:04 +0000 Subject: IVGCVSW-2277 Remove the input swizzling from ParsePooling2d * Remove input swizzling from ParsePooling2D and add parameterized tests for NCHW, NHWC and Padding="SAME". Change-Id: I4786fcc31b6ac46bf19d887f007963eb924f0f9f --- src/armnnTfParser/test/Pooling.cpp | 99 +++++++++++++++++++++++++++++++++----- 1 file changed, 87 insertions(+), 12 deletions(-) (limited to 'src/armnnTfParser/test') diff --git a/src/armnnTfParser/test/Pooling.cpp b/src/armnnTfParser/test/Pooling.cpp index 346599fb43..f6de44c95f 100644 --- a/src/armnnTfParser/test/Pooling.cpp +++ b/src/armnnTfParser/test/Pooling.cpp @@ -11,7 +11,7 @@ BOOST_AUTO_TEST_SUITE(TensorflowParser) struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture { - explicit Pooling2dFixture(const char* poolingtype) + explicit Pooling2dFixture(const char* poolingtype, std::string dataLayout, std::string paddingOption) { m_Prototext = "node {\n" " name: \"Placeholder\"\n" @@ -50,24 +50,40 @@ struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); } +struct MaxPoolFixtureNchwValid : Pooling2dFixture +{ + MaxPoolFixtureNchwValid() : Pooling2dFixture("MaxPool", "NCHW", "VALID") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwValid, MaxPoolFixtureNchwValid) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); +} + +struct MaxPoolFixtureNhwcSame : Pooling2dFixture +{ + MaxPoolFixtureNhwcSame() : Pooling2dFixture("MaxPool", "NHWC", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcSame, MaxPoolFixtureNhwcSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f}); +} + +struct MaxPoolFixtureNchwSame : Pooling2dFixture +{ + MaxPoolFixtureNchwSame() : Pooling2dFixture("MaxPool", "NCHW", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwSame, MaxPoolFixtureNchwSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f}); +} + +struct AvgPoolFixtureNhwcValid : Pooling2dFixture +{ + AvgPoolFixtureNhwcValid() : Pooling2dFixture("AvgPool", "NHWC", "VALID") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcValid, AvgPoolFixtureNhwcValid) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); +} -struct AvgPoolFixture : Pooling2dFixture +struct AvgPoolFixtureNchwValid : Pooling2dFixture { - AvgPoolFixture() : Pooling2dFixture("AvgPool") {} + AvgPoolFixtureNchwValid() : Pooling2dFixture("AvgPool", "NCHW", "VALID") {} }; -BOOST_FIXTURE_TEST_CASE(ParseAvgPool, AvgPoolFixture) +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwValid, AvgPoolFixtureNchwValid) { RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); } +struct AvgPoolFixtureNhwcSame : Pooling2dFixture +{ + AvgPoolFixtureNhwcSame() : Pooling2dFixture("AvgPool", "NHWC", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcSame, AvgPoolFixtureNhwcSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f}); +} + +struct AvgPoolFixtureNchwSame : Pooling2dFixture +{ + AvgPoolFixtureNchwSame() : Pooling2dFixture("AvgPool", "NCHW", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwSame, AvgPoolFixtureNchwSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f}); +} BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1