diff options
Diffstat (limited to 'src/core/CL/kernels')
-rw-r--r-- | src/core/CL/kernels/CLStridedSliceKernel.cpp | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/src/core/CL/kernels/CLStridedSliceKernel.cpp b/src/core/CL/kernels/CLStridedSliceKernel.cpp index f07436ac60..2d2ba103e5 100644 --- a/src/core/CL/kernels/CLStridedSliceKernel.cpp +++ b/src/core/CL/kernels/CLStridedSliceKernel.cpp @@ -55,10 +55,10 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, ARM_COMPUTE_RETURN_ERROR_ON(starts.num_dimensions() > input->num_dimensions()); ARM_COMPUTE_RETURN_ERROR_ON(ends.num_dimensions() > input->num_dimensions()); ARM_COMPUTE_RETURN_ERROR_ON(strides.num_dimensions() > input->num_dimensions()); - for(unsigned int i = 0; i < strides.num_dimensions(); ++i) + ARM_COMPUTE_RETURN_ERROR_ON(std::any_of(strides.cbegin(), strides.cbegin() + strides.num_dimensions(), [](int i) { - ARM_COMPUTE_RETURN_ERROR_ON(strides[i] == 0); - } + return i == 0; + })); // Get expected output shape const TensorShape exp_output_shape = arm_compute::misc::shape_calculator::compute_strided_slice_shape(*input, @@ -120,6 +120,19 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output, // Configure kernel window auto win_config = validate_and_configure_window(input->info(), output->info(), starts, ends, strides, begin_mask, end_mask, shrink_axis_mask); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + + // Enable multiple elements processing along x if stride_x is 1 and output width greater than the access vector size + const int vec_size_x = 16 / input->info()->element_size(); + const int output_width_x = output->info()->tensor_shape().x(); + const bool multi_access_x = (final_strides.x() == 1) && (output_width_x / vec_size_x > 0); + + // Update window if needed + if(multi_access_x) + { + Window &updated_window = std::get<1>(win_config); + updated_window.set(Window::DimX, + Window::Dimension(updated_window.x().start(), ceil_to_multiple(updated_window.x().end(), vec_size_x), vec_size_x)); + } ICLKernel::configure_internal(win_config.second); // Create build options @@ -130,6 +143,8 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output, build_opts.add_option("-DSTART_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(starts_abs[i])); build_opts.add_option("-DSTRIDE_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(final_strides[i])); } + build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(output_width_x - vec_size_x, 0))); + build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x)); build_opts.add_option_if_else(input_shape.num_dimensions() > 2, "-DSRC_DEPTH=" + support::cpp11::to_string(input_shape.z()), "-DSRC_DEPTH=1"); |