From a23183221e5ba2c02863d3aa673da224ba42e364 Mon Sep 17 00:00:00 2001 From: Omar Al Khatib Date: Wed, 4 Jan 2023 15:43:17 +0000 Subject: Improve the strided_slice layer on all data types Resolves : [COMPMID-5110] Signed-off-by: Omar Al Khatib Change-Id: I3889a79c311b697c56d7369305c862433e856487 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8903 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice Benchmark: Arm Jenkins --- src/core/NEON/kernels/NEStridedSliceKernel.cpp | 115 ++++++++++++++----------- 1 file changed, 67 insertions(+), 48 deletions(-) diff --git a/src/core/NEON/kernels/NEStridedSliceKernel.cpp b/src/core/NEON/kernels/NEStridedSliceKernel.cpp index 1d71339257..2b406a8b8b 100644 --- a/src/core/NEON/kernels/NEStridedSliceKernel.cpp +++ b/src/core/NEON/kernels/NEStridedSliceKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -86,43 +86,6 @@ std::pair validate_and_configure_window(const ITensorInfo *input return std::make_pair(Status{}, win); } - -void strided_slice_generic(const ITensor *input, ITensor *output, - const Coordinates &starts, const BiStrides &strides, int32_t shrink_axis_mask, - const Window &window) -{ - Iterator output_it(output, window); - const size_t width_size = input->info()->element_size(); - - const bool is_shrink_w = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, 0); - const bool is_shrink_h = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, 1); - const bool is_shrink_c = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, 2); - const bool is_shrink_n = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, 3); - - unsigned int index = 0; - const int idx_w = is_shrink_w ? 0 : index++; - const int idx_h = is_shrink_h ? 0 : index++; - const int idx_c = is_shrink_c ? 0 : index++; - const int idx_n = is_shrink_n ? 0 : index; - - BiStrides shrinked_strides; - shrinked_strides.set(0, is_shrink_w ? 0 : strides[0]); - shrinked_strides.set(1, is_shrink_h ? 0 : strides[1]); - shrinked_strides.set(2, is_shrink_c ? 0 : strides[2]); - shrinked_strides.set(3, is_shrink_n ? 0 : strides[3]); - - execute_window_loop(window, [&](const Coordinates & id) - { - const int w_coord = starts[0] + (id[idx_w] * shrinked_strides[0]); - const int h_coord = starts[1] + (id[idx_h] * shrinked_strides[1]); - const int c_coord = starts[2] + (id[idx_c] * shrinked_strides[2]); - const int n_coord = starts[3] + (id[idx_n] * shrinked_strides[3]); - - Coordinates in_coords(w_coord, h_coord, c_coord, n_coord); - std::copy_n(input->ptr_to_element(in_coords), width_size, output_it.ptr()); - }, - output_it); -} } // namespace NEStridedSliceKernel::NEStridedSliceKernel() @@ -136,17 +99,13 @@ void NEStridedSliceKernel::configure(const ITensorInfo *input, ITensorInfo *outp { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input, output, starts, ends, strides, begin_mask, end_mask, shrink_axis_mask)); - - _shrink_mask = shrink_axis_mask; - + _shrink_mask = shrink_axis_mask; const TensorShape &input_shape = input->tensor_shape(); - - Coordinates ends_abs; + Coordinates ends_abs; std::tie(_starts_abs, ends_abs, _final_strides) = arm_compute::helpers::tensor_transform::calculate_strided_slice_coords( input_shape, starts, ends, strides, begin_mask, end_mask, shrink_axis_mask); - // Configure kernel window auto win_config = validate_and_configure_window(input, output, starts, ends, strides, begin_mask, end_mask, shrink_axis_mask); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); @@ -171,9 +130,69 @@ void NEStridedSliceKernel::run_op(ITensorPack &tensors, const Window &window, co ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - // Dispatch kernel - strided_slice_generic(tensors.get_const_tensor(TensorType::ACL_SRC_0), - tensors.get_tensor(TensorType::ACL_DST), - _starts_abs, _final_strides, _shrink_mask, window); + const ITensor *input = tensors.get_const_tensor(TensorType::ACL_SRC_0); + const ITensor *output = tensors.get_tensor(TensorType::ACL_DST); + + size_t width_size = input->info()->element_size(); + + const bool is_shrink_x = arm_compute::helpers::bit_ops::is_bit_set(_shrink_mask, 0); + const bool is_shrink_y = arm_compute::helpers::bit_ops::is_bit_set(_shrink_mask, 1); + const bool is_shrink_z = arm_compute::helpers::bit_ops::is_bit_set(_shrink_mask, 2); + const bool is_shrink_w = arm_compute::helpers::bit_ops::is_bit_set(_shrink_mask, 3); + + unsigned int index = 0; + const int idx_x = is_shrink_x ? 0 : index++; + const int idx_y = is_shrink_y ? 0 : index++; + const int idx_z = is_shrink_z ? 0 : index++; + const int idx_w = is_shrink_w ? 0 : index; + + BiStrides shrinked_strides; + shrinked_strides.set(0, is_shrink_x ? 0 : _final_strides[0]); + shrinked_strides.set(1, is_shrink_y ? 0 : _final_strides[1]); + shrinked_strides.set(2, is_shrink_z ? 0 : _final_strides[2]); + shrinked_strides.set(3, is_shrink_w ? 0 : _final_strides[3]); + + Window win = window; + + size_t length_x = win.shape()[0]; + + if(_final_strides[0] == 1 && !is_shrink_x) + { + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + width_size = width_size * length_x; + } + + Iterator output_it(output, win); + + const int start_0 = _starts_abs[0]; + const int start_1 = _starts_abs[1]; + const int start_2 = _starts_abs[2]; + const int start_3 = _starts_abs[3]; + + const int shrinked_stride_0 = shrinked_strides[0]; + const int shrinked_stride_1 = shrinked_strides[1]; + const int shrinked_stride_2 = shrinked_strides[2]; + const int shrinked_stride_3 = shrinked_strides[3]; + + const int byte_increment_0 = static_cast(input->info()->strides_in_bytes()[0]); + const int byte_increment_1 = static_cast(input->info()->strides_in_bytes()[1]); + const int byte_increment_2 = static_cast(input->info()->strides_in_bytes()[2]); + const int byte_increment_3 = static_cast(input->info()->strides_in_bytes()[3]); + + uint8_t *input_base = input->ptr_to_element(Coordinates(0, 0, 0, 0)); + uint8_t *cur_ptr; + + execute_window_loop( + win, [&](const Coordinates & id) + { + cur_ptr = input_base; + cur_ptr += (start_0 + (id[idx_x] * shrinked_stride_0)) * byte_increment_0; + cur_ptr += (start_1 + (id[idx_y] * shrinked_stride_1)) * byte_increment_1; + cur_ptr += (start_2 + (id[idx_z] * shrinked_stride_2)) * byte_increment_2; + cur_ptr += (start_3 + (id[idx_w] * shrinked_stride_3)) * byte_increment_3; + + std::copy_n(cur_ptr, width_size, output_it.ptr()); + }, + output_it); } } // namespace arm_compute -- cgit v1.2.1