aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>2024-03-25 15:55:42 +0000
committerSuhail M <MohammedSuhail.Munshi@arm.com>2024-04-22 14:44:09 +0000
commit7377107378d6c26439320fce78a551e85b5ad36a (patch)
tree3aa9c74c59993f9d51924fc123eefa17e3376a79 /src
parent5057ce9e1866ffa0388543d81af32083b5b1c684 (diff)
downloadComputeLibrary-7377107378d6c26439320fce78a551e85b5ad36a.tar.gz
Scatter GPU Kernel Implementation for 1D tensors.
Resolves: [COMPMID-6891, COMPMID-6892] Change-Id: I5b094fff1bff4c4c59cc44f7d6beab0e40133d8e Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11394 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/CL/cl_kernels/common/scatter.cl117
-rw-r--r--src/gpu/cl/ClKernelLibrary.cpp7
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.cpp98
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.h13
-rw-r--r--src/gpu/cl/operators/ClScatter.cpp59
-rw-r--r--src/gpu/cl/operators/ClScatter.h8
6 files changed, 265 insertions, 37 deletions
diff --git a/src/core/CL/cl_kernels/common/scatter.cl b/src/core/CL/cl_kernels/common/scatter.cl
new file mode 100644
index 0000000000..73b714e042
--- /dev/null
+++ b/src/core/CL/cl_kernels/common/scatter.cl
@@ -0,0 +1,117 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "helpers.h"
+
+#if defined(INDICES_SHAPE_Y) && defined(DATA_TYPE) && defined(OUT_SHAPE_X) && defined(SCATTER_FUNCTION)
+
+// The below defines the various reduce operations for our purposes.
+// Where a corresponds to the existing value, and b the new value.
+#define ADD_OP(a, b) ((a) + (b))
+#define SUB_OP(a, b) ((a) - (b))
+#define MAX_OP(a, b) fmax(a, b)
+#define MIN_OP(a, b) fmin(a, b)
+#define UPDATE_OP(a, b) (b)
+
+/** Performs the ScatterND operation
+ * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
+ * @note the size of the dst tensor in the "x" dimension should be passed using -DOUT_SHAPE_X at compile time.
+ * @note the number of values in the indices tensor in the y-dim should be passed with -DINDICES_SHAPE_Y at compile time.
+ * @note Negative indices are treated as out of bounds.
+ *
+ * @param[in] updates_ptr Pointer to the source tensor. Supported data types: All
+ * @param[in] updates_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] updates_step_x updates_stride_x * number of elements along X processed per work item (in bytes)
+ * @param[in] updates_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] updates_step_y updates_stride_y * number of elements along Y processed per work item (in bytes)
+ * @param[in] updates_stride_z Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] updates_step_z updates_stride_z * number of elements along Z processed per work item (in bytes)
+ * @param[in] updates_stride_w Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] updates_step_w updates_stride_w * number of elements along W processed per work item (in bytes)
+ * @param[in] updates_offset_first_element_in_bytes Offset of the first element in the source tensor
+ * @param[in] indices_ptr Pointer to the indices vector. Supported data types: S32.
+ * @param[in] indices_stride_x Stride of the indices vector in X dimension (in bytes)
+ * @param[in] indices_step_x updates_stride_x * number of elements along X processed per work item (in bytes)
+ * @param[in] indices_offset_first_element_in_bytes Offset of the first element in the indices vector
+ * @param[out] output_ptr Pointer to the destination tensor. Supported data types: same as @p updates_ptr
+ * @param[in] output_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] output_step_x output_stride_x * number of elements along X processed per work item (in bytes)
+ * @param[in] output_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] output_step_y output_stride_y * number of elements along Y processed per work item (in bytes)
+ * @param[in] output_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] output_step_z output_stride_z * number of elements along Z processed per work item (in bytes)
+ * @param[in] output_stride_w Stride of the destination tensor in W dimension (in bytes)
+ * @param[in] output_step_w output_stride_w * number of elements along W processed per work item (in bytes)
+ * @param[in] output_offset_first_element_in_bytes Offset of the first element in the destination tensor
+ */
+// The below kernel code is expected to be excecuted sequentially with a single thread to ensure a deterministic outcome.
+__kernel void scatter1D(
+ TENSOR4D_DECLARATION(updates),
+ TENSOR4D_DECLARATION(indices),
+ TENSOR4D_DECLARATION(output))
+{
+ // Currently 1D - only iterate through y dimension of indices.
+ unsigned int* indices_start_offset = (unsigned int*)(indices_ptr + indices_offset_first_element_in_bytes);
+ DATA_TYPE* updates_start_offset = (DATA_TYPE*)(updates_ptr + updates_offset_first_element_in_bytes);
+ DATA_TYPE* out_start_offset = (DATA_TYPE*)(output_ptr + output_offset_first_element_in_bytes);
+ for (int px = 0; px < INDICES_SHAPE_Y; px++)
+ {
+ const int index_value = *(indices_start_offset);
+ DATA_TYPE* out_addr = out_start_offset + index_value;
+ if((index_value < OUT_SHAPE_X) && (index_value >= 0))
+ {
+ *(__global DATA_TYPE *)(out_addr) = SCATTER_FUNCTION(*(out_addr), *updates_start_offset);
+ }
+ // Increment pointers.
+ indices_start_offset++;
+ updates_start_offset++;
+ }
+}
+
+#endif //defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && defined(INDICES_SHAPE_Y)
+
+#if defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && !defined(INDICES_SHAPE_Y)
+
+// NOTE : This code is non-deterministic and can only be excecuted with the "update" ScatterFunction
+// This code is currently unusued as it requires changes to the existing test suite.
+/** Performs the Scatter1D operation with multiple threads.
+ * Similar to @ref scatter1D()
+ */
+__kernel void scatter1D_parallel(
+ TENSOR4D_DECLARATION(updates),
+ TENSOR4D_DECLARATION(indices),
+ TENSOR4D_DECLARATION(output))
+{
+ // Currently 1D - only iterate through x dimension of indices.
+ const int px = get_global_id(0);
+ const int index_value = *(uchar*)(indices_ptr + indices_offset_first_element_in_bytes + (sizeof(int) * px));
+
+ if(index_value < OUT_SHAPE_X)
+ {
+ const DATA_TYPE update = *(DATA_TYPE *)(updates_ptr + updates_offset_first_element_in_bytes + (sizeof(DATA_TYPE) * px));
+ __global uchar *out_addr = output_ptr + indices_offset_first_element_in_bytes + (sizeof(DATA_TYPE) * index_value);
+ *(__global DATA_TYPE *)(out_addr) = update;
+ }
+}
+
+#endif //defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && !defined(INDICES_SHAPE_Y)
diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp
index 4544a66e39..3e32a27d03 100644
--- a/src/gpu/cl/ClKernelLibrary.cpp
+++ b/src/gpu/cl/ClKernelLibrary.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -441,6 +441,7 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map =
{"reorg_layer_nhwc", "nhwc/reorg_layer.cl"},
{"scale_nearest_neighbour_nhwc", "nhwc/scale.cl"},
{"scale_bilinear_nhwc", "nhwc/scale.cl"},
+ {"scatter1D", "common/scatter.cl"},
{"space_to_batch_nhwc", "nhwc/space_to_batch.cl"},
{"space_to_batch_static_nhwc", "nhwc/space_to_batch.cl"},
{"space_to_depth_nhwc", "nhwc/space_to_depth.cl"},
@@ -591,6 +592,10 @@ const std::map<std::string, std::string> ClKernelLibrary::_program_source_map =
#include "./cl_kernels/common/gather.clembed"
},
{
+ "common/scatter.cl",
+#include "./cl_kernels/common/scatter.clembed"
+ },
+ {
"common/gemm.cl",
#include "./cl_kernels/common/gemm.clembed"
},
diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp
index 720164366e..c95e156679 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.cpp
+++ b/src/gpu/cl/kernels/ClScatterKernel.cpp
@@ -26,6 +26,11 @@
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Utils.h"
+
+#include "src/common/utils/Log.h"
+#include "src/core/helpers/WindowHelpers.h"
+#include "support/Cast.h"
namespace arm_compute
{
@@ -37,40 +42,101 @@ ClScatterKernel::ClScatterKernel()
{
}
-Status ClScatterKernel::validate(const ITensorInfo *src,
- const ITensorInfo *updates,
+Status ClScatterKernel::validate(const ITensorInfo *updates,
const ITensorInfo *indices,
const ITensorInfo *dst,
const ScatterInfo &info)
{
- ARM_COMPUTE_UNUSED(src);
- ARM_COMPUTE_UNUSED(updates);
- ARM_COMPUTE_UNUSED(indices);
- ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->num_dimensions() > 1, "Only 1D output tensors are currently supported.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(indices->num_dimensions() > 2, "Only 2D indices tensors are currently supported.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(updates->num_dimensions() > 1, "Only 1D update tensors are currently supported.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ indices->tensor_shape().y() != updates->tensor_shape()[updates->num_dimensions() - 1],
+ "Height of indices tensor should match size of highest dimension in updates tensor.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(updates->num_dimensions() > dst->num_dimensions(),
+ "Update tensor cannot have more dims than output tensor.");
ARM_COMPUTE_UNUSED(info);
return Status{};
}
void ClScatterKernel::configure(const ClCompileContext &compile_context,
- const ITensorInfo *src,
const ITensorInfo *updates,
const ITensorInfo *indices,
ITensorInfo *dst,
const ScatterInfo &info)
{
- ARM_COMPUTE_UNUSED(compile_context);
- ARM_COMPUTE_UNUSED(src);
- ARM_COMPUTE_UNUSED(updates);
- ARM_COMPUTE_UNUSED(indices);
- ARM_COMPUTE_UNUSED(dst);
- ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(updates, dst, indices);
+ ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info);
+
+ // Configure kernel window
+ const auto indices_shape = indices->tensor_shape();
+ Window win = calculate_max_window(
+ *indices, Steps(indices_shape.x(), indices_shape.y())); // Ensures single thread for deterministic output.
+
+ // Set build options
+ CLBuildOptions build_opts;
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type()));
+ build_opts.add_option("-DINDICES_DIMS=" + support::cpp11::to_string(indices->num_dimensions()));
+ build_opts.add_option("-DINDICES_SHAPE_Y=" + support::cpp11::to_string(indices_shape.y()));
+ build_opts.add_option("-DOUT_SHAPE_X=" + support::cpp11::to_string(dst->tensor_shape().x()));
+
+ switch (info.func)
+ {
+ case ScatterFunction::Update:
+ build_opts.add_option("-DSCATTER_FUNCTION=UPDATE_OP");
+ break;
+ case ScatterFunction::Add:
+ build_opts.add_option("-DSCATTER_FUNCTION=ADD_OP");
+ break;
+ case ScatterFunction::Sub:
+ build_opts.add_option("-DSCATTER_FUNCTION=SUB_OP");
+ break;
+ case ScatterFunction::Max:
+ build_opts.add_option("-DSCATTER_FUNCTION=MAX_OP");
+ break;
+ case ScatterFunction::Min:
+ build_opts.add_option("-DSCATTER_FUNCTION=MIN_OP");
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+
+ // Create kernel
+ std::string kernel_name("scatter1D");
+ ICLKernel::configure_internal(win);
+ _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
+ // Set config_id for enabling LWS tuning
+ _config_id = kernel_name;
+ _config_id += "_";
+ _config_id += lower_string(string_from_data_type(updates->data_type()));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(dst->dimension(1));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(dst->dimension(0));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(dst->dimension(2));
+ _config_id += "_";
}
void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue)
{
- ARM_COMPUTE_UNUSED(tensors);
- ARM_COMPUTE_UNUSED(window);
- ARM_COMPUTE_UNUSED(queue);
+ unsigned int idx = 0;
+
+ Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+
+ const auto updates =
+ utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0));
+ const auto indices =
+ utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1));
+ auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
+ add_4D_tensor_argument(idx, updates, window_collapsed);
+ add_4D_tensor_argument(idx, indices, window_collapsed);
+ add_4D_tensor_argument(idx, dst, window_collapsed);
+
+ enqueue(queue, *this, window, lws_hint());
}
} // namespace kernels
diff --git a/src/gpu/cl/kernels/ClScatterKernel.h b/src/gpu/cl/kernels/ClScatterKernel.h
index dda614ff3e..d2a41adde9 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.h
+++ b/src/gpu/cl/kernels/ClScatterKernel.h
@@ -44,15 +44,15 @@ public:
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClScatterKernel);
/** Initialise the kernel's input and output.
*
+ * @note Negative indices are treated as out of bounds.
+ *
* @param[in] compile_context The compile context to be used.
- * @param[in] src Input tensor info for the source matrix.
* @param[in] updates Input tensor info for the Update matrix. Data type supported: same as @p src
- * @param[in] indices Input tensor info for the Indices matrix. Data type supported: U32.
+ * @param[in] indices Input tensor info for the Indices matrix. Data type supported: S32.
* @param[out] dst Output tensor info. Data type supported: same as @p src
* @param[in] info Attributes for Scatter Kernel
*/
void configure(const ClCompileContext &compile_context,
- const ITensorInfo *src,
const ITensorInfo *updates,
const ITensorInfo *indices,
ITensorInfo *dst,
@@ -63,11 +63,8 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *src,
- const ITensorInfo *updates,
- const ITensorInfo *indices,
- const ITensorInfo *dst,
- const ScatterInfo &info);
+ static Status
+ validate(const ITensorInfo *updates, const ITensorInfo *indices, const ITensorInfo *dst, const ScatterInfo &info);
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
diff --git a/src/gpu/cl/operators/ClScatter.cpp b/src/gpu/cl/operators/ClScatter.cpp
index af5fbb86f3..62711ddfe8 100644
--- a/src/gpu/cl/operators/ClScatter.cpp
+++ b/src/gpu/cl/operators/ClScatter.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "src/common/utils/Log.h"
+#include "src/gpu/cl/kernels/ClCopyKernel.h"
#include "src/gpu/cl/kernels/ClFillKernel.h"
#include "src/gpu/cl/kernels/ClScatterKernel.h"
@@ -47,9 +48,21 @@ Status ClScatter::validate(const ITensorInfo *src,
const ScatterInfo &info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(updates, indices, dst);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32); // Currently, other datatypes are not suppported.
+ if (src != nullptr)
+ {
+ // Check dst/src are same shape and datatype.
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(src->tensor_shape(), dst->tensor_shape());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, updates, dst);
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClCopyKernel::validate(src, dst)); // Validate Copy kernel
+ }
+ if (src != dst)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClFillKernel::validate(dst, PixelValue(0.0f))); // Validate Fill kernel.
+ }
- return kernels::ClScatterKernel::validate(src, updates, indices, dst, info);
+ return kernels::ClScatterKernel::validate(updates, indices, dst, info);
}
void ClScatter::configure(const CLCompileContext &compile_context,
@@ -61,11 +74,6 @@ void ClScatter::configure(const CLCompileContext &compile_context,
{
ARM_COMPUTE_ERROR_ON_NULLPTR(updates, indices, dst);
ARM_COMPUTE_LOG_PARAMS(src, indices, dst, info);
- ARM_COMPUTE_UNUSED(src);
- ARM_COMPUTE_UNUSED(updates);
- ARM_COMPUTE_UNUSED(indices);
- ARM_COMPUTE_UNUSED(dst);
- ARM_COMPUTE_UNUSED(info);
// Perform validation step
ARM_COMPUTE_ERROR_THROW_ON(validate(src, updates, indices, dst, info));
@@ -74,19 +82,50 @@ void ClScatter::configure(const CLCompileContext &compile_context,
// If necessary, create fill kernel to fill dst tensor.
if (_fill_zero)
{
- _fill_kernel = std::make_unique<kernels::ClFillKernel>();
+ auto f = std::make_unique<kernels::ClFillKernel>();
+ f->configure(compile_context, dst, PixelValue(0.0f));
+ _fill_kernel = std::move(f);
+ }
+ else if (src != dst) // Check whether copying is necessary
+ {
+ // Fill dst with src copy here.
+ auto j = std::make_unique<kernels::ClCopyKernel>();
+ j->configure(compile_context, src, dst);
+ _copy_kernel = std::move(j);
+ _run_copy = true;
}
// Configure ClScatterKernel
auto k = std::make_unique<kernels::ClScatterKernel>();
k->set_target(CLScheduler::get().target());
- k->configure(compile_context, src, updates, indices, dst, info);
+ k->configure(compile_context, updates, indices, dst, info);
_scatter_kernel = std::move(k);
}
void ClScatter::run(ITensorPack &tensors)
{
- ARM_COMPUTE_UNUSED(tensors);
+ // Get tensors.
+ auto src = tensors.get_const_tensor(ACL_SRC_0);
+ auto updates = tensors.get_const_tensor(ACL_SRC_1);
+ auto indices = tensors.get_const_tensor(ACL_SRC_2);
+ auto dst = tensors.get_tensor(ACL_DST);
+
+ if (_fill_zero)
+ {
+ // Fill destination tensor with 0 values if zero init.
+ ITensorPack fill_pack{{ACL_SRC, dst}};
+ CLScheduler::get().enqueue_op(*_fill_kernel, fill_pack, false);
+ }
+
+ if (_run_copy)
+ {
+ // copy src to dst before scatter op.
+ ITensorPack copy_pack{{ACL_SRC, src}, {ACL_DST, dst}};
+ CLScheduler::get().enqueue_op(*_copy_kernel, copy_pack, false);
+ }
+
+ ITensorPack scatter_pack{{ACL_SRC_0, updates}, {ACL_SRC_1, indices}, {ACL_DST, dst}};
+ CLScheduler::get().enqueue_op(*_scatter_kernel, scatter_pack, false);
}
} // namespace opencl
diff --git a/src/gpu/cl/operators/ClScatter.h b/src/gpu/cl/operators/ClScatter.h
index 433f7ca3a4..a1b32fed45 100644
--- a/src/gpu/cl/operators/ClScatter.h
+++ b/src/gpu/cl/operators/ClScatter.h
@@ -39,6 +39,7 @@ namespace opencl
// Forward declaration
class ClFillKernel;
class ClScatterKernel;
+class ClCopyKernel;
/** Basic operator to execute Scatter on OpenCL. This operator calls the following OpenCL kernels:
*
@@ -56,13 +57,14 @@ public:
* Valid data layouts:
* - All
*
- * @note indices must always be U32
+ * @note indices must always be S32.
+ * @note Negative indices are treated as out of bounds.
* @note src, updates and dst tensors must be same datatype.
*
* @param[in] compile_context The compile context to be used.
* @param[in] src Source input tensor info. Can be nullptr when using "Add" Scatter Function with zero initialization.
* @param[in] updates Tensor info for tensor storing update values to use for scatter function. Data types supported: same as @p src.
- * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: U32 only.
+ * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: S32 only.
* @param[out] dst Output tensor to store the result of the Scatter Function. Data types supported: same as @p src and @p updates.
* @param[in] Scatter_info Contains Scatter operation information described in @ref ScatterInfo.
*/
@@ -89,7 +91,9 @@ public:
private:
std::unique_ptr<opencl::IClKernel> _scatter_kernel{nullptr};
std::unique_ptr<opencl::IClKernel> _fill_kernel{nullptr};
+ std::unique_ptr<opencl::IClKernel> _copy_kernel{nullptr};
bool _fill_zero{false};
+ bool _run_copy{false};
};
} // namespace opencl
} // namespace arm_compute