aboutsummaryrefslogtreecommitdiff
path: root/src/backends/test/CreateWorkloadNeon.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/test/CreateWorkloadNeon.cpp')
-rw-r--r--src/backends/test/CreateWorkloadNeon.cpp27
1 files changed, 20 insertions, 7 deletions
diff --git a/src/backends/test/CreateWorkloadNeon.cpp b/src/backends/test/CreateWorkloadNeon.cpp
index a67e68d8a5..b2ec563a69 100644
--- a/src/backends/test/CreateWorkloadNeon.cpp
+++ b/src/backends/test/CreateWorkloadNeon.cpp
@@ -179,33 +179,46 @@ BOOST_AUTO_TEST_CASE(CreateBatchNormalizationFloatWorkload)
}
template <typename Convolution2dWorkloadType, typename armnn::DataType DataType>
-static void NeonCreateConvolution2dWorkloadTest()
+static void NeonCreateConvolution2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW)
{
Graph graph;
NeonWorkloadFactory factory;
auto workload = CreateConvolution2dWorkloadTest<Convolution2dWorkloadType,
- DataType>(factory, graph);
+ DataType>(factory, graph, dataLayout);
+
+ TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 3, 8, 16} : TensorShape{2, 8, 16, 3};
+ TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{2, 2, 2, 10} : TensorShape{2, 2, 10, 2};
// Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
Convolution2dQueueDescriptor queueDescriptor = workload->GetData();
auto inputHandle = boost::polymorphic_downcast<INeonTensorHandle*>(queueDescriptor.m_Inputs[0]);
auto outputHandle = boost::polymorphic_downcast<INeonTensorHandle*>(queueDescriptor.m_Outputs[0]);
- BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo({2, 3, 8, 16}, DataType)));
- BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo({2, 2, 2, 10}, DataType)));
+ BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo(inputShape, DataType)));
+ BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo(outputShape, DataType)));
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-BOOST_AUTO_TEST_CASE(CreateConvolution2dFloat16Workload)
+BOOST_AUTO_TEST_CASE(CreateConvolution2dFloat16NchwWorkload)
{
NeonCreateConvolution2dWorkloadTest<NeonConvolution2dFloatWorkload, DataType::Float16>();
}
-#endif
-BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatWorkload)
+BOOST_AUTO_TEST_CASE(CreateConvolution2dFloat16NhwcWorkload)
+{
+ NeonCreateConvolution2dWorkloadTest<NeonConvolution2dFloatWorkload, DataType::Float16>(DataLayout::NHWC);
+}
+
+#endif
+BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNchwWorkload)
{
NeonCreateConvolution2dWorkloadTest<NeonConvolution2dFloatWorkload, DataType::Float32>();
}
+BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNhwcWorkload)
+{
+ NeonCreateConvolution2dWorkloadTest<NeonConvolution2dFloatWorkload, DataType::Float32>(DataLayout::NHWC);
+}
+
template <typename FullyConnectedWorkloadType, typename armnn::DataType DataType>
static void NeonCreateFullyConnectedWorkloadTest()
{