30 const std::vector<T>& paramsData,
31 const std::vector<int32_t>& indicesData,
32 const std::vector<T>& outputData)
35 auto params = MakeTensor<T, ParamsDim>(paramsInfo, paramsData);
36 auto indices = MakeTensor<int32_t, IndicesDim>(indicesInfo, indicesData);
39 result.outputExpected = MakeTensor<T, OutputDim>(outputInfo, outputData);
42 std::unique_ptr<armnn::ITensorHandle> paramsHandle = workloadFactory.
CreateTensorHandle(paramsInfo);
43 std::unique_ptr<armnn::ITensorHandle> indicesHandle = workloadFactory.
CreateTensorHandle(indicesInfo);
44 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.
CreateTensorHandle(outputInfo);
49 AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
50 AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
51 AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
53 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.
CreateGather(data, info);
55 paramsHandle->Allocate();
56 indicesHandle->Allocate();
57 outputHandle->Allocate();
69 template<armnn::DataType ArmnnType,
typename T = armnn::ResolveType<ArmnnType>>
70 struct GatherTestHelper
80 if (armnn::IsQuantizedType<T>())
83 paramsInfo.SetQuantizationOffset(1);
84 outputInfo.SetQuantizationScale(1.0f);
85 outputInfo.SetQuantizationOffset(1);
87 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 });
88 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
89 const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 });
91 return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
110 if (armnn::IsQuantizedType<T>())
113 paramsInfo.SetQuantizationOffset(1);
114 outputInfo.SetQuantizationScale(1.0f);
115 outputInfo.SetQuantizationOffset(1);
118 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
119 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
120 const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 });
122 return GatherTestImpl<ArmnnType, T, 2, 1, 2>(
141 if (armnn::IsQuantizedType<T>())
144 paramsInfo.SetQuantizationOffset(1);
145 outputInfo.SetQuantizationScale(1.0f);
146 outputInfo.SetQuantizationOffset(1);
149 const std::vector<T> params =
161 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
163 const std::vector<T> expectedOutput =
180 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
193 struct GatherTestHelper<
armnn::DataType::Float16, T>
199 using namespace half_float::literal;
205 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h });
206 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
207 const std::vector<T> expectedOutput = std::vector<T>({ 1._h, 3._h, 2._h, 6._h });
209 return GatherTestImpl<armnn::DataType::Float16, T, 1, 1, 1>(
224 using namespace half_float::literal;
230 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h, 9._h, 10._h });
232 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
233 const std::vector<T> expectedOutput = std::vector<T>({ 3._h, 4._h, 7._h, 8._h, 9._h, 10._h });
235 return GatherTestImpl<armnn::DataType::Float16, T, 2, 1, 2>(
250 using namespace half_float::literal;
256 const std::vector<T> params =
268 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
270 const std::vector<T> expectedOutput =
287 return GatherTestImpl<armnn::DataType::Float16, T, 3, 2, 4>(
305 return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
312 return GatherTestHelper<armnn::DataType::Float16>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
319 return GatherTestHelper<armnn::DataType::QAsymmU8>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
326 return GatherTestHelper<armnn::DataType::QSymmS16>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
333 return GatherTestHelper<armnn::DataType::Signed32>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
340 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsTestImpl(workloadFactory, memoryManager);
347 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsTestImpl(workloadFactory, memoryManager);
354 return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsTestImpl(
355 workloadFactory, memoryManager);
362 return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsTestImpl(
363 workloadFactory, memoryManager);
370 return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsTestImpl(
371 workloadFactory, memoryManager);
378 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
379 workloadFactory, memoryManager);
386 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
387 workloadFactory, memoryManager);
394 return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsMultiDimIndicesTestImpl(
395 workloadFactory, memoryManager);
402 return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
403 workloadFactory, memoryManager);
410 return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
411 workloadFactory, memoryManager);
LayerTestResult< int32_t, 4 > GatherMultiDimParamsMultiDimIndicesInt32Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
LayerTestResult< armnn::Half, 4 > GatherMultiDimParamsMultiDimIndicesFloat16Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
LayerTestResult< uint8_t, 1 > Gather1dParamsUint8Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
LayerTestResult< int32_t, 2 > GatherMultiDimParamsInt32Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
LayerTestResult< int16_t, 1 > Gather1dParamsInt16Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
LayerTestResult< armnn::Half, 1 > Gather1dParamsFloat16Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
typename ResolveTypeImpl< DT >::Type ResolveType
Copyright (c) 2020 ARM Limited.
void IgnoreUnused(Ts &&...)
LayerTestResult< uint8_t, 2 > GatherMultiDimParamsUint8Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
LayerTestResult< int16_t, 2 > GatherMultiDimParamsInt16Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
LayerTestResult< float, 4 > GatherMultiDimParamsMultiDimIndicesFloat32Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
LayerTestResult< int16_t, 4 > GatherMultiDimParamsMultiDimIndicesInt16Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
#define ARMNN_NO_DEPRECATE_WARN_END
std::shared_ptr< IMemoryManager > IMemoryManagerSharedPtr
LayerTestResult< int32_t, 1 > Gather1dParamsInt32Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
void SetQuantizationScale(float scale)
void CopyDataFromITensorHandle(void *memory, const armnn::ITensorHandle *tensorHandle)
LayerTestResult< uint8_t, 4 > GatherMultiDimParamsMultiDimIndicesUint8Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
virtual std::unique_ptr< ITensorHandle > CreateTensorHandle(const TensorInfo &tensorInfo, const bool IsMemoryManaged=true) const =0
LayerTestResult< armnn::Half, 2 > GatherMultiDimParamsFloat16Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
virtual std::unique_ptr< IWorkload > CreateGather(const GatherQueueDescriptor &descriptor, const WorkloadInfo &info) const
LayerTestResult< float, 1 > Gather1dParamsFloat32Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
LayerTestResult< float, 2 > GatherMultiDimParamsFloat32Test(armnn::IWorkloadFactory &workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager)
Contains information about inputs and outputs to a layer.
void CopyDataToITensorHandle(armnn::ITensorHandle *tensorHandle, const void *memory)