From cdd1e039ad598aec10d8c1b81e08de9412324bf2 Mon Sep 17 00:00:00 2001 From: Omar Al Khatib Date: Wed, 26 Apr 2023 11:31:45 +0100 Subject: Support multi-dimensional indices in the CL Gather Layer up to four-dimensional output tensors Resolves [COMPMID-5775] Signed-off-by: Omar Al Khatib Change-Id: I6f6c12ac08f0b0ad070ca5d715c531c2c3762c30 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9498 Tested-by: Arm Jenkins Reviewed-by: Viet-Hoa Do Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- src/core/CL/kernels/CLGatherKernel.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'src/core/CL/kernels/CLGatherKernel.cpp') diff --git a/src/core/CL/kernels/CLGatherKernel.cpp b/src/core/CL/kernels/CLGatherKernel.cpp index 31a9a3bba4..5495023b80 100644 --- a/src/core/CL/kernels/CLGatherKernel.cpp +++ b/src/core/CL/kernels/CLGatherKernel.cpp @@ -38,8 +38,8 @@ inline Status validate_arguments(const ITensorInfo *input, const ITensorInfo *in { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output); const uint32_t actual_axis = wrap_around(axis, static_cast(input->num_dimensions())); - ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 1); - ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4); + ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() + indices->num_dimensions() - 1) > 4); + ARM_COMPUTE_RETURN_ERROR_ON(actual_axis >= input->num_dimensions()); ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); @@ -102,7 +102,9 @@ void CLGatherKernel::configure(const CLCompileContext &compile_context, const IC CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(data_size_from_type(input->info()->data_type()))); build_opts.add_option("-DOUTPUT_DIM_Z=" + support::cpp11::to_string(output->info()->dimension(2))); + build_opts.add_option("-DINDICES_DIM_Z=" + support::cpp11::to_string(indices->info()->dimension(2))); build_opts.add_option("-DINPUT_DIM_Z=" + support::cpp11::to_string(input->info()->dimension(2))); + build_opts.add_option("-DINDICES_DIMS=" + support::cpp11::to_string(indices->info()->num_dimensions())); build_opts.add_option("-DAXIS=" + support::cpp11::to_string(_axis)); build_opts.add_option("-DINDEX_LIMIT=" + support::cpp11::to_string(input->info()->tensor_shape()[_axis])); @@ -127,7 +129,7 @@ void CLGatherKernel::run(const Window &window, cl::CommandQueue &queue) Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); unsigned int idx = 0; add_4D_tensor_argument(idx, _input, window_collapsed); - add_1D_tensor_argument(idx, _indices, window_collapsed); + add_4D_tensor_argument(idx, _indices, window_collapsed); add_4D_tensor_argument(idx, _output, window_collapsed); enqueue(queue, *this, window_collapsed, lws_hint()); } -- cgit v1.2.1