aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r--src/armnnTfParser/TfParser.cpp20
1 files changed, 7 insertions, 13 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index b40b05409a..73bdb656b5 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1421,9 +1421,14 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no
ParsedConstTfOperation<float>* varianceNode =
boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue);
+ const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
+
+ CHECK_DATA_FORMAT(nodeDef, dataFormat, "FusedBatchNorm");
+
// The descriptor only has the epsilon attribute.
BatchNormalizationDescriptor desc;
desc.m_Eps = ReadMandatoryNodeFloatAttribute(nodeDef, "epsilon");
+ desc.m_DataLayout = dataFormat == "NHWC" ? DataLayout::NHWC : DataLayout::NCHW;
// Data for the parsed tensor args (scale, offset, mean, variance) must be stored
// locally until the layer is added.
@@ -1448,19 +1453,8 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no
IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
- const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
-
- if (dataFormat == "NHWC")
- {
- const TensorInfo outputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN);
- layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
- layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name());
- }
- else
- {
- layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo());
- inputSlot.Connect(layer->GetInputSlot(0));
- }
+ layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo());
+ inputSlot.Connect(layer->GetInputSlot(0));
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}