aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp16
1 files changed, 14 insertions, 2 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index d085ed84e3..51423bf6a7 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -195,6 +195,19 @@ std::vector<uint32_t> ReadOptionalNodeUint32ListAttribute(const tensorflow::Node
return attriList;
}
+std::string ReadOptionalNodeStringAttribute(const tensorflow::NodeDef& nodeDef,
+ const std::string& name,
+ const std::string& defaultValue = "")
+{
+ std::string attribValue = defaultValue;
+ ReadOptionalNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kS,
+ [&attribValue](const tensorflow::AttrValue& attrValue)
+ {
+ attribValue = attrValue.s();
+ });
+ return attribValue;
+}
+
bool ReadOptionalNodeBoolAttribute(const tensorflow::NodeDef& nodeDef,
const std::string& name,
bool defaultValue = false)
@@ -1594,8 +1607,7 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no
ParsedConstTfOperation<float>* varianceNode =
boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue);
- const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
-
+ const std::string dataFormat = ReadOptionalNodeStringAttribute(nodeDef, "data_format", "NHWC");
CHECK_DATA_FORMAT(nodeDef, dataFormat, "FusedBatchNorm");
// The descriptor only has the epsilon attribute.