aboutsummaryrefslogtreecommitdiff
path: root/src/backends/test/CreateWorkloadNeon.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/test/CreateWorkloadNeon.cpp')
-rw-r--r--src/backends/test/CreateWorkloadNeon.cpp29
1 files changed, 21 insertions, 8 deletions
diff --git a/src/backends/test/CreateWorkloadNeon.cpp b/src/backends/test/CreateWorkloadNeon.cpp
index a6f3540994..a67e68d8a5 100644
--- a/src/backends/test/CreateWorkloadNeon.cpp
+++ b/src/backends/test/CreateWorkloadNeon.cpp
@@ -273,19 +273,22 @@ BOOST_AUTO_TEST_CASE(CreateNormalizationFloatNhwcWorkload)
template <typename Pooling2dWorkloadType, typename armnn::DataType DataType>
-static void NeonCreatePooling2dWorkloadTest()
+static void NeonCreatePooling2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW)
{
Graph graph;
NeonWorkloadFactory factory;
auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>
- (factory, graph);
+ (factory, graph, dataLayout);
+
+ TensorShape inputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{3, 2, 5, 5} : TensorShape{3, 5, 5, 2};
+ TensorShape outputShape = (dataLayout == DataLayout::NCHW) ? TensorShape{3, 2, 2, 4} : TensorShape{3, 2, 4, 2};
// Checks that outputs and inputs are as we expect them (see definition of CreatePooling2dWorkloadTest).
Pooling2dQueueDescriptor queueDescriptor = workload->GetData();
auto inputHandle = boost::polymorphic_downcast<INeonTensorHandle*>(queueDescriptor.m_Inputs[0]);
auto outputHandle = boost::polymorphic_downcast<INeonTensorHandle*>(queueDescriptor.m_Outputs[0]);
- BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo({3, 2, 5, 5}, DataType)));
- BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo({3, 2, 2, 4}, DataType)));
+ BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo(inputShape, DataType)));
+ BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo(outputShape, DataType)));
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -295,14 +298,24 @@ BOOST_AUTO_TEST_CASE(CreatePooling2dFloat16Workload)
}
#endif
-BOOST_AUTO_TEST_CASE(CreatePooling2dFloatWorkload)
+BOOST_AUTO_TEST_CASE(CreatePooling2dFloatNchwWorkload)
+{
+ NeonCreatePooling2dWorkloadTest<NeonPooling2dFloatWorkload, DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePooling2dFloatNhwcWorkload)
+{
+ NeonCreatePooling2dWorkloadTest<NeonPooling2dFloatWorkload, DataType::Float32>(DataLayout::NHWC);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePooling2dUint8NchwWorkload)
{
- NeonCreatePooling2dWorkloadTest<NeonPooling2dFloatWorkload, DataType::Float32>();
+ NeonCreatePooling2dWorkloadTest<NeonPooling2dUint8Workload, DataType::QuantisedAsymm8>(DataLayout::NCHW);
}
-BOOST_AUTO_TEST_CASE(CreatePooling2dUint8Workload)
+BOOST_AUTO_TEST_CASE(CreatePooling2dUint8NhwcWorkload)
{
- NeonCreatePooling2dWorkloadTest<NeonPooling2dUint8Workload, DataType::QuantisedAsymm8>();
+ NeonCreatePooling2dWorkloadTest<NeonPooling2dUint8Workload, DataType::QuantisedAsymm8>(DataLayout::NHWC);
}
template <typename ReshapeWorkloadType, typename armnn::DataType DataType>