diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index d085ed84e3..51423bf6a7 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -195,6 +195,19 @@ std::vector<uint32_t> ReadOptionalNodeUint32ListAttribute(const tensorflow::Node return attriList; } +std::string ReadOptionalNodeStringAttribute(const tensorflow::NodeDef& nodeDef, + const std::string& name, + const std::string& defaultValue = "") +{ + std::string attribValue = defaultValue; + ReadOptionalNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kS, + [&attribValue](const tensorflow::AttrValue& attrValue) + { + attribValue = attrValue.s(); + }); + return attribValue; +} + bool ReadOptionalNodeBoolAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name, bool defaultValue = false) @@ -1594,8 +1607,7 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no ParsedConstTfOperation<float>* varianceNode = boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue); - const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); - + const std::string dataFormat = ReadOptionalNodeStringAttribute(nodeDef, "data_format", "NHWC"); CHECK_DATA_FORMAT(nodeDef, dataFormat, "FusedBatchNorm"); // The descriptor only has the epsilon attribute. |