diff options
Diffstat (limited to 'src/core/NEON/kernels/NEGatherKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEGatherKernel.cpp | 204 |
1 files changed, 86 insertions, 118 deletions
diff --git a/src/core/NEON/kernels/NEGatherKernel.cpp b/src/core/NEON/kernels/NEGatherKernel.cpp index 085ab7cb18..d361eb93fd 100644 --- a/src/core/NEON/kernels/NEGatherKernel.cpp +++ b/src/core/NEON/kernels/NEGatherKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022 Arm Limited. + * Copyright (c) 2019-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -30,7 +30,6 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "src/core/CPP/Validate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -69,7 +68,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices, } ARM_COMPUTE_RETURN_ERROR_ON(0 > axis || axis >= static_cast<int32_t>(input->num_dimensions())); - ARM_COMPUTE_RETURN_ERROR_ON(axis != 1 && indices->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() + indices->num_dimensions() - 1 > Coordinates::num_max_dimensions); ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); if(output->total_size() != 0) @@ -87,84 +86,55 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices, } // namespace NEGatherKernel::NEGatherKernel() - : _input{}, _indices{}, _axis{}, _output{}, _func{} + : _input{}, _indices{}, _axis{}, _output{}, _func{}, _src_it_strides{}, _idx_it_strides{} { } -template <typename U> -inline void NEGatherKernel::gather_multiindices_1_axis(const Window &window, const ThreadInfo &info) -{ - ARM_COMPUTE_UNUSED(info); - ARM_COMPUTE_ERROR_ON(_indices->info()->num_dimensions() < 2 || _indices->info()->num_dimensions() > 3); - validate_indices<U>(_indices); - Window win = window; - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - execute_window_loop(win, [&](const Coordinates & id) - { - auto *dst_ptr = _output->ptr_to_element(id); - Coordinates index_offset; - for(uint32_t k = 0; k < _indices->info()->num_dimensions(); ++k) - { - index_offset.set(k, id[k + 1]); - } - const uint32_t row = *(reinterpret_cast<uint32_t *>(_indices->ptr_to_element(index_offset))); - Coordinates src_offset; - // Set up input coords to read the row specified by the current index - src_offset.set(0, 0); - src_offset.set(1, row); - for(uint32_t j = 2; j < _input->info()->num_dimensions(); ++j) - { - src_offset.set(j, id[1 + _indices->info()->num_dimensions() + (j - 2)]); - } - const auto in_ptr_row = _input->ptr_to_element(src_offset); - // Copy a row from input to output - memcpy(dst_ptr, in_ptr_row, _input->info()->tensor_shape()[0] * _input->info()->element_size()); - }); -} - -template <typename U> -inline void NEGatherKernel::gather_0_axis(const Window &window, const ThreadInfo &info) +template <typename TIndex> +void NEGatherKernel::gather_common(const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); - // Validate that the indices are not negative - validate_indices<U>(_indices); - - Iterator output_it(_output, window); - execute_window_loop(window, [&](const Coordinates & id) - { - Coordinates gather_id(id); + auto dst_win = window; - auto new_index = *(reinterpret_cast<U *>(_indices->ptr_to_element(Coordinates(id[0])))); - gather_id.set(0, new_index); + const auto src_info = _input->info(); + const auto idx_info = _indices->info(); + const auto dst_info = _output->info(); - std::copy_n(_input->ptr_to_element(gather_id), _output->info()->element_size(), output_it.ptr()); - }, - output_it); -} + const auto num_dims = dst_info->num_dimensions(); + const auto chunk_stride = src_info->strides_in_bytes()[_axis]; -template <typename U> -void NEGatherKernel::gather_n_axis(const Window &window, const ThreadInfo &info) -{ - ARM_COMPUTE_UNUSED(info); + const auto window_start_x = window.x().start(); + const auto window_end_x = window.x().end(); + auto window_size_x = src_info->element_size(); - // Validate that the indices are not negative - validate_indices<U>(_indices); + if(_axis != 0) + { + dst_win.set(0, Window::Dimension(window_start_x, window_start_x + 1, 1)); + window_size_x *= window_end_x - window_start_x; + } - Window output_window{ window }; - output_window.set(Window::DimX, Window::Dimension(0, 1, 1)); + // Compute source and index tensors window based on the output window. + auto src_win = dst_win; + Window idx_win; - Iterator output_it(_output, output_window); - execute_window_loop(output_window, [&](const Coordinates & id) + for (size_t i = 0; i < idx_info->num_dimensions(); ++i) { - Coordinates gather_id(id); + src_win.set(_axis + i, Window::Dimension(0, 1, 1)); + idx_win.set(_axis + i, window[_axis + i]); + } - auto new_index = *(reinterpret_cast<U *>(_indices->ptr_to_element(Coordinates(id[_axis])))); - gather_id.set(_axis, new_index); + // Use the custom strides to access all three tensors using the same loop. + Iterator src_it(num_dims, _src_it_strides, _input->buffer(), src_info->offset_first_element_in_bytes(), src_win); + Iterator idx_it(num_dims, _idx_it_strides, _indices->buffer(), idx_info->offset_first_element_in_bytes(), idx_win); + Iterator dst_it(num_dims, dst_info->strides_in_bytes(), _output->buffer(), dst_info->offset_first_element_in_bytes(), dst_win); - std::copy_n(_input->ptr_to_element(gather_id), _input->info()->dimension(0) * _output->info()->element_size(), output_it.ptr()); - }, - output_it); + execute_window_loop(dst_win, [&](const Coordinates &) { + const auto idx = *reinterpret_cast<const TIndex *>(idx_it.ptr()); + const auto src_ptr = src_it.ptr() + idx * chunk_stride; + + std::copy_n(src_ptr, window_size_x, dst_it.ptr()); + }, src_it, idx_it, dst_it); } void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITensor *output, int axis) @@ -183,60 +153,17 @@ void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITe } ARM_COMPUTE_ERROR_ON(0 > _axis || _axis >= static_cast<int32_t>(input->info()->num_dimensions())); - if(indices->info()->num_dimensions() == 1u) + switch(_indices->info()->data_type()) { - if(_axis == 0) - { - switch(_indices->info()->data_type()) - { - case DataType::U32: - _func = &NEGatherKernel::gather_0_axis<uint32_t>; - break; - case DataType::S32: - _func = &NEGatherKernel::gather_0_axis<int32_t>; - break; - default: - ARM_COMPUTE_ERROR("Not supported"); - break; - } - } - else - { - switch(_indices->info()->data_type()) - { - case DataType::U32: - _func = &NEGatherKernel::gather_n_axis<uint32_t>; - break; - case DataType::S32: - _func = &NEGatherKernel::gather_n_axis<int32_t>; - break; - default: - ARM_COMPUTE_ERROR("Not supported"); - break; - } - } - } - else - { - if(_axis == 1) - { - switch(_indices->info()->data_type()) - { - case DataType::U32: - _func = &NEGatherKernel::gather_multiindices_1_axis<uint32_t>; - break; - case DataType::S32: - _func = &NEGatherKernel::gather_multiindices_1_axis<int32_t>; - break; - default: - ARM_COMPUTE_ERROR("Not supported"); - break; - } - } - else - { + case DataType::U32: + _func = &NEGatherKernel::gather_common<uint32_t>; + break; + case DataType::S32: + _func = &NEGatherKernel::gather_common<int32_t>; + break; + default: ARM_COMPUTE_ERROR("Not supported"); - } + break; } // Output auto initialization if not yet initialized @@ -247,6 +174,32 @@ void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITe Window win = calculate_max_window(*output->info(), Steps()); INEKernel::configure(win); + + // Create input and indices strides that have the same number of dimensions as the output tensor. + // These will be used to iterate lock-step through all tensors (input, indices and output). + size_t dim_no = 0; + + const auto input_info = input->info(); + const auto &input_strides = input_info->strides_in_bytes(); + + const auto indices_info = indices->info(); + const auto &indices_strides = indices_info->strides_in_bytes(); + const auto indices_num_dims = indices_info->num_dimensions(); + + for(; dim_no < static_cast<size_t>(_axis); ++dim_no) + { + _src_it_strides[dim_no] = input_strides[dim_no]; + } + + for(; dim_no < static_cast<size_t>(_axis) + indices_num_dims; ++dim_no) + { + _idx_it_strides[dim_no] = indices_strides[dim_no - _axis]; + } + + for(; dim_no < Coordinates::num_max_dimensions; ++dim_no) + { + _src_it_strides[dim_no] = input_strides[dim_no - indices_num_dims + 1]; + } } Status NEGatherKernel::validate(const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis) @@ -261,6 +214,21 @@ void NEGatherKernel::run(const Window &window, const ThreadInfo &info) ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON(_func == nullptr); + switch(_indices->info()->data_type()) + { + case DataType::U32: + validate_indices<uint32_t>(_indices); + break; + + case DataType::S32: + validate_indices<int32_t>(_indices); + break; + + default: + ARM_COMPUTE_ERROR("Not supported"); + break; + } + (this->*_func)(window, info); } |