aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/NormTestImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/NormTestImpl.hpp')
-rw-r--r--src/backends/backendsCommon/test/NormTestImpl.hpp43
1 files changed, 23 insertions, 20 deletions
diff --git a/src/backends/backendsCommon/test/NormTestImpl.hpp b/src/backends/backendsCommon/test/NormTestImpl.hpp
index 16893eb315..38a0053a56 100644
--- a/src/backends/backendsCommon/test/NormTestImpl.hpp
+++ b/src/backends/backendsCommon/test/NormTestImpl.hpp
@@ -3,16 +3,21 @@
// SPDX-License-Identifier: MIT
//
+#include "WorkloadTestUtils.hpp"
+
#include <armnn/Exceptions.hpp>
#include <armnn/LayerSupport.hpp>
-#include "armnn/Types.hpp"
+#include <armnn/Types.hpp>
#include <backendsCommon/CpuTensorHandle.hpp>
+#include <backendsCommon/IBackendInternal.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
-LayerTestResult<float,4> SimpleNormalizationTestImpl(armnn::IWorkloadFactory& workloadFactory,
- armnn::NormalizationAlgorithmChannel normChannel,
- armnn::NormalizationAlgorithmMethod normMethod)
+LayerTestResult<float,4> SimpleNormalizationTestImpl(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::NormalizationAlgorithmChannel normChannel,
+ armnn::NormalizationAlgorithmMethod normMethod)
{
const unsigned int inputHeight = 2;
const unsigned int inputWidth = 2;
@@ -73,9 +78,7 @@ LayerTestResult<float,4> SimpleNormalizationTestImpl(armnn::IWorkloadFactory& wo
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
- workloadFactory.Acquire();
- workload->Execute();
- workloadFactory.Release();
+ ExecuteWorkload(*workload, memoryManager);
CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
@@ -153,9 +156,11 @@ LayerTestResult<float,4> SimpleNormalizationTestImpl(armnn::IWorkloadFactory& wo
return ret;
}
-LayerTestResult<float,4> SimpleNormalizationNhwcTestImpl(armnn::IWorkloadFactory& workloadFactory,
- armnn::NormalizationAlgorithmChannel normChannel,
- armnn::NormalizationAlgorithmMethod normMethod)
+LayerTestResult<float,4> SimpleNormalizationNhwcTestImpl(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::NormalizationAlgorithmChannel normChannel,
+ armnn::NormalizationAlgorithmMethod normMethod)
{
const unsigned int inputHeight = 2;
const unsigned int inputWidth = 2;
@@ -216,9 +221,7 @@ LayerTestResult<float,4> SimpleNormalizationNhwcTestImpl(armnn::IWorkloadFactory
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
- workloadFactory.Acquire();
- workload->Execute();
- workloadFactory.Release();
+ ExecuteWorkload(*workload, memoryManager);
CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get());
@@ -254,10 +257,12 @@ LayerTestResult<float,4> SimpleNormalizationNhwcTestImpl(armnn::IWorkloadFactory
return ret;
}
-LayerTestResult<float,4> CompareNormalizationTestImpl(armnn::IWorkloadFactory& workloadFactory,
- armnn::IWorkloadFactory& refWorkloadFactory,
- armnn::NormalizationAlgorithmChannel normChannel,
- armnn::NormalizationAlgorithmMethod normMethod)
+LayerTestResult<float,4> CompareNormalizationTestImpl(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::IWorkloadFactory& refWorkloadFactory,
+ armnn::NormalizationAlgorithmChannel normChannel,
+ armnn::NormalizationAlgorithmMethod normMethod)
{
constexpr unsigned int inputNum = 5;
constexpr unsigned int inputChannels = 3;
@@ -332,9 +337,7 @@ LayerTestResult<float,4> CompareNormalizationTestImpl(armnn::IWorkloadFactory& w
CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]);
CopyDataToITensorHandle(inputHandleRef.get(), &input[0][0][0][0]);
- workloadFactory.Acquire();
- workload->Execute();
- workloadFactory.Release();
+ ExecuteWorkload(*workload, memoryManager);
workloadRef->Execute();