aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/test/CreateWorkloadCl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/test/CreateWorkloadCl.cpp')
-rw-r--r--src/armnn/backends/test/CreateWorkloadCl.cpp101
1 files changed, 64 insertions, 37 deletions
diff --git a/src/armnn/backends/test/CreateWorkloadCl.cpp b/src/armnn/backends/test/CreateWorkloadCl.cpp
index 96001a4b78..340279e619 100644
--- a/src/armnn/backends/test/CreateWorkloadCl.cpp
+++ b/src/armnn/backends/test/CreateWorkloadCl.cpp
@@ -47,15 +47,18 @@ BOOST_AUTO_TEST_CASE(CreateActivationFloat16Workload)
ClCreateActivationWorkloadTest<ClActivationFloatWorkload, armnn::DataType::Float16>();
}
-template <typename AdditionWorkloadType, armnn::DataType DataType>
-static void ClCreateAdditionWorkloadTest()
+template <typename WorkloadType,
+ typename DescriptorType,
+ typename LayerType,
+ armnn::DataType DataType>
+static void ClCreateArithmethicWorkloadTest()
{
Graph graph;
ClWorkloadFactory factory;
- auto workload = CreateAdditionWorkloadTest<AdditionWorkloadType, DataType>(factory, graph);
+ auto workload = CreateArithmeticWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(factory, graph);
- // Checks that inputs/outputs are as we expect them (see definition of CreateAdditionWorkloadTest).
- AdditionQueueDescriptor queueDescriptor = workload->GetData();
+ // Checks that inputs/outputs are as we expect them (see definition of CreateSubtractionWorkloadTest).
+ DescriptorType queueDescriptor = workload->GetData();
auto inputHandle1 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]);
auto inputHandle2 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[1]);
auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]);
@@ -66,12 +69,66 @@ static void ClCreateAdditionWorkloadTest()
BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload)
{
- ClCreateAdditionWorkloadTest<ClAdditionFloatWorkload, armnn::DataType::Float32>();
+ ClCreateArithmethicWorkloadTest<ClAdditionFloatWorkload,
+ AdditionQueueDescriptor,
+ AdditionLayer,
+ armnn::DataType::Float32>();
}
BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload)
{
- ClCreateAdditionWorkloadTest<ClAdditionFloatWorkload, armnn::DataType::Float16>();
+ ClCreateArithmethicWorkloadTest<ClAdditionFloatWorkload,
+ AdditionQueueDescriptor,
+ AdditionLayer,
+ armnn::DataType::Float16>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload)
+{
+ ClCreateArithmethicWorkloadTest<ClSubtractionFloatWorkload,
+ SubtractionQueueDescriptor,
+ SubtractionLayer,
+ armnn::DataType::Float32>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload)
+{
+ ClCreateArithmethicWorkloadTest<ClSubtractionFloatWorkload,
+ SubtractionQueueDescriptor,
+ SubtractionLayer,
+ armnn::DataType::Float16>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest)
+{
+ ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload,
+ MultiplicationQueueDescriptor,
+ MultiplicationLayer,
+ armnn::DataType::Float32>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest)
+{
+ ClCreateArithmethicWorkloadTest<ClMultiplicationFloatWorkload,
+ MultiplicationQueueDescriptor,
+ MultiplicationLayer,
+ armnn::DataType::Float16>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkloadTest)
+{
+ ClCreateArithmethicWorkloadTest<ClDivisionFloatWorkload,
+ DivisionQueueDescriptor,
+ DivisionLayer,
+ armnn::DataType::Float32>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateDivisionFloat16WorkloadTest)
+{
+ ClCreateArithmethicWorkloadTest<ClDivisionFloatWorkload,
+ DivisionQueueDescriptor,
+ DivisionLayer,
+ armnn::DataType::Float16>();
}
template <typename BatchNormalizationWorkloadType, armnn::DataType DataType>
@@ -219,36 +276,6 @@ BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloat16WorkloadTest)
ClCreateFullyConnectedWorkloadTest<ClFullyConnectedFloatWorkload, armnn::DataType::Float16>();
}
-
-template <typename MultiplicationWorkloadType, typename armnn::DataType DataType>
-static void ClCreateMultiplicationWorkloadTest()
-{
- Graph graph;
- ClWorkloadFactory factory;
-
- auto workload =
- CreateMultiplicationWorkloadTest<MultiplicationWorkloadType, DataType>(factory, graph);
-
- // Checks that inputs/outputs are as we expect them (see definition of CreateMultiplicationWorkloadTest).
- MultiplicationQueueDescriptor queueDescriptor = workload->GetData();
- auto inputHandle1 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]);
- auto inputHandle2 = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[1]);
- auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]);
- BOOST_TEST(CompareIClTensorHandleShape(inputHandle1, {2, 3}));
- BOOST_TEST(CompareIClTensorHandleShape(inputHandle2, {2, 3}));
- BOOST_TEST(CompareIClTensorHandleShape(outputHandle, {2, 3}));
-}
-
-BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest)
-{
- ClCreateMultiplicationWorkloadTest<ClMultiplicationFloatWorkload, armnn::DataType::Float32>();
-}
-
-BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest)
-{
- ClCreateMultiplicationWorkloadTest<ClMultiplicationFloatWorkload, armnn::DataType::Float16>();
-}
-
template <typename NormalizationWorkloadType, typename armnn::DataType DataType>
static void ClNormalizationWorkloadTest()
{