diff options
Diffstat (limited to 'src/backends/backendsCommon/test/layerTests/ReverseV2TestImpl.cpp')
-rw-r--r-- | src/backends/backendsCommon/test/layerTests/ReverseV2TestImpl.cpp | 149 |
1 files changed, 75 insertions, 74 deletions
diff --git a/src/backends/backendsCommon/test/layerTests/ReverseV2TestImpl.cpp b/src/backends/backendsCommon/test/layerTests/ReverseV2TestImpl.cpp index f2774037bd..144bf9e8e1 100644 --- a/src/backends/backendsCommon/test/layerTests/ReverseV2TestImpl.cpp +++ b/src/backends/backendsCommon/test/layerTests/ReverseV2TestImpl.cpp @@ -18,73 +18,74 @@ namespace { - template<armnn::DataType ArmnnType, typename T, std::size_t NumDims> - LayerTestResult<T, NumDims> ReverseV2TestImpl( - armnn::IWorkloadFactory& workloadFactory, - const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, - const armnn::ITensorHandleFactory& tensorHandleFactory, - const std::vector<T>& input, - const std::vector<int>& axis, - const std::vector<T>& outputExpected, - const armnn::TensorInfo& inputInfo, - const armnn::TensorInfo& axisInfo, - const armnn::TensorInfo& outputInfo) +template<armnn::DataType ArmnnType, typename T, std::size_t NumDims> +LayerTestResult<T, NumDims> ReverseV2TestImpl( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory, + const std::vector<T>& input, + const std::vector<int>& axis, + const std::vector<T>& outputExpected, + const armnn::TensorInfo& inputInfo, + const armnn::TensorInfo& axisInfo, + const armnn::TensorInfo& outputInfo) +{ + LayerTestResult<T, NumDims> result(outputInfo); + std::vector<T> outputActual(outputInfo.GetNumElements()); + + std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo); + std::unique_ptr<armnn::ITensorHandle> axisHandle = tensorHandleFactory.CreateTensorHandle(axisInfo); + std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo); + + armnn::ReverseV2QueueDescriptor queueDescriptor; + armnn::WorkloadInfo workloadInfo; + + AddInputToWorkload(queueDescriptor, workloadInfo, inputInfo, inputHandle.get()); + AddInputToWorkload(queueDescriptor, workloadInfo, axisInfo, axisHandle.get()); + AddOutputToWorkload(queueDescriptor, workloadInfo, outputInfo, outputHandle.get()); + + // Don't execute if ReverseV2 is not supported, as an exception will be raised. + const armnn::BackendId& backend = workloadFactory.GetBackendId(); + std::string reasonIfUnsupported; + armnn::LayerSupportHandle handle = armnn::GetILayerSupportByBackendId(backend); + result.m_Supported = handle.IsReverseV2Supported(inputInfo, + axisInfo, + outputInfo, + reasonIfUnsupported); + if (!result.m_Supported) + { + return result; + } + + auto workload = workloadFactory.CreateWorkload(armnn::LayerType::ReverseV2, queueDescriptor, workloadInfo); + + inputHandle->Allocate(); + axisHandle->Allocate(); + outputHandle->Allocate(); + + if (input.data() != nullptr) + { + CopyDataToITensorHandle(inputHandle.get(), input.data()); + } + if (axis.data() != nullptr) + { + CopyDataToITensorHandle(axisHandle.get(), axis.data()); + } + + workload->PostAllocationConfigure(); + ExecuteWorkload(*workload, memoryManager); + + if (outputActual.data() != nullptr) { - LayerTestResult<T, NumDims> result(outputInfo); - std::vector<T> outputActual(outputInfo.GetNumElements()); - - std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo); - std::unique_ptr<armnn::ITensorHandle> axisHandle = tensorHandleFactory.CreateTensorHandle(axisInfo); - std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo); - - armnn::ReverseV2QueueDescriptor queueDescriptor; - armnn::WorkloadInfo workloadInfo; - - AddInputToWorkload(queueDescriptor, workloadInfo, inputInfo, inputHandle.get()); - AddInputToWorkload(queueDescriptor, workloadInfo, axisInfo, axisHandle.get()); - AddOutputToWorkload(queueDescriptor, workloadInfo, outputInfo, outputHandle.get()); - - // Don't execute if ReverseV2 is not supported, as an exception will be raised. - const armnn::BackendId& backend = workloadFactory.GetBackendId(); - std::string reasonIfUnsupported; - armnn::LayerSupportHandle handle = armnn::GetILayerSupportByBackendId(backend); - result.m_Supported = handle.IsReverseV2Supported(inputInfo, - axisInfo, - outputInfo, - reasonIfUnsupported); - if (!result.m_Supported) - { - return result; - } - - auto workload = workloadFactory.CreateWorkload(armnn::LayerType::ReverseV2, queueDescriptor, workloadInfo); - - inputHandle->Allocate(); - axisHandle->Allocate(); - outputHandle->Allocate(); - - if (input.data() != nullptr) - { - CopyDataToITensorHandle(inputHandle.get(), input.data()); - } - if (axis.data() != nullptr) - { - CopyDataToITensorHandle(axisHandle.get(), axis.data()); - } - - workload->PostAllocationConfigure(); - ExecuteWorkload(*workload, memoryManager); - - if (outputActual.data() != nullptr) - { - CopyDataFromITensorHandle(outputActual.data(), outputHandle.get()); - } - - return LayerTestResult<T, NumDims>(outputActual, - outputExpected, - outputHandle->GetShape(), - outputInfo.GetShape()); + CopyDataFromITensorHandle(outputActual.data(), outputHandle.get()); } + + return LayerTestResult<T, NumDims>(outputActual, + outputExpected, + outputHandle->GetShape(), + outputInfo.GetShape()); + +} } template<armnn::DataType ArmnnType, typename T> @@ -107,7 +108,7 @@ LayerTestResult<T, 2> ReverseV2SimpleTestEmptyAxis( 3, 4 }, qScale, qOffset); - std::vector<int> axis = armnnUtils::QuantizedVector<int>({1, 1}, qScale, qOffset); + std::vector<int> axis = armnnUtils::QuantizedVector<int>({}, qScale, qOffset); std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({ 1, 2, @@ -115,14 +116,14 @@ LayerTestResult<T, 2> ReverseV2SimpleTestEmptyAxis( }, qScale, qOffset); return ReverseV2TestImpl<ArmnnType, T, 2>(workloadFactory, - memoryManager, - tensorHandleFactory, - input, - axis, - outputExpected, - inputInfo, - axisInfo, - outputInfo); + memoryManager, + tensorHandleFactory, + input, + axis, + outputExpected, + inputInfo, + axisInfo, + outputInfo); } template<armnn::DataType ArmnnType, typename T> |