diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r-- | src/armnnTfParser/TfParser.cpp | 20 |
1 files changed, 7 insertions, 13 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index b40b05409a..73bdb656b5 100644 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -1421,9 +1421,14 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no ParsedConstTfOperation<float>* varianceNode = boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue); + const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); + + CHECK_DATA_FORMAT(nodeDef, dataFormat, "FusedBatchNorm"); + // The descriptor only has the epsilon attribute. BatchNormalizationDescriptor desc; desc.m_Eps = ReadMandatoryNodeFloatAttribute(nodeDef, "epsilon"); + desc.m_DataLayout = dataFormat == "NHWC" ? DataLayout::NHWC : DataLayout::NCHW; // Data for the parsed tensor args (scale, offset, mean, variance) must be stored // locally until the layer is added. @@ -1448,19 +1453,8 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); - const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); - - if (dataFormat == "NHWC") - { - const TensorInfo outputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN); - layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); - layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); - } - else - { - layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo()); - inputSlot.Connect(layer->GetInputSlot(0)); - } + layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo()); + inputSlot.Connect(layer->GetInputSlot(0)); return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); } |