diff options
Diffstat (limited to 'src/armnn/backends/test/CreateWorkloadCl.cpp')
-rw-r--r-- | src/armnn/backends/test/CreateWorkloadCl.cpp | 101 |
1 files changed, 64 insertions, 37 deletions
diff --git a/src/armnn/backends/test/CreateWorkloadCl.cpp b/src/armnn/backends/test/CreateWorkloadCl.cpp index 96001a4b78..340279e619 100644 --- a/src/armnn/backends/test/CreateWorkloadCl.cpp +++ b/src/armnn/backends/test/CreateWorkloadCl.cpp @@ -47,15 +47,18 @@ BOOST_AUTO_TEST_CASE(CreateActivationFloat16Workload) ClCreateActivationWorkloadTest<ClActivationFloatWorkload, armnn::DataType::Float16>(); } -template <typename AdditionWorkloadType, armnn::DataType DataType> -static void ClCreateAdditionWorkloadTest() +template <typename WorkloadType, + typename DescriptorType, + typename LayerType, + armnn::DataType DataType> +static void ClCreateArithmethicWorkloadTest() { Graph graph; ClWorkloadFactory factory; - auto workload = CreateAdditionWorkloadTest<AdditionWorkloadType, DataType>(factory, graph); + auto workload = CreateArithmeticWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(factory, graph); - // Checks that inputs/outputs are as we expect them (see definition of CreateAdditionWorkloadTest). - AdditionQueueDescriptor queueDescriptor = workload->GetData(); + // Checks that inputs/outputs are as we expect them (see definition of CreateSubtractionWorkloadTest). + DescriptorType queueDescriptor = workload->GetData(); auto inputHandle1 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]); auto inputHandle2 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[1]); auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]); @@ -66,12 +69,66 @@ static void ClCreateAdditionWorkloadTest() BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) { - ClCreateAdditionWorkloadTest<ClAdditionFloatWorkload, armnn::DataType::Float32>(); + ClCreateArithmethicWorkloadTest<ClAdditionFloatWorkload, + AdditionQueueDescriptor, + AdditionLayer, + armnn::DataType::Float32>(); } BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload) { - ClCreateAdditionWorkloadTest<ClAdditionFloatWorkload, armnn::DataType::Float16>(); + ClCreateArithmethicWorkloadTest<ClAdditionFloatWorkload, + AdditionQueueDescriptor, + AdditionLayer, + armnn::DataType::Float16>(); +} + +BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) +{ + ClCreateArithmethicWorkloadTest<ClSubtractionFloatWorkload, + SubtractionQueueDescriptor, + SubtractionLayer, + armnn::DataType::Float32>(); +} + +BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload) +{ + ClCreateArithmethicWorkloadTest<ClSubtractionFloatWorkload, + SubtractionQueueDescriptor, + SubtractionLayer, + armnn::DataType::Float16>(); +} + +BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest) +{ + ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload, + MultiplicationQueueDescriptor, + MultiplicationLayer, + armnn::DataType::Float32>(); +} + +BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest) +{ + ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload, + MultiplicationQueueDescriptor, + MultiplicationLayer, + armnn::DataType::Float16>(); +} + +BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkloadTest) +{ + ClCreateArithmethicWorkloadTest<ClDivisionFloatWorkload, + DivisionQueueDescriptor, + DivisionLayer, + armnn::DataType::Float32>(); +} + +BOOST_AUTO_TEST_CASE(CreateDivisionFloat16WorkloadTest) +{ + ClCreateArithmethicWorkloadTest<ClDivisionFloatWorkload, + DivisionQueueDescriptor, + DivisionLayer, + armnn::DataType::Float16>(); } template <typename BatchNormalizationWorkloadType, armnn::DataType DataType> @@ -219,36 +276,6 @@ BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloat16WorkloadTest) ClCreateFullyConnectedWorkloadTest<ClFullyConnectedFloatWorkload, armnn::DataType::Float16>(); } - -template <typename MultiplicationWorkloadType, typename armnn::DataType DataType> -static void ClCreateMultiplicationWorkloadTest() -{ - Graph graph; - ClWorkloadFactory factory; - - auto workload = - CreateMultiplicationWorkloadTest<MultiplicationWorkloadType, DataType>(factory, graph); - - // Checks that inputs/outputs are as we expect them (see definition of CreateMultiplicationWorkloadTest). - MultiplicationQueueDescriptor queueDescriptor = workload->GetData(); - auto inputHandle1 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]); - auto inputHandle2 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[1]); - auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]); - BOOST_TEST(CompareIClTensorHandleShape(inputHandle1, {2, 3})); - BOOST_TEST(CompareIClTensorHandleShape(inputHandle2, {2, 3})); - BOOST_TEST(CompareIClTensorHandleShape(outputHandle, {2, 3})); -} - -BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest) -{ - ClCreateMultiplicationWorkloadTest<ClMultiplicationFloatWorkload, armnn::DataType::Float32>(); -} - -BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest) -{ - ClCreateMultiplicationWorkloadTest<ClMultiplicationFloatWorkload, armnn::DataType::Float16>(); -} - template <typename NormalizationWorkloadType, typename armnn::DataType DataType> static void ClNormalizationWorkloadTest() { |