aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/test/RefCreateWorkloadTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/test/RefCreateWorkloadTests.cpp')
-rw-r--r--src/backends/reference/test/RefCreateWorkloadTests.cpp27
1 files changed, 21 insertions, 6 deletions
diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp
index dc0348dc10..236267c177 100644
--- a/src/backends/reference/test/RefCreateWorkloadTests.cpp
+++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp
@@ -177,17 +177,32 @@ BOOST_AUTO_TEST_CASE(CreateConvertFp32ToFp16Float16Workload)
std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float32), TensorInfo({1, 3, 2, 3}, DataType::Float16));
}
-BOOST_AUTO_TEST_CASE(CreateConvolution2dWorkload)
+static void RefCreateConvolution2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW)
{
- Graph graph;
+ Graph graph;
RefWorkloadFactory factory;
- auto workload = CreateConvolution2dWorkloadTest<RefConvolution2dFloat32Workload,
- DataType::Float32>(factory, graph);
+ auto workload = CreateConvolution2dWorkloadTest<RefConvolution2dFloat32Workload, DataType::Float32>
+ (factory, graph, dataLayout);
+
+ std::initializer_list<unsigned int> inputShape = (dataLayout == DataLayout::NCHW) ?
+ std::initializer_list<unsigned int>({2, 3, 8, 16}) : std::initializer_list<unsigned int>({2, 8, 16, 3});
+ std::initializer_list<unsigned int> outputShape = (dataLayout == DataLayout::NCHW) ?
+ std::initializer_list<unsigned int>({2, 2, 2, 10}) : std::initializer_list<unsigned int>({2, 2, 10, 2});
// Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
CheckInputOutput(std::move(workload),
- TensorInfo({2, 3, 8, 16}, DataType::Float32),
- TensorInfo({2, 2, 2, 10}, DataType::Float32));
+ TensorInfo(inputShape, DataType::Float32),
+ TensorInfo(outputShape, DataType::Float32));
+}
+
+BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNchwWorkload)
+{
+ RefCreateConvolution2dWorkloadTest(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNhwcWorkload)
+{
+ RefCreateConvolution2dWorkloadTest(DataLayout::NHWC);
}
template <typename FullyConnectedWorkloadType, armnn::DataType DataType>