diff options
Diffstat (limited to 'src/backends/reference/test/RefCreateWorkloadTests.cpp')
-rw-r--r-- | src/backends/reference/test/RefCreateWorkloadTests.cpp | 38 |
1 files changed, 31 insertions, 7 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index a8901d2cc5..dc0348dc10 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -449,18 +449,42 @@ BOOST_AUTO_TEST_CASE(CreateResizeBilinearFloat32Nhwc) RefCreateResizeBilinearTest<RefResizeBilinearFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC); } -BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32) +template <typename L2NormalizationWorkloadType, armnn::DataType DataType> +static void RefCreateL2NormalizationTest(DataLayout dataLayout) { Graph graph; RefWorkloadFactory factory; - auto workload = CreateL2NormalizationWorkloadTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32> - (factory, graph); + auto workload = + CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout); + + TensorShape inputShape; + TensorShape outputShape; + + switch (dataLayout) + { + case DataLayout::NHWC: + inputShape = { 5, 50, 67, 20 }; + outputShape = { 5, 50, 67, 20 }; + break; + case DataLayout::NCHW: + default: + inputShape = { 5, 20, 50, 67 }; + outputShape = { 5, 20, 50, 67 }; + break; + } // Checks that outputs and inputs are as we expect them (see definition of CreateL2NormalizationWorkloadTest). - CheckInputOutput( - std::move(workload), - TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32), - TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32)); + CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType)); +} + +BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32) +{ + RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32Nhwc) +{ + RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC); } template <typename ReshapeWorkloadType, armnn::DataType DataType> |