diff options
Diffstat (limited to 'src/gpu/cl')
-rw-r--r-- | src/gpu/cl/kernels/ClScatterKernel.cpp | 47 |
1 files changed, 31 insertions, 16 deletions
diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp index f76a674b27..19adc1ef34 100644 --- a/src/gpu/cl/kernels/ClScatterKernel.cpp +++ b/src/gpu/cl/kernels/ClScatterKernel.cpp @@ -69,7 +69,10 @@ Status ClScatterKernel::validate(const ITensorInfo *updates, const int32_t data_dim = upt_dims - (ind_dims - 1); // Number of batch dims is the number of indices dims - 1 const int32_t index_len = ind_shape[0]; + bool unsupported_padding_config = + (dst_dims == index_len) && index_len > 1 && (dst->has_padding() || updates->has_padding()); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(unsupported_padding_config, "Padding is not supported with these shapes."); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32, DataType::F16, DataType::S32, DataType::S16, @@ -99,9 +102,8 @@ Status ClScatterKernel::validate(const ITensorInfo *updates, ARM_COMPUTE_RETURN_ERROR_ON_MSG((ind_dims < 2), "Shape of Indices tensor must be at least 2D"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > max_index_length, "Maximum supported index length is 5!"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - index_len >= dst_dims && dst_dims != 1, - "Index length should be smaller than number of output dims (or equal to with 1D output)"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > dst_dims && dst_dims != 1, + "Index length should be smaller than or equal to number of output dims"); return Status{}; } @@ -116,25 +118,31 @@ void ClScatterKernel::configure(const ClCompileContext &compile_context, ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info); const TensorShape &dst_shape = dst->tensor_shape(); + const int index_len = indices->dimension(0); - const bool is_scalar_block = updates->num_dimensions() == 1; // Checks for replacing only a single element. - const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0)); + // Check for single element data block + const bool is_scalar_block = (dst->num_dimensions() == static_cast<uint32_t>(index_len)); + const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0)); const int partial_n0 = updates->dimension(0) % n0; // The GWS will be 2D [x, y] // x-dimension refers to the x coordinate of the dst tensor // y-dimension refers to the collapsed y-coordinate of the data part of the dst tensor - Window win = calculate_max_window(dst_shape, Steps(n0)); - const int index_len = indices->dimension(0); + Window win; - // Collapse the dimensions corresponding to indices in the execution window - for (int i = 0; i < index_len; ++i) + if (!is_scalar_block) { - win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1)); - } + win = calculate_max_window(dst_shape, Steps(n0)); + + // Collapse the dimensions corresponding to indices in the execution window + for (int i = 0; i < index_len; ++i) + { + win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1)); + } - win = win.collapse(win, 1); + win = win.collapse(win, 1); + } // Set build options CLBuildOptions build_opts; @@ -206,11 +214,18 @@ void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::Com utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1)); auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST)); - const ITensorInfo *dst_info = dst->info(); - const int num_dims = dst_info->num_dimensions(); - const int ind_dims = indices->info()->num_dimensions(); + const ITensorInfo *dst_info = dst->info(); + const ITensorInfo *upd_info = updates->info(); + const int num_dims = dst_info->num_dimensions(); + const int ind_dims = indices->info()->num_dimensions(); + const int index_len = indices->info()->dimension(0); - const int index_len = indices->info()->dimension(0); + bool unsupported_padding_config = + num_dims == index_len && index_len > 1 && (dst_info->has_padding() || upd_info->has_padding()); + if (unsupported_padding_config) + { + ARM_COMPUTE_ERROR("Unsupported Configuration! Padding not supported with these shapes."); + } // calculate m-dimensional data block strides in updates and destination tensors const int upt_block_stride = |