aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test/RefCreateWorkloadTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test/RefCreateWorkloadTests.cpp')
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp52
1 files changed, 46 insertions, 6 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index 1ec7749168..d258b81932 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -143,16 +143,56 @@ BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload)
armnn::DataType::QuantisedAsymm8>();
}
-BOOST_AUTO_TEST_CASE(CreateBatchNormalizationWorkload)
+template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
+static void RefCreateBatchNormalizationWorkloadTest(DataLayout dataLayout)
{
- Graph graph;
+ Graph graph;
RefWorkloadFactory factory;
- auto workload = CreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload, armnn::DataType::Float32>
- (factory, graph);
+ auto workload =
+ CreateBatchNormalizationWorkloadTest<BatchNormalizationWorkloadType, DataType>(factory, graph, dataLayout);
+
+ TensorShape inputShape;
+ TensorShape outputShape;
+
+ switch (dataLayout)
+ {
+ case DataLayout::NHWC:
+ inputShape = { 2, 1, 1, 3 };
+ outputShape = { 2, 1, 1, 3 };
+ break;
+ case DataLayout::NCHW:
+ default:
+ inputShape = { 2, 3, 1, 1 };
+ outputShape = { 2, 3, 1, 1 };
+ break;
+ }
// Checks that outputs and inputs are as we expect them (see definition of CreateBatchNormalizationWorkloadTest).
- CheckInputOutput(
- std::move(workload), TensorInfo({2, 3, 1, 1}, DataType::Float32), TensorInfo({2, 3, 1, 1}, DataType::Float32));
+ CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat32Workload)
+{
+ RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload,armnn::DataType::Float32>
+ (DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloat32WorkloadNhwc)
+{
+ RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationFloat32Workload, armnn::DataType::Float32>
+ (DataLayout::NHWC);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8Workload)
+{
+ RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationUint8Workload, armnn::DataType::QuantisedAsymm8>
+ (DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateBatchNormalizationUint8WorkloadNhwc)
+{
+ RefCreateBatchNormalizationWorkloadTest<RefBatchNormalizationUint8Workload, armnn::DataType::QuantisedAsymm8>
+ (DataLayout::NHWC);
}
BOOST_AUTO_TEST_CASE(CreateConvertFp16ToFp32Float32Workload)