From b4af2c6738614850aaca3754904f0e8e3b17f0b2 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 10 Dec 2018 18:45:35 +0000 Subject: COMPMID-1710: Fixes in StrideSlice calculations. Change-Id: I66eb922f1ff15142de278bf4439a61c979f98ba7 Reviewed-on: https://review.mlplatform.org/382 Reviewed-by: Matthew Bentham Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez --- tests/datasets/SliceOperationsDataset.h | 6 ++++++ tests/validation/reference/SliceOperations.cpp | 28 ++++++++++++++------------ 2 files changed, 21 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/tests/datasets/SliceOperationsDataset.h b/tests/datasets/SliceOperationsDataset.h index b6df4040fd..e891419e9b 100644 --- a/tests/datasets/SliceOperationsDataset.h +++ b/tests/datasets/SliceOperationsDataset.h @@ -262,6 +262,12 @@ public: add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4), BiStrides(2, 1, 2), 0, 1); // 4D add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5), BiStrides(2, 1, 2, 3)); + + // Shrink axis + add_config(TensorShape(1U, 3U, 2U, 3U), Coordinates(0, 1, 0, 0), Coordinates(1, 1, 1, 1), BiStrides(1, 1, 1, 1), 0, 15, 6); + add_config(TensorShape(3U, 2U), Coordinates(0, 0), Coordinates(3U, 1U), BiStrides(1, 1), 0, 0, 2); + add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 0, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 6, 1); + add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 1, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 5, 3); } }; diff --git a/tests/validation/reference/SliceOperations.cpp b/tests/validation/reference/SliceOperations.cpp index 04b5b98453..40ca9de927 100644 --- a/tests/validation/reference/SliceOperations.cpp +++ b/tests/validation/reference/SliceOperations.cpp @@ -24,6 +24,7 @@ #include "SliceOperations.h" #include "arm_compute/core/utils/helpers/tensor_transform.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" namespace arm_compute { @@ -50,11 +51,8 @@ SimpleTensor slice(const SimpleTensor &src, Coordinates starts, Coordinate // Get source shape const TensorShape &src_shape = src.shape(); - // Get actual end - Coordinates ends_abs = slice_absolute_end_coords(src_shape, ends); - // Get destination shape - TensorShape dst_shape = compute_slice_output_shape(src_shape, starts, ends_abs); + TensorShape dst_shape = arm_compute::misc::shape_calculator::compute_slice_shape(src_shape, starts, ends); // Create destination tensor SimpleTensor dst{ dst_shape, src.data_type(), 1 }; @@ -98,20 +96,24 @@ SimpleTensor strided_slice(const SimpleTensor &src, // Get source shape const TensorShape &src_shape = src.shape(); - // Get actual start, end coordinates and strides - const Coordinates final_strides = strided_slice_strides(src_shape, strides); - const Coordinates starts_abs = strided_slice_absolute_start_coords(src_shape, starts, final_strides, begin_mask); - const Coordinates ends_abs = strided_slice_absolute_end_coords(src_shape, starts_abs, ends, final_strides, end_mask, shrink_axis_mask); - // Get destination shape - const TensorShape dst_shape = compute_strided_slice_output_shape(src_shape, starts_abs, ends_abs, final_strides); + const TensorShape dst_shape = compute_strided_slice_output_shape(src_shape, starts, ends, strides, begin_mask, end_mask, shrink_axis_mask); // Create destination tensor SimpleTensor dst{ dst_shape, src.data_type(), 1 }; + // Get coordinates + Coordinates starts_abs, ends_abs, final_strides; + std::tie(starts_abs, ends_abs, final_strides) = calculate_strided_slice_coords(src_shape, + starts, ends, strides, + begin_mask, end_mask, shrink_axis_mask); + // Perform strided slice - Window win; - win.use_tensor_dimensions(dst_shape); + unsigned int idx = 0; + Window win; + win.use_tensor_dimensions(compute_strided_slice_output_shape(src_shape, + starts, ends, strides, + begin_mask, end_mask, shrink_axis_mask, true)); execute_window_loop(win, [&](const Coordinates & id) { Coordinates offset; @@ -119,7 +121,7 @@ SimpleTensor strided_slice(const SimpleTensor &src, { offset.set(i, starts_abs[i] + id[i] * final_strides[i]); } - *reinterpret_cast(dst(id)) = *reinterpret_cast(src(offset)); + dst.data()[idx++] = *reinterpret_cast(src(offset)); }); return dst; -- cgit v1.2.1