aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-11-27 13:29:51 +0000
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-11-29 10:28:32 +0000
commit2e259276fba9fa5c6c2e146de3b26e3d6c6cccc6 (patch)
tree935888a1252056036ac1cfdd623f60b46bd3579d
parent2445744f46921ba5e30f9110c26b9485ddfe90b5 (diff)
downloadarmnn-2e259276fba9fa5c6c2e146de3b26e3d6c6cccc6.tar.gz
Github #306 Treat data_format attribute as optional in TfParser::ParseFusedBatchNorm()
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: I1c6583e4abb43b864dc636f8cdcd9011c763a6fe
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp16
-rw-r--r--src/armnnTfParser/test/FusedBatchNorm.cpp26
2 files changed, 30 insertions, 12 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index d085ed84e3..51423bf6a7 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -195,6 +195,19 @@ std::vector<uint32_t> ReadOptionalNodeUint32ListAttribute(const tensorflow::Node
return attriList;
}
+std::string ReadOptionalNodeStringAttribute(const tensorflow::NodeDef& nodeDef,
+ const std::string& name,
+ const std::string& defaultValue = "")
+{
+ std::string attribValue = defaultValue;
+ ReadOptionalNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kS,
+ [&attribValue](const tensorflow::AttrValue& attrValue)
+ {
+ attribValue = attrValue.s();
+ });
+ return attribValue;
+}
+
bool ReadOptionalNodeBoolAttribute(const tensorflow::NodeDef& nodeDef,
const std::string& name,
bool defaultValue = false)
@@ -1594,8 +1607,7 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no
ParsedConstTfOperation<float>* varianceNode =
boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue);
- const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
-
+ const std::string dataFormat = ReadOptionalNodeStringAttribute(nodeDef, "data_format", "NHWC");
CHECK_DATA_FORMAT(nodeDef, dataFormat, "FusedBatchNorm");
// The descriptor only has the epsilon attribute.
diff --git a/src/armnnTfParser/test/FusedBatchNorm.cpp b/src/armnnTfParser/test/FusedBatchNorm.cpp
index 98bdb26183..b93a4728d0 100644
--- a/src/armnnTfParser/test/FusedBatchNorm.cpp
+++ b/src/armnnTfParser/test/FusedBatchNorm.cpp
@@ -141,16 +141,22 @@ struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfP
" value { \n"
" type: DT_FLOAT \n"
" } \n"
- " } \n"
- " attr { \n"
- " key: \"data_format\" \n"
- " value { \n"
- " s: \"";
- m_Prototext.append(dataLayout);
- m_Prototext.append("\" \n"
- " } \n"
- " } \n"
- " attr { \n"
+ " } \n";
+
+ // NOTE: we only explicitly set data_format when it is not the default NHWC
+ if (dataLayout != "NHWC")
+ {
+ m_Prototext.append(" attr { \n"
+ " key: \"data_format\" \n"
+ " value { \n"
+ " s: \"");
+ m_Prototext.append(dataLayout);
+ m_Prototext.append("\" \n"
+ " } \n"
+ " } \n");
+ }
+
+ m_Prototext.append(" attr { \n"
" key: \"epsilon\" \n"
" value { \n"
" f: 0.0010000000475 \n"