From 7ad6257ff09c94aade46ce5d02b644821235121a Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Wed, 16 Jan 2019 11:18:15 +0000 Subject: COMPMID-1727: CL: Implement Gather - support S32 indices Change-Id: Ib0298dc75d25acf8db2262c0ab73ecaa6fec636c Reviewed-on: https://review.mlplatform.org/522 Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins --- src/core/CL/cl_kernels/gather.cl | 2 +- src/core/CL/kernels/CLGatherKernel.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'src/core') diff --git a/src/core/CL/cl_kernels/gather.cl b/src/core/CL/cl_kernels/gather.cl index 34593ef60f..d6fe52d9d5 100644 --- a/src/core/CL/cl_kernels/gather.cl +++ b/src/core/CL/cl_kernels/gather.cl @@ -42,7 +42,7 @@ * @param[in] input_stride_w Stride of the source tensor in Z dimension (in bytes) * @param[in] input_step_w input_stride_w * number of elements along W processed per work item (in bytes) * @param[in] input_offset_first_element_in_bytes Offset of the first element in the source tensor - * @param[in] indices_ptr Pointer to the indices vector. Supported data types: U32. + * @param[in] indices_ptr Pointer to the indices vector. Supported data types: S32/U32. * @param[in] indices_stride_x Stride of the indices vector in X dimension (in bytes) * @param[in] indices_step_x input_stride_x * number of elements along X processed per work item (in bytes) * @param[in] indices_offset_first_element_in_bytes Offset of the first element in the indices vector diff --git a/src/core/CL/kernels/CLGatherKernel.cpp b/src/core/CL/kernels/CLGatherKernel.cpp index 006e755b30..a4a8808f25 100644 --- a/src/core/CL/kernels/CLGatherKernel.cpp +++ b/src/core/CL/kernels/CLGatherKernel.cpp @@ -60,7 +60,7 @@ inline Status validate_arguments(const ITensorInfo *input, const ITensorInfo *in ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() != output->tensor_shape().total_size()); } - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32, DataType::S32); return Status{}; } -- cgit v1.2.1