diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-11-27 13:29:51 +0000 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-11-29 10:28:32 +0000 |
commit | 2e259276fba9fa5c6c2e146de3b26e3d6c6cccc6 (patch) | |
tree | 935888a1252056036ac1cfdd623f60b46bd3579d | |
parent | 2445744f46921ba5e30f9110c26b9485ddfe90b5 (diff) | |
download | armnn-2e259276fba9fa5c6c2e146de3b26e3d6c6cccc6.tar.gz |
Github #306 Treat data_format attribute as optional in TfParser::ParseFusedBatchNorm()
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I1c6583e4abb43b864dc636f8cdcd9011c763a6fe
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 16 | ||||
-rw-r--r-- | src/armnnTfParser/test/FusedBatchNorm.cpp | 26 |
2 files changed, 30 insertions, 12 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index d085ed84e3..51423bf6a7 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -195,6 +195,19 @@ std::vector<uint32_t> ReadOptionalNodeUint32ListAttribute(const tensorflow::Node return attriList; } +std::string ReadOptionalNodeStringAttribute(const tensorflow::NodeDef& nodeDef, + const std::string& name, + const std::string& defaultValue = "") +{ + std::string attribValue = defaultValue; + ReadOptionalNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kS, + [&attribValue](const tensorflow::AttrValue& attrValue) + { + attribValue = attrValue.s(); + }); + return attribValue; +} + bool ReadOptionalNodeBoolAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name, bool defaultValue = false) @@ -1594,8 +1607,7 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no ParsedConstTfOperation<float>* varianceNode = boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue); - const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); - + const std::string dataFormat = ReadOptionalNodeStringAttribute(nodeDef, "data_format", "NHWC"); CHECK_DATA_FORMAT(nodeDef, dataFormat, "FusedBatchNorm"); // The descriptor only has the epsilon attribute. diff --git a/src/armnnTfParser/test/FusedBatchNorm.cpp b/src/armnnTfParser/test/FusedBatchNorm.cpp index 98bdb26183..b93a4728d0 100644 --- a/src/armnnTfParser/test/FusedBatchNorm.cpp +++ b/src/armnnTfParser/test/FusedBatchNorm.cpp @@ -141,16 +141,22 @@ struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfP " value { \n" " type: DT_FLOAT \n" " } \n" - " } \n" - " attr { \n" - " key: \"data_format\" \n" - " value { \n" - " s: \""; - m_Prototext.append(dataLayout); - m_Prototext.append("\" \n" - " } \n" - " } \n" - " attr { \n" + " } \n"; + + // NOTE: we only explicitly set data_format when it is not the default NHWC + if (dataLayout != "NHWC") + { + m_Prototext.append(" attr { \n" + " key: \"data_format\" \n" + " value { \n" + " s: \""); + m_Prototext.append(dataLayout); + m_Prototext.append("\" \n" + " } \n" + " } \n"); + } + + m_Prototext.append(" attr { \n" " key: \"epsilon\" \n" " value { \n" " f: 0.0010000000475 \n" |