aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/BatchNormTestImpl.hpp')
-rw-r--r--src/backends/backendsCommon/test/BatchNormTestImpl.hpp17
1 files changed, 9 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/test/BatchNormTestImpl.hpp b/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
index d63f0b5610..ded4a067b4 100644
--- a/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
+++ b/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
@@ -4,6 +4,7 @@
//
#pragma once
+#include "TypeUtils.hpp"
#include "WorkloadTestUtils.hpp"
#include <armnn/ArmNN.hpp>
@@ -18,7 +19,7 @@
#include <DataLayoutIndexed.hpp>
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LayerTestResult<T, 4> BatchNormTestImpl(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -29,13 +30,13 @@ LayerTestResult<T, 4> BatchNormTestImpl(
int32_t qOffset,
armnn::DataLayout dataLayout)
{
- armnn::TensorInfo inputTensorInfo(inputOutputTensorShape, armnn::GetDataType<T>());
- armnn::TensorInfo outputTensorInfo(inputOutputTensorShape, armnn::GetDataType<T>());
+ armnn::TensorInfo inputTensorInfo(inputOutputTensorShape, ArmnnType);
+ armnn::TensorInfo outputTensorInfo(inputOutputTensorShape, ArmnnType);
armnnUtils::DataLayoutIndexed dataLayoutIndexed(dataLayout);
armnn::TensorInfo tensorInfo({ inputOutputTensorShape[dataLayoutIndexed.GetChannelsIndex()] },
- armnn::GetDataType<T>());
+ ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
if (armnn::IsQuantizedType<T>())
@@ -102,7 +103,7 @@ LayerTestResult<T, 4> BatchNormTestImpl(
}
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LayerTestResult<T,4> BatchNormTestNhwcImpl(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -114,9 +115,9 @@ LayerTestResult<T,4> BatchNormTestNhwcImpl(
const unsigned int channels = 2;
const unsigned int num = 1;
- armnn::TensorInfo inputTensorInfo({num, height, width, channels}, armnn::GetDataType<T>());
- armnn::TensorInfo outputTensorInfo({num, height, width, channels}, armnn::GetDataType<T>());
- armnn::TensorInfo tensorInfo({channels}, armnn::GetDataType<T>());
+ armnn::TensorInfo inputTensorInfo({num, height, width, channels}, ArmnnType);
+ armnn::TensorInfo outputTensorInfo({num, height, width, channels}, ArmnnType);
+ armnn::TensorInfo tensorInfo({channels}, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
if(armnn::IsQuantizedType<T>())