diff options
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp | 41 |
1 files changed, 21 insertions, 20 deletions
diff --git a/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp index 5721952066..c014078d12 100644 --- a/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp +++ b/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp @@ -11,11 +11,12 @@ #include <armnn/backends/IBackendInternal.hpp> #include <backendsCommon/WorkloadFactory.hpp> +#include <backendsCommon/test/WorkloadFactoryHelper.hpp> #include <backendsCommon/test/WorkloadTestUtils.hpp> #include <test/TensorHelpers.hpp> -template<typename T> +template<typename FactoryType, typename T> LayerTestResult<T, 4> SimpleTransposeTestImpl( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, @@ -25,14 +26,14 @@ LayerTestResult<T, 4> SimpleTransposeTestImpl( const std::vector<T>& inputData, const std::vector<T>& outputExpectedData) { - IgnoreUnused(memoryManager); auto input = MakeTensor<T, 4>(inputTensorInfo, inputData); LayerTestResult<T, 4> ret(outputTensorInfo); ret.outputExpected = MakeTensor<T, 4>(outputTensorInfo, outputExpectedData); - std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo); - std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo); + auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager); + std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo); + std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo); armnn::TransposeQueueDescriptor data; data.m_Parameters = descriptor; @@ -54,7 +55,7 @@ LayerTestResult<T, 4> SimpleTransposeTestImpl( return ret; } -template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> +template<typename FactoryType, armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> SimpleTransposeTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -98,12 +99,12 @@ LayerTestResult<T, 4> SimpleTransposeTest( }, qScale, qOffset); - return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager, - descriptor, inputTensorInfo, - outputTensorInfo, input, outputExpected); + return SimpleTransposeTestImpl<FactoryType, T>(workloadFactory, memoryManager, + descriptor, inputTensorInfo, + outputTensorInfo, input, outputExpected); } -template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> +template<typename FactoryType, armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> TransposeValueSet1Test( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -148,12 +149,12 @@ LayerTestResult<T, 4> TransposeValueSet1Test( }, qScale, qOffset); - return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager, - descriptor, inputTensorInfo, - outputTensorInfo, input, outputExpected); + return SimpleTransposeTestImpl<FactoryType, T>(workloadFactory, memoryManager, + descriptor, inputTensorInfo, + outputTensorInfo, input, outputExpected); } -template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> +template<typename FactoryType, armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> TransposeValueSet2Test( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -198,12 +199,12 @@ LayerTestResult<T, 4> TransposeValueSet2Test( }, qScale, qOffset); - return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager, - descriptor, inputTensorInfo, - outputTensorInfo, input, outputExpected); + return SimpleTransposeTestImpl<FactoryType, T>(workloadFactory, memoryManager, + descriptor, inputTensorInfo, + outputTensorInfo, input, outputExpected); } -template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> +template<typename FactoryType, armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> TransposeValueSet3Test( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -250,7 +251,7 @@ LayerTestResult<T, 4> TransposeValueSet3Test( }, qScale, qOffset); - return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager, - descriptor, inputTensorInfo, - outputTensorInfo, input, outputExpected); + return SimpleTransposeTestImpl<FactoryType, T>(workloadFactory, memoryManager, + descriptor, inputTensorInfo, + outputTensorInfo, input, outputExpected); } |