aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>2024-03-25 15:55:42 +0000
committerSuhail M <MohammedSuhail.Munshi@arm.com>2024-04-22 14:44:09 +0000
commit7377107378d6c26439320fce78a551e85b5ad36a (patch)
tree3aa9c74c59993f9d51924fc123eefa17e3376a79
parent5057ce9e1866ffa0388543d81af32083b5b1c684 (diff)
downloadComputeLibrary-7377107378d6c26439320fce78a551e85b5ad36a.tar.gz
Scatter GPU Kernel Implementation for 1D tensors.
Resolves: [COMPMID-6891, COMPMID-6892] Change-Id: I5b094fff1bff4c4c59cc44f7d6beab0e40133d8e Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11394 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--Android.bp1
-rw-r--r--SConscript1
-rw-r--r--arm_compute/runtime/CL/functions/CLScatter.h7
-rw-r--r--src/core/CL/cl_kernels/common/scatter.cl117
-rw-r--r--src/gpu/cl/ClKernelLibrary.cpp7
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.cpp98
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.h13
-rw-r--r--src/gpu/cl/operators/ClScatter.cpp59
-rw-r--r--src/gpu/cl/operators/ClScatter.h8
-rw-r--r--tests/datasets/ScatterDataset.h4
-rw-r--r--tests/validation/CL/ScatterLayer.cpp38
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h19
-rw-r--r--tests/validation/reference/ScatterLayer.cpp10
-rw-r--r--tests/validation/reference/ScatterLayer.h4
14 files changed, 309 insertions, 77 deletions
diff --git a/Android.bp b/Android.bp
index 6cc85f1928..ab554a8ca2 100644
--- a/Android.bp
+++ b/Android.bp
@@ -65,6 +65,7 @@ opencl_srcs = [
"src/core/CL/cl_kernels/common/roi_align_layer.cl",
"src/core/CL/cl_kernels/common/roi_align_layer_quantized.cl",
"src/core/CL/cl_kernels/common/roi_pooling_layer.cl",
+ "src/core/CL/cl_kernels/common/scatter.cl",
"src/core/CL/cl_kernels/common/select.cl",
"src/core/CL/cl_kernels/common/slice_ops.cl",
"src/core/CL/cl_kernels/common/softmax_layer.cl",
diff --git a/SConscript b/SConscript
index c1eef44ebe..80aa87cae8 100644
--- a/SConscript
+++ b/SConscript
@@ -429,6 +429,7 @@ if env['opencl'] and env['embed_kernels']:
'src/core/CL/cl_kernels/common/fill_border.cl',
'src/core/CL/cl_kernels/common/floor.cl',
'src/core/CL/cl_kernels/common/gather.cl',
+ 'src/core/CL/cl_kernels/common/scatter.cl',
'src/core/CL/cl_kernels/common/gemm.cl',
'src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl',
'src/core/CL/cl_kernels/common/gemm_utils.cl',
diff --git a/arm_compute/runtime/CL/functions/CLScatter.h b/arm_compute/runtime/CL/functions/CLScatter.h
index 1c90d208bd..973953624e 100644
--- a/arm_compute/runtime/CL/functions/CLScatter.h
+++ b/arm_compute/runtime/CL/functions/CLScatter.h
@@ -55,14 +55,15 @@ public:
~CLScatter();
/** Initialise the kernel's inputs and outputs
*
+ * @note Negative indices are treated as out of bounds.
+ *
* Valid data layouts:
* - All
*
- *
* @param[in] compile_context The compile context to be used.
* @param[in] src Source tensor. Values used to fill output. Can be nullptr when zero initialization is true.
* @param[in] updates Tensor containing values used to update output tensor. Data types supported: same as @p src
- * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : U32
+ * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : S32
* @param[out] output Destination tensor. Data types supported: same as @p src.
* @param[in] info Scatter info object.
*/
@@ -85,7 +86,7 @@ public:
*
* @param[in] src Source tensor.
* @param[in] updates Tensor containing values used for updating the output Tensor. Data types supported : same as @p src
- * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : U32
+ * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : S32
* @param[in] output Destination tensor. Data types supported: same as @p src.
* @param[in] info Scatter info containing type of scatter.
*
diff --git a/src/core/CL/cl_kernels/common/scatter.cl b/src/core/CL/cl_kernels/common/scatter.cl
new file mode 100644
index 0000000000..73b714e042
--- /dev/null
+++ b/src/core/CL/cl_kernels/common/scatter.cl
@@ -0,0 +1,117 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "helpers.h"
+
+#if defined(INDICES_SHAPE_Y) && defined(DATA_TYPE) && defined(OUT_SHAPE_X) && defined(SCATTER_FUNCTION)
+
+// The below defines the various reduce operations for our purposes.
+// Where a corresponds to the existing value, and b the new value.
+#define ADD_OP(a, b) ((a) + (b))
+#define SUB_OP(a, b) ((a) - (b))
+#define MAX_OP(a, b) fmax(a, b)
+#define MIN_OP(a, b) fmin(a, b)
+#define UPDATE_OP(a, b) (b)
+
+/** Performs the ScatterND operation
+ * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
+ * @note the size of the dst tensor in the "x" dimension should be passed using -DOUT_SHAPE_X at compile time.
+ * @note the number of values in the indices tensor in the y-dim should be passed with -DINDICES_SHAPE_Y at compile time.
+ * @note Negative indices are treated as out of bounds.
+ *
+ * @param[in] updates_ptr Pointer to the source tensor. Supported data types: All
+ * @param[in] updates_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] updates_step_x updates_stride_x * number of elements along X processed per work item (in bytes)
+ * @param[in] updates_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] updates_step_y updates_stride_y * number of elements along Y processed per work item (in bytes)
+ * @param[in] updates_stride_z Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] updates_step_z updates_stride_z * number of elements along Z processed per work item (in bytes)
+ * @param[in] updates_stride_w Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] updates_step_w updates_stride_w * number of elements along W processed per work item (in bytes)
+ * @param[in] updates_offset_first_element_in_bytes Offset of the first element in the source tensor
+ * @param[in] indices_ptr Pointer to the indices vector. Supported data types: S32.
+ * @param[in] indices_stride_x Stride of the indices vector in X dimension (in bytes)
+ * @param[in] indices_step_x updates_stride_x * number of elements along X processed per work item (in bytes)
+ * @param[in] indices_offset_first_element_in_bytes Offset of the first element in the indices vector
+ * @param[out] output_ptr Pointer to the destination tensor. Supported data types: same as @p updates_ptr
+ * @param[in] output_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] output_step_x output_stride_x * number of elements along X processed per work item (in bytes)
+ * @param[in] output_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] output_step_y output_stride_y * number of elements along Y processed per work item (in bytes)
+ * @param[in] output_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] output_step_z output_stride_z * number of elements along Z processed per work item (in bytes)
+ * @param[in] output_stride_w Stride of the destination tensor in W dimension (in bytes)
+ * @param[in] output_step_w output_stride_w * number of elements along W processed per work item (in bytes)
+ * @param[in] output_offset_first_element_in_bytes Offset of the first element in the destination tensor
+ */
+// The below kernel code is expected to be excecuted sequentially with a single thread to ensure a deterministic outcome.
+__kernel void scatter1D(
+ TENSOR4D_DECLARATION(updates),
+ TENSOR4D_DECLARATION(indices),
+ TENSOR4D_DECLARATION(output))
+{
+ // Currently 1D - only iterate through y dimension of indices.
+ unsigned int* indices_start_offset = (unsigned int*)(indices_ptr + indices_offset_first_element_in_bytes);
+ DATA_TYPE* updates_start_offset = (DATA_TYPE*)(updates_ptr + updates_offset_first_element_in_bytes);
+ DATA_TYPE* out_start_offset = (DATA_TYPE*)(output_ptr + output_offset_first_element_in_bytes);
+ for (int px = 0; px < INDICES_SHAPE_Y; px++)
+ {
+ const int index_value = *(indices_start_offset);
+ DATA_TYPE* out_addr = out_start_offset + index_value;
+ if((index_value < OUT_SHAPE_X) && (index_value >= 0))
+ {
+ *(__global DATA_TYPE *)(out_addr) = SCATTER_FUNCTION(*(out_addr), *updates_start_offset);
+ }
+ // Increment pointers.
+ indices_start_offset++;
+ updates_start_offset++;
+ }
+}
+
+#endif //defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && defined(INDICES_SHAPE_Y)
+
+#if defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && !defined(INDICES_SHAPE_Y)
+
+// NOTE : This code is non-deterministic and can only be excecuted with the "update" ScatterFunction
+// This code is currently unusued as it requires changes to the existing test suite.
+/** Performs the Scatter1D operation with multiple threads.
+ * Similar to @ref scatter1D()
+ */
+__kernel void scatter1D_parallel(
+ TENSOR4D_DECLARATION(updates),
+ TENSOR4D_DECLARATION(indices),
+ TENSOR4D_DECLARATION(output))
+{
+ // Currently 1D - only iterate through x dimension of indices.
+ const int px = get_global_id(0);
+ const int index_value = *(uchar*)(indices_ptr + indices_offset_first_element_in_bytes + (sizeof(int) * px));
+
+ if(index_value < OUT_SHAPE_X)
+ {
+ const DATA_TYPE update = *(DATA_TYPE *)(updates_ptr + updates_offset_first_element_in_bytes + (sizeof(DATA_TYPE) * px));
+ __global uchar *out_addr = output_ptr + indices_offset_first_element_in_bytes + (sizeof(DATA_TYPE) * index_value);
+ *(__global DATA_TYPE *)(out_addr) = update;
+ }
+}
+
+#endif //defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && !defined(INDICES_SHAPE_Y)
diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp
index 4544a66e39..3e32a27d03 100644
--- a/src/gpu/cl/ClKernelLibrary.cpp
+++ b/src/gpu/cl/ClKernelLibrary.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -441,6 +441,7 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map =
{"reorg_layer_nhwc", "nhwc/reorg_layer.cl"},
{"scale_nearest_neighbour_nhwc", "nhwc/scale.cl"},
{"scale_bilinear_nhwc", "nhwc/scale.cl"},
+ {"scatter1D", "common/scatter.cl"},
{"space_to_batch_nhwc", "nhwc/space_to_batch.cl"},
{"space_to_batch_static_nhwc", "nhwc/space_to_batch.cl"},
{"space_to_depth_nhwc", "nhwc/space_to_depth.cl"},
@@ -591,6 +592,10 @@ const std::map<std::string, std::string> ClKernelLibrary::_program_source_map =
#include "./cl_kernels/common/gather.clembed"
},
{
+ "common/scatter.cl",
+#include "./cl_kernels/common/scatter.clembed"
+ },
+ {
"common/gemm.cl",
#include "./cl_kernels/common/gemm.clembed"
},
diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp
index 720164366e..c95e156679 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.cpp
+++ b/src/gpu/cl/kernels/ClScatterKernel.cpp
@@ -26,6 +26,11 @@
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Utils.h"
+
+#include "src/common/utils/Log.h"
+#include "src/core/helpers/WindowHelpers.h"
+#include "support/Cast.h"
namespace arm_compute
{
@@ -37,40 +42,101 @@ ClScatterKernel::ClScatterKernel()
{
}
-Status ClScatterKernel::validate(const ITensorInfo *src,
- const ITensorInfo *updates,
+Status ClScatterKernel::validate(const ITensorInfo *updates,
const ITensorInfo *indices,
const ITensorInfo *dst,
const ScatterInfo &info)
{
- ARM_COMPUTE_UNUSED(src);
- ARM_COMPUTE_UNUSED(updates);
- ARM_COMPUTE_UNUSED(indices);
- ARM_COMPUTE_UNUSED(dst);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->num_dimensions() > 1, "Only 1D output tensors are currently supported.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(indices->num_dimensions() > 2, "Only 2D indices tensors are currently supported.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(updates->num_dimensions() > 1, "Only 1D update tensors are currently supported.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ indices->tensor_shape().y() != updates->tensor_shape()[updates->num_dimensions() - 1],
+ "Height of indices tensor should match size of highest dimension in updates tensor.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(updates->num_dimensions() > dst->num_dimensions(),
+ "Update tensor cannot have more dims than output tensor.");
ARM_COMPUTE_UNUSED(info);
return Status{};
}
void ClScatterKernel::configure(const ClCompileContext &compile_context,
- const ITensorInfo *src,
const ITensorInfo *updates,
const ITensorInfo *indices,
ITensorInfo *dst,
const ScatterInfo &info)
{
- ARM_COMPUTE_UNUSED(compile_context);
- ARM_COMPUTE_UNUSED(src);
- ARM_COMPUTE_UNUSED(updates);
- ARM_COMPUTE_UNUSED(indices);
- ARM_COMPUTE_UNUSED(dst);
- ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(updates, dst, indices);
+ ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info);
+
+ // Configure kernel window
+ const auto indices_shape = indices->tensor_shape();
+ Window win = calculate_max_window(
+ *indices, Steps(indices_shape.x(), indices_shape.y())); // Ensures single thread for deterministic output.
+
+ // Set build options
+ CLBuildOptions build_opts;
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type()));
+ build_opts.add_option("-DINDICES_DIMS=" + support::cpp11::to_string(indices->num_dimensions()));
+ build_opts.add_option("-DINDICES_SHAPE_Y=" + support::cpp11::to_string(indices_shape.y()));
+ build_opts.add_option("-DOUT_SHAPE_X=" + support::cpp11::to_string(dst->tensor_shape().x()));
+
+ switch (info.func)
+ {
+ case ScatterFunction::Update:
+ build_opts.add_option("-DSCATTER_FUNCTION=UPDATE_OP");
+ break;
+ case ScatterFunction::Add:
+ build_opts.add_option("-DSCATTER_FUNCTION=ADD_OP");
+ break;
+ case ScatterFunction::Sub:
+ build_opts.add_option("-DSCATTER_FUNCTION=SUB_OP");
+ break;
+ case ScatterFunction::Max:
+ build_opts.add_option("-DSCATTER_FUNCTION=MAX_OP");
+ break;
+ case ScatterFunction::Min:
+ build_opts.add_option("-DSCATTER_FUNCTION=MIN_OP");
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+
+ // Create kernel
+ std::string kernel_name("scatter1D");
+ ICLKernel::configure_internal(win);
+ _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
+ // Set config_id for enabling LWS tuning
+ _config_id = kernel_name;
+ _config_id += "_";
+ _config_id += lower_string(string_from_data_type(updates->data_type()));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(dst->dimension(1));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(dst->dimension(0));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(dst->dimension(2));
+ _config_id += "_";
}
void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue)
{
- ARM_COMPUTE_UNUSED(tensors);
- ARM_COMPUTE_UNUSED(window);
- ARM_COMPUTE_UNUSED(queue);
+ unsigned int idx = 0;
+
+ Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+
+ const auto updates =
+ utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0));
+ const auto indices =
+ utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1));
+ auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
+ add_4D_tensor_argument(idx, updates, window_collapsed);
+ add_4D_tensor_argument(idx, indices, window_collapsed);
+ add_4D_tensor_argument(idx, dst, window_collapsed);
+
+ enqueue(queue, *this, window, lws_hint());
}
} // namespace kernels
diff --git a/src/gpu/cl/kernels/ClScatterKernel.h b/src/gpu/cl/kernels/ClScatterKernel.h
index dda614ff3e..d2a41adde9 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.h
+++ b/src/gpu/cl/kernels/ClScatterKernel.h
@@ -44,15 +44,15 @@ public:
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClScatterKernel);
/** Initialise the kernel's input and output.
*
+ * @note Negative indices are treated as out of bounds.
+ *
* @param[in] compile_context The compile context to be used.
- * @param[in] src Input tensor info for the source matrix.
* @param[in] updates Input tensor info for the Update matrix. Data type supported: same as @p src
- * @param[in] indices Input tensor info for the Indices matrix. Data type supported: U32.
+ * @param[in] indices Input tensor info for the Indices matrix. Data type supported: S32.
* @param[out] dst Output tensor info. Data type supported: same as @p src
* @param[in] info Attributes for Scatter Kernel
*/
void configure(const ClCompileContext &compile_context,
- const ITensorInfo *src,
const ITensorInfo *updates,
const ITensorInfo *indices,
ITensorInfo *dst,
@@ -63,11 +63,8 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *src,
- const ITensorInfo *updates,
- const ITensorInfo *indices,
- const ITensorInfo *dst,
- const ScatterInfo &info);
+ static Status
+ validate(const ITensorInfo *updates, const ITensorInfo *indices, const ITensorInfo *dst, const ScatterInfo &info);
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
diff --git a/src/gpu/cl/operators/ClScatter.cpp b/src/gpu/cl/operators/ClScatter.cpp
index af5fbb86f3..62711ddfe8 100644
--- a/src/gpu/cl/operators/ClScatter.cpp
+++ b/src/gpu/cl/operators/ClScatter.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "src/common/utils/Log.h"
+#include "src/gpu/cl/kernels/ClCopyKernel.h"
#include "src/gpu/cl/kernels/ClFillKernel.h"
#include "src/gpu/cl/kernels/ClScatterKernel.h"
@@ -47,9 +48,21 @@ Status ClScatter::validate(const ITensorInfo *src,
const ScatterInfo &info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(updates, indices, dst);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32); // Currently, other datatypes are not suppported.
+ if (src != nullptr)
+ {
+ // Check dst/src are same shape and datatype.
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(src->tensor_shape(), dst->tensor_shape());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, updates, dst);
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClCopyKernel::validate(src, dst)); // Validate Copy kernel
+ }
+ if (src != dst)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClFillKernel::validate(dst, PixelValue(0.0f))); // Validate Fill kernel.
+ }
- return kernels::ClScatterKernel::validate(src, updates, indices, dst, info);
+ return kernels::ClScatterKernel::validate(updates, indices, dst, info);
}
void ClScatter::configure(const CLCompileContext &compile_context,
@@ -61,11 +74,6 @@ void ClScatter::configure(const CLCompileContext &compile_context,
{
ARM_COMPUTE_ERROR_ON_NULLPTR(updates, indices, dst);
ARM_COMPUTE_LOG_PARAMS(src, indices, dst, info);
- ARM_COMPUTE_UNUSED(src);
- ARM_COMPUTE_UNUSED(updates);
- ARM_COMPUTE_UNUSED(indices);
- ARM_COMPUTE_UNUSED(dst);
- ARM_COMPUTE_UNUSED(info);
// Perform validation step
ARM_COMPUTE_ERROR_THROW_ON(validate(src, updates, indices, dst, info));
@@ -74,19 +82,50 @@ void ClScatter::configure(const CLCompileContext &compile_context,
// If necessary, create fill kernel to fill dst tensor.
if (_fill_zero)
{
- _fill_kernel = std::make_unique<kernels::ClFillKernel>();
+ auto f = std::make_unique<kernels::ClFillKernel>();
+ f->configure(compile_context, dst, PixelValue(0.0f));
+ _fill_kernel = std::move(f);
+ }
+ else if (src != dst) // Check whether copying is necessary
+ {
+ // Fill dst with src copy here.
+ auto j = std::make_unique<kernels::ClCopyKernel>();
+ j->configure(compile_context, src, dst);
+ _copy_kernel = std::move(j);
+ _run_copy = true;
}
// Configure ClScatterKernel
auto k = std::make_unique<kernels::ClScatterKernel>();
k->set_target(CLScheduler::get().target());
- k->configure(compile_context, src, updates, indices, dst, info);
+ k->configure(compile_context, updates, indices, dst, info);
_scatter_kernel = std::move(k);
}
void ClScatter::run(ITensorPack &tensors)
{
- ARM_COMPUTE_UNUSED(tensors);
+ // Get tensors.
+ auto src = tensors.get_const_tensor(ACL_SRC_0);
+ auto updates = tensors.get_const_tensor(ACL_SRC_1);
+ auto indices = tensors.get_const_tensor(ACL_SRC_2);
+ auto dst = tensors.get_tensor(ACL_DST);
+
+ if (_fill_zero)
+ {
+ // Fill destination tensor with 0 values if zero init.
+ ITensorPack fill_pack{{ACL_SRC, dst}};
+ CLScheduler::get().enqueue_op(*_fill_kernel, fill_pack, false);
+ }
+
+ if (_run_copy)
+ {
+ // copy src to dst before scatter op.
+ ITensorPack copy_pack{{ACL_SRC, src}, {ACL_DST, dst}};
+ CLScheduler::get().enqueue_op(*_copy_kernel, copy_pack, false);
+ }
+
+ ITensorPack scatter_pack{{ACL_SRC_0, updates}, {ACL_SRC_1, indices}, {ACL_DST, dst}};
+ CLScheduler::get().enqueue_op(*_scatter_kernel, scatter_pack, false);
}
} // namespace opencl
diff --git a/src/gpu/cl/operators/ClScatter.h b/src/gpu/cl/operators/ClScatter.h
index 433f7ca3a4..a1b32fed45 100644
--- a/src/gpu/cl/operators/ClScatter.h
+++ b/src/gpu/cl/operators/ClScatter.h
@@ -39,6 +39,7 @@ namespace opencl
// Forward declaration
class ClFillKernel;
class ClScatterKernel;
+class ClCopyKernel;
/** Basic operator to execute Scatter on OpenCL. This operator calls the following OpenCL kernels:
*
@@ -56,13 +57,14 @@ public:
* Valid data layouts:
* - All
*
- * @note indices must always be U32
+ * @note indices must always be S32.
+ * @note Negative indices are treated as out of bounds.
* @note src, updates and dst tensors must be same datatype.
*
* @param[in] compile_context The compile context to be used.
* @param[in] src Source input tensor info. Can be nullptr when using "Add" Scatter Function with zero initialization.
* @param[in] updates Tensor info for tensor storing update values to use for scatter function. Data types supported: same as @p src.
- * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: U32 only.
+ * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: S32 only.
* @param[out] dst Output tensor to store the result of the Scatter Function. Data types supported: same as @p src and @p updates.
* @param[in] Scatter_info Contains Scatter operation information described in @ref ScatterInfo.
*/
@@ -89,7 +91,9 @@ public:
private:
std::unique_ptr<opencl::IClKernel> _scatter_kernel{nullptr};
std::unique_ptr<opencl::IClKernel> _fill_kernel{nullptr};
+ std::unique_ptr<opencl::IClKernel> _copy_kernel{nullptr};
bool _fill_zero{false};
+ bool _run_copy{false};
};
} // namespace opencl
} // namespace arm_compute
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h
index d204d17855..f7547ecc94 100644
--- a/tests/datasets/ScatterDataset.h
+++ b/tests/datasets/ScatterDataset.h
@@ -118,8 +118,8 @@ class Small1DScatterDataset final : public ScatterDataset
public:
Small1DScatterDataset()
{
- add_config(TensorShape(6U), TensorShape(6U), TensorShape(6U), TensorShape(6U));
- add_config(TensorShape(10U), TensorShape(2U), TensorShape(2U), TensorShape(10U));
+ add_config(TensorShape(6U), TensorShape(6U), TensorShape(1U, 6U), TensorShape(6U));
+ add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U));
}
};
} // namespace datasets
diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp
index 56338f489f..9711671841 100644
--- a/tests/validation/CL/ScatterLayer.cpp
+++ b/tests/validation/CL/ScatterLayer.cpp
@@ -38,6 +38,10 @@ namespace test
{
namespace validation
{
+namespace
+{
+RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for fp32 data type */
+} // namespace
template <typename T>
using CLScatterLayerFixture = ScatterValidationFixture<CLTensor, CLAccessor, CLScatter, T>;
@@ -46,7 +50,7 @@ using framework::dataset::make;
TEST_SUITE(CL)
TEST_SUITE(Scatter)
-DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip(
+DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types
TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
TensorInfo(TensorShape(8U), 1, DataType::F32),
@@ -61,12 +65,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip(
TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
TensorInfo(TensorShape(2U), 1, DataType::F32),
}),
- make("IndicesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U32),
- TensorInfo(TensorShape(15U), 1, DataType::U32),
- TensorInfo(TensorShape(2U), 1, DataType::U32),
- TensorInfo(TensorShape(271U), 1, DataType::U32),
- TensorInfo(TensorShape(271U), 1, DataType::U32),
- TensorInfo(TensorShape(2U), 1 , DataType::S32)
+ make("IndicesInfo",{ TensorInfo(TensorShape(1U, 3U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 2U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32)
}),
make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16),
TensorInfo(TensorShape(15U), 1, DataType::F32),
@@ -76,27 +80,27 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip(
TensorInfo(TensorShape(12U), 1, DataType::F32)
}),
make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
+ ScatterInfo(ScatterFunction::Max, false),
+ ScatterInfo(ScatterFunction::Min, false),
+ ScatterInfo(ScatterFunction::Add, false),
+ ScatterInfo(ScatterFunction::Update, false),
+ ScatterInfo(ScatterFunction::Sub, false),
}),
make("Expected", { false, true, true, false, false, false })),
input_info, updates_info, indices_info, output_info, scatter_info, expected)
{
- // TODO: Enable validation tests.
- ARM_COMPUTE_UNUSED(input_info);
- ARM_COMPUTE_UNUSED(updates_info);
- ARM_COMPUTE_UNUSED(indices_info);
- ARM_COMPUTE_UNUSED(output_info);
- ARM_COMPUTE_UNUSED(scatter_info);
- ARM_COMPUTE_UNUSED(expected);
+ const Status status = CLScatter::validate(&input_info.clone()->set_is_resizable(true), &updates_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), scatter_info);
+ ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
TEST_SUITE(Float)
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
make("DataType", {DataType::F32}),
- make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max}),
+ make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max }),
make("ZeroInit", {false})))
{
- // TODO: Add validate() here.
+ validate(CLAccessor(_target), _reference, tolerance_f32);
}
// With this test, src should be passed as nullptr.
@@ -105,7 +109,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework
make("ScatterFunction", {ScatterFunction::Add}),
make("ZeroInit", {true})))
{
- // TODO: Add validate() here
+ validate(CLAccessor(_target), _reference, tolerance_f32);
}
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index bda5532a51..451a1e1416 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -27,7 +27,7 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/runtime/CL/CLTensorAllocator.h"
#include "tests/Globals.h"
-#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT
+#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
#include "tests/validation/Validation.h"
#include "tests/validation/reference/ScatterLayer.h"
@@ -71,14 +71,14 @@ protected:
}
}
- // This is used to fill indices tensor with U32 datatype.
+ // This is used to fill indices tensor with S32 datatype.
// Used to prevent ONLY having values that are out of bounds.
template <typename U>
void fill_indices(U &&tensor, int i, const TensorShape &shape)
{
- // Calculate max indices the shape should contain. Add an arbitrary constant to allow testing for some out of bounds values.
- const uint32_t max = std::max({shape[0] , shape[1], shape[2]}) + 5;
- library->fill_tensor_uniform(tensor, i, static_cast<uint32_t>(0), static_cast<uint32_t>(max));
+ // Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension)
+ const int32_t max = std::max({shape[0] , shape[1], shape[2]});
+ library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
}
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
@@ -88,7 +88,7 @@ protected:
// In order - src, updates, indices, output.
TensorType src = create_tensor<TensorType>(shape_a, data_type, 1, a_qinfo);
TensorType updates = create_tensor<TensorType>(shape_b, data_type, 1, a_qinfo);
- TensorType indices = create_tensor<TensorType>(shape_c, DataType::U32, 1, QuantizationInfo());
+ TensorType indices = create_tensor<TensorType>(shape_c, DataType::S32, 1, QuantizationInfo());
TensorType dst = create_tensor<TensorType>(out_shape, data_type, 1, o_qinfo);
FunctionType scatter;
@@ -127,7 +127,6 @@ protected:
fill_indices(AccessorType(indices), 2, out_shape);
scatter.run();
-
return dst;
}
@@ -140,7 +139,7 @@ protected:
// Create reference tensors
SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo };
SimpleTensor<T> updates{b_shape, data_type, 1, QuantizationInfo() };
- SimpleTensor<uint32_t> indices{ c_shape, DataType::U32, 1, QuantizationInfo() };
+ SimpleTensor<int32_t> indices{ c_shape, DataType::S32, 1, QuantizationInfo() };
// Fill reference
fill(src, 0);
@@ -148,9 +147,7 @@ protected:
fill_indices(indices, 2, out_shape);
// Calculate individual reference.
- auto result = reference::scatter_layer<T>(src, updates, indices, out_shape, info);
-
- return result;
+ return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
}
TensorType _target{};
diff --git a/tests/validation/reference/ScatterLayer.cpp b/tests/validation/reference/ScatterLayer.cpp
index 920f2b9990..7543b46bb1 100644
--- a/tests/validation/reference/ScatterLayer.cpp
+++ b/tests/validation/reference/ScatterLayer.cpp
@@ -66,7 +66,7 @@ template float reduce_op(const float &current,const float &update,const ScatterF
// Note : This function currently only supports 1D src, 1D updates, 2D indices, 1D output tensors.
template <typename T>
-SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
+SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
{
SimpleTensor<T> dst{ out_shape, src.data_type(), 1 };
@@ -84,14 +84,14 @@ SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleT
}
// 2. Get max index of output tensor, then iterate over index tensor.
- const auto x_bound = dst.shape().x();
+ const int x_bound = static_cast<int>(dst.shape().x());
for(int i = 0; i < indices.num_elements(); ++i)
{
// 3. Check whether index is out of bounds for dst, if not then apply reduce op.
const auto index = indices[i];
- if (index < x_bound) // Note : index is always >= 0 as datatype is unsigned.
+ if (index < x_bound && index >= 0) // Note : we ignore negative index values.
{
dst[index] = reduce_op(dst[index], updates[i], info.func);
}
@@ -100,12 +100,12 @@ SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleT
}
template <typename T>
-SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
+SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
{
return scatter_layer_internal<T>(src, updates, indices, out_shape, info);
}
-template SimpleTensor<float> scatter_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
+template SimpleTensor<float> scatter_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/reference/ScatterLayer.h b/tests/validation/reference/ScatterLayer.h
index dc441a8894..97d5e70b0d 100644
--- a/tests/validation/reference/ScatterLayer.h
+++ b/tests/validation/reference/ScatterLayer.h
@@ -37,10 +37,10 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<uint32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
+SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<int32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
template <typename T>
-SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<uint32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
+SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<int32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
} // namespace reference
} // namespace validation
} // namespace test