diff options
-rw-r--r-- | src/armnnTfParser/TfParser.cpp | 20 | ||||
-rw-r--r-- | src/armnnTfParser/test/FusedBatchNorm.cpp | 77 |
2 files changed, 61 insertions, 36 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index b40b05409a..73bdb656b5 100644 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -1421,9 +1421,14 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no ParsedConstTfOperation<float>* varianceNode = boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue); + const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); + + CHECK_DATA_FORMAT(nodeDef, dataFormat, "FusedBatchNorm"); + // The descriptor only has the epsilon attribute. BatchNormalizationDescriptor desc; desc.m_Eps = ReadMandatoryNodeFloatAttribute(nodeDef, "epsilon"); + desc.m_DataLayout = dataFormat == "NHWC" ? DataLayout::NHWC : DataLayout::NCHW; // Data for the parsed tensor args (scale, offset, mean, variance) must be stored // locally until the layer is added. @@ -1448,19 +1453,8 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); - const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); - - if (dataFormat == "NHWC") - { - const TensorInfo outputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN); - layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); - layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); - } - else - { - layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo()); - inputSlot.Connect(layer->GetInputSlot(0)); - } + layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo()); + inputSlot.Connect(layer->GetInputSlot(0)); return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); } diff --git a/src/armnnTfParser/test/FusedBatchNorm.cpp b/src/armnnTfParser/test/FusedBatchNorm.cpp index bb9e3ed863..98bdb26183 100644 --- a/src/armnnTfParser/test/FusedBatchNorm.cpp +++ b/src/armnnTfParser/test/FusedBatchNorm.cpp @@ -7,11 +7,13 @@ #include "armnnTfParser/ITfParser.hpp" #include "ParserPrototxtFixture.hpp" +#include <array> + BOOST_AUTO_TEST_SUITE(TensorflowParser) struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser> { - FusedBatchNormFixture() + explicit FusedBatchNormFixture(const std::string& dataLayout) { m_Prototext = "node { \n" " name: \"graphInput\" \n" @@ -143,33 +145,62 @@ struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfP " attr { \n" " key: \"data_format\" \n" " value { \n" - " s: \"NHWC\" \n" - " } \n" - " } \n" - " attr { \n" - " key: \"epsilon\" \n" - " value { \n" - " f: 0.0010000000475 \n" - " } \n" - " } \n" - " attr { \n" - " key: \"is_training\" \n" - " value { \n" - " b: false \n" - " } \n" - " } \n" - "} \n"; + " s: \""; + m_Prototext.append(dataLayout); + m_Prototext.append("\" \n" + " } \n" + " } \n" + " attr { \n" + " key: \"epsilon\" \n" + " value { \n" + " f: 0.0010000000475 \n" + " } \n" + " } \n" + " attr { \n" + " key: \"is_training\" \n" + " value { \n" + " b: false \n" + " } \n" + " } \n" + "} \n"); + + // Set the input shape according to the data layout + std::array<unsigned int, 4> dims; + if (dataLayout == "NHWC") + { + dims = { 1u, 3u, 3u, 1u }; + } + else // dataLayout == "NCHW" + { + dims = { 1u, 1u, 3u, 3u }; + } - SetupSingleInputSingleOutput({1, 3, 3, 1}, "graphInput", "output"); + SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "output"); } }; -BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNorm, FusedBatchNormFixture) +struct FusedBatchNormNhwcFixture : FusedBatchNormFixture +{ + FusedBatchNormNhwcFixture() : FusedBatchNormFixture("NHWC"){} +}; +BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture) +{ + RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data. + { -2.8277204f, -2.12079024f, -1.4138602f, + -0.7069301f, 0.0f, 0.7069301f, + 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data. +} + +struct FusedBatchNormNchwFixture : FusedBatchNormFixture +{ + FusedBatchNormNchwFixture() : FusedBatchNormFixture("NCHW"){} +}; +BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw, FusedBatchNormNchwFixture) { - RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, // Input data. - {-2.8277204f, -2.12079024f, -1.4138602f, - -0.7069301f, 0.0f, 0.7069301f, - 1.4138602f, 2.12079024f, 2.8277204f}); // Expected output data. + RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data. + { -2.8277204f, -2.12079024f, -1.4138602f, + -0.7069301f, 0.0f, 0.7069301f, + 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data. } BOOST_AUTO_TEST_SUITE_END() |