From 7ad6257ff09c94aade46ce5d02b644821235121a Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Wed, 16 Jan 2019 11:18:15 +0000 Subject: COMPMID-1727: CL: Implement Gather - support S32 indices Change-Id: Ib0298dc75d25acf8db2262c0ab73ecaa6fec636c Reviewed-on: https://review.mlplatform.org/522 Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins --- arm_compute/core/CL/kernels/CLGatherKernel.h | 4 ++-- arm_compute/runtime/CL/functions/CLGather.h | 4 ++-- src/core/CL/cl_kernels/gather.cl | 2 +- src/core/CL/kernels/CLGatherKernel.cpp | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/arm_compute/core/CL/kernels/CLGatherKernel.h b/arm_compute/core/CL/kernels/CLGatherKernel.h index 4dac6b0d1f..a348977767 100644 --- a/arm_compute/core/CL/kernels/CLGatherKernel.h +++ b/arm_compute/core/CL/kernels/CLGatherKernel.h @@ -50,7 +50,7 @@ public: /** Initialise the kernel's inputs and outputs * * @param[in] input Source tensor. Supported tensor rank: up to 4. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32 - * @param[in] indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following type: S64. Each value Must be in range [0, input.shape[@p axis]) + * @param[in] indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following types: U32/S32. Each value must be in range [0, input.shape[@p axis]) * @param[out] output Destination tensor. Data type supported: Same as @p input * @param[in] axis (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0 */ @@ -59,7 +59,7 @@ public: /** Static function to check if given info will lead to a valid configuration of @ref CLGatherKernel * * @param[in] input Source tensor info. Supported tensor rank: up to 4. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32 - * @param[in] indices Indices tensor info. Supported tensor rank: up to 4. Must be one of the following type: S64. Each value Must be in range [0, input.shape[@p axis]) + * @param[in] indices Indices tensor info. Supported tensor rank: up to 4. Must be one of the following types: U32/S32. Each value must be in range [0, input.shape[@p axis]) * @param[in] output Destination tensor info. Data type supported: Same as @p input * @param[in] axis (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0 * diff --git a/arm_compute/runtime/CL/functions/CLGather.h b/arm_compute/runtime/CL/functions/CLGather.h index 048804dfb2..78bf82594a 100644 --- a/arm_compute/runtime/CL/functions/CLGather.h +++ b/arm_compute/runtime/CL/functions/CLGather.h @@ -38,7 +38,7 @@ public: /** Initialise the kernel's inputs and outputs * * @param[in] input Source tensor. Supported tensor rank: up to 4. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32 - * @param[in] indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following type: S64. Each value Must be in range [0, input.shape[@p axis]) + * @param[in] indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following types: U32/S32. Each value must be in range [0, input.shape[@p axis]) * @param[out] output Destination tensor. Data type supported: Same as @p input * @param[in] axis (Optional) The axis in @p input to gather @p indices from. Defaults to 0 */ @@ -47,7 +47,7 @@ public: /** Static function to check if given info will lead to a valid configuration of @ref CLGatherKernel * * @param[in] input Source tensor info. Supported tensor rank: up to 4. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32 - * @param[in] indices Indices tensor info. Supported tensor rank: up to 4. Must be one of the following types: S64. Each value Must be in range [0, input.shape[@p axis]) + * @param[in] indices Indices tensor info. Supported tensor rank: up to 4. Must be one of the following types: U32/S32. Each value must be in range [0, input.shape[@p axis]) * @param[in] output Destination tensor info. Data type supported: Same as @p input * @param[in] axis (Optional) The axis in @p input to gather @p indices from. Defaults to 0 * diff --git a/src/core/CL/cl_kernels/gather.cl b/src/core/CL/cl_kernels/gather.cl index 34593ef60f..d6fe52d9d5 100644 --- a/src/core/CL/cl_kernels/gather.cl +++ b/src/core/CL/cl_kernels/gather.cl @@ -42,7 +42,7 @@ * @param[in] input_stride_w Stride of the source tensor in Z dimension (in bytes) * @param[in] input_step_w input_stride_w * number of elements along W processed per work item (in bytes) * @param[in] input_offset_first_element_in_bytes Offset of the first element in the source tensor - * @param[in] indices_ptr Pointer to the indices vector. Supported data types: U32. + * @param[in] indices_ptr Pointer to the indices vector. Supported data types: S32/U32. * @param[in] indices_stride_x Stride of the indices vector in X dimension (in bytes) * @param[in] indices_step_x input_stride_x * number of elements along X processed per work item (in bytes) * @param[in] indices_offset_first_element_in_bytes Offset of the first element in the indices vector diff --git a/src/core/CL/kernels/CLGatherKernel.cpp b/src/core/CL/kernels/CLGatherKernel.cpp index 006e755b30..a4a8808f25 100644 --- a/src/core/CL/kernels/CLGatherKernel.cpp +++ b/src/core/CL/kernels/CLGatherKernel.cpp @@ -60,7 +60,7 @@ inline Status validate_arguments(const ITensorInfo *input, const ITensorInfo *in ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() != output->tensor_shape().total_size()); } - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32, DataType::S32); return Status{}; } -- cgit v1.2.1