aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test')
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp38
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp9
2 files changed, 35 insertions, 12 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index a8901d2cc5..dc0348dc10 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -449,18 +449,42 @@ BOOST_AUTO_TEST_CASE(CreateResizeBilinearFloat32Nhwc)
RefCreateResizeBilinearTest<RefResizeBilinearFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
}
-BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32)
+template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
+static void RefCreateL2NormalizationTest(DataLayout dataLayout)
{
Graph graph;
RefWorkloadFactory factory;
- auto workload = CreateL2NormalizationWorkloadTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>
- (factory, graph);
+ auto workload =
+ CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
+
+ TensorShape inputShape;
+ TensorShape outputShape;
+
+ switch (dataLayout)
+ {
+ case DataLayout::NHWC:
+ inputShape = { 5, 50, 67, 20 };
+ outputShape = { 5, 50, 67, 20 };
+ break;
+ case DataLayout::NCHW:
+ default:
+ inputShape = { 5, 20, 50, 67 };
+ outputShape = { 5, 20, 50, 67 };
+ break;
+ }
// Checks that outputs and inputs are as we expect them (see definition of CreateL2NormalizationWorkloadTest).
- CheckInputOutput(
- std::move(workload),
- TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32),
- TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32));
+ CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32)
+{
+ RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32Nhwc)
+{
+ RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
}
template <typename ReshapeWorkloadType, armnn::DataType DataType>
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 797051ee18..2815e342c0 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -211,11 +211,10 @@ ARMNN_AUTO_TEST_CASE(Pad2d, Pad2dTest)
ARMNN_AUTO_TEST_CASE(Pad3d, Pad3dTest)
ARMNN_AUTO_TEST_CASE(Pad4d, Pad4dTest)
-// NOTE: These tests are disabled until NHWC is supported by the reference L2Normalization implementation.
-//ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest);
-//ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest);
-//ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest);
-//ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest);
+ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest)
// Constant
ARMNN_AUTO_TEST_CASE(Constant, ConstantTest)