diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2018-10-16 16:23:33 +0100 |
---|---|---|
committer | Matthew Bentham <matthew.bentham@arm.com> | 2018-10-22 16:57:54 +0100 |
commit | b63973ee1134336434a490fc9af8bba6cde79820 (patch) | |
tree | 1304b693044697454bc10cd52b7a4746444b5feb /src/backends/reference/test | |
parent | 177d8d26925a58a579943e010d28d1ceaa033d64 (diff) | |
download | armnn-b63973ee1134336434a490fc9af8bba6cde79820.tar.gz |
IVGCVSW-2018 Support NHWC in the current ref implementation
* Enabled the now supported ref layer tests
* Re-enabled the failing test now that the bug has been fixed in
ACL 1903a9976ae24f40cb2203364211ed62fcfbb985
* Added CreateWorkload test for ref L2Normalization NHWC
* Refactoring the ref L2Normalization for clarity
!armnn:153723
Change-Id: Id0067e49072b3e057ffe3ae3b70d928be6091c0f
Diffstat (limited to 'src/backends/reference/test')
-rw-r--r-- | src/backends/reference/test/RefCreateWorkloadTests.cpp | 38 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 9 |
2 files changed, 35 insertions, 12 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index a8901d2cc5..dc0348dc10 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -449,18 +449,42 @@ BOOST_AUTO_TEST_CASE(CreateResizeBilinearFloat32Nhwc) RefCreateResizeBilinearTest<RefResizeBilinearFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC); } -BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32) +template <typename L2NormalizationWorkloadType, armnn::DataType DataType> +static void RefCreateL2NormalizationTest(DataLayout dataLayout) { Graph graph; RefWorkloadFactory factory; - auto workload = CreateL2NormalizationWorkloadTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32> - (factory, graph); + auto workload = + CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout); + + TensorShape inputShape; + TensorShape outputShape; + + switch (dataLayout) + { + case DataLayout::NHWC: + inputShape = { 5, 50, 67, 20 }; + outputShape = { 5, 50, 67, 20 }; + break; + case DataLayout::NCHW: + default: + inputShape = { 5, 20, 50, 67 }; + outputShape = { 5, 20, 50, 67 }; + break; + } // Checks that outputs and inputs are as we expect them (see definition of CreateL2NormalizationWorkloadTest). - CheckInputOutput( - std::move(workload), - TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32), - TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32)); + CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType)); +} + +BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32) +{ + RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32Nhwc) +{ + RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC); } template <typename ReshapeWorkloadType, armnn::DataType DataType> diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 797051ee18..2815e342c0 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -211,11 +211,10 @@ ARMNN_AUTO_TEST_CASE(Pad2d, Pad2dTest) ARMNN_AUTO_TEST_CASE(Pad3d, Pad3dTest) ARMNN_AUTO_TEST_CASE(Pad4d, Pad4dTest) -// NOTE: These tests are disabled until NHWC is supported by the reference L2Normalization implementation. -//ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest); -//ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest); -//ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest); -//ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest); +ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest) +ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest) +ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest) +ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest) // Constant ARMNN_AUTO_TEST_CASE(Constant, ConstantTest) |