aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test/RefCreateWorkloadTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test/RefCreateWorkloadTests.cpp')
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp31
1 files changed, 25 insertions, 6 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 236267c177..1ec7749168 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -231,21 +231,40 @@ BOOST_AUTO_TEST_CASE(CreateFullyConnectedUint8Workload)
}
template <typename NormalizationWorkloadType, armnn::DataType DataType>
-static void RefCreateNormalizationWorkloadTest()
+static void RefCreateNormalizationWorkloadTest(DataLayout dataLayout)
{
Graph graph;
RefWorkloadFactory factory;
- auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph);
+ auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
+
+ TensorShape inputShape;
+ TensorShape outputShape;
+
+ switch (dataLayout)
+ {
+ case DataLayout::NHWC:
+ inputShape = { 3, 1, 5, 5 };
+ outputShape = { 3, 1, 5, 5 };
+ break;
+ case DataLayout::NCHW:
+ default:
+ inputShape = { 3, 5, 5, 1 };
+ outputShape = { 3, 5, 5, 1 };
+ break;
+ }
// Checks that outputs and inputs are as we expect them (see definition of CreateNormalizationWorkloadTest).
- CheckInputOutput(std::move(workload),
- TensorInfo({3, 5, 5, 1}, DataType),
- TensorInfo({3, 5, 5, 1}, DataType));
+ CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
}
BOOST_AUTO_TEST_CASE(CreateRefNormalizationNchwWorkload)
{
- RefCreateNormalizationWorkloadTest<RefNormalizationFloat32Workload, armnn::DataType::Float32>();
+ RefCreateNormalizationWorkloadTest<RefNormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateRefNormalizationNhwcWorkload)
+{
+ RefCreateNormalizationWorkloadTest<RefNormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
}
template <typename Pooling2dWorkloadType, armnn::DataType DataType>