aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2019-01-16 11:18:15 +0000
committerManuel Bottini <manuel.bottini@arm.com>2019-01-21 13:54:01 +0000
commit7ad6257ff09c94aade46ce5d02b644821235121a (patch)
tree246c3c266413ef431ff2794a46d5666ee92c8438
parent3de4c73ffee79570c876e191541afc79f143d7a0 (diff)
downloadComputeLibrary-7ad6257ff09c94aade46ce5d02b644821235121a.tar.gz
COMPMID-1727: CL: Implement Gather - support S32 indices
Change-Id: Ib0298dc75d25acf8db2262c0ab73ecaa6fec636c Reviewed-on: https://review.mlplatform.org/522 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CL/kernels/CLGatherKernel.h4
-rw-r--r--arm_compute/runtime/CL/functions/CLGather.h4
-rw-r--r--src/core/CL/cl_kernels/gather.cl2
-rw-r--r--src/core/CL/kernels/CLGatherKernel.cpp2
4 files changed, 6 insertions, 6 deletions
diff --git a/arm_compute/core/CL/kernels/CLGatherKernel.h b/arm_compute/core/CL/kernels/CLGatherKernel.h
index 4dac6b0d1f..a348977767 100644
--- a/arm_compute/core/CL/kernels/CLGatherKernel.h
+++ b/arm_compute/core/CL/kernels/CLGatherKernel.h
@@ -50,7 +50,7 @@ public:
/** Initialise the kernel's inputs and outputs
*
* @param[in] input Source tensor. Supported tensor rank: up to 4. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
- * @param[in] indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following type: S64. Each value Must be in range [0, input.shape[@p axis])
+ * @param[in] indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following types: U32/S32. Each value must be in range [0, input.shape[@p axis])
* @param[out] output Destination tensor. Data type supported: Same as @p input
* @param[in] axis (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0
*/
@@ -59,7 +59,7 @@ public:
/** Static function to check if given info will lead to a valid configuration of @ref CLGatherKernel
*
* @param[in] input Source tensor info. Supported tensor rank: up to 4. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
- * @param[in] indices Indices tensor info. Supported tensor rank: up to 4. Must be one of the following type: S64. Each value Must be in range [0, input.shape[@p axis])
+ * @param[in] indices Indices tensor info. Supported tensor rank: up to 4. Must be one of the following types: U32/S32. Each value must be in range [0, input.shape[@p axis])
* @param[in] output Destination tensor info. Data type supported: Same as @p input
* @param[in] axis (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0
*
diff --git a/arm_compute/runtime/CL/functions/CLGather.h b/arm_compute/runtime/CL/functions/CLGather.h
index 048804dfb2..78bf82594a 100644
--- a/arm_compute/runtime/CL/functions/CLGather.h
+++ b/arm_compute/runtime/CL/functions/CLGather.h
@@ -38,7 +38,7 @@ public:
/** Initialise the kernel's inputs and outputs
*
* @param[in] input Source tensor. Supported tensor rank: up to 4. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
- * @param[in] indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following type: S64. Each value Must be in range [0, input.shape[@p axis])
+ * @param[in] indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following types: U32/S32. Each value must be in range [0, input.shape[@p axis])
* @param[out] output Destination tensor. Data type supported: Same as @p input
* @param[in] axis (Optional) The axis in @p input to gather @p indices from. Defaults to 0
*/
@@ -47,7 +47,7 @@ public:
/** Static function to check if given info will lead to a valid configuration of @ref CLGatherKernel
*
* @param[in] input Source tensor info. Supported tensor rank: up to 4. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32
- * @param[in] indices Indices tensor info. Supported tensor rank: up to 4. Must be one of the following types: S64. Each value Must be in range [0, input.shape[@p axis])
+ * @param[in] indices Indices tensor info. Supported tensor rank: up to 4. Must be one of the following types: U32/S32. Each value must be in range [0, input.shape[@p axis])
* @param[in] output Destination tensor info. Data type supported: Same as @p input
* @param[in] axis (Optional) The axis in @p input to gather @p indices from. Defaults to 0
*
diff --git a/src/core/CL/cl_kernels/gather.cl b/src/core/CL/cl_kernels/gather.cl
index 34593ef60f..d6fe52d9d5 100644
--- a/src/core/CL/cl_kernels/gather.cl
+++ b/src/core/CL/cl_kernels/gather.cl
@@ -42,7 +42,7 @@
* @param[in] input_stride_w Stride of the source tensor in Z dimension (in bytes)
* @param[in] input_step_w input_stride_w * number of elements along W processed per work item (in bytes)
* @param[in] input_offset_first_element_in_bytes Offset of the first element in the source tensor
- * @param[in] indices_ptr Pointer to the indices vector. Supported data types: U32.
+ * @param[in] indices_ptr Pointer to the indices vector. Supported data types: S32/U32.
* @param[in] indices_stride_x Stride of the indices vector in X dimension (in bytes)
* @param[in] indices_step_x input_stride_x * number of elements along X processed per work item (in bytes)
* @param[in] indices_offset_first_element_in_bytes Offset of the first element in the indices vector
diff --git a/src/core/CL/kernels/CLGatherKernel.cpp b/src/core/CL/kernels/CLGatherKernel.cpp
index 006e755b30..a4a8808f25 100644
--- a/src/core/CL/kernels/CLGatherKernel.cpp
+++ b/src/core/CL/kernels/CLGatherKernel.cpp
@@ -60,7 +60,7 @@ inline Status validate_arguments(const ITensorInfo *input, const ITensorInfo *in
ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() != output->tensor_shape().total_size());
}
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32, DataType::S32);
return Status{};
}