aboutsummaryrefslogtreecommitdiff
path: root/src/backends/test/CreateWorkloadCl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/test/CreateWorkloadCl.cpp')
-rw-r--r--src/backends/test/CreateWorkloadCl.cpp31
1 files changed, 23 insertions, 8 deletions
diff --git a/src/backends/test/CreateWorkloadCl.cpp b/src/backends/test/CreateWorkloadCl.cpp
index e7e39b0f70..0314f6d92a 100644
--- a/src/backends/test/CreateWorkloadCl.cpp
+++ b/src/backends/test/CreateWorkloadCl.cpp
@@ -320,30 +320,45 @@ BOOST_AUTO_TEST_CASE(CreateNormalizationFloat16NhwcWorkload)
}
template <typename Pooling2dWorkloadType, typename armnn::DataType DataType>
-static void ClPooling2dWorkloadTest()
+static void ClPooling2dWorkloadTest(DataLayout dataLayout)
{
Graph graph;
ClWorkloadFactory factory;
- auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph);
+ auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph, dataLayout);
+
+ std::initializer_list<unsigned int> inputShape = (dataLayout == DataLayout::NCHW) ?
+ std::initializer_list<unsigned int>({3, 2, 5, 5}) : std::initializer_list<unsigned int>({3, 5, 5, 2});
+ std::initializer_list<unsigned int> outputShape = (dataLayout == DataLayout::NCHW) ?
+ std::initializer_list<unsigned int>({3, 2, 2, 4}) : std::initializer_list<unsigned int>({3, 2, 4, 2});
// Check that inputs/outputs are as we expect them (see definition of CreatePooling2dWorkloadTest).
Pooling2dQueueDescriptor queueDescriptor = workload->GetData();
auto inputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]);
auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]);
- BOOST_TEST(CompareIClTensorHandleShape(inputHandle, {3, 2, 5, 5}));
- BOOST_TEST(CompareIClTensorHandleShape(outputHandle, {3, 2, 2, 4}));
+ BOOST_TEST(CompareIClTensorHandleShape(inputHandle, inputShape));
+ BOOST_TEST(CompareIClTensorHandleShape(outputHandle, outputShape));
+}
+
+BOOST_AUTO_TEST_CASE(CreatePooling2dFloatNchwWorkload)
+{
+ ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreatePooling2dFloatNhwcWorkload)
+{
+ ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float32>(DataLayout::NHWC);
}
-BOOST_AUTO_TEST_CASE(CreatePooling2dFloatWorkload)
+BOOST_AUTO_TEST_CASE(CreatePooling2dFloat16NchwWorkload)
{
- ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float32>();
+ ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float16>(DataLayout::NCHW);
}
-BOOST_AUTO_TEST_CASE(CreatePooling2dFloat16Workload)
+BOOST_AUTO_TEST_CASE(CreatePooling2dFloat16NhwcWorkload)
{
- ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float16>();
+ ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float16>(DataLayout::NHWC);
}
template <typename ReshapeWorkloadType, typename armnn::DataType DataType>