From 075c7504b6d6ecd50fa61ca53ab1bcccc8865843 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Wed, 5 Dec 2018 13:10:45 +0000 Subject: IVGCVSW-2267 Remove the input swizzling from ParseFusedBatchNorm * Removed the input swizzling when the data layout is NHWC * Split the unit test into NHWC and NCHW cases Change-Id: I6b9fef70bc4ba5e01d14cbfaea3c842a289b0a0e --- src/armnnTfParser/test/FusedBatchNorm.cpp | 77 ++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 23 deletions(-) (limited to 'src/armnnTfParser/test') diff --git a/src/armnnTfParser/test/FusedBatchNorm.cpp b/src/armnnTfParser/test/FusedBatchNorm.cpp index bb9e3ed863..98bdb26183 100644 --- a/src/armnnTfParser/test/FusedBatchNorm.cpp +++ b/src/armnnTfParser/test/FusedBatchNorm.cpp @@ -7,11 +7,13 @@ #include "armnnTfParser/ITfParser.hpp" #include "ParserPrototxtFixture.hpp" +#include + BOOST_AUTO_TEST_SUITE(TensorflowParser) struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture { - FusedBatchNormFixture() + explicit FusedBatchNormFixture(const std::string& dataLayout) { m_Prototext = "node { \n" " name: \"graphInput\" \n" @@ -143,33 +145,62 @@ struct FusedBatchNormFixture : public armnnUtils::ParserPrototxtFixture dims; + if (dataLayout == "NHWC") + { + dims = { 1u, 3u, 3u, 1u }; + } + else // dataLayout == "NCHW" + { + dims = { 1u, 1u, 3u, 3u }; + } - SetupSingleInputSingleOutput({1, 3, 3, 1}, "graphInput", "output"); + SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "output"); } }; -BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNorm, FusedBatchNormFixture) +struct FusedBatchNormNhwcFixture : FusedBatchNormFixture +{ + FusedBatchNormNhwcFixture() : FusedBatchNormFixture("NHWC"){} +}; +BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNhwc, FusedBatchNormNhwcFixture) +{ + RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data. + { -2.8277204f, -2.12079024f, -1.4138602f, + -0.7069301f, 0.0f, 0.7069301f, + 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data. +} + +struct FusedBatchNormNchwFixture : FusedBatchNormFixture +{ + FusedBatchNormNchwFixture() : FusedBatchNormFixture("NCHW"){} +}; +BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNormNchw, FusedBatchNormNchwFixture) { - RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, // Input data. - {-2.8277204f, -2.12079024f, -1.4138602f, - -0.7069301f, 0.0f, 0.7069301f, - 1.4138602f, 2.12079024f, 2.8277204f}); // Expected output data. + RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // Input data. + { -2.8277204f, -2.12079024f, -1.4138602f, + -0.7069301f, 0.0f, 0.7069301f, + 1.4138602f, 2.12079024f, 2.8277204f }); // Expected output data. } BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1