aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/test/CreateWorkloadRef.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/test/CreateWorkloadRef.cpp')
-rw-r--r--src/armnn/backends/test/CreateWorkloadRef.cpp92
1 files changed, 62 insertions, 30 deletions
diff --git a/src/armnn/backends/test/CreateWorkloadRef.cpp b/src/armnn/backends/test/CreateWorkloadRef.cpp
index 46ee3225a0..41419dafd0 100644
--- a/src/armnn/backends/test/CreateWorkloadRef.cpp
+++ b/src/armnn/backends/test/CreateWorkloadRef.cpp
@@ -62,14 +62,16 @@ BOOST_AUTO_TEST_CASE(CreateActivationUint8Workload)
RefCreateActivationWorkloadTest<RefActivationUint8Workload, armnn::DataType::QuantisedAsymm8>();
}
-template <typename AdditionWorkloadType, armnn::DataType DataType>
-static void RefCreateAdditionWorkloadTest()
+template <typename WorkloadType,
+ typename DescriptorType,
+ typename LayerType,
+ armnn::DataType DataType>
+static void RefCreateArithmethicWorkloadTest()
{
Graph graph;
RefWorkloadFactory factory;
- auto workload = CreateAdditionWorkloadTest<AdditionWorkloadType, DataType>(factory, graph);
+ auto workload = CreateArithmeticWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(factory, graph);
- // Checks that outputs are as we expect them (see definition of CreateAdditionWorkloadTest).
CheckInputsOutput(std::move(workload),
TensorInfo({ 2, 3 }, DataType),
TensorInfo({ 2, 3 }, DataType),
@@ -78,12 +80,66 @@ static void RefCreateAdditionWorkloadTest()
BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload)
{
- RefCreateAdditionWorkloadTest<RefAdditionFloat32Workload, armnn::DataType::Float32>();
+ RefCreateArithmethicWorkloadTest<RefAdditionFloat32Workload,
+ AdditionQueueDescriptor,
+ AdditionLayer,
+ armnn::DataType::Float32>();
}
BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload)
{
- RefCreateAdditionWorkloadTest<RefAdditionUint8Workload, armnn::DataType::QuantisedAsymm8>();
+ RefCreateArithmethicWorkloadTest<RefAdditionUint8Workload,
+ AdditionQueueDescriptor,
+ AdditionLayer,
+ armnn::DataType::QuantisedAsymm8>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload)
+{
+ RefCreateArithmethicWorkloadTest<RefSubtractionFloat32Workload,
+ SubtractionQueueDescriptor,
+ SubtractionLayer,
+ armnn::DataType::Float32>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload)
+{
+ RefCreateArithmethicWorkloadTest<RefSubtractionUint8Workload,
+ SubtractionQueueDescriptor,
+ SubtractionLayer,
+ armnn::DataType::QuantisedAsymm8>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload)
+{
+ RefCreateArithmethicWorkloadTest<RefMultiplicationFloat32Workload,
+ MultiplicationQueueDescriptor,
+ MultiplicationLayer,
+ armnn::DataType::Float32>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload)
+{
+ RefCreateArithmethicWorkloadTest<RefMultiplicationUint8Workload,
+ MultiplicationQueueDescriptor,
+ MultiplicationLayer,
+ armnn::DataType::QuantisedAsymm8>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkload)
+{
+ RefCreateArithmethicWorkloadTest<RefDivisionFloat32Workload,
+ DivisionQueueDescriptor,
+ DivisionLayer,
+ armnn::DataType::Float32>();
+}
+
+BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload)
+{
+ RefCreateArithmethicWorkloadTest<RefDivisionUint8Workload,
+ DivisionQueueDescriptor,
+ DivisionLayer,
+ armnn::DataType::QuantisedAsymm8>();
}
BOOST_AUTO_TEST_CASE(CreateBatchNormalizationWorkload)
@@ -171,30 +227,6 @@ BOOST_AUTO_TEST_CASE(CreateFullyConnectedUint8Workload)
RefCreateFullyConnectedWorkloadTest<RefFullyConnectedUint8Workload, armnn::DataType::QuantisedAsymm8>();
}
-template <typename MultiplicationWorkloadType, armnn::DataType DataType>
-static void RefCreateMultiplicationWorkloadTest()
-{
- Graph graph;
- RefWorkloadFactory factory;
- auto workload = CreateMultiplicationWorkloadTest<MultiplicationWorkloadType, DataType>(factory, graph);
-
- // Checks that outputs are as we expect them (see definition of CreateMultiplicationWorkloadTest).
- CheckInputsOutput(std::move(workload),
- TensorInfo({ 2, 3 }, DataType),
- TensorInfo({ 2, 3 }, DataType),
- TensorInfo({ 2, 3 }, DataType));
-}
-
-BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload)
-{
- RefCreateMultiplicationWorkloadTest<RefMultiplicationFloat32Workload, armnn::DataType::Float32>();
-}
-
-BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload)
-{
- RefCreateMultiplicationWorkloadTest<RefMultiplicationUint8Workload, armnn::DataType::QuantisedAsymm8>();
-}
-
BOOST_AUTO_TEST_CASE(CreateNormalizationWorkload)
{
Graph graph;