aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2018-12-05 13:10:45 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2018-12-05 13:56:16 +0000
commit075c7504b6d6ecd50fa61ca53ab1bcccc8865843 (patch)
tree6cf8009a0c63fc13927e8f8e11d5713683cf1d72 /src/armnnTfParser
parent4631582a4c8b92917633d0af4ebcc8fff2abd04a (diff)
downloadarmnn-075c7504b6d6ecd50fa61ca53ab1bcccc8865843.tar.gz
IVGCVSW-2267 Remove the input swizzling from ParseFusedBatchNorm
* Removed the input swizzling when the data layout is NHWC * Split the unit test into NHWC and NCHW cases Change-Id: I6b9fef70bc4ba5e01d14cbfaea3c842a289b0a0e
Diffstat (limited to 'src/armnnTfParser')
-rw-r--r--src/armnnTfParser/TfParser.cpp20
-rw-r--r--src/armnnTfParser/test/FusedBatchNorm.cpp77
2 files changed, 61 insertions, 36 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index b40b05409a..73bdb656b5 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1421,9 +1421,14 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no
ParsedConstTfOperation<float>* varianceNode =
boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue);
+ const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
+
+ CHECK_DATA_FORMAT(nodeDef, dataFormat, "FusedBatchNorm");
+
// The descriptor only has the epsilon attribute.
BatchNormalizationDescriptor desc;
desc.m_Eps = ReadMandatoryNodeFloatAttribute(nodeDef, "epsilon");
+ desc.m_DataLayout = dataFormat == "NHWC" ? DataLayout::NHWC : DataLayout::NCHW;
// Data for the parsed tensor args (scale, offset, mean, variance) must be stored
// locally until the layer is added.
@@ -1448,19 +1453,8 @@ ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& no
IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
- const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
-
- if (dataFormat == "NHWC")
- {
- const TensorInfo outputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN);
- layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
- layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name());
- }
- else
- {
- layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo());
- inputSlot.Connect(layer->GetInputSlot(0));
- }
+ layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo());
+ inputSlot.Connect(layer->GetInputSlot(0));
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
diff --git a/src/armnnTfParser/test/FusedBatchNorm.cpp b/src/armnnTfParser/test/FusedBatchNorm.cpp
index bb9e3ed863..98bdb26183 100644
--- a/src/armnnTfParser/test/FusedBatchNorm.cpp
+++ b/src/armnnTfParser/test/FusedBatchNorm.cpp
@@ -7,11 +7,13 @@
#include "armnnTfParser/ITfParser.hpp"
#include "ParserPrototxtFixture.hpp"
+#include <array>
+
BOOST_AUTO_TEST_SUITE(TensorflowParser)
struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
- FusedBatchNormFixture()
+ explicit FusedBatchNormFixture(const std::string& dataLayout)
{
m_Prototext = "node { \n"
" name: \"graphInput\" \n"
@@ -143,33 +145,62 @@ struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture<armnnTfP
" attr { \n"
" key: \"data_format\" \n"
" value { \n"
- " s: \"NHWC\" \n"
- " } \n"
- " } \n"
- " attr { \n"
- " key: \"epsilon\" \n"
- " value { \n"
- " f: 0.0010000000475 \n"
- " } \n"
- " } \n"
- " attr { \n"
- " key: \"is_training\" \n"
- " value { \n"
- " b: false \n"
- " } \n"
- " } \n"
- "} \n";
+ " s: \"";
+ m_Prototext.append(dataLayout);
+ m_Prototext.append("\" \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"epsilon\" \n"
+ " value { \n"
+ " f: 0.0010000000475 \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"is_training\" \n"
+ " value { \n"
+ " b: false \n"
+ " } \n"
+ " } \n"
+ "} \n");
+
+ // Set the input shape according to the data layout
+ std::array<unsigned int, 4> dims;
+ if (dataLayout == "NHWC")
+ {
+ dims = { 1u, 3u, 3u, 1u };
+ }
+ else // dataLayout == "NCHW"
+ {
+ dims = { 1u, 1u, 3u, 3u };
+ }
- SetupSingleInputSingleOutput({1, 3, 3, 1}, "graphInput", "output");
+ SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "output");
}
};
-BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNorm, FusedBatchNormFixture)
+struct FusedBatchNormNhwcFixture : FusedBatchNormFixture
+{
+ FusedBatchNormNhwcFixture() : FusedBatchNormFixture("NHWC"){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture)
+{
+ RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data.
+ { -2.8277204f, -2.12079024f, -1.4138602f,
+ -0.7069301f, 0.0f, 0.7069301f,
+ 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data.
+}
+
+struct FusedBatchNormNchwFixture : FusedBatchNormFixture
+{
+ FusedBatchNormNchwFixture() : FusedBatchNormFixture("NCHW"){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw, FusedBatchNormNchwFixture)
{
- RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, // Input data.
- {-2.8277204f, -2.12079024f, -1.4138602f,
- -0.7069301f, 0.0f, 0.7069301f,
- 1.4138602f, 2.12079024f, 2.8277204f}); // Expected output data.
+ RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data.
+ { -2.8277204f, -2.12079024f, -1.4138602f,
+ -0.7069301f, 0.0f, 0.7069301f,
+ 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data.
}
BOOST_AUTO_TEST_SUITE_END()