diff options
author | FrancisMurtagh <francis.murtagh@arm.com> | 2018-12-06 15:26:04 +0000 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2018-12-06 15:48:16 +0000 |
commit | f005e313e501a28669e214fd05db04c12c8af7fc (patch) | |
tree | 6acc10eff98ce306c6b54d6b53df7602a1a8c31f /src | |
parent | ba8815f4c38966cc15bb5bcd0960fdd23d89e365 (diff) | |
download | armnn-f005e313e501a28669e214fd05db04c12c8af7fc.tar.gz |
IVGCVSW-2277 Remove the input swizzling from ParsePooling2d
* Remove input swizzling from ParsePooling2D and add parameterized
tests for NCHW, NHWC and Padding="SAME".
Change-Id: I4786fcc31b6ac46bf19d887f007963eb924f0f9f
Diffstat (limited to 'src')
-rw-r--r-- | src/armnnTfParser/TfParser.cpp | 94 | ||||
-rw-r--r-- | src/armnnTfParser/test/Pooling.cpp | 99 |
2 files changed, 133 insertions, 60 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index c00722c4ad..d5372a598b 100644 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -2304,68 +2304,73 @@ ParsedTfOperationPtr TfParser::ParsePooling2d(const tensorflow::NodeDef& nodeDef std::vector<uint32_t> ksize = ReadMandatoryNodeUint32ListAttribute(nodeDef, "ksize"); // size of pool windows Pooling2dDescriptor pooling2dDescriptor; - pooling2dDescriptor.m_PoolType = pooltype; - pooling2dDescriptor.m_PaddingMethod = PaddingMethod::Exclude; + pooling2dDescriptor.m_PoolType = pooltype; + pooling2dDescriptor.m_PaddingMethod = PaddingMethod::Exclude; pooling2dDescriptor.m_OutputShapeRounding = OutputShapeRounding::Floor; CHECK_DATA_FORMAT(nodeDef, dataFormat, "Pooling2D"); + DataLayout dataLayout = dataFormat == "NHWC" ? DataLayout::NHWC : DataLayout::NCHW; + pooling2dDescriptor.m_DataLayout = dataLayout; + DataLayoutIndexed dataLayoutIndexed(dataLayout); - if (dataFormat == "NHWC") - { - pooling2dDescriptor.m_StrideX = strides[2]; - pooling2dDescriptor.m_StrideY = strides[1]; - pooling2dDescriptor.m_PoolWidth = ksize[2]; - pooling2dDescriptor.m_PoolHeight = ksize[1]; - // Swizzles input to supported memory layout. - inputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN); - } - else if (dataFormat == "NCHW") - { - pooling2dDescriptor.m_StrideX = strides[3]; - pooling2dDescriptor.m_StrideY = strides[2]; - pooling2dDescriptor.m_PoolWidth = ksize[3]; - pooling2dDescriptor.m_PoolHeight = ksize[2]; - } + pooling2dDescriptor.m_StrideX = strides[dataLayoutIndexed.GetWidthIndex()]; + pooling2dDescriptor.m_StrideY = strides[dataLayoutIndexed.GetHeightIndex()]; + pooling2dDescriptor.m_PoolWidth = ksize[dataLayoutIndexed.GetWidthIndex()]; + pooling2dDescriptor.m_PoolHeight = ksize[dataLayoutIndexed.GetHeightIndex()]; - uint32_t inputHeight = inputTensorInfo.GetShape()[2]; - uint32_t inputWidth = inputTensorInfo.GetShape()[3]; + uint32_t inputHeight = inputTensorInfo.GetShape()[dataLayoutIndexed.GetHeightIndex()]; + uint32_t inputWidth = inputTensorInfo.GetShape()[dataLayoutIndexed.GetWidthIndex()]; bool padding = false; TensorInfo outputInfo; + unsigned int outputHeight = 0; + unsigned int outputWidth = 0; CHECK_PADDING_TYPE(nodeDef, paddingString); if (paddingString == "SAME") { padding = true; - outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], - inputTensorInfo.GetShape()[1], - static_cast<uint32_t>(ceil( - static_cast<float>(inputHeight) / - static_cast<float>(pooling2dDescriptor.m_StrideY))), - static_cast<uint32_t>(ceil( - static_cast<float>(inputWidth) / - static_cast<float>(pooling2dDescriptor.m_StrideX))) - }, DataType::Float32); + + outputHeight = static_cast<uint32_t>(ceil(static_cast<float>(inputHeight) / + static_cast<float>(pooling2dDescriptor.m_StrideY))); + outputWidth = static_cast<uint32_t>(ceil(static_cast<float>(inputWidth) / + static_cast<float>(pooling2dDescriptor.m_StrideX))); } else if (paddingString == "VALID") { padding = false; - outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], - inputTensorInfo.GetShape()[1], - static_cast<uint32_t>(ceil( - static_cast<float>(inputHeight - pooling2dDescriptor.m_PoolHeight + 1) / - static_cast<float>(pooling2dDescriptor.m_StrideY))), - static_cast<uint32_t>(ceil( - static_cast<float>(inputWidth - pooling2dDescriptor.m_PoolWidth + 1) / - static_cast<float>(pooling2dDescriptor.m_StrideX))) - }, DataType::Float32); + + outputHeight = static_cast<uint32_t>(ceil( + static_cast<float>(inputHeight - pooling2dDescriptor.m_PoolHeight + 1) / + static_cast<float>(pooling2dDescriptor.m_StrideY))); + outputWidth = static_cast<uint32_t>(ceil( + static_cast<float>(inputWidth - pooling2dDescriptor.m_PoolWidth + 1) / + static_cast<float>(pooling2dDescriptor.m_StrideX))); + } + + switch (dataLayout) + { + case DataLayout::NHWC: + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + outputHeight, + outputWidth, + inputTensorInfo.GetShape()[3] }, + DataType::Float32); + break; + case DataLayout::NCHW: + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + inputTensorInfo.GetShape()[1], + outputHeight, + outputWidth }, + DataType::Float32); + break; } CalcPadding(inputWidth, pooling2dDescriptor.m_PoolWidth, pooling2dDescriptor.m_StrideX, - pooling2dDescriptor.m_PadLeft, pooling2dDescriptor.m_PadRight, padding); + pooling2dDescriptor.m_PadLeft, pooling2dDescriptor.m_PadRight, padding); CalcPadding(inputHeight, pooling2dDescriptor.m_PoolHeight, pooling2dDescriptor.m_StrideY, - pooling2dDescriptor.m_PadTop, pooling2dDescriptor.m_PadBottom, padding); + pooling2dDescriptor.m_PadTop, pooling2dDescriptor.m_PadBottom, padding); IConnectableLayer* layer = m_Network->AddPooling2dLayer(pooling2dDescriptor, nodeDef.name().c_str()); @@ -2381,14 +2386,7 @@ ParsedTfOperationPtr TfParser::ParsePooling2d(const tensorflow::NodeDef& nodeDef layer->GetOutputSlot(0).SetTensorInfo(outputInfo); - if (dataFormat == "NHWC") - { - layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); - } - else - { - inputSlot.Connect(layer->GetInputSlot(0)); - } + inputSlot.Connect(layer->GetInputSlot(0)); return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); } diff --git a/src/armnnTfParser/test/Pooling.cpp b/src/armnnTfParser/test/Pooling.cpp index 346599fb43..f6de44c95f 100644 --- a/src/armnnTfParser/test/Pooling.cpp +++ b/src/armnnTfParser/test/Pooling.cpp @@ -11,7 +11,7 @@ BOOST_AUTO_TEST_SUITE(TensorflowParser) struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser> { - explicit Pooling2dFixture(const char* poolingtype) + explicit Pooling2dFixture(const char* poolingtype, std::string dataLayout, std::string paddingOption) { m_Prototext = "node {\n" " name: \"Placeholder\"\n" @@ -50,24 +50,40 @@ struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser " attr {\n" " key: \"data_format\"\n" " value {\n" - " s: \"NHWC\"\n" + " s: \""); + m_Prototext.append(dataLayout); + m_Prototext.append("\"\n" " }\n" " }\n" " attr {\n" " key: \"ksize\"\n" " value {\n" " list {\n" - " i: 1\n" + + " i: 1\n"); + if(dataLayout == "NHWC") + { + m_Prototext.append(" i: 2\n" " i: 2\n" + " i: 1\n"); + } + else + { + m_Prototext.append(" i: 1\n" " i: 2\n" - " i: 1\n" + " i: 2\n"); + } + m_Prototext.append( " }\n" " }\n" " }\n" " attr {\n" " key: \"padding\"\n" " value {\n" - " s: \"VALID\"\n" + " s: \""); + m_Prototext.append(paddingOption); + m_Prototext.append( + "\"\n" " }\n" " }\n" " attr {\n" @@ -83,29 +99,88 @@ struct Pooling2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser " }\n" "}\n"); - SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype); + if(dataLayout == "NHWC") + { + SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype); + } + else + { + SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "Placeholder", poolingtype); + } } }; -struct MaxPoolFixture : Pooling2dFixture +struct MaxPoolFixtureNhwcValid : Pooling2dFixture { - MaxPoolFixture() : Pooling2dFixture("MaxPool") {} + MaxPoolFixtureNhwcValid() : Pooling2dFixture("MaxPool", "NHWC", "VALID") {} }; -BOOST_FIXTURE_TEST_CASE(ParseMaxPool, MaxPoolFixture) +BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcValid, MaxPoolFixtureNhwcValid) { RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); } +struct MaxPoolFixtureNchwValid : Pooling2dFixture +{ + MaxPoolFixtureNchwValid() : Pooling2dFixture("MaxPool", "NCHW", "VALID") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwValid, MaxPoolFixtureNchwValid) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); +} + +struct MaxPoolFixtureNhwcSame : Pooling2dFixture +{ + MaxPoolFixtureNhwcSame() : Pooling2dFixture("MaxPool", "NHWC", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNhwcSame, MaxPoolFixtureNhwcSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f}); +} + +struct MaxPoolFixtureNchwSame : Pooling2dFixture +{ + MaxPoolFixtureNchwSame() : Pooling2dFixture("MaxPool", "NCHW", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseMaxPoolNchwSame, MaxPoolFixtureNchwSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f, 2.0f, 3.0f, -4.0f}); +} + +struct AvgPoolFixtureNhwcValid : Pooling2dFixture +{ + AvgPoolFixtureNhwcValid() : Pooling2dFixture("AvgPool", "NHWC", "VALID") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcValid, AvgPoolFixtureNhwcValid) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); +} -struct AvgPoolFixture : Pooling2dFixture +struct AvgPoolFixtureNchwValid : Pooling2dFixture { - AvgPoolFixture() : Pooling2dFixture("AvgPool") {} + AvgPoolFixtureNchwValid() : Pooling2dFixture("AvgPool", "NCHW", "VALID") {} }; -BOOST_FIXTURE_TEST_CASE(ParseAvgPool, AvgPoolFixture) +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwValid, AvgPoolFixtureNchwValid) { RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); } +struct AvgPoolFixtureNhwcSame : Pooling2dFixture +{ + AvgPoolFixtureNhwcSame() : Pooling2dFixture("AvgPool", "NHWC", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNhwcSame, AvgPoolFixtureNhwcSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f}); +} + +struct AvgPoolFixtureNchwSame : Pooling2dFixture +{ + AvgPoolFixtureNchwSame() : Pooling2dFixture("AvgPool", "NCHW", "SAME") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseAvgPoolNchwSame, AvgPoolFixtureNchwSame) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f, 3.0f, 3.5f, 4.0f}); +} BOOST_AUTO_TEST_SUITE_END() |