// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "WorkloadTestUtils.hpp" #include #include #include #include template , unsigned int paramsDim, unsigned int indicesDim, unsigned int OutputDim> LayerTestResult GatherTestImpl( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const armnn::TensorInfo& paramsInfo, const armnn::TensorInfo& indicesInfo, const armnn::TensorInfo& outputInfo, const std::vector& paramsData, const std::vector& indicesData, const std::vector& outputData) { auto params = MakeTensor(paramsInfo, paramsData); auto indices = MakeTensor(indicesInfo, indicesData); LayerTestResult result(outputInfo); result.outputExpected = MakeTensor(outputInfo, outputData); std::unique_ptr paramsHandle = workloadFactory.CreateTensorHandle(paramsInfo); std::unique_ptr indicesHandle = workloadFactory.CreateTensorHandle(indicesInfo); std::unique_ptr outputHandle = workloadFactory.CreateTensorHandle(outputInfo); armnn::GatherQueueDescriptor data; armnn::WorkloadInfo info; AddInputToWorkload(data, info, paramsInfo, paramsHandle.get()); AddInputToWorkload(data, info, indicesInfo, indicesHandle.get()); AddOutputToWorkload(data, info, outputInfo, outputHandle.get()); std::unique_ptr workload = workloadFactory.CreateGather(data, info); paramsHandle->Allocate(); indicesHandle->Allocate(); outputHandle->Allocate(); CopyDataToITensorHandle(paramsHandle.get(), params.origin()); CopyDataToITensorHandle(indicesHandle.get(), indices.origin()); workload->Execute(); CopyDataFromITensorHandle(result.output.origin(), outputHandle.get()); return result; } template > LayerTestResult Gather1DParamsTestImpl(armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) { armnn::TensorInfo paramsInfo({ 8 }, ArmnnType); armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32); armnn::TensorInfo outputInfo({ 4 }, ArmnnType); if (armnn::IsQuantizedType()) { paramsInfo.SetQuantizationScale(1.0f); paramsInfo.SetQuantizationOffset(1); outputInfo.SetQuantizationScale(1.0f); outputInfo.SetQuantizationOffset(1); } const std::vector params = std::vector({ 1, 2, 3, 4, 5, 6, 7, 8 }); const std::vector indices = std::vector({ 0, 2, 1, 5 }); const std::vector expectedOutput = std::vector({ 1, 3, 2, 6 }); return GatherTestImpl(workloadFactory, memoryManager, paramsInfo, indicesInfo, outputInfo, params,indices, expectedOutput); } template > LayerTestResult GatherMultiDimParamsTestImpl( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) { armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType); armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType); if (armnn::IsQuantizedType()) { paramsInfo.SetQuantizationScale(1.0f); paramsInfo.SetQuantizationOffset(1); outputInfo.SetQuantizationScale(1.0f); outputInfo.SetQuantizationOffset(1); } const std::vector params = std::vector({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); const std::vector indices = std::vector({ 1, 3, 4 }); const std::vector expectedOutput = std::vector({ 3, 4, 7, 8, 9, 10 }); return GatherTestImpl(workloadFactory, memoryManager, paramsInfo, indicesInfo, outputInfo, params,indices, expectedOutput); } template > LayerTestResult GatherMultiDimParamsMultiDimIndicesTestImpl( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager) { armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType); armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32); armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType); if (armnn::IsQuantizedType()) { paramsInfo.SetQuantizationScale(1.0f); paramsInfo.SetQuantizationOffset(1); outputInfo.SetQuantizationScale(1.0f); outputInfo.SetQuantizationOffset(1); } const std::vector params = std::vector({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18 }); const std::vector indices = std::vector({ 1, 2, 1, 2, 1, 0 }); const std::vector expectedOutput = std::vector({ 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6 }); return GatherTestImpl(workloadFactory, memoryManager, paramsInfo, indicesInfo, outputInfo, params,indices, expectedOutput); }