diff options
Diffstat (limited to 'src/backends/backendsCommon/test/BatchNormTestImpl.hpp')
-rw-r--r-- | src/backends/backendsCommon/test/BatchNormTestImpl.hpp | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/test/BatchNormTestImpl.hpp b/src/backends/backendsCommon/test/BatchNormTestImpl.hpp index d63f0b5610..ded4a067b4 100644 --- a/src/backends/backendsCommon/test/BatchNormTestImpl.hpp +++ b/src/backends/backendsCommon/test/BatchNormTestImpl.hpp @@ -4,6 +4,7 @@ // #pragma once +#include "TypeUtils.hpp" #include "WorkloadTestUtils.hpp" #include <armnn/ArmNN.hpp> @@ -18,7 +19,7 @@ #include <DataLayoutIndexed.hpp> -template<typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> BatchNormTestImpl( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, @@ -29,13 +30,13 @@ LayerTestResult<T, 4> BatchNormTestImpl( int32_t qOffset, armnn::DataLayout dataLayout) { - armnn::TensorInfo inputTensorInfo(inputOutputTensorShape, armnn::GetDataType<T>()); - armnn::TensorInfo outputTensorInfo(inputOutputTensorShape, armnn::GetDataType<T>()); + armnn::TensorInfo inputTensorInfo(inputOutputTensorShape, ArmnnType); + armnn::TensorInfo outputTensorInfo(inputOutputTensorShape, ArmnnType); armnnUtils::DataLayoutIndexed dataLayoutIndexed(dataLayout); armnn::TensorInfo tensorInfo({ inputOutputTensorShape[dataLayoutIndexed.GetChannelsIndex()] }, - armnn::GetDataType<T>()); + ArmnnType); // Set quantization parameters if the requested type is a quantized type. if (armnn::IsQuantizedType<T>()) @@ -102,7 +103,7 @@ LayerTestResult<T, 4> BatchNormTestImpl( } -template<typename T> +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T,4> BatchNormTestNhwcImpl( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, @@ -114,9 +115,9 @@ LayerTestResult<T,4> BatchNormTestNhwcImpl( const unsigned int channels = 2; const unsigned int num = 1; - armnn::TensorInfo inputTensorInfo({num, height, width, channels}, armnn::GetDataType<T>()); - armnn::TensorInfo outputTensorInfo({num, height, width, channels}, armnn::GetDataType<T>()); - armnn::TensorInfo tensorInfo({channels}, armnn::GetDataType<T>()); + armnn::TensorInfo inputTensorInfo({num, height, width, channels}, ArmnnType); + armnn::TensorInfo outputTensorInfo({num, height, width, channels}, ArmnnType); + armnn::TensorInfo tensorInfo({channels}, ArmnnType); // Set quantization parameters if the requested type is a quantized type. if(armnn::IsQuantizedType<T>()) |