From cdd1e039ad598aec10d8c1b81e08de9412324bf2 Mon Sep 17 00:00:00 2001 From: Omar Al Khatib Date: Wed, 26 Apr 2023 11:31:45 +0100 Subject: Support multi-dimensional indices in the CL Gather Layer up to four-dimensional output tensors Resolves [COMPMID-5775] Signed-off-by: Omar Al Khatib Change-Id: I6f6c12ac08f0b0ad070ca5d715c531c2c3762c30 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9498 Tested-by: Arm Jenkins Reviewed-by: Viet-Hoa Do Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- src/core/CL/cl_kernels/common/gather.cl | 56 +++++++++++++++++++++++----- src/core/CL/kernels/CLGatherKernel.cpp | 8 ++-- tests/datasets/GatherDataset.h | 38 +++++++++++++++++++ tests/validation/CL/Gather.cpp | 61 ++++++++++++++++++++++++++----- tests/validation/fixtures/GatherFixture.h | 5 +-- 5 files changed, 142 insertions(+), 26 deletions(-) diff --git a/src/core/CL/cl_kernels/common/gather.cl b/src/core/CL/cl_kernels/common/gather.cl index a47c8a7bb7..5d180f3781 100644 --- a/src/core/CL/cl_kernels/common/gather.cl +++ b/src/core/CL/cl_kernels/common/gather.cl @@ -59,34 +59,70 @@ */ __kernel void gather( TENSOR4D_DECLARATION(input), - VECTOR_DECLARATION(indices), + TENSOR4D_DECLARATION(indices), TENSOR4D_DECLARATION(output)) { const int px = get_global_id(0); const int py = get_global_id(1); const int pz = get_global_id(2) % OUTPUT_DIM_Z; - const int pw = get_global_id(2) / OUTPUT_DIM_Z; + const int pw = (get_global_id(2) / OUTPUT_DIM_Z ); const Tensor4D input = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(input, INPUT_DIM_Z); - const Vector indices = CONVERT_TO_VECTOR_STRUCT_NO_STEP(indices); + const Tensor4D indices = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(indices, INDICES_DIM_Z); Tensor4D output = CONVERT_TO_TENSOR4D_STRUCT(output, OUTPUT_DIM_Z); #if AXIS == 0 - const uint index = *(__global const uint *)vector_offset(&indices, px); +#if INDICES_DIMS == 1 + const uint index = *(__global const uint *)tensor4D_offset(&indices, px, 0, 0, 0); const uint safe_index = select((uint)0, index, index < INDEX_LIMIT); __global const uchar *input_addr = tensor4D_offset(&input, safe_index, py, pz, pw); +#elif INDICES_DIMS == 2 + const uint index = *(__global const uint *)tensor4D_offset(&indices, px, py, 0, 0); + const uint safe_index = select((uint)0, index, index < INDEX_LIMIT); + __global const uchar *input_addr = tensor4D_offset(&input, safe_index, pz, pw, 0); +#elif INDICES_DIMS == 3 + const uint index = *(__global const uint *)tensor4D_offset(&indices, px, py, pz, 0); + const uint safe_index = select((uint)0, index, index < INDEX_LIMIT); + __global const uchar *input_addr = tensor4D_offset(&input, safe_index, pw, 0, 0); +#elif INDICES_DIMS == 4 + const uint index = *(__global const uint *)tensor4D_offset(&indices, px, py, pz, pw); + const uint safe_index = select((uint)0, index, index < INDEX_LIMIT); + __global const uchar *input_addr = tensor4D_offset(&input, safe_index, 0, 0, 0); +#endif //INDICES_DIMS + #elif AXIS == 1 - const uint index = *(__global const uint *)vector_offset(&indices, py); +#if INDICES_DIMS == 1 + const uint index = *(__global const uint *)tensor4D_offset(&indices, py, 0, 0, 0); + const uint safe_index = select((uint)0, index, index < INDEX_LIMIT); + __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, pz, pw); +#elif INDICES_DIMS == 2 + const uint index = *(__global const uint *)tensor4D_offset(&indices, py, pz, 0, 0); const uint safe_index = select((uint)0, index, index < INDEX_LIMIT); - __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, pz, pw); + __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, pw, 0); +#elif INDICES_DIMS == 3 + const uint index = *(__global const uint *)tensor4D_offset(&indices, py, pz, pw, 0); + const uint safe_index = select((uint)0, index, index < INDEX_LIMIT); + __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, 0, 0); +#endif //INDICES_DIMS + #elif AXIS == 2 - const uint index = *(__global const uint *)vector_offset(&indices, pz); +#if INDICES_DIMS == 1 + const uint index = *(__global const uint *)tensor4D_offset(&indices, pz, 0, 0, 0); + const uint safe_index = select((uint)0, index, index < INDEX_LIMIT); + __global const uchar *input_addr = tensor4D_offset(&input, px, py, safe_index, pw); +#elif INDICES_DIMS == 2 + const uint index = *(__global const uint *)tensor4D_offset(&indices, pz, pw, 0, 0); const uint safe_index = select((uint)0, index, index < INDEX_LIMIT); - __global const uchar *input_addr = tensor4D_offset(&input, px, py, safe_index, pw); + __global const uchar *input_addr = tensor4D_offset(&input, px, py, safe_index, 0); +#endif //INDICES_DIMS + #elif AXIS == 3 - const uint index = *(__global const uint *)vector_offset(&indices, pw); +#if INDICES_DIMS == 1 + const uint index = *(__global const uint *)tensor4D_offset(&indices, pw, 0, 0, 0); const uint safe_index = select((uint)0, index, index < INDEX_LIMIT); - __global const uchar *input_addr = tensor4D_offset(&input, px, py, pz, safe_index); + __global const uchar *input_addr = tensor4D_offset(&input, px, py, pz, safe_index); +#endif //INDICES_DIMS + #endif //AXIS *(__global DATA_TYPE *)output.ptr = select((DATA_TYPE)0, *((__global const DATA_TYPE *)input_addr), (DATA_TYPE)(index < INDEX_LIMIT)); diff --git a/src/core/CL/kernels/CLGatherKernel.cpp b/src/core/CL/kernels/CLGatherKernel.cpp index 31a9a3bba4..5495023b80 100644 --- a/src/core/CL/kernels/CLGatherKernel.cpp +++ b/src/core/CL/kernels/CLGatherKernel.cpp @@ -38,8 +38,8 @@ inline Status validate_arguments(const ITensorInfo *input, const ITensorInfo *in { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output); const uint32_t actual_axis = wrap_around(axis, static_cast(input->num_dimensions())); - ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 1); - ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4); + ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() + indices->num_dimensions() - 1) > 4); + ARM_COMPUTE_RETURN_ERROR_ON(actual_axis >= input->num_dimensions()); ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); @@ -102,7 +102,9 @@ void CLGatherKernel::configure(const CLCompileContext &compile_context, const IC CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(data_size_from_type(input->info()->data_type()))); build_opts.add_option("-DOUTPUT_DIM_Z=" + support::cpp11::to_string(output->info()->dimension(2))); + build_opts.add_option("-DINDICES_DIM_Z=" + support::cpp11::to_string(indices->info()->dimension(2))); build_opts.add_option("-DINPUT_DIM_Z=" + support::cpp11::to_string(input->info()->dimension(2))); + build_opts.add_option("-DINDICES_DIMS=" + support::cpp11::to_string(indices->info()->num_dimensions())); build_opts.add_option("-DAXIS=" + support::cpp11::to_string(_axis)); build_opts.add_option("-DINDEX_LIMIT=" + support::cpp11::to_string(input->info()->tensor_shape()[_axis])); @@ -127,7 +129,7 @@ void CLGatherKernel::run(const Window &window, cl::CommandQueue &queue) Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); unsigned int idx = 0; add_4D_tensor_argument(idx, _input, window_collapsed); - add_1D_tensor_argument(idx, _indices, window_collapsed); + add_4D_tensor_argument(idx, _indices, window_collapsed); add_4D_tensor_argument(idx, _output, window_collapsed); enqueue(queue, *this, window_collapsed, lws_hint()); } diff --git a/tests/datasets/GatherDataset.h b/tests/datasets/GatherDataset.h index 487ce19bc7..74ea3b4a06 100644 --- a/tests/datasets/GatherDataset.h +++ b/tests/datasets/GatherDataset.h @@ -126,6 +126,44 @@ public: } }; +class CLSmallGatherMultiDimIndicesDataset final : public GatherDataset +{ +public: + CLSmallGatherMultiDimIndicesDataset() + { + add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 0); + add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 0); + add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 0); + add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 0); + + add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 0); + add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 0); + add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 0); + + add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),0); + + add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 1); + add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 1); + add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 1); + add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 1); + + add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 1); + add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 1); + add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 1); + + add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),1); + + add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 2); + add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 2); + add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 2); + + add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 2); + add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 2); + + add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),2); + } +}; + class SmallGatherDataset final : public GatherDataset { public: diff --git a/tests/validation/CL/Gather.cpp b/tests/validation/CL/Gather.cpp index f0b87d7d9f..7619baae1e 100644 --- a/tests/validation/CL/Gather.cpp +++ b/tests/validation/CL/Gather.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2020, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -48,19 +48,21 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices data type - TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices dimensionality - TensorInfo(TensorShape(5U, 5U, 5U, 5U, 5U), 1, DataType::F32), // Invalid Input dimensionality - TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Mismatching data type input/output - TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid positive axis value - TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Invalid negative axis value + TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Output shape + TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices data type + TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices dimensionality + TensorInfo(TensorShape(5U, 5U, 5U, 5U, 5U), 1, DataType::F32), // Invalid Input dimensionality + TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Mismatching data type input/output + TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid positive axis value + TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Invalid negative axis value }), framework::dataset::make("IndicesInfo", { TensorInfo(TensorShape(10U), 1, DataType::U32), TensorInfo(TensorShape(10U), 1, DataType::U32), TensorInfo(TensorShape(10U), 1, DataType::U32), - TensorInfo(TensorShape(10U), 1, DataType::U8), TensorInfo(TensorShape(10U, 10U), 1, DataType::U32), + TensorInfo(TensorShape(10U), 1, DataType::U8), + TensorInfo(TensorShape(10U, 10U, 10U, 10U), 1, DataType::U32), TensorInfo(TensorShape(10U), 1, DataType::U32), TensorInfo(TensorShape(10U), 1, DataType::U32), TensorInfo(TensorShape(10U), 1, DataType::U32), @@ -71,7 +73,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( TensorInfo(TensorShape(27U, 10U), 1, DataType::F32), TensorInfo(TensorShape(10U, 27U), 1, DataType::F32), TensorInfo(TensorShape(10U, 27U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 10U), 1, DataType::F32), + TensorInfo(TensorShape(10U, 27U), 1, DataType::F32), + TensorInfo(TensorShape(27U, 10U, 10U, 10U, 10U), 1, DataType::F32), TensorInfo(TensorShape(10U, 5U, 5U, 5U, 5U), 1, DataType::F32), TensorInfo(TensorShape(27U, 10U), 1, DataType::F32), TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), @@ -82,13 +85,14 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( 1, -2, 0, + 0, 1, 0, 1, 2, -3, })), - framework::dataset::make("Expected", { true, true, true, false, false, false, false, false, false })), + framework::dataset::make("Expected", { true, true, true, false, false, false, false, false, false, false })), input_info, indices_info, output_info, axis, expected) { const Status status = CLGather::validate(&input_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), axis); @@ -111,6 +115,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall, validate(CLAccessor(_target), _reference); } +FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices, + CLGatherFixture, + framework::DatasetMode::PRECOMMIT, + combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::F16))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + FIXTURE_DATA_TEST_CASE(RunLarge, CLGatherFixture, framework::DatasetMode::NIGHTLY, @@ -131,6 +144,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall, validate(CLAccessor(_target), _reference); } +FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices, + CLGatherFixture, + framework::DatasetMode::PRECOMMIT, + combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::F32))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + FIXTURE_DATA_TEST_CASE(RunLarge, CLGatherFixture, framework::DatasetMode::NIGHTLY, @@ -152,6 +174,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall, validate(CLAccessor(_target), _reference); } +FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices, + CLGatherFixture, + framework::DatasetMode::PRECOMMIT, + combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::U8))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + FIXTURE_DATA_TEST_CASE(RunLarge, CLGatherFixture, framework::DatasetMode::NIGHTLY, @@ -172,6 +203,16 @@ FIXTURE_DATA_TEST_CASE(RunSmall, validate(CLAccessor(_target), _reference); } +FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices, + CLGatherFixture, + framework::DatasetMode::PRECOMMIT, + combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::U16))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + + FIXTURE_DATA_TEST_CASE(RunLarge, CLGatherFixture, framework::DatasetMode::NIGHTLY, diff --git a/tests/validation/fixtures/GatherFixture.h b/tests/validation/fixtures/GatherFixture.h index f6f70023b9..b28f93d850 100644 --- a/tests/validation/fixtures/GatherFixture.h +++ b/tests/validation/fixtures/GatherFixture.h @@ -69,10 +69,9 @@ protected: // 10% of the time the index is out-of-range. uint32_t max_index = input_shape[actual_axis] + input_shape[actual_axis] / 9 + 1; - std::uniform_int_distribution dist_index(0, max_index - 1); - //Let's consider 1D indices - for(unsigned int ind = 0; ind < indices_shape[0]; ind++) + + for(unsigned int ind = 0; ind < indices_shape.total_size(); ind++) { indices_ptr[ind] = dist_index(gen); } -- cgit v1.2.1