diff options
Diffstat (limited to 'src/backends/backendsCommon/test/layerTests/PreluTestImpl.hpp')
-rw-r--r-- | src/backends/backendsCommon/test/layerTests/PreluTestImpl.hpp | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/src/backends/backendsCommon/test/layerTests/PreluTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/PreluTestImpl.hpp index 3b6c2d8412..de0b27b6c1 100644 --- a/src/backends/backendsCommon/test/layerTests/PreluTestImpl.hpp +++ b/src/backends/backendsCommon/test/layerTests/PreluTestImpl.hpp @@ -15,11 +15,12 @@ #include <backendsCommon/WorkloadFactory.hpp> #include <backendsCommon/test/TensorCopyUtils.hpp> +#include <backendsCommon/test/WorkloadFactoryHelper.hpp> #include <backendsCommon/test/WorkloadTestUtils.hpp> #include <test/TensorHelpers.hpp> -template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> +template<typename FactoryType, armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> PreluTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) @@ -76,9 +77,10 @@ LayerTestResult<T, 4> PreluTest( outputTensorInfo.GetQuantizationScale(), outputTensorInfo.GetQuantizationOffset())); - std::unique_ptr <armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo); - std::unique_ptr <armnn::ITensorHandle> alphaHandle = workloadFactory.CreateTensorHandle(alphaTensorInfo); - std::unique_ptr <armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo); + auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager); + std::unique_ptr <armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo); + std::unique_ptr <armnn::ITensorHandle> alphaHandle = tensorHandleFactory.CreateTensorHandle(alphaTensorInfo); + std::unique_ptr <armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo); armnn::PreluQueueDescriptor descriptor; armnn::WorkloadInfo info; |