aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/BatchNormTestImpl.hpp')
-rw-r--r--src/backends/backendsCommon/test/BatchNormTestImpl.hpp27
1 files changed, 17 insertions, 10 deletions
diff --git a/src/backends/backendsCommon/test/BatchNormTestImpl.hpp b/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
index 67282ed819..6325130218 100644
--- a/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
+++ b/src/backends/backendsCommon/test/BatchNormTestImpl.hpp
@@ -4,23 +4,28 @@
//
#pragma once
+#include "WorkloadTestUtils.hpp"
+
#include <armnn/ArmNN.hpp>
#include <armnn/Tensor.hpp>
#include <backendsCommon/CpuTensorHandle.hpp>
+#include <backendsCommon/IBackendInternal.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
#include <backendsCommon/test/QuantizeHelper.hpp>
#include <test/TensorHelpers.hpp>
template<typename T>
-LayerTestResult<T, 4> BatchNormTestImpl(armnn::IWorkloadFactory& workloadFactory,
- const armnn::TensorShape& inputOutputTensorShape,
- const std::vector<float>& inputValues,
- const std::vector<float>& expectedOutputValues,
- float qScale,
- int32_t qOffset,
- armnn::DataLayout dataLayout)
+LayerTestResult<T, 4> BatchNormTestImpl(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::TensorShape& inputOutputTensorShape,
+ const std::vector<float>& inputValues,
+ const std::vector<float>& expectedOutputValues,
+ float qScale,
+ int32_t qOffset,
+ armnn::DataLayout dataLayout)
{
armnn::TensorInfo inputTensorInfo(inputOutputTensorShape, armnn::GetDataType<T>());
armnn::TensorInfo outputTensorInfo(inputOutputTensorShape, armnn::GetDataType<T>());
@@ -96,9 +101,11 @@ LayerTestResult<T, 4> BatchNormTestImpl(armnn::IWorkloadFactory& workloadFactory
template<typename T>
-LayerTestResult<T,4> BatchNormTestNhwcImpl(armnn::IWorkloadFactory& workloadFactory,
- float qScale,
- int32_t qOffset)
+LayerTestResult<T,4> BatchNormTestNhwcImpl(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ float qScale,
+ int32_t qOffset)
{
const unsigned int width = 2;
const unsigned int height = 3;