From f005e313e501a28669e214fd05db04c12c8af7fc Mon Sep 17 00:00:00 2001 From: FrancisMurtagh Date: Thu, 6 Dec 2018 15:26:04 +0000 Subject: IVGCVSW-2277 Remove the input swizzling from ParsePooling2d * Remove input swizzling from ParsePooling2D and add parameterized tests for NCHW, NHWC and Padding="SAME". Change-Id: I4786fcc31b6ac46bf19d887f007963eb924f0f9f --- src/armnnTfParser/TfParser.cpp | 94 ++++++++++++++++++------------------ src/armnnTfParser/test/Pooling.cpp | 99 +++++++++++++++++++++++++++++++++----- 2 files changed, 133 insertions(+), 60 deletions(-) diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index c00722c4ad..d5372a598b 100644 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -2304,68 +2304,73 @@ ParsedTfOperationPtr TfParser::ParsePooling2d(const tensorflow::NodeDef& nodeDef std::vector ksize = ReadMandatoryNodeUint32ListAttribute(nodeDef, "ksize"); // size of pool windows Pooling2dDescriptor pooling2dDescriptor; - pooling2dDescriptor.m_PoolType = pooltype; - pooling2dDescriptor.m_PaddingMethod = PaddingMethod::Exclude; + pooling2dDescriptor.m_PoolType = pooltype; + pooling2dDescriptor.m_PaddingMethod = PaddingMethod::Exclude; pooling2dDescriptor.m_OutputShapeRounding = OutputShapeRounding::Floor; CHECK_DATA_FORMAT(nodeDef, dataFormat, "Pooling2D"); + DataLayout dataLayout = dataFormat == "NHWC" ? DataLayout::NHWC : DataLayout::NCHW; + pooling2dDescriptor.m_DataLayout = dataLayout; + DataLayoutIndexed dataLayoutIndexed(dataLayout); - if (dataFormat == "NHWC") - { - pooling2dDescriptor.m_StrideX = strides[2]; - pooling2dDescriptor.m_StrideY = strides[1]; - pooling2dDescriptor.m_PoolWidth = ksize[2]; - pooling2dDescriptor.m_PoolHeight = ksize[1]; - // Swizzles input to supported memory layout. - inputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN); - } - else if (dataFormat == "NCHW") - { - pooling2dDescriptor.m_StrideX = strides[3]; - pooling2dDescriptor.m_StrideY = strides[2]; - pooling2dDescriptor.m_PoolWidth = ksize[3]; - pooling2dDescriptor.m_PoolHeight = ksize[2]; - } + pooling2dDescriptor.m_StrideX = strides[dataLayoutIndexed.GetWidthIndex()]; + pooling2dDescriptor.m_StrideY = strides[dataLayoutIndexed.GetHeightIndex()]; + pooling2dDescriptor.m_PoolWidth = ksize[dataLayoutIndexed.GetWidthIndex()]; + pooling2dDescriptor.m_PoolHeight = ksize[dataLayoutIndexed.GetHeightIndex()]; - uint32_t inputHeight = inputTensorInfo.GetShape()[2]; - uint32_t inputWidth = inputTensorInfo.GetShape()[3]; + uint32_t inputHeight = inputTensorInfo.GetShape()[dataLayoutIndexed.GetHeightIndex()]; + uint32_t inputWidth = inputTensorInfo.GetShape()[dataLayoutIndexed.GetWidthIndex()]; bool padding = false; TensorInfo outputInfo; + unsigned int outputHeight = 0; + unsigned int outputWidth = 0; CHECK_PADDING_TYPE(nodeDef, paddingString); if (paddingString == "SAME") { padding = true; - outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], - inputTensorInfo.GetShape()[1], - static_cast(ceil( - static_cast(inputHeight) / - static_cast(pooling2dDescriptor.m_StrideY))), - static_cast(ceil( - static_cast(inputWidth) / - static_cast(pooling2dDescriptor.m_StrideX))) - }, DataType::Float32); + + outputHeight = static_cast(ceil(static_cast(inputHeight) / + static_cast(pooling2dDescriptor.m_StrideY))); + outputWidth = static_cast(ceil(static_cast(inputWidth) / + static_cast(pooling2dDescriptor.m_StrideX))); } else if (paddingString == "VALID") { padding = false; - outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], - inputTensorInfo.GetShape()[1], - static_cast(ceil( - static_cast(inputHeight - pooling2dDescriptor.m_PoolHeight + 1) / - static_cast(pooling2dDescriptor.m_StrideY))), - static_cast(ceil( - static_cast(inputWidth - pooling2dDescriptor.m_PoolWidth + 1) / - static_cast(pooling2dDescriptor.m_StrideX))) - }, DataType::Float32); + + outputHeight = static_cast(ceil( + static_cast(inputHeight - pooling2dDescriptor.m_PoolHeight + 1) / + static_cast(pooling2dDescriptor.m_StrideY))); + outputWidth = static_cast(ceil( + static_cast(inputWidth - pooling2dDescriptor.m_PoolWidth + 1) / + static_cast(pooling2dDescriptor.m_StrideX))); + } + + switch (dataLayout) + { + case DataLayout::NHWC: + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + outputHeight, + outputWidth, + inputTensorInfo.GetShape()[3] }, + DataType::Float32); + break; + case DataLayout::NCHW: + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + inputTensorInfo.GetShape()[1], + outputHeight, + outputWidth }, + DataType::Float32); + break; } CalcPadding(inputWidth, pooling2dDescriptor.m_PoolWidth, pooling2dDescriptor.m_StrideX, - pooling2dDescriptor.m_PadLeft, pooling2dDescriptor.m_PadRight, padding); + pooling2dDescriptor.m_PadLeft, pooling2dDescriptor.m_PadRight, padding); CalcPadding(inputHeight, pooling2dDescriptor.m_PoolHeight, pooling2dDescriptor.m_StrideY, - pooling2dDescriptor.m_PadTop, pooling2dDescriptor.m_PadBottom, padding); + pooling2dDescriptor.m_PadTop, pooling2dDescriptor.m_PadBottom, padding); IConnectableLayer* layer = m_Network->AddPooling2dLayer(pooling2dDescriptor, nodeDef.name().c_str()); @@ -2381,14 +2386,7 @@ ParsedTfOperationPtr TfParser::ParsePooling2d(const tensorflow::NodeDef& nodeDef layer->GetOutputSlot(0).SetTensorInfo(outputInfo); - if (dataFormat == "NHWC") - { - layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); - } - else - { - inputSlot.Connect(layer->GetInputSlot(0)); - } + inputSlot.Connect(layer->GetInputSlot(0)); return std::make_unique(this, nodeDef, layer); } diff --git a/src/armnnTfParser/test/Pooling.cpp b/src/armnnTfParser/test/Pooling.cpp index 346599fb43..f6de44c95f 100644 --- a/src/armnnTfParser/test/Pooling.cpp +++ b/src/armnnTfParser/test/Pooling.cpp @@ -11,7 +11,7 @@ BOOST_AUTO_TEST_SUITE(TensorflowParser) struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture { - explicit Pooling2dFixture(const char* poolingtype) + explicit Pooling2dFixture(const char* poolingtype, std::string dataLayout, std::string paddingOption) { m_Prototext = "node {\n" " name: \"Placeholder\"\n" @@ -50,24 +50,40 @@ struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); } +struct MaxPoolFixtureNchwValid : Pooling2dFixture +{ + MaxPoolFixtureNchwValid() : Pooling2dFixture("MaxPool", "NCHW", "VALID") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwValid, MaxPoolFixtureNchwValid) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); +} + +struct MaxPoolFixtureNhwcSame : Pooling2dFixture +{ + MaxPoolFixtureNhwcSame() : Pooling2dFixture("MaxPool", "NHWC", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcSame, MaxPoolFixtureNhwcSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f}); +} + +struct MaxPoolFixtureNchwSame : Pooling2dFixture +{ + MaxPoolFixtureNchwSame() : Pooling2dFixture("MaxPool", "NCHW", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwSame, MaxPoolFixtureNchwSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f}); +} + +struct AvgPoolFixtureNhwcValid : Pooling2dFixture +{ + AvgPoolFixtureNhwcValid() : Pooling2dFixture("AvgPool", "NHWC", "VALID") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcValid, AvgPoolFixtureNhwcValid) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); +} -struct AvgPoolFixture : Pooling2dFixture +struct AvgPoolFixtureNchwValid : Pooling2dFixture { - AvgPoolFixture() : Pooling2dFixture("AvgPool") {} + AvgPoolFixtureNchwValid() : Pooling2dFixture("AvgPool", "NCHW", "VALID") {} }; -BOOST_FIXTURE_TEST_CASE(ParseAvgPool, AvgPoolFixture) +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwValid, AvgPoolFixtureNchwValid) { RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); } +struct AvgPoolFixtureNhwcSame : Pooling2dFixture +{ + AvgPoolFixtureNhwcSame() : Pooling2dFixture("AvgPool", "NHWC", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcSame, AvgPoolFixtureNhwcSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f}); +} + +struct AvgPoolFixtureNchwSame : Pooling2dFixture +{ + AvgPoolFixtureNchwSame() : Pooling2dFixture("AvgPool", "NCHW", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwSame, AvgPoolFixtureNchwSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f}); +} BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1