diff options
Diffstat (limited to 'src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp')
-rw-r--r-- | src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp | 125 |
1 files changed, 79 insertions, 46 deletions
diff --git a/src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp index b57f2ef569..7fabff6c1c 100644 --- a/src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp +++ b/src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp @@ -24,6 +24,7 @@ template <armnn::DataType ArmnnType, LayerTestResult<T, OutputDim> GatherTestImpl( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory, const armnn::TensorInfo& paramsInfo, const armnn::TensorInfo& indicesInfo, const armnn::TensorInfo& outputInfo, @@ -38,11 +39,9 @@ LayerTestResult<T, OutputDim> GatherTestImpl( LayerTestResult<T, OutputDim> result(outputInfo); result.outputExpected = MakeTensor<T, OutputDim>(outputInfo, outputData); - ARMNN_NO_DEPRECATE_WARN_BEGIN - std::unique_ptr<armnn::ITensorHandle> paramsHandle = workloadFactory.CreateTensorHandle(paramsInfo); - std::unique_ptr<armnn::ITensorHandle> indicesHandle = workloadFactory.CreateTensorHandle(indicesInfo); - std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputInfo); - ARMNN_NO_DEPRECATE_WARN_END + std::unique_ptr<armnn::ITensorHandle> paramsHandle = tensorHandleFactory.CreateTensorHandle(paramsInfo); + std::unique_ptr<armnn::ITensorHandle> indicesHandle = tensorHandleFactory.CreateTensorHandle(indicesInfo); + std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo); armnn::GatherQueueDescriptor data; armnn::WorkloadInfo info; @@ -71,7 +70,8 @@ struct GatherTestHelper { static LayerTestResult<T, 1> Gather1dParamsTestImpl( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { armnn::TensorInfo paramsInfo({ 8 }, ArmnnType); armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32); @@ -91,6 +91,7 @@ struct GatherTestHelper return GatherTestImpl<ArmnnType, T, 1, 1, 1>( workloadFactory, memoryManager, + tensorHandleFactory, paramsInfo, indicesInfo, outputInfo, @@ -101,7 +102,8 @@ struct GatherTestHelper static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType); armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); @@ -122,6 +124,7 @@ struct GatherTestHelper return GatherTestImpl<ArmnnType, T, 2, 1, 2>( workloadFactory, memoryManager, + tensorHandleFactory, paramsInfo, indicesInfo, outputInfo, @@ -132,7 +135,8 @@ struct GatherTestHelper static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType); armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32); @@ -180,6 +184,7 @@ struct GatherTestHelper return GatherTestImpl<ArmnnType, T, 3, 2, 4>( workloadFactory, memoryManager, + tensorHandleFactory, paramsInfo, indicesInfo, outputInfo, @@ -194,7 +199,8 @@ struct GatherTestHelper<armnn::DataType::Float16, T> { static LayerTestResult<T, 1> Gather1dParamsTestImpl( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { using namespace half_float::literal; @@ -209,6 +215,7 @@ struct GatherTestHelper<armnn::DataType::Float16, T> return GatherTestImpl<armnn::DataType::Float16, T, 1, 1, 1>( workloadFactory, memoryManager, + tensorHandleFactory, paramsInfo, indicesInfo, outputInfo, @@ -219,7 +226,8 @@ struct GatherTestHelper<armnn::DataType::Float16, T> static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { using namespace half_float::literal; @@ -235,6 +243,7 @@ struct GatherTestHelper<armnn::DataType::Float16, T> return GatherTestImpl<armnn::DataType::Float16, T, 2, 1, 2>( workloadFactory, memoryManager, + tensorHandleFactory, paramsInfo, indicesInfo, outputInfo, @@ -245,7 +254,8 @@ struct GatherTestHelper<armnn::DataType::Float16, T> static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { using namespace half_float::literal; @@ -287,6 +297,7 @@ struct GatherTestHelper<armnn::DataType::Float16, T> return GatherTestImpl<armnn::DataType::Float16, T, 3, 2, 4>( workloadFactory, memoryManager, + tensorHandleFactory, paramsInfo, indicesInfo, outputInfo, @@ -300,113 +311,135 @@ struct GatherTestHelper<armnn::DataType::Float16, T> LayerTestResult<float, 1> Gather1dParamsFloat32Test( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { - return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsTestImpl(workloadFactory, memoryManager); + return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsTestImpl( + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<armnn::Half, 1> Gather1dParamsFloat16Test( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { - return GatherTestHelper<armnn::DataType::Float16>::Gather1dParamsTestImpl(workloadFactory, memoryManager); + return GatherTestHelper<armnn::DataType::Float16>::Gather1dParamsTestImpl( + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<uint8_t, 1> Gather1dParamsUint8Test( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { - return GatherTestHelper<armnn::DataType::QAsymmU8>::Gather1dParamsTestImpl(workloadFactory, memoryManager); + return GatherTestHelper<armnn::DataType::QAsymmU8>::Gather1dParamsTestImpl( + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<int16_t, 1> Gather1dParamsInt16Test( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { - return GatherTestHelper<armnn::DataType::QSymmS16>::Gather1dParamsTestImpl(workloadFactory, memoryManager); + return GatherTestHelper<armnn::DataType::QSymmS16>::Gather1dParamsTestImpl( + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<int32_t, 1> Gather1dParamsInt32Test( - armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { - return GatherTestHelper<armnn::DataType::Signed32>::Gather1dParamsTestImpl(workloadFactory, memoryManager); + return GatherTestHelper<armnn::DataType::Signed32>::Gather1dParamsTestImpl( + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<float, 2> GatherMultiDimParamsFloat32Test( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { - return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsTestImpl(workloadFactory, memoryManager); + return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsTestImpl( + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<armnn::Half, 2> GatherMultiDimParamsFloat16Test( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { - return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsTestImpl(workloadFactory, memoryManager); + return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsTestImpl( + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<uint8_t, 2> GatherMultiDimParamsUint8Test( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsTestImpl( - workloadFactory, memoryManager); + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<int16_t, 2> GatherMultiDimParamsInt16Test( - armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsTestImpl( - workloadFactory, memoryManager); + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<int32_t, 2> GatherMultiDimParamsInt32Test( - armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsTestImpl( - workloadFactory, memoryManager); + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesFloat32Test( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesTestImpl( - workloadFactory, memoryManager); + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<armnn::Half, 4> GatherMultiDimParamsMultiDimIndicesFloat16Test( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsMultiDimIndicesTestImpl( - workloadFactory, memoryManager); + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<uint8_t, 4> GatherMultiDimParamsMultiDimIndicesUint8Test( armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsMultiDimIndicesTestImpl( - workloadFactory, memoryManager); + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<int16_t, 4> GatherMultiDimParamsMultiDimIndicesInt16Test( - armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsMultiDimIndicesTestImpl( - workloadFactory, memoryManager); + workloadFactory, memoryManager, tensorHandleFactory); } LayerTestResult<int32_t, 4> GatherMultiDimParamsMultiDimIndicesInt32Test( - armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) { return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsMultiDimIndicesTestImpl( - workloadFactory, memoryManager); + workloadFactory, memoryManager, tensorHandleFactory); }
\ No newline at end of file |