diff options
Diffstat (limited to 'src/armnn/backends/test/CreateWorkloadRef.cpp')
-rw-r--r-- | src/armnn/backends/test/CreateWorkloadRef.cpp | 92 |
1 files changed, 62 insertions, 30 deletions
diff --git a/src/armnn/backends/test/CreateWorkloadRef.cpp b/src/armnn/backends/test/CreateWorkloadRef.cpp index 46ee3225a0..41419dafd0 100644 --- a/src/armnn/backends/test/CreateWorkloadRef.cpp +++ b/src/armnn/backends/test/CreateWorkloadRef.cpp @@ -62,14 +62,16 @@ BOOST_AUTO_TEST_CASE(CreateActivationUint8Workload) RefCreateActivationWorkloadTest<RefActivationUint8Workload, armnn::DataType::QuantisedAsymm8>(); } -template <typename AdditionWorkloadType, armnn::DataType DataType> -static void RefCreateAdditionWorkloadTest() +template <typename WorkloadType, + typename DescriptorType, + typename LayerType, + armnn::DataType DataType> +static void RefCreateArithmethicWorkloadTest() { Graph graph; RefWorkloadFactory factory; - auto workload = CreateAdditionWorkloadTest<AdditionWorkloadType, DataType>(factory, graph); + auto workload = CreateArithmeticWorkloadTest<WorkloadType, DescriptorType, LayerType, DataType>(factory, graph); - // Checks that outputs are as we expect them (see definition of CreateAdditionWorkloadTest). CheckInputsOutput(std::move(workload), TensorInfo({ 2, 3 }, DataType), TensorInfo({ 2, 3 }, DataType), @@ -78,12 +80,66 @@ static void RefCreateAdditionWorkloadTest() BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) { - RefCreateAdditionWorkloadTest<RefAdditionFloat32Workload, armnn::DataType::Float32>(); + RefCreateArithmethicWorkloadTest<RefAdditionFloat32Workload, + AdditionQueueDescriptor, + AdditionLayer, + armnn::DataType::Float32>(); } BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload) { - RefCreateAdditionWorkloadTest<RefAdditionUint8Workload, armnn::DataType::QuantisedAsymm8>(); + RefCreateArithmethicWorkloadTest<RefAdditionUint8Workload, + AdditionQueueDescriptor, + AdditionLayer, + armnn::DataType::QuantisedAsymm8>(); +} + +BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) +{ + RefCreateArithmethicWorkloadTest<RefSubtractionFloat32Workload, + SubtractionQueueDescriptor, + SubtractionLayer, + armnn::DataType::Float32>(); +} + +BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload) +{ + RefCreateArithmethicWorkloadTest<RefSubtractionUint8Workload, + SubtractionQueueDescriptor, + SubtractionLayer, + armnn::DataType::QuantisedAsymm8>(); +} + +BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload) +{ + RefCreateArithmethicWorkloadTest<RefMultiplicationFloat32Workload, + MultiplicationQueueDescriptor, + MultiplicationLayer, + armnn::DataType::Float32>(); +} + +BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload) +{ + RefCreateArithmethicWorkloadTest<RefMultiplicationUint8Workload, + MultiplicationQueueDescriptor, + MultiplicationLayer, + armnn::DataType::QuantisedAsymm8>(); +} + +BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkload) +{ + RefCreateArithmethicWorkloadTest<RefDivisionFloat32Workload, + DivisionQueueDescriptor, + DivisionLayer, + armnn::DataType::Float32>(); +} + +BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload) +{ + RefCreateArithmethicWorkloadTest<RefDivisionUint8Workload, + DivisionQueueDescriptor, + DivisionLayer, + armnn::DataType::QuantisedAsymm8>(); } BOOST_AUTO_TEST_CASE(CreateBatchNormalizationWorkload) @@ -171,30 +227,6 @@ BOOST_AUTO_TEST_CASE(CreateFullyConnectedUint8Workload) RefCreateFullyConnectedWorkloadTest<RefFullyConnectedUint8Workload, armnn::DataType::QuantisedAsymm8>(); } -template <typename MultiplicationWorkloadType, armnn::DataType DataType> -static void RefCreateMultiplicationWorkloadTest() -{ - Graph graph; - RefWorkloadFactory factory; - auto workload = CreateMultiplicationWorkloadTest<MultiplicationWorkloadType, DataType>(factory, graph); - - // Checks that outputs are as we expect them (see definition of CreateMultiplicationWorkloadTest). - CheckInputsOutput(std::move(workload), - TensorInfo({ 2, 3 }, DataType), - TensorInfo({ 2, 3 }, DataType), - TensorInfo({ 2, 3 }, DataType)); -} - -BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload) -{ - RefCreateMultiplicationWorkloadTest<RefMultiplicationFloat32Workload, armnn::DataType::Float32>(); -} - -BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload) -{ - RefCreateMultiplicationWorkloadTest<RefMultiplicationUint8Workload, armnn::DataType::QuantisedAsymm8>(); -} - BOOST_AUTO_TEST_CASE(CreateNormalizationWorkload) { Graph graph; |