diff options
Diffstat (limited to 'src/core/NEON/kernels/NEGatherKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEGatherKernel.cpp | 80 |
1 files changed, 45 insertions, 35 deletions
diff --git a/src/core/NEON/kernels/NEGatherKernel.cpp b/src/core/NEON/kernels/NEGatherKernel.cpp index 11332ffac8..f1d457d399 100644 --- a/src/core/NEON/kernels/NEGatherKernel.cpp +++ b/src/core/NEON/kernels/NEGatherKernel.cpp @@ -27,9 +27,10 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" -#include "arm_compute/core/utils/misc/ShapeCalculator.h" + #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -42,20 +43,22 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices, ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output); ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4); - if(axis < 0) + if (axis < 0) { axis += input->num_dimensions(); } ARM_COMPUTE_RETURN_ERROR_ON(0 > axis || axis >= static_cast<int32_t>(input->num_dimensions())); - ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() + indices->num_dimensions() - 1 > Coordinates::num_max_dimensions); + ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() + indices->num_dimensions() - 1 > + Coordinates::num_max_dimensions); ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); - if(output->total_size() != 0) + if (output->total_size() != 0) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output); - TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(input->tensor_shape(), indices->tensor_shape(), axis); + TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape( + input->tensor_shape(), indices->tensor_shape(), axis); ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() != output->tensor_shape().total_size()); } @@ -81,23 +84,23 @@ void NEGatherKernel::gather_common(const Window &window, const ThreadInfo &info) const auto idx_info = _indices->info(); const auto dst_info = _output->info(); - const auto num_dims = dst_info->num_dimensions(); + const auto num_dims = dst_info->num_dimensions(); const auto chunk_stride = src_info->strides_in_bytes()[_axis]; const auto window_start_x = window.x().start(); - const auto window_end_x = window.x().end(); - auto window_size_x = src_info->element_size(); + const auto window_end_x = window.x().end(); + auto window_size_x = src_info->element_size(); const auto idx_limit = static_cast<TIndex>(src_info->tensor_shape()[_axis]); - if(_axis != 0) + if (_axis != 0) { dst_win.set(0, Window::Dimension(window_start_x, window_start_x + 1, 1)); window_size_x *= window_end_x - window_start_x; } // Compute source and index tensors window based on the output window. - auto src_win = dst_win; + auto src_win = dst_win; Window idx_win; for (size_t i = 0; i < idx_info->num_dimensions(); ++i) @@ -109,22 +112,27 @@ void NEGatherKernel::gather_common(const Window &window, const ThreadInfo &info) // Use the custom strides to access all three tensors using the same loop. Iterator src_it(num_dims, _src_it_strides, _input->buffer(), src_info->offset_first_element_in_bytes(), src_win); Iterator idx_it(num_dims, _idx_it_strides, _indices->buffer(), idx_info->offset_first_element_in_bytes(), idx_win); - Iterator dst_it(num_dims, dst_info->strides_in_bytes(), _output->buffer(), dst_info->offset_first_element_in_bytes(), dst_win); - - execute_window_loop(dst_win, [&](const Coordinates &) { - const auto idx = *reinterpret_cast<const TIndex *>(idx_it.ptr()); - - if(idx >= 0 && idx < idx_limit) - { - const auto src_ptr = src_it.ptr() + idx * chunk_stride; + Iterator dst_it(num_dims, dst_info->strides_in_bytes(), _output->buffer(), + dst_info->offset_first_element_in_bytes(), dst_win); - std::copy_n(src_ptr, window_size_x, dst_it.ptr()); - } - else + execute_window_loop( + dst_win, + [&](const Coordinates &) { - std::fill_n(dst_it.ptr(), window_size_x, 0); - } - }, src_it, idx_it, dst_it); + const auto idx = *reinterpret_cast<const TIndex *>(idx_it.ptr()); + + if (idx >= 0 && idx < idx_limit) + { + const auto src_ptr = src_it.ptr() + idx * chunk_stride; + + std::copy_n(src_ptr, window_size_x, dst_it.ptr()); + } + else + { + std::fill_n(dst_it.ptr(), window_size_x, 0); + } + }, + src_it, idx_it, dst_it); } void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITensor *output, int axis) @@ -137,13 +145,13 @@ void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITe _output = output; _axis = axis; - if(_axis < 0) + if (_axis < 0) { _axis += input->info()->num_dimensions(); } ARM_COMPUTE_ERROR_ON(0 > _axis || _axis >= static_cast<int32_t>(input->info()->num_dimensions())); - switch(_indices->info()->data_type()) + switch (_indices->info()->data_type()) { case DataType::U32: _func = &NEGatherKernel::gather_common<uint32_t>; @@ -157,7 +165,8 @@ void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITe } // Output auto initialization if not yet initialized - const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis); + const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape( + input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis); auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape)); // Create window @@ -169,30 +178,31 @@ void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITe // These will be used to iterate lock-step through all tensors (input, indices and output). size_t dim_no = 0; - const auto input_info = input->info(); + const auto input_info = input->info(); const auto &input_strides = input_info->strides_in_bytes(); - const auto indices_info = indices->info(); - const auto &indices_strides = indices_info->strides_in_bytes(); - const auto indices_num_dims = indices_info->num_dimensions(); + const auto indices_info = indices->info(); + const auto &indices_strides = indices_info->strides_in_bytes(); + const auto indices_num_dims = indices_info->num_dimensions(); - for(; dim_no < static_cast<size_t>(_axis); ++dim_no) + for (; dim_no < static_cast<size_t>(_axis); ++dim_no) { _src_it_strides[dim_no] = input_strides[dim_no]; } - for(; dim_no < static_cast<size_t>(_axis) + indices_num_dims; ++dim_no) + for (; dim_no < static_cast<size_t>(_axis) + indices_num_dims; ++dim_no) { _idx_it_strides[dim_no] = indices_strides[dim_no - _axis]; } - for(; dim_no < Coordinates::num_max_dimensions; ++dim_no) + for (; dim_no < Coordinates::num_max_dimensions; ++dim_no) { _src_it_strides[dim_no] = input_strides[dim_no - indices_num_dims + 1]; } } -Status NEGatherKernel::validate(const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis) +Status +NEGatherKernel::validate(const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis) { ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, indices, output, axis)); return Status{}; |