aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/FusedBatchNorm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/test/FusedBatchNorm.cpp')
-rw-r--r--src/armnnTfParser/test/FusedBatchNorm.cpp26
1 files changed, 16 insertions, 10 deletions
diff --git a/src/armnnTfParser/test/FusedBatchNorm.cpp b/src/armnnTfParser/test/FusedBatchNorm.cpp
index 98bdb26183..b93a4728d0 100644
--- a/src/armnnTfParser/test/FusedBatchNorm.cpp
+++ b/src/armnnTfParser/test/FusedBatchNorm.cpp
@@ -141,16 +141,22 @@ struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfP
" value { \n"
" type: DT_FLOAT \n"
" } \n"
- " } \n"
- " attr { \n"
- " key: \"data_format\" \n"
- " value { \n"
- " s: \"";
- m_Prototext.append(dataLayout);
- m_Prototext.append("\" \n"
- " } \n"
- " } \n"
- " attr { \n"
+ " } \n";
+
+ // NOTE: we only explicitly set data_format when it is not the default NHWC
+ if (dataLayout != "NHWC")
+ {
+ m_Prototext.append(" attr { \n"
+ " key: \"data_format\" \n"
+ " value { \n"
+ " s: \"");
+ m_Prototext.append(dataLayout);
+ m_Prototext.append("\" \n"
+ " } \n"
+ " } \n");
+ }
+
+ m_Prototext.append(" attr { \n"
" key: \"epsilon\" \n"
" value { \n"
" f: 0.0010000000475 \n"