aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-02-24 15:52:21 +0000
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-03-08 15:09:25 +0000
commit37c989a58a04985dfdc21089c7dacc7e1925a4d0 (patch)
tree6e60ada38ceaf2b651cc44a481004abbb89ceae4
parent98aca0fda7f7c7c16bd2d1cf5386246ad796d9de (diff)
downloadComputeLibrary-37c989a58a04985dfdc21089c7dacc7e1925a4d0.tar.gz
Add support for arbitrary parameters for CPU Gather
* The shape of input and indices tensors, and the gather axis can be any number, as long as these are valid and the output tensor doesn't have more dimensions than the library supports. * Update the reference code to be more generic and straightforward. * Add necessary test cases. Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Resolves: COMPMID-5919 Change-Id: Ic7e2032777aa97ecc147f61d5388528697508ab1 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9199 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/Helpers.h23
-rw-r--r--arm_compute/core/Helpers.inl22
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h53
-rw-r--r--src/core/NEON/kernels/NEGatherKernel.cpp204
-rw-r--r--src/core/NEON/kernels/NEGatherKernel.h28
-rw-r--r--tests/datasets/GatherDataset.h9
-rw-r--r--tests/validation/reference/Gather.cpp74
7 files changed, 189 insertions, 224 deletions
diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h
index fd6e94c079..f19e1e12e0 100644
--- a/arm_compute/core/Helpers.h
+++ b/arm_compute/core/Helpers.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -55,6 +55,16 @@ public:
*/
Iterator(const ITensor *tensor, const Window &window);
+ /** Create a container iterator for the tensor with the specified number of dimensions, stride, buffer pointer and window.
+ *
+ * @param[in] num_dims The number of dimensions.
+ * @param[in] strides The strides in bytes.
+ * @param[in] buffer The data buffer.
+ * @param[in] offset The offset in bytes from the beginning of the buffer to the first element of the tensor.
+ * @param[in] window The window which will be used to iterate over the tensor.
+ */
+ Iterator(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &window);
+
/** Increment the iterator along the specified dimension of the step value associated to the dimension.
*
* @warning It is the caller's responsibility to call increment(dimension+1) when reaching the end of a dimension, the iterator will not check for overflow.
@@ -86,6 +96,17 @@ public:
void reset(size_t dimension);
private:
+
+ /** Initialize a container iterator for the tensor with the specified number of dimensions, stride, buffer pointer and window.
+ *
+ * @param[in] num_dims The number of dimensions.
+ * @param[in] strides The strides in bytes.
+ * @param[in] buffer The data buffer.
+ * @param[in] offset The offset in bytes from the beginning of the buffer to the first element of the tensor.
+ * @param[in] window The window which will be used to iterate over the tensor.
+ */
+ void initialize(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &window);
+
uint8_t *_ptr;
class Dimension
diff --git a/arm_compute/core/Helpers.inl b/arm_compute/core/Helpers.inl
index a910521f94..ff902bba20 100644
--- a/arm_compute/core/Helpers.inl
+++ b/arm_compute/core/Helpers.inl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -98,13 +98,23 @@ inline Iterator::Iterator(const ITensor *tensor, const Window &win)
ARM_COMPUTE_ERROR_ON(tensor == nullptr);
ARM_COMPUTE_ERROR_ON(tensor->info() == nullptr);
- const ITensorInfo *info = tensor->info();
- const Strides &strides = info->strides_in_bytes();
+ initialize(tensor->info()->num_dimensions(), tensor->info()->strides_in_bytes(), tensor->buffer(), tensor->info()->offset_first_element_in_bytes(), win);
+}
+
+inline Iterator::Iterator(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &win)
+ : Iterator()
+{
+ initialize(num_dims, strides, buffer, offset, win);
+}
+
+inline void Iterator::initialize(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(buffer == nullptr);
- _ptr = tensor->buffer() + info->offset_first_element_in_bytes();
+ _ptr = buffer + offset;
//Initialize the stride for each dimension and calculate the position of the first element of the iteration:
- for(unsigned int n = 0; n < info->num_dimensions(); ++n)
+ for(unsigned int n = 0; n < num_dims; ++n)
{
_dims[n]._stride = win[n].step() * strides[n];
std::get<0>(_dims)._dim_start += static_cast<size_t>(strides[n]) * win[n].start();
@@ -116,7 +126,7 @@ inline Iterator::Iterator(const ITensor *tensor, const Window &win)
_dims[n]._dim_start = std::get<0>(_dims)._dim_start;
}
- ARM_COMPUTE_ERROR_ON_WINDOW_DIMENSIONS_GTE(win, info->num_dimensions());
+ ARM_COMPUTE_ERROR_ON_WINDOW_DIMENSIONS_GTE(win, num_dims);
}
inline void Iterator::increment(const size_t dimension)
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 9e7c981814..94bd3aca03 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -1537,39 +1537,32 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn
*/
inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis)
{
- ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4);
- ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions());
- ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 3);
- TensorShape output_shape = input_shape;
- if(indices_shape.num_dimensions() == 1u)
+ const auto input_num_dims = input_shape.num_dimensions();
+ const auto indices_num_dims = indices_shape.num_dimensions();
+
+ ARM_COMPUTE_ERROR_ON(actual_axis >= input_num_dims);
+ ARM_COMPUTE_ERROR_ON(input_num_dims + indices_num_dims - 1 > Coordinates::num_max_dimensions);
+
+ TensorShape output_shape;
+ size_t dim_no = 0;
+
+ for(; dim_no < actual_axis; ++dim_no)
{
- output_shape[actual_axis] = indices_shape[0];
+ output_shape.set(dim_no, input_shape[dim_no]);
}
- else
+
+ for(; dim_no < actual_axis + indices_num_dims; ++dim_no)
{
- const auto ind_num_dims
- {
- indices_shape.num_dimensions()
- };
- output_shape.shift_right(ind_num_dims - 1);
- switch(actual_axis)
- {
- case 1:
- {
- output_shape[0] = input_shape[0];
- for(size_t idx = 0; idx < ind_num_dims; ++idx)
- {
- output_shape.set(actual_axis + idx, indices_shape[idx], false);
- }
- break;
- }
- default:
- {
- // 2d and 3d indices are only supported for axis == 1
- ARM_COMPUTE_ERROR_ON(actual_axis != 1 && indices_shape.num_dimensions() > 1);
- }
- }
+ output_shape.set(dim_no, indices_shape[dim_no - actual_axis]);
+ }
+
+ for(; dim_no < input_num_dims + indices_num_dims - 1; ++dim_no)
+ {
+ output_shape.set(dim_no, input_shape[dim_no + 1 - indices_num_dims]);
}
+
+ ARM_COMPUTE_ERROR_ON(input_shape.total_size() * indices_shape.total_size() != output_shape.total_size() * input_shape[actual_axis]);
+
return output_shape;
}
} // namespace shape_calculator
diff --git a/src/core/NEON/kernels/NEGatherKernel.cpp b/src/core/NEON/kernels/NEGatherKernel.cpp
index 085ab7cb18..d361eb93fd 100644
--- a/src/core/NEON/kernels/NEGatherKernel.cpp
+++ b/src/core/NEON/kernels/NEGatherKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,7 +30,6 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "src/core/CPP/Validate.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
@@ -69,7 +68,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices,
}
ARM_COMPUTE_RETURN_ERROR_ON(0 > axis || axis >= static_cast<int32_t>(input->num_dimensions()));
- ARM_COMPUTE_RETURN_ERROR_ON(axis != 1 && indices->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() + indices->num_dimensions() - 1 > Coordinates::num_max_dimensions);
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
if(output->total_size() != 0)
@@ -87,84 +86,55 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices,
} // namespace
NEGatherKernel::NEGatherKernel()
- : _input{}, _indices{}, _axis{}, _output{}, _func{}
+ : _input{}, _indices{}, _axis{}, _output{}, _func{}, _src_it_strides{}, _idx_it_strides{}
{
}
-template <typename U>
-inline void NEGatherKernel::gather_multiindices_1_axis(const Window &window, const ThreadInfo &info)
-{
- ARM_COMPUTE_UNUSED(info);
- ARM_COMPUTE_ERROR_ON(_indices->info()->num_dimensions() < 2 || _indices->info()->num_dimensions() > 3);
- validate_indices<U>(_indices);
- Window win = window;
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
- execute_window_loop(win, [&](const Coordinates & id)
- {
- auto *dst_ptr = _output->ptr_to_element(id);
- Coordinates index_offset;
- for(uint32_t k = 0; k < _indices->info()->num_dimensions(); ++k)
- {
- index_offset.set(k, id[k + 1]);
- }
- const uint32_t row = *(reinterpret_cast<uint32_t *>(_indices->ptr_to_element(index_offset)));
- Coordinates src_offset;
- // Set up input coords to read the row specified by the current index
- src_offset.set(0, 0);
- src_offset.set(1, row);
- for(uint32_t j = 2; j < _input->info()->num_dimensions(); ++j)
- {
- src_offset.set(j, id[1 + _indices->info()->num_dimensions() + (j - 2)]);
- }
- const auto in_ptr_row = _input->ptr_to_element(src_offset);
- // Copy a row from input to output
- memcpy(dst_ptr, in_ptr_row, _input->info()->tensor_shape()[0] * _input->info()->element_size());
- });
-}
-
-template <typename U>
-inline void NEGatherKernel::gather_0_axis(const Window &window, const ThreadInfo &info)
+template <typename TIndex>
+void NEGatherKernel::gather_common(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
- // Validate that the indices are not negative
- validate_indices<U>(_indices);
-
- Iterator output_it(_output, window);
- execute_window_loop(window, [&](const Coordinates & id)
- {
- Coordinates gather_id(id);
+ auto dst_win = window;
- auto new_index = *(reinterpret_cast<U *>(_indices->ptr_to_element(Coordinates(id[0]))));
- gather_id.set(0, new_index);
+ const auto src_info = _input->info();
+ const auto idx_info = _indices->info();
+ const auto dst_info = _output->info();
- std::copy_n(_input->ptr_to_element(gather_id), _output->info()->element_size(), output_it.ptr());
- },
- output_it);
-}
+ const auto num_dims = dst_info->num_dimensions();
+ const auto chunk_stride = src_info->strides_in_bytes()[_axis];
-template <typename U>
-void NEGatherKernel::gather_n_axis(const Window &window, const ThreadInfo &info)
-{
- ARM_COMPUTE_UNUSED(info);
+ const auto window_start_x = window.x().start();
+ const auto window_end_x = window.x().end();
+ auto window_size_x = src_info->element_size();
- // Validate that the indices are not negative
- validate_indices<U>(_indices);
+ if(_axis != 0)
+ {
+ dst_win.set(0, Window::Dimension(window_start_x, window_start_x + 1, 1));
+ window_size_x *= window_end_x - window_start_x;
+ }
- Window output_window{ window };
- output_window.set(Window::DimX, Window::Dimension(0, 1, 1));
+ // Compute source and index tensors window based on the output window.
+ auto src_win = dst_win;
+ Window idx_win;
- Iterator output_it(_output, output_window);
- execute_window_loop(output_window, [&](const Coordinates & id)
+ for (size_t i = 0; i < idx_info->num_dimensions(); ++i)
{
- Coordinates gather_id(id);
+ src_win.set(_axis + i, Window::Dimension(0, 1, 1));
+ idx_win.set(_axis + i, window[_axis + i]);
+ }
- auto new_index = *(reinterpret_cast<U *>(_indices->ptr_to_element(Coordinates(id[_axis]))));
- gather_id.set(_axis, new_index);
+ // Use the custom strides to access all three tensors using the same loop.
+ Iterator src_it(num_dims, _src_it_strides, _input->buffer(), src_info->offset_first_element_in_bytes(), src_win);
+ Iterator idx_it(num_dims, _idx_it_strides, _indices->buffer(), idx_info->offset_first_element_in_bytes(), idx_win);
+ Iterator dst_it(num_dims, dst_info->strides_in_bytes(), _output->buffer(), dst_info->offset_first_element_in_bytes(), dst_win);
- std::copy_n(_input->ptr_to_element(gather_id), _input->info()->dimension(0) * _output->info()->element_size(), output_it.ptr());
- },
- output_it);
+ execute_window_loop(dst_win, [&](const Coordinates &) {
+ const auto idx = *reinterpret_cast<const TIndex *>(idx_it.ptr());
+ const auto src_ptr = src_it.ptr() + idx * chunk_stride;
+
+ std::copy_n(src_ptr, window_size_x, dst_it.ptr());
+ }, src_it, idx_it, dst_it);
}
void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITensor *output, int axis)
@@ -183,60 +153,17 @@ void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITe
}
ARM_COMPUTE_ERROR_ON(0 > _axis || _axis >= static_cast<int32_t>(input->info()->num_dimensions()));
- if(indices->info()->num_dimensions() == 1u)
+ switch(_indices->info()->data_type())
{
- if(_axis == 0)
- {
- switch(_indices->info()->data_type())
- {
- case DataType::U32:
- _func = &NEGatherKernel::gather_0_axis<uint32_t>;
- break;
- case DataType::S32:
- _func = &NEGatherKernel::gather_0_axis<int32_t>;
- break;
- default:
- ARM_COMPUTE_ERROR("Not supported");
- break;
- }
- }
- else
- {
- switch(_indices->info()->data_type())
- {
- case DataType::U32:
- _func = &NEGatherKernel::gather_n_axis<uint32_t>;
- break;
- case DataType::S32:
- _func = &NEGatherKernel::gather_n_axis<int32_t>;
- break;
- default:
- ARM_COMPUTE_ERROR("Not supported");
- break;
- }
- }
- }
- else
- {
- if(_axis == 1)
- {
- switch(_indices->info()->data_type())
- {
- case DataType::U32:
- _func = &NEGatherKernel::gather_multiindices_1_axis<uint32_t>;
- break;
- case DataType::S32:
- _func = &NEGatherKernel::gather_multiindices_1_axis<int32_t>;
- break;
- default:
- ARM_COMPUTE_ERROR("Not supported");
- break;
- }
- }
- else
- {
+ case DataType::U32:
+ _func = &NEGatherKernel::gather_common<uint32_t>;
+ break;
+ case DataType::S32:
+ _func = &NEGatherKernel::gather_common<int32_t>;
+ break;
+ default:
ARM_COMPUTE_ERROR("Not supported");
- }
+ break;
}
// Output auto initialization if not yet initialized
@@ -247,6 +174,32 @@ void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITe
Window win = calculate_max_window(*output->info(), Steps());
INEKernel::configure(win);
+
+ // Create input and indices strides that have the same number of dimensions as the output tensor.
+ // These will be used to iterate lock-step through all tensors (input, indices and output).
+ size_t dim_no = 0;
+
+ const auto input_info = input->info();
+ const auto &input_strides = input_info->strides_in_bytes();
+
+ const auto indices_info = indices->info();
+ const auto &indices_strides = indices_info->strides_in_bytes();
+ const auto indices_num_dims = indices_info->num_dimensions();
+
+ for(; dim_no < static_cast<size_t>(_axis); ++dim_no)
+ {
+ _src_it_strides[dim_no] = input_strides[dim_no];
+ }
+
+ for(; dim_no < static_cast<size_t>(_axis) + indices_num_dims; ++dim_no)
+ {
+ _idx_it_strides[dim_no] = indices_strides[dim_no - _axis];
+ }
+
+ for(; dim_no < Coordinates::num_max_dimensions; ++dim_no)
+ {
+ _src_it_strides[dim_no] = input_strides[dim_no - indices_num_dims + 1];
+ }
}
Status NEGatherKernel::validate(const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis)
@@ -261,6 +214,21 @@ void NEGatherKernel::run(const Window &window, const ThreadInfo &info)
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON(_func == nullptr);
+ switch(_indices->info()->data_type())
+ {
+ case DataType::U32:
+ validate_indices<uint32_t>(_indices);
+ break;
+
+ case DataType::S32:
+ validate_indices<int32_t>(_indices);
+ break;
+
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ break;
+ }
+
(this->*_func)(window, info);
}
diff --git a/src/core/NEON/kernels/NEGatherKernel.h b/src/core/NEON/kernels/NEGatherKernel.h
index 3dc0cad7be..ce69daeda7 100644
--- a/src/core/NEON/kernels/NEGatherKernel.h
+++ b/src/core/NEON/kernels/NEGatherKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -81,27 +81,8 @@ public:
void run(const Window &window, const ThreadInfo &info) override;
private:
- /** Implementation of the gather operation for 0 axis.
- *
- * For gather on the 0 axis an element by element copy is performed.
- *
- * @param[in] window Region on which to run the kernel. (Must be a region of the window returned by window())
- * @param[in] info Info about running thread and CPU.
- */
- template <typename U>
- void gather_0_axis(const Window &window, const ThreadInfo &info);
-
- template <typename U>
- void gather_multiindices_1_axis(const Window &window, const ThreadInfo &info);
- /** Implementation of the gather operation.
- *
- * For 1<=axis a row-wise copy is taking place.
- *
- * @param[in] window Region on which to run the kernel. (Must be a region of the window returned by window())
- * @param[in] info Info about running thread and CPU.
- */
- template <typename U>
- void gather_n_axis(const Window &window, const ThreadInfo &info);
+ template <typename TIndex>
+ void gather_common(const Window &window, const ThreadInfo &info);
using kernel_ptr = void (NEGatherKernel::*)(const Window &window, const ThreadInfo &info);
@@ -110,6 +91,9 @@ private:
int _axis;
ITensor *_output;
kernel_ptr _func;
+
+ Strides _src_it_strides;
+ Strides _idx_it_strides;
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_NEGATHERKERNEL_H */
diff --git a/tests/datasets/GatherDataset.h b/tests/datasets/GatherDataset.h
index 8fec5441b1..487ce19bc7 100644
--- a/tests/datasets/GatherDataset.h
+++ b/tests/datasets/GatherDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019, 2022 Arm Limited.
+ * Copyright (c) 2018-2019, 2022-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -116,6 +116,13 @@ public:
add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 1);
add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 1);
add_config(TensorShape(1U, 5U, 3U), TensorShape(1U, 7U, 3U), 1);
+
+ add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 0);
+ add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 0);
+ add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 0);
+
+ add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U), 2);
+ add_config(TensorShape(8U, 2U, 3U), TensorShape(4U, 2U, 5U), 2);
}
};
diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp
index 8de1a473eb..12d1a3cd3c 100644
--- a/tests/validation/reference/Gather.cpp
+++ b/tests/validation/reference/Gather.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019, 2022 Arm Limited.
+ * Copyright (c) 2018-2019, 2022-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -39,61 +39,43 @@ namespace reference
template <typename T>
SimpleTensor<T> gather(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> &indices, uint32_t actual_axis)
{
- const auto *indices_ptr = static_cast<const uint32_t *>(indices.data());
const TensorShape dst_shape = arm_compute::misc::shape_calculator::compute_gather_shape(src.shape(), indices.shape(), actual_axis);
SimpleTensor<T> dst(dst_shape, src.data_type());
+ const auto src_ptr = static_cast<const T *>(src.data());
+ const auto indices_ptr = static_cast<const uint32_t *>(indices.data());
+ const auto dst_ptr = static_cast<T *>(dst.data());
+
Window win;
win.use_tensor_dimensions(dst_shape);
- if(indices.shape().num_dimensions() == 1u)
- {
- execute_window_loop(win, [&](const Coordinates & id)
+
+ execute_window_loop(win, [&](const Coordinates &dst_coords) {
+ // Calculate the coordinates of the index value.
+ Coordinates idx_coords;
+
+ for(size_t i = 0; i < indices.shape().num_dimensions(); ++i)
{
- Coordinates offset;
- for(unsigned int dim = 0; dim < id.num_dimensions(); ++dim)
- {
- if(dim == actual_axis)
- {
- offset.set(dim, indices_ptr[id[dim]]);
- }
- else
- {
- offset.set(dim, id[dim]);
- }
- }
- *reinterpret_cast<T *>(dst(id)) = *reinterpret_cast<const T *>(src(offset));
- });
- }
- else
- {
- if(actual_axis == 1)
+ idx_coords.set(i, dst_coords[i + actual_axis]);
+ }
+
+ // Calculate the coordinates of the source data.
+ Coordinates src_coords;
+
+ for(size_t i = 0; i < actual_axis; ++i)
{
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
- execute_window_loop(win, [&](const Coordinates & id)
- {
- auto *dst_ptr = dst(id);
- Coordinates index_offset;
- for(uint32_t k = 0; k < indices.shape().num_dimensions(); ++k)
- {
- index_offset.set(k, id[k + 1]);
- }
- const uint32_t row = *reinterpret_cast<const uint32_t *>(indices(index_offset));
- Coordinates src_offset;
- src_offset.set(0, 0);
- src_offset.set(1, row);
- for(uint32_t j = 2; j < src.shape().num_dimensions(); ++j)
- {
- src_offset.set(j, id[1 + indices.shape().num_dimensions() + (j - 2)]);
- }
- const auto in_ptr_row = src(src_offset);
- memcpy(dst_ptr, in_ptr_row, src.shape()[0] * src.element_size());
- });
+ src_coords.set(i, dst_coords[i]);
}
- else
+
+ src_coords.set(actual_axis, indices_ptr[coords2index(indices.shape(), idx_coords)]);
+
+ for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i)
{
- ARM_COMPUTE_ERROR("Not implemented.");
+ src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]);
}
- }
+
+ // Copy the data.
+ dst_ptr[coords2index(dst.shape(), dst_coords)] = src_ptr[coords2index(src.shape(), src_coords)];
+ });
return dst;
}