aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2024-05-09 13:24:15 +0100
committerSuhail M <MohammedSuhail.Munshi@arm.com>2024-05-10 11:01:30 +0000
commit05269f013cf2b7c4a53f5950cdd6bfea26367769 (patch)
treebc620cce08a1569c311ce0fd9d74833e5ff63382 /src
parent48f120c64c21d983318c6e65f6d5609a8f8e92e6 (diff)
downloadComputeLibrary-05269f013cf2b7c4a53f5950cdd6bfea26367769.tar.gz
ScatterND fix for scalar cases
- Padding with batched scalar cases is unsupported, adds checks. - Adds tests for scalar cases, without padding. Resolves: [COMPMID-7015] Change-Id: Ib9cf5db990420ff4b442d003ef9424e365bee86d Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11536 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.cpp47
1 files changed, 31 insertions, 16 deletions
diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp
index f76a674b27..19adc1ef34 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.cpp
+++ b/src/gpu/cl/kernels/ClScatterKernel.cpp
@@ -69,7 +69,10 @@ Status ClScatterKernel::validate(const ITensorInfo *updates,
const int32_t data_dim = upt_dims - (ind_dims - 1); // Number of batch dims is the number of indices dims - 1
const int32_t index_len = ind_shape[0];
+ bool unsupported_padding_config =
+ (dst_dims == index_len) && index_len > 1 && (dst->has_padding() || updates->has_padding());
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(unsupported_padding_config, "Padding is not supported with these shapes.");
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32, DataType::F16, DataType::S32, DataType::S16,
@@ -99,9 +102,8 @@ Status ClScatterKernel::validate(const ITensorInfo *updates,
ARM_COMPUTE_RETURN_ERROR_ON_MSG((ind_dims < 2), "Shape of Indices tensor must be at least 2D");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > max_index_length, "Maximum supported index length is 5!");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- index_len >= dst_dims && dst_dims != 1,
- "Index length should be smaller than number of output dims (or equal to with 1D output)");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > dst_dims && dst_dims != 1,
+ "Index length should be smaller than or equal to number of output dims");
return Status{};
}
@@ -116,25 +118,31 @@ void ClScatterKernel::configure(const ClCompileContext &compile_context,
ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info);
const TensorShape &dst_shape = dst->tensor_shape();
+ const int index_len = indices->dimension(0);
- const bool is_scalar_block = updates->num_dimensions() == 1; // Checks for replacing only a single element.
- const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0));
+ // Check for single element data block
+ const bool is_scalar_block = (dst->num_dimensions() == static_cast<uint32_t>(index_len));
+ const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0));
const int partial_n0 = updates->dimension(0) % n0;
// The GWS will be 2D [x, y]
// x-dimension refers to the x coordinate of the dst tensor
// y-dimension refers to the collapsed y-coordinate of the data part of the dst tensor
- Window win = calculate_max_window(dst_shape, Steps(n0));
- const int index_len = indices->dimension(0);
+ Window win;
- // Collapse the dimensions corresponding to indices in the execution window
- for (int i = 0; i < index_len; ++i)
+ if (!is_scalar_block)
{
- win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1));
- }
+ win = calculate_max_window(dst_shape, Steps(n0));
+
+ // Collapse the dimensions corresponding to indices in the execution window
+ for (int i = 0; i < index_len; ++i)
+ {
+ win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1));
+ }
- win = win.collapse(win, 1);
+ win = win.collapse(win, 1);
+ }
// Set build options
CLBuildOptions build_opts;
@@ -206,11 +214,18 @@ void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::Com
utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1));
auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
- const ITensorInfo *dst_info = dst->info();
- const int num_dims = dst_info->num_dimensions();
- const int ind_dims = indices->info()->num_dimensions();
+ const ITensorInfo *dst_info = dst->info();
+ const ITensorInfo *upd_info = updates->info();
+ const int num_dims = dst_info->num_dimensions();
+ const int ind_dims = indices->info()->num_dimensions();
+ const int index_len = indices->info()->dimension(0);
- const int index_len = indices->info()->dimension(0);
+ bool unsupported_padding_config =
+ num_dims == index_len && index_len > 1 && (dst_info->has_padding() || upd_info->has_padding());
+ if (unsupported_padding_config)
+ {
+ ARM_COMPUTE_ERROR("Unsupported Configuration! Padding not supported with these shapes.");
+ }
// calculate m-dimensional data block strides in updates and destination tensors
const int upt_block_stride =