aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-01-22 16:10:44 +0000
committerNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-01-22 16:10:44 +0000
commit649dd9515ddf4bd00a0bff64d51dfd835a6c7b39 (patch)
treec938bc8eb11dd24223c0cb00a57d4372a907b943 /src/backends/backendsCommon/test/BatchNormTestImpl.hpp
parent382e21ce95c04479a6900afca81a57949b369f1e (diff)
downloadarmnn-649dd9515ddf4bd00a0bff64d51dfd835a6c7b39.tar.gz
IVGCVSW-2467 Remove GetDataType<T> function
Change-Id: I7359617a307b9abb4c30b3d5f2364dc6d0f828f0
Diffstat (limited to 'src/backends/backendsCommon/test/BatchNormTestImpl.hpp')
-rw-r--r--src/backends/backendsCommon/test/BatchNormTestImpl.hpp17
1 files changed, 9 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/test/BatchNormTestImpl.hpp b/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
index d63f0b5610..ded4a067b4 100644
--- a/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
+++ b/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
@@ -4,6 +4,7 @@
//
#pragma once
+#include "TypeUtils.hpp"
#include "WorkloadTestUtils.hpp"
#include <armnn/ArmNN.hpp>
@@ -18,7 +19,7 @@
#include <DataLayoutIndexed.hpp>
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LayerTestResult<T, 4> BatchNormTestImpl(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -29,13 +30,13 @@ LayerTestResult<T, 4> BatchNormTestImpl(
int32_t qOffset,
armnn::DataLayout dataLayout)
{
- armnn::TensorInfo inputTensorInfo(inputOutputTensorShape, armnn::GetDataType<T>());
- armnn::TensorInfo outputTensorInfo(inputOutputTensorShape, armnn::GetDataType<T>());
+ armnn::TensorInfo inputTensorInfo(inputOutputTensorShape, ArmnnType);
+ armnn::TensorInfo outputTensorInfo(inputOutputTensorShape, ArmnnType);
armnnUtils::DataLayoutIndexed dataLayoutIndexed(dataLayout);
armnn::TensorInfo tensorInfo({ inputOutputTensorShape[dataLayoutIndexed.GetChannelsIndex()] },
- armnn::GetDataType<T>());
+ ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
if (armnn::IsQuantizedType<T>())
@@ -102,7 +103,7 @@ LayerTestResult<T, 4> BatchNormTestImpl(
}
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LayerTestResult<T,4> BatchNormTestNhwcImpl(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -114,9 +115,9 @@ LayerTestResult<T,4> BatchNormTestNhwcImpl(
const unsigned int channels = 2;
const unsigned int num = 1;
- armnn::TensorInfo inputTensorInfo({num, height, width, channels}, armnn::GetDataType<T>());
- armnn::TensorInfo outputTensorInfo({num, height, width, channels}, armnn::GetDataType<T>());
- armnn::TensorInfo tensorInfo({channels}, armnn::GetDataType<T>());
+ armnn::TensorInfo inputTensorInfo({num, height, width, channels}, ArmnnType);
+ armnn::TensorInfo outputTensorInfo({num, height, width, channels}, ArmnnType);
+ armnn::TensorInfo tensorInfo({channels}, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
if(armnn::IsQuantizedType<T>())