aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLStridedSliceKernel.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-12-10 18:45:35 +0000
committerPablo Marquez <pablo.tello@arm.com>2018-12-14 15:27:18 +0000
commitb4af2c6738614850aaca3754904f0e8e3b17f0b2 (patch)
treea2d234a99d0599c325311c73a4e4f2df019eb3ee /src/core/CL/kernels/CLStridedSliceKernel.cpp
parentbf9731edfa0439cad4d70efc3065e71e199c62b8 (diff)
downloadComputeLibrary-b4af2c6738614850aaca3754904f0e8e3b17f0b2.tar.gz
COMPMID-1710: Fixes in StrideSlice calculations.
Change-Id: I66eb922f1ff15142de278bf4439a61c979f98ba7 Reviewed-on: https://review.mlplatform.org/382 Reviewed-by: Matthew Bentham <matthew.bentham@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLStridedSliceKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLStridedSliceKernel.cpp14
1 files changed, 10 insertions, 4 deletions
diff --git a/src/core/CL/kernels/CLStridedSliceKernel.cpp b/src/core/CL/kernels/CLStridedSliceKernel.cpp
index 3828a48d02..c40f3c9f0b 100644
--- a/src/core/CL/kernels/CLStridedSliceKernel.cpp
+++ b/src/core/CL/kernels/CLStridedSliceKernel.cpp
@@ -32,6 +32,7 @@
#include "arm_compute/core/Window.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/helpers/bit_ops.h"
#include "arm_compute/core/utils/helpers/tensor_transform.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
@@ -114,9 +115,11 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output,
const TensorShape &input_shape = input->info()->tensor_shape();
- const Coordinates final_strides = arm_compute::helpers::tensor_transform::strided_slice_strides(input_shape, strides);
- const Coordinates starts_abs = arm_compute::helpers::tensor_transform::strided_slice_absolute_start_coords(input_shape, starts, final_strides, begin_mask);
- const Coordinates ends_abs = arm_compute::helpers::tensor_transform::strided_slice_absolute_end_coords(input_shape, starts_abs, ends, final_strides, end_mask, shrink_axis_mask);
+ Coordinates starts_abs, ends_abs, final_strides;
+ std::tie(starts_abs, ends_abs, final_strides) = arm_compute::helpers::tensor_transform::calculate_strided_slice_coords(
+ input_shape,
+ starts, ends, strides,
+ begin_mask, end_mask, shrink_axis_mask);
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), output->info(), starts, ends, strides, begin_mask, end_mask, shrink_axis_mask);
@@ -125,7 +128,8 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output,
// Enable multiple elements processing along x if stride_x is 1 and output width greater than the access vector size
const int vec_size_x = 16 / input->info()->element_size();
const int output_width_x = output->info()->tensor_shape().x();
- const bool multi_access_x = (final_strides.x() == 1) && (output_width_x / vec_size_x > 0);
+ const bool is_shrink_on_x = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, 0);
+ const bool multi_access_x = !is_shrink_on_x && (final_strides.x() == 1) && (output_width_x / vec_size_x > 0);
// Update window if needed
if(multi_access_x)
@@ -141,8 +145,10 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output,
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
{
+ const bool is_shrink = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, i);
build_opts.add_option("-DSTART_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(starts_abs[i]));
build_opts.add_option("-DSTRIDE_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(final_strides[i]));
+ build_opts.add_option_if(is_shrink, "-DSHRINK_" + support::cpp11::to_string(i));
}
build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(output_width_x - vec_size_x, 0)));
build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x));