aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOmar Al Khatib <omar.alkhatib@arm.com>2023-04-26 11:31:45 +0100
committerOmar Al Khatib <omar.alkhatib@arm.com>2023-05-03 13:22:48 +0000
commitcdd1e039ad598aec10d8c1b81e08de9412324bf2 (patch)
tree344bfa6dc1e30604c6e67533eccb08a71e235fde
parent911d5728fccdabbdf41549c58f0266e49c2aeaf0 (diff)
downloadComputeLibrary-cdd1e039ad598aec10d8c1b81e08de9412324bf2.tar.gz
Support multi-dimensional indices in the CL Gather Layer up to four-dimensional output tensors
Resolves [COMPMID-5775] Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com> Change-Id: I6f6c12ac08f0b0ad070ca5d715c531c2c3762c30 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9498 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/CL/cl_kernels/common/gather.cl56
-rw-r--r--src/core/CL/kernels/CLGatherKernel.cpp8
-rw-r--r--tests/datasets/GatherDataset.h38
-rw-r--r--tests/validation/CL/Gather.cpp61
-rw-r--r--tests/validation/fixtures/GatherFixture.h5
5 files changed, 142 insertions, 26 deletions
diff --git a/src/core/CL/cl_kernels/common/gather.cl b/src/core/CL/cl_kernels/common/gather.cl
index a47c8a7bb7..5d180f3781 100644
--- a/src/core/CL/cl_kernels/common/gather.cl
+++ b/src/core/CL/cl_kernels/common/gather.cl
@@ -59,34 +59,70 @@
*/
__kernel void gather(
TENSOR4D_DECLARATION(input),
- VECTOR_DECLARATION(indices),
+ TENSOR4D_DECLARATION(indices),
TENSOR4D_DECLARATION(output))
{
const int px = get_global_id(0);
const int py = get_global_id(1);
const int pz = get_global_id(2) % OUTPUT_DIM_Z;
- const int pw = get_global_id(2) / OUTPUT_DIM_Z;
+ const int pw = (get_global_id(2) / OUTPUT_DIM_Z );
const Tensor4D input = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(input, INPUT_DIM_Z);
- const Vector indices = CONVERT_TO_VECTOR_STRUCT_NO_STEP(indices);
+ const Tensor4D indices = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(indices, INDICES_DIM_Z);
Tensor4D output = CONVERT_TO_TENSOR4D_STRUCT(output, OUTPUT_DIM_Z);
#if AXIS == 0
- const uint index = *(__global const uint *)vector_offset(&indices, px);
+#if INDICES_DIMS == 1
+ const uint index = *(__global const uint *)tensor4D_offset(&indices, px, 0, 0, 0);
const uint safe_index = select((uint)0, index, index < INDEX_LIMIT);
__global const uchar *input_addr = tensor4D_offset(&input, safe_index, py, pz, pw);
+#elif INDICES_DIMS == 2
+ const uint index = *(__global const uint *)tensor4D_offset(&indices, px, py, 0, 0);
+ const uint safe_index = select((uint)0, index, index < INDEX_LIMIT);
+ __global const uchar *input_addr = tensor4D_offset(&input, safe_index, pz, pw, 0);
+#elif INDICES_DIMS == 3
+ const uint index = *(__global const uint *)tensor4D_offset(&indices, px, py, pz, 0);
+ const uint safe_index = select((uint)0, index, index < INDEX_LIMIT);
+ __global const uchar *input_addr = tensor4D_offset(&input, safe_index, pw, 0, 0);
+#elif INDICES_DIMS == 4
+ const uint index = *(__global const uint *)tensor4D_offset(&indices, px, py, pz, pw);
+ const uint safe_index = select((uint)0, index, index < INDEX_LIMIT);
+ __global const uchar *input_addr = tensor4D_offset(&input, safe_index, 0, 0, 0);
+#endif //INDICES_DIMS
+
#elif AXIS == 1
- const uint index = *(__global const uint *)vector_offset(&indices, py);
+#if INDICES_DIMS == 1
+ const uint index = *(__global const uint *)tensor4D_offset(&indices, py, 0, 0, 0);
+ const uint safe_index = select((uint)0, index, index < INDEX_LIMIT);
+ __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, pz, pw);
+#elif INDICES_DIMS == 2
+ const uint index = *(__global const uint *)tensor4D_offset(&indices, py, pz, 0, 0);
const uint safe_index = select((uint)0, index, index < INDEX_LIMIT);
- __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, pz, pw);
+ __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, pw, 0);
+#elif INDICES_DIMS == 3
+ const uint index = *(__global const uint *)tensor4D_offset(&indices, py, pz, pw, 0);
+ const uint safe_index = select((uint)0, index, index < INDEX_LIMIT);
+ __global const uchar *input_addr = tensor4D_offset(&input, px, safe_index, 0, 0);
+#endif //INDICES_DIMS
+
#elif AXIS == 2
- const uint index = *(__global const uint *)vector_offset(&indices, pz);
+#if INDICES_DIMS == 1
+ const uint index = *(__global const uint *)tensor4D_offset(&indices, pz, 0, 0, 0);
+ const uint safe_index = select((uint)0, index, index < INDEX_LIMIT);
+ __global const uchar *input_addr = tensor4D_offset(&input, px, py, safe_index, pw);
+#elif INDICES_DIMS == 2
+ const uint index = *(__global const uint *)tensor4D_offset(&indices, pz, pw, 0, 0);
const uint safe_index = select((uint)0, index, index < INDEX_LIMIT);
- __global const uchar *input_addr = tensor4D_offset(&input, px, py, safe_index, pw);
+ __global const uchar *input_addr = tensor4D_offset(&input, px, py, safe_index, 0);
+#endif //INDICES_DIMS
+
#elif AXIS == 3
- const uint index = *(__global const uint *)vector_offset(&indices, pw);
+#if INDICES_DIMS == 1
+ const uint index = *(__global const uint *)tensor4D_offset(&indices, pw, 0, 0, 0);
const uint safe_index = select((uint)0, index, index < INDEX_LIMIT);
- __global const uchar *input_addr = tensor4D_offset(&input, px, py, pz, safe_index);
+ __global const uchar *input_addr = tensor4D_offset(&input, px, py, pz, safe_index);
+#endif //INDICES_DIMS
+
#endif //AXIS
*(__global DATA_TYPE *)output.ptr = select((DATA_TYPE)0, *((__global const DATA_TYPE *)input_addr), (DATA_TYPE)(index < INDEX_LIMIT));
diff --git a/src/core/CL/kernels/CLGatherKernel.cpp b/src/core/CL/kernels/CLGatherKernel.cpp
index 31a9a3bba4..5495023b80 100644
--- a/src/core/CL/kernels/CLGatherKernel.cpp
+++ b/src/core/CL/kernels/CLGatherKernel.cpp
@@ -38,8 +38,8 @@ inline Status validate_arguments(const ITensorInfo *input, const ITensorInfo *in
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output);
const uint32_t actual_axis = wrap_around(axis, static_cast<int>(input->num_dimensions()));
- ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 1);
- ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
+ ARM_COMPUTE_RETURN_ERROR_ON((input->num_dimensions() + indices->num_dimensions() - 1) > 4);
+
ARM_COMPUTE_RETURN_ERROR_ON(actual_axis >= input->num_dimensions());
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
@@ -102,7 +102,9 @@ void CLGatherKernel::configure(const CLCompileContext &compile_context, const IC
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(data_size_from_type(input->info()->data_type())));
build_opts.add_option("-DOUTPUT_DIM_Z=" + support::cpp11::to_string(output->info()->dimension(2)));
+ build_opts.add_option("-DINDICES_DIM_Z=" + support::cpp11::to_string(indices->info()->dimension(2)));
build_opts.add_option("-DINPUT_DIM_Z=" + support::cpp11::to_string(input->info()->dimension(2)));
+ build_opts.add_option("-DINDICES_DIMS=" + support::cpp11::to_string(indices->info()->num_dimensions()));
build_opts.add_option("-DAXIS=" + support::cpp11::to_string(_axis));
build_opts.add_option("-DINDEX_LIMIT=" + support::cpp11::to_string(input->info()->tensor_shape()[_axis]));
@@ -127,7 +129,7 @@ void CLGatherKernel::run(const Window &window, cl::CommandQueue &queue)
Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
unsigned int idx = 0;
add_4D_tensor_argument(idx, _input, window_collapsed);
- add_1D_tensor_argument(idx, _indices, window_collapsed);
+ add_4D_tensor_argument(idx, _indices, window_collapsed);
add_4D_tensor_argument(idx, _output, window_collapsed);
enqueue(queue, *this, window_collapsed, lws_hint());
}
diff --git a/tests/datasets/GatherDataset.h b/tests/datasets/GatherDataset.h
index 487ce19bc7..74ea3b4a06 100644
--- a/tests/datasets/GatherDataset.h
+++ b/tests/datasets/GatherDataset.h
@@ -126,6 +126,44 @@ public:
}
};
+class CLSmallGatherMultiDimIndicesDataset final : public GatherDataset
+{
+public:
+ CLSmallGatherMultiDimIndicesDataset()
+ {
+ add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 0);
+ add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 0);
+ add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 0);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 0);
+
+ add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 0);
+ add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 0);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 0);
+
+ add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),0);
+
+ add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 1);
+ add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 1);
+ add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 1);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 1);
+
+ add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 1);
+ add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 1);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 1);
+
+ add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),1);
+
+ add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 2);
+ add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 2);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 2);
+
+ add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 2);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 2);
+
+ add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),2);
+ }
+};
+
class SmallGatherDataset final : public GatherDataset
{
public:
diff --git a/tests/validation/CL/Gather.cpp b/tests/validation/CL/Gather.cpp
index f0b87d7d9f..7619baae1e 100644
--- a/tests/validation/CL/Gather.cpp
+++ b/tests/validation/CL/Gather.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -48,19 +48,21 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 27U), 1, DataType::F16),
TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),
TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),
- TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices data type
- TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices dimensionality
- TensorInfo(TensorShape(5U, 5U, 5U, 5U, 5U), 1, DataType::F32), // Invalid Input dimensionality
- TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Mismatching data type input/output
- TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid positive axis value
- TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Invalid negative axis value
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Output shape
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices data type
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices dimensionality
+ TensorInfo(TensorShape(5U, 5U, 5U, 5U, 5U), 1, DataType::F32), // Invalid Input dimensionality
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Mismatching data type input/output
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid positive axis value
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Invalid negative axis value
}),
framework::dataset::make("IndicesInfo", {
TensorInfo(TensorShape(10U), 1, DataType::U32),
TensorInfo(TensorShape(10U), 1, DataType::U32),
TensorInfo(TensorShape(10U), 1, DataType::U32),
- TensorInfo(TensorShape(10U), 1, DataType::U8),
TensorInfo(TensorShape(10U, 10U), 1, DataType::U32),
+ TensorInfo(TensorShape(10U), 1, DataType::U8),
+ TensorInfo(TensorShape(10U, 10U, 10U, 10U), 1, DataType::U32),
TensorInfo(TensorShape(10U), 1, DataType::U32),
TensorInfo(TensorShape(10U), 1, DataType::U32),
TensorInfo(TensorShape(10U), 1, DataType::U32),
@@ -71,7 +73,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
TensorInfo(TensorShape(27U, 10U), 1, DataType::F32),
TensorInfo(TensorShape(10U, 27U), 1, DataType::F32),
TensorInfo(TensorShape(10U, 27U), 1, DataType::F32),
- TensorInfo(TensorShape(27U, 10U), 1, DataType::F32),
+ TensorInfo(TensorShape(10U, 27U), 1, DataType::F32),
+ TensorInfo(TensorShape(27U, 10U, 10U, 10U, 10U), 1, DataType::F32),
TensorInfo(TensorShape(10U, 5U, 5U, 5U, 5U), 1, DataType::F32),
TensorInfo(TensorShape(27U, 10U), 1, DataType::F32),
TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),
@@ -82,13 +85,14 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
1,
-2,
0,
+ 0,
1,
0,
1,
2,
-3,
})),
- framework::dataset::make("Expected", { true, true, true, false, false, false, false, false, false })),
+ framework::dataset::make("Expected", { true, true, true, false, false, false, false, false, false, false })),
input_info, indices_info, output_info, axis, expected)
{
const Status status = CLGather::validate(&input_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), axis);
@@ -111,6 +115,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+ CLGatherFixture<half>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLGatherFixture<half>,
framework::DatasetMode::NIGHTLY,
@@ -131,6 +144,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+ CLGatherFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLGatherFixture<float>,
framework::DatasetMode::NIGHTLY,
@@ -152,6 +174,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+ CLGatherFixture<uint8_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::U8)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLGatherFixture<uint8_t>,
framework::DatasetMode::NIGHTLY,
@@ -172,6 +203,16 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+ CLGatherFixture<uint16_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::U16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLGatherFixture<uint16_t>,
framework::DatasetMode::NIGHTLY,
diff --git a/tests/validation/fixtures/GatherFixture.h b/tests/validation/fixtures/GatherFixture.h
index f6f70023b9..b28f93d850 100644
--- a/tests/validation/fixtures/GatherFixture.h
+++ b/tests/validation/fixtures/GatherFixture.h
@@ -69,10 +69,9 @@ protected:
// 10% of the time the index is out-of-range.
uint32_t max_index = input_shape[actual_axis] + input_shape[actual_axis] / 9 + 1;
-
std::uniform_int_distribution<uint32_t> dist_index(0, max_index - 1);
- //Let's consider 1D indices
- for(unsigned int ind = 0; ind < indices_shape[0]; ind++)
+
+ for(unsigned int ind = 0; ind < indices_shape.total_size(); ind++)
{
indices_ptr[ind] = dist_index(gen);
}