aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEGatherKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEGatherKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEGatherKernel.cpp183
1 files changed, 99 insertions, 84 deletions
diff --git a/src/core/NEON/kernels/NEGatherKernel.cpp b/src/core/NEON/kernels/NEGatherKernel.cpp
index 7090da8015..f1d457d399 100644
--- a/src/core/NEON/kernels/NEGatherKernel.cpp
+++ b/src/core/NEON/kernels/NEGatherKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,10 +27,10 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
-#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "src/core/CPP/Validate.h"
+
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
@@ -38,40 +38,27 @@ namespace arm_compute
{
namespace
{
-/** Validate the indices
- *
- * Validate that indices are not negative
- *
- * @param[in] indices Indices tensor info.
- */
-template <typename U>
-void validate_indices(const ITensor *indices)
-{
- for(size_t i = 0; i < indices->info()->tensor_shape()[0]; ++i)
- {
- ARM_COMPUTE_ERROR_ON(*(reinterpret_cast<U *>(indices->ptr_to_element(Coordinates(i)))) < 0);
- }
-}
-
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output);
- ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 1);
ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
- if(axis < 0)
+ if (axis < 0)
{
axis += input->num_dimensions();
}
ARM_COMPUTE_RETURN_ERROR_ON(0 > axis || axis >= static_cast<int32_t>(input->num_dimensions()));
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() + indices->num_dimensions() - 1 >
+ Coordinates::num_max_dimensions);
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
- if(output->total_size() != 0)
+ if (output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
- TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(input->tensor_shape(), indices->tensor_shape(), axis);
+ TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(
+ input->tensor_shape(), indices->tensor_shape(), axis);
ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() != output->tensor_shape().total_size());
}
@@ -82,53 +69,70 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices,
} // namespace
NEGatherKernel::NEGatherKernel()
- : _input{}, _indices{}, _axis{}, _output{}, _func{}
+ : _input{}, _indices{}, _axis{}, _output{}, _func{}, _src_it_strides{}, _idx_it_strides{}
{
}
-template <typename U>
-inline void NEGatherKernel::gather_0_axis(const Window &window, const ThreadInfo &info)
+template <typename TIndex>
+void NEGatherKernel::gather_common(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
- // Validate that the indices are not negative
- validate_indices<U>(_indices);
+ auto dst_win = window;
- Iterator output_it(_output, window);
- execute_window_loop(window, [&](const Coordinates & id)
- {
- Coordinates gather_id(id);
+ const auto src_info = _input->info();
+ const auto idx_info = _indices->info();
+ const auto dst_info = _output->info();
- auto new_index = *(reinterpret_cast<U *>(_indices->ptr_to_element(Coordinates(id[0]))));
- gather_id.set(0, new_index);
+ const auto num_dims = dst_info->num_dimensions();
+ const auto chunk_stride = src_info->strides_in_bytes()[_axis];
- std::copy_n(_input->ptr_to_element(gather_id), _output->info()->element_size(), output_it.ptr());
- },
- output_it);
-}
+ const auto window_start_x = window.x().start();
+ const auto window_end_x = window.x().end();
+ auto window_size_x = src_info->element_size();
-template <typename U>
-void NEGatherKernel::gather_n_axis(const Window &window, const ThreadInfo &info)
-{
- ARM_COMPUTE_UNUSED(info);
+ const auto idx_limit = static_cast<TIndex>(src_info->tensor_shape()[_axis]);
- // Validate that the indices are not negative
- validate_indices<U>(_indices);
+ if (_axis != 0)
+ {
+ dst_win.set(0, Window::Dimension(window_start_x, window_start_x + 1, 1));
+ window_size_x *= window_end_x - window_start_x;
+ }
- Window output_window{ window };
- output_window.set(Window::DimX, Window::Dimension(0, 1, 1));
+ // Compute source and index tensors window based on the output window.
+ auto src_win = dst_win;
+ Window idx_win;
- Iterator output_it(_output, output_window);
- execute_window_loop(output_window, [&](const Coordinates & id)
+ for (size_t i = 0; i < idx_info->num_dimensions(); ++i)
{
- Coordinates gather_id(id);
+ src_win.set(_axis + i, Window::Dimension(0, 1, 1));
+ idx_win.set(_axis + i, window[_axis + i]);
+ }
- auto new_index = *(reinterpret_cast<U *>(_indices->ptr_to_element(Coordinates(id[_axis]))));
- gather_id.set(_axis, new_index);
+ // Use the custom strides to access all three tensors using the same loop.
+ Iterator src_it(num_dims, _src_it_strides, _input->buffer(), src_info->offset_first_element_in_bytes(), src_win);
+ Iterator idx_it(num_dims, _idx_it_strides, _indices->buffer(), idx_info->offset_first_element_in_bytes(), idx_win);
+ Iterator dst_it(num_dims, dst_info->strides_in_bytes(), _output->buffer(),
+ dst_info->offset_first_element_in_bytes(), dst_win);
- std::copy_n(_input->ptr_to_element(gather_id), _input->info()->dimension(0) * _output->info()->element_size(), output_it.ptr());
- },
- output_it);
+ execute_window_loop(
+ dst_win,
+ [&](const Coordinates &)
+ {
+ const auto idx = *reinterpret_cast<const TIndex *>(idx_it.ptr());
+
+ if (idx >= 0 && idx < idx_limit)
+ {
+ const auto src_ptr = src_it.ptr() + idx * chunk_stride;
+
+ std::copy_n(src_ptr, window_size_x, dst_it.ptr());
+ }
+ else
+ {
+ std::fill_n(dst_it.ptr(), window_size_x, 0);
+ }
+ },
+ src_it, idx_it, dst_it);
}
void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITensor *output, int axis)
@@ -141,53 +145,64 @@ void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITe
_output = output;
_axis = axis;
- if(_axis < 0)
+ if (_axis < 0)
{
_axis += input->info()->num_dimensions();
}
ARM_COMPUTE_ERROR_ON(0 > _axis || _axis >= static_cast<int32_t>(input->info()->num_dimensions()));
- if(0 == _axis)
+ switch (_indices->info()->data_type())
{
- switch(_indices->info()->data_type())
- {
- case DataType::U32:
- _func = &NEGatherKernel::gather_0_axis<uint32_t>;
- break;
- case DataType::S32:
- _func = &NEGatherKernel::gather_0_axis<int32_t>;
- break;
- default:
- ARM_COMPUTE_ERROR("Not supported");
- break;
- }
- }
- else
- {
- switch(_indices->info()->data_type())
- {
- case DataType::U32:
- _func = &NEGatherKernel::gather_n_axis<uint32_t>;
- break;
- case DataType::S32:
- _func = &NEGatherKernel::gather_n_axis<int32_t>;
- break;
- default:
- ARM_COMPUTE_ERROR("Not supported");
- break;
- }
+ case DataType::U32:
+ _func = &NEGatherKernel::gather_common<uint32_t>;
+ break;
+ case DataType::S32:
+ _func = &NEGatherKernel::gather_common<int32_t>;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ break;
}
+
// Output auto initialization if not yet initialized
- TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis);
+ const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(
+ input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis);
auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
// Create window
Window win = calculate_max_window(*output->info(), Steps());
INEKernel::configure(win);
+
+ // Create input and indices strides that have the same number of dimensions as the output tensor.
+ // These will be used to iterate lock-step through all tensors (input, indices and output).
+ size_t dim_no = 0;
+
+ const auto input_info = input->info();
+ const auto &input_strides = input_info->strides_in_bytes();
+
+ const auto indices_info = indices->info();
+ const auto &indices_strides = indices_info->strides_in_bytes();
+ const auto indices_num_dims = indices_info->num_dimensions();
+
+ for (; dim_no < static_cast<size_t>(_axis); ++dim_no)
+ {
+ _src_it_strides[dim_no] = input_strides[dim_no];
+ }
+
+ for (; dim_no < static_cast<size_t>(_axis) + indices_num_dims; ++dim_no)
+ {
+ _idx_it_strides[dim_no] = indices_strides[dim_no - _axis];
+ }
+
+ for (; dim_no < Coordinates::num_max_dimensions; ++dim_no)
+ {
+ _src_it_strides[dim_no] = input_strides[dim_no - indices_num_dims + 1];
+ }
}
-Status NEGatherKernel::validate(const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis)
+Status
+NEGatherKernel::validate(const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, indices, output, axis));
return Status{};