diff options
Diffstat (limited to 'src/core/CL')
-rw-r--r-- | src/core/CL/cl_kernels/slice_ops.cl | 36 | ||||
-rw-r--r-- | src/core/CL/kernels/CLStridedSliceKernel.cpp | 14 |
2 files changed, 42 insertions, 8 deletions
diff --git a/src/core/CL/cl_kernels/slice_ops.cl b/src/core/CL/cl_kernels/slice_ops.cl index bc3df47345..97decee6fc 100644 --- a/src/core/CL/cl_kernels/slice_ops.cl +++ b/src/core/CL/cl_kernels/slice_ops.cl @@ -64,7 +64,9 @@ __kernel void strided_slice( int offset = 0; // Offset X -#if defined(START_0) && defined(STRIDE_0) && defined(VEC_SIZE) && defined(LAST_ACCESSED_X) +#if defined(SHRINK_0) + input.ptr += (int)START_0 * input_stride_x; +#elif defined(START_0) && defined(STRIDE_0) && defined(VEC_SIZE) && defined(LAST_ACCESSED_X) // Check if access on width gets out of bounds // If it does shift access vector to access elements within bounds const int xi = (int)(get_global_id(0) * VEC_SIZE); @@ -77,20 +79,46 @@ __kernel void strided_slice( #endif // defined(START_0) && defined(STRIDE_0) // Offset Y -#if defined(START_1) && defined(STRIDE_1) +#if defined(SHRINK_1) + input.ptr += (int)START_1 * input_stride_y; +#elif defined(START_1) && defined(STRIDE_1) +#if defined(SHRINK_0) + offset = (int)START_1 + (int)get_global_id(0) * (int)STRIDE_1; +#else // defined(SHRINK_0) offset = (int)START_1 + (int)get_global_id(1) * (int)STRIDE_1; +#endif // defined(SHRINK_0) input.ptr += offset * input_stride_y; #endif // defined(START_1) && defined(STRIDE_1) // Offset Z -#if defined(START_2) && defined(STRIDE_2) +#if defined(SHRINK_2) + input.ptr += (int)START_2 * input_stride_z; +#elif defined(START_2) && defined(STRIDE_2) + +#if defined(SHRINK_1) && defined(SHRINK_0) + offset = (int)START_2 + (int)get_global_id(0) * (int)STRIDE_2; +#elif defined(SHRINK_1) || defined(SHRINK_0) + offset = (int)START_2 + (int)get_global_id(1) * (int)STRIDE_2; +#else // defined(SHRINK_1) && defined(SHRINK_0) offset = (int)START_2 + ((int)get_global_id(2) % (int)DST_DEPTH) * (int)STRIDE_2; +#endif // defined(SHRINK_1) && defined(SHRINK_0) + input.ptr += offset * input_stride_z; #endif // defined(START_2) && defined(STRIDE_2) // Offset depth -#if defined(START_3) && defined(STRIDE_3) +#if defined(SHRINK_3) + input.ptr += (int)START_3 * input_stride_w; +#elif defined(START_3) && defined(STRIDE_3) +#if defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0) + offset = (int)START_3 + (int)get_global_id(0) * (int)STRIDE_3; +#elif !defined(SHRINK_2) && !defined(SHRINK_1) && !defined(SHRINK_0) offset = (int)START_3 + ((int)get_global_id(2) / (int)DST_DEPTH) * (int)STRIDE_3; +#elif(defined(SHRINK_0) && defined(SHRINK_1)) || (defined(SHRINK_1) && defined(SHRINK_2)) || (defined(SHRINK_0) && defined(SHRINK_2)) + offset = (int)START_3 + (int)get_global_id(1) * (int)STRIDE_3; +#else // defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0) + offset = (int)START_3 + ((int)get_global_id(2) % (int)DST_DEPTH) * (int)STRIDE_3; +#endif // defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0) input.ptr += offset * input_stride_w; #endif // defined(START_3) && defined(STRIDE_3) diff --git a/src/core/CL/kernels/CLStridedSliceKernel.cpp b/src/core/CL/kernels/CLStridedSliceKernel.cpp index 3828a48d02..c40f3c9f0b 100644 --- a/src/core/CL/kernels/CLStridedSliceKernel.cpp +++ b/src/core/CL/kernels/CLStridedSliceKernel.cpp @@ -32,6 +32,7 @@ #include "arm_compute/core/Window.h" #include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/helpers/bit_ops.h" #include "arm_compute/core/utils/helpers/tensor_transform.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" @@ -114,9 +115,11 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output, const TensorShape &input_shape = input->info()->tensor_shape(); - const Coordinates final_strides = arm_compute::helpers::tensor_transform::strided_slice_strides(input_shape, strides); - const Coordinates starts_abs = arm_compute::helpers::tensor_transform::strided_slice_absolute_start_coords(input_shape, starts, final_strides, begin_mask); - const Coordinates ends_abs = arm_compute::helpers::tensor_transform::strided_slice_absolute_end_coords(input_shape, starts_abs, ends, final_strides, end_mask, shrink_axis_mask); + Coordinates starts_abs, ends_abs, final_strides; + std::tie(starts_abs, ends_abs, final_strides) = arm_compute::helpers::tensor_transform::calculate_strided_slice_coords( + input_shape, + starts, ends, strides, + begin_mask, end_mask, shrink_axis_mask); // Configure kernel window auto win_config = validate_and_configure_window(input->info(), output->info(), starts, ends, strides, begin_mask, end_mask, shrink_axis_mask); @@ -125,7 +128,8 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output, // Enable multiple elements processing along x if stride_x is 1 and output width greater than the access vector size const int vec_size_x = 16 / input->info()->element_size(); const int output_width_x = output->info()->tensor_shape().x(); - const bool multi_access_x = (final_strides.x() == 1) && (output_width_x / vec_size_x > 0); + const bool is_shrink_on_x = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, 0); + const bool multi_access_x = !is_shrink_on_x && (final_strides.x() == 1) && (output_width_x / vec_size_x > 0); // Update window if needed if(multi_access_x) @@ -141,8 +145,10 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output, build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())); for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i) { + const bool is_shrink = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, i); build_opts.add_option("-DSTART_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(starts_abs[i])); build_opts.add_option("-DSTRIDE_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(final_strides[i])); + build_opts.add_option_if(is_shrink, "-DSHRINK_" + support::cpp11::to_string(i)); } build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(output_width_x - vec_size_x, 0))); build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x)); |