diff options
Diffstat (limited to 'src/backends/test/CreateWorkloadRef.cpp')
-rw-r--r-- | src/backends/test/CreateWorkloadRef.cpp | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/src/backends/test/CreateWorkloadRef.cpp b/src/backends/test/CreateWorkloadRef.cpp index 9313ee851f..c30093da92 100644 --- a/src/backends/test/CreateWorkloadRef.cpp +++ b/src/backends/test/CreateWorkloadRef.cpp @@ -227,17 +227,22 @@ BOOST_AUTO_TEST_CASE(CreateFullyConnectedUint8Workload) RefCreateFullyConnectedWorkloadTest<RefFullyConnectedUint8Workload, armnn::DataType::QuantisedAsymm8>(); } -BOOST_AUTO_TEST_CASE(CreateNormalizationWorkload) +template <typename NormalizationWorkloadType, armnn::DataType DataType> +static void RefCreateNormalizationWorkloadTest() { - Graph graph; + Graph graph; RefWorkloadFactory factory; - auto workload = CreateNormalizationWorkloadTest<RefNormalizationFloat32Workload, - armnn::DataType::Float32>(factory, graph); + auto workload = CreateNormalizationWorkloadTest<NormalizationWorkloadType, DataType>(factory, graph); // Checks that outputs and inputs are as we expect them (see definition of CreateNormalizationWorkloadTest). CheckInputOutput(std::move(workload), - TensorInfo({3, 5, 5, 1}, DataType::Float32), - TensorInfo({3, 5, 5, 1}, DataType::Float32)); + TensorInfo({3, 5, 5, 1}, DataType), + TensorInfo({3, 5, 5, 1}, DataType)); +} + +BOOST_AUTO_TEST_CASE(CreateRefNormalizationNchwWorkload) +{ + RefCreateNormalizationWorkloadTest<RefNormalizationFloat32Workload, armnn::DataType::Float32>(); } template <typename Pooling2dWorkloadType, armnn::DataType DataType> |