diff options
Diffstat (limited to 'src/backends/test/CreateWorkloadCl.cpp')
-rw-r--r-- | src/backends/test/CreateWorkloadCl.cpp | 31 |
1 files changed, 23 insertions, 8 deletions
diff --git a/src/backends/test/CreateWorkloadCl.cpp b/src/backends/test/CreateWorkloadCl.cpp index e7e39b0f70..0314f6d92a 100644 --- a/src/backends/test/CreateWorkloadCl.cpp +++ b/src/backends/test/CreateWorkloadCl.cpp @@ -320,30 +320,45 @@ BOOST_AUTO_TEST_CASE(CreateNormalizationFloat16NhwcWorkload) } template <typename Pooling2dWorkloadType, typename armnn::DataType DataType> -static void ClPooling2dWorkloadTest() +static void ClPooling2dWorkloadTest(DataLayout dataLayout) { Graph graph; ClWorkloadFactory factory; - auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph); + auto workload = CreatePooling2dWorkloadTest<Pooling2dWorkloadType, DataType>(factory, graph, dataLayout); + + std::initializer_list<unsigned int> inputShape = (dataLayout == DataLayout::NCHW) ? + std::initializer_list<unsigned int>({3, 2, 5, 5}) : std::initializer_list<unsigned int>({3, 5, 5, 2}); + std::initializer_list<unsigned int> outputShape = (dataLayout == DataLayout::NCHW) ? + std::initializer_list<unsigned int>({3, 2, 2, 4}) : std::initializer_list<unsigned int>({3, 2, 4, 2}); // Check that inputs/outputs are as we expect them (see definition of CreatePooling2dWorkloadTest). Pooling2dQueueDescriptor queueDescriptor = workload->GetData(); auto inputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]); auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]); - BOOST_TEST(CompareIClTensorHandleShape(inputHandle, {3, 2, 5, 5})); - BOOST_TEST(CompareIClTensorHandleShape(outputHandle, {3, 2, 2, 4})); + BOOST_TEST(CompareIClTensorHandleShape(inputHandle, inputShape)); + BOOST_TEST(CompareIClTensorHandleShape(outputHandle, outputShape)); +} + +BOOST_AUTO_TEST_CASE(CreatePooling2dFloatNchwWorkload) +{ + ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float32>(DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(CreatePooling2dFloatNhwcWorkload) +{ + ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float32>(DataLayout::NHWC); } -BOOST_AUTO_TEST_CASE(CreatePooling2dFloatWorkload) +BOOST_AUTO_TEST_CASE(CreatePooling2dFloat16NchwWorkload) { - ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float32>(); + ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float16>(DataLayout::NCHW); } -BOOST_AUTO_TEST_CASE(CreatePooling2dFloat16Workload) +BOOST_AUTO_TEST_CASE(CreatePooling2dFloat16NhwcWorkload) { - ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float16>(); + ClPooling2dWorkloadTest<ClPooling2dFloatWorkload, armnn::DataType::Float16>(DataLayout::NHWC); } template <typename ReshapeWorkloadType, typename armnn::DataType DataType> |