From b63973ee1134336434a490fc9af8bba6cde79820 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Tue, 16 Oct 2018 16:23:33 +0100 Subject: IVGCVSW-2018 Support NHWC in the current ref implementation * Enabled the now supported ref layer tests * Re-enabled the failing test now that the bug has been fixed in ACL 1903a9976ae24f40cb2203364211ed62fcfbb985 * Added CreateWorkload test for ref L2Normalization NHWC * Refactoring the ref L2Normalization for clarity !armnn:153723 Change-Id: Id0067e49072b3e057ffe3ae3b70d928be6091c0f --- .../reference/test/RefCreateWorkloadTests.cpp | 38 ++++++++++++++++++---- src/backends/reference/test/RefLayerTests.cpp | 9 +++-- 2 files changed, 35 insertions(+), 12 deletions(-) (limited to 'src/backends/reference/test') diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index a8901d2cc5..dc0348dc10 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -449,18 +449,42 @@ BOOST_AUTO_TEST_CASE(CreateResizeBilinearFloat32Nhwc) RefCreateResizeBilinearTest(DataLayout::NHWC); } -BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32) +template +static void RefCreateL2NormalizationTest(DataLayout dataLayout) { Graph graph; RefWorkloadFactory factory; - auto workload = CreateL2NormalizationWorkloadTest - (factory, graph); + auto workload = + CreateL2NormalizationWorkloadTest(factory, graph, dataLayout); + + TensorShape inputShape; + TensorShape outputShape; + + switch (dataLayout) + { + case DataLayout::NHWC: + inputShape = { 5, 50, 67, 20 }; + outputShape = { 5, 50, 67, 20 }; + break; + case DataLayout::NCHW: + default: + inputShape = { 5, 20, 50, 67 }; + outputShape = { 5, 20, 50, 67 }; + break; + } // Checks that outputs and inputs are as we expect them (see definition of CreateL2NormalizationWorkloadTest). - CheckInputOutput( - std::move(workload), - TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32), - TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32)); + CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType)); +} + +BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32) +{ + RefCreateL2NormalizationTest(DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32Nhwc) +{ + RefCreateL2NormalizationTest(DataLayout::NHWC); } template diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 797051ee18..2815e342c0 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -211,11 +211,10 @@ ARMNN_AUTO_TEST_CASE(Pad2d, Pad2dTest) ARMNN_AUTO_TEST_CASE(Pad3d, Pad3dTest) ARMNN_AUTO_TEST_CASE(Pad4d, Pad4dTest) -// NOTE: These tests are disabled until NHWC is supported by the reference L2Normalization implementation. -//ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest); -//ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest); -//ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest); -//ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest); +ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest) +ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest) +ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest) +ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest) // Constant ARMNN_AUTO_TEST_CASE(Constant, ConstantTest) -- cgit v1.2.1