diff options
Diffstat (limited to 'src/backends/backendsCommon/test/BatchNormTestImpl.hpp')
-rw-r--r-- | src/backends/backendsCommon/test/BatchNormTestImpl.hpp | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/src/backends/backendsCommon/test/BatchNormTestImpl.hpp b/src/backends/backendsCommon/test/BatchNormTestImpl.hpp index 67282ed819..6325130218 100644 --- a/src/backends/backendsCommon/test/BatchNormTestImpl.hpp +++ b/src/backends/backendsCommon/test/BatchNormTestImpl.hpp @@ -4,23 +4,28 @@ // #pragma once +#include "WorkloadTestUtils.hpp" + #include <armnn/ArmNN.hpp> #include <armnn/Tensor.hpp> #include <backendsCommon/CpuTensorHandle.hpp> +#include <backendsCommon/IBackendInternal.hpp> #include <backendsCommon/WorkloadFactory.hpp> #include <backendsCommon/test/QuantizeHelper.hpp> #include <test/TensorHelpers.hpp> template<typename T> -LayerTestResult<T, 4> BatchNormTestImpl(armnn::IWorkloadFactory& workloadFactory, - const armnn::TensorShape& inputOutputTensorShape, - const std::vector<float>& inputValues, - const std::vector<float>& expectedOutputValues, - float qScale, - int32_t qOffset, - armnn::DataLayout dataLayout) +LayerTestResult<T, 4> BatchNormTestImpl( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::TensorShape& inputOutputTensorShape, + const std::vector<float>& inputValues, + const std::vector<float>& expectedOutputValues, + float qScale, + int32_t qOffset, + armnn::DataLayout dataLayout) { armnn::TensorInfo inputTensorInfo(inputOutputTensorShape, armnn::GetDataType<T>()); armnn::TensorInfo outputTensorInfo(inputOutputTensorShape, armnn::GetDataType<T>()); @@ -96,9 +101,11 @@ LayerTestResult<T, 4> BatchNormTestImpl(armnn::IWorkloadFactory& workloadFactory template<typename T> -LayerTestResult<T,4> BatchNormTestNhwcImpl(armnn::IWorkloadFactory& workloadFactory, - float qScale, - int32_t qOffset) +LayerTestResult<T,4> BatchNormTestNhwcImpl( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + float qScale, + int32_t qOffset) { const unsigned int width = 2; const unsigned int height = 3; |