diff options
Diffstat (limited to 'src/backends/backendsCommon/test/NormTestImpl.hpp')
-rw-r--r-- | src/backends/backendsCommon/test/NormTestImpl.hpp | 43 |
1 files changed, 23 insertions, 20 deletions
diff --git a/src/backends/backendsCommon/test/NormTestImpl.hpp b/src/backends/backendsCommon/test/NormTestImpl.hpp index 16893eb315..38a0053a56 100644 --- a/src/backends/backendsCommon/test/NormTestImpl.hpp +++ b/src/backends/backendsCommon/test/NormTestImpl.hpp @@ -3,16 +3,21 @@ // SPDX-License-Identifier: MIT // +#include "WorkloadTestUtils.hpp" + #include <armnn/Exceptions.hpp> #include <armnn/LayerSupport.hpp> -#include "armnn/Types.hpp" +#include <armnn/Types.hpp> #include <backendsCommon/CpuTensorHandle.hpp> +#include <backendsCommon/IBackendInternal.hpp> #include <backendsCommon/WorkloadFactory.hpp> -LayerTestResult<float,4> SimpleNormalizationTestImpl(armnn::IWorkloadFactory& workloadFactory, - armnn::NormalizationAlgorithmChannel normChannel, - armnn::NormalizationAlgorithmMethod normMethod) +LayerTestResult<float,4> SimpleNormalizationTestImpl( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + armnn::NormalizationAlgorithmChannel normChannel, + armnn::NormalizationAlgorithmMethod normMethod) { const unsigned int inputHeight = 2; const unsigned int inputWidth = 2; @@ -73,9 +78,7 @@ LayerTestResult<float,4> SimpleNormalizationTestImpl(armnn::IWorkloadFactory& wo CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]); - workloadFactory.Acquire(); - workload->Execute(); - workloadFactory.Release(); + ExecuteWorkload(*workload, memoryManager); CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get()); @@ -153,9 +156,11 @@ LayerTestResult<float,4> SimpleNormalizationTestImpl(armnn::IWorkloadFactory& wo return ret; } -LayerTestResult<float,4> SimpleNormalizationNhwcTestImpl(armnn::IWorkloadFactory& workloadFactory, - armnn::NormalizationAlgorithmChannel normChannel, - armnn::NormalizationAlgorithmMethod normMethod) +LayerTestResult<float,4> SimpleNormalizationNhwcTestImpl( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + armnn::NormalizationAlgorithmChannel normChannel, + armnn::NormalizationAlgorithmMethod normMethod) { const unsigned int inputHeight = 2; const unsigned int inputWidth = 2; @@ -216,9 +221,7 @@ LayerTestResult<float,4> SimpleNormalizationNhwcTestImpl(armnn::IWorkloadFactory CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]); - workloadFactory.Acquire(); - workload->Execute(); - workloadFactory.Release(); + ExecuteWorkload(*workload, memoryManager); CopyDataFromITensorHandle(&ret.output[0][0][0][0], outputHandle.get()); @@ -254,10 +257,12 @@ LayerTestResult<float,4> SimpleNormalizationNhwcTestImpl(armnn::IWorkloadFactory return ret; } -LayerTestResult<float,4> CompareNormalizationTestImpl(armnn::IWorkloadFactory& workloadFactory, - armnn::IWorkloadFactory& refWorkloadFactory, - armnn::NormalizationAlgorithmChannel normChannel, - armnn::NormalizationAlgorithmMethod normMethod) +LayerTestResult<float,4> CompareNormalizationTestImpl( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + armnn::IWorkloadFactory& refWorkloadFactory, + armnn::NormalizationAlgorithmChannel normChannel, + armnn::NormalizationAlgorithmMethod normMethod) { constexpr unsigned int inputNum = 5; constexpr unsigned int inputChannels = 3; @@ -332,9 +337,7 @@ LayerTestResult<float,4> CompareNormalizationTestImpl(armnn::IWorkloadFactory& w CopyDataToITensorHandle(inputHandle.get(), &input[0][0][0][0]); CopyDataToITensorHandle(inputHandleRef.get(), &input[0][0][0][0]); - workloadFactory.Acquire(); - workload->Execute(); - workloadFactory.Release(); + ExecuteWorkload(*workload, memoryManager); workloadRef->Execute(); |