aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-12-10 18:45:35 +0000
committerPablo Marquez <pablo.tello@arm.com>2018-12-14 15:27:18 +0000
commitb4af2c6738614850aaca3754904f0e8e3b17f0b2 (patch)
treea2d234a99d0599c325311c73a4e4f2df019eb3ee /tests
parentbf9731edfa0439cad4d70efc3065e71e199c62b8 (diff)
downloadComputeLibrary-b4af2c6738614850aaca3754904f0e8e3b17f0b2.tar.gz
COMPMID-1710: Fixes in StrideSlice calculations.
Change-Id: I66eb922f1ff15142de278bf4439a61c979f98ba7 Reviewed-on: https://review.mlplatform.org/382 Reviewed-by: Matthew Bentham <matthew.bentham@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/datasets/SliceOperationsDataset.h6
-rw-r--r--tests/validation/reference/SliceOperations.cpp28
2 files changed, 21 insertions, 13 deletions
diff --git a/tests/datasets/SliceOperationsDataset.h b/tests/datasets/SliceOperationsDataset.h
index b6df4040fd..e891419e9b 100644
--- a/tests/datasets/SliceOperationsDataset.h
+++ b/tests/datasets/SliceOperationsDataset.h
@@ -262,6 +262,12 @@ public:
add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4), BiStrides(2, 1, 2), 0, 1);
// 4D
add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5), BiStrides(2, 1, 2, 3));
+
+ // Shrink axis
+ add_config(TensorShape(1U, 3U, 2U, 3U), Coordinates(0, 1, 0, 0), Coordinates(1, 1, 1, 1), BiStrides(1, 1, 1, 1), 0, 15, 6);
+ add_config(TensorShape(3U, 2U), Coordinates(0, 0), Coordinates(3U, 1U), BiStrides(1, 1), 0, 0, 2);
+ add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 0, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 6, 1);
+ add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 1, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 5, 3);
}
};
diff --git a/tests/validation/reference/SliceOperations.cpp b/tests/validation/reference/SliceOperations.cpp
index 04b5b98453..40ca9de927 100644
--- a/tests/validation/reference/SliceOperations.cpp
+++ b/tests/validation/reference/SliceOperations.cpp
@@ -24,6 +24,7 @@
#include "SliceOperations.h"
#include "arm_compute/core/utils/helpers/tensor_transform.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
namespace arm_compute
{
@@ -50,11 +51,8 @@ SimpleTensor<T> slice(const SimpleTensor<T> &src, Coordinates starts, Coordinate
// Get source shape
const TensorShape &src_shape = src.shape();
- // Get actual end
- Coordinates ends_abs = slice_absolute_end_coords(src_shape, ends);
-
// Get destination shape
- TensorShape dst_shape = compute_slice_output_shape(src_shape, starts, ends_abs);
+ TensorShape dst_shape = arm_compute::misc::shape_calculator::compute_slice_shape(src_shape, starts, ends);
// Create destination tensor
SimpleTensor<T> dst{ dst_shape, src.data_type(), 1 };
@@ -98,20 +96,24 @@ SimpleTensor<T> strided_slice(const SimpleTensor<T> &src,
// Get source shape
const TensorShape &src_shape = src.shape();
- // Get actual start, end coordinates and strides
- const Coordinates final_strides = strided_slice_strides(src_shape, strides);
- const Coordinates starts_abs = strided_slice_absolute_start_coords(src_shape, starts, final_strides, begin_mask);
- const Coordinates ends_abs = strided_slice_absolute_end_coords(src_shape, starts_abs, ends, final_strides, end_mask, shrink_axis_mask);
-
// Get destination shape
- const TensorShape dst_shape = compute_strided_slice_output_shape(src_shape, starts_abs, ends_abs, final_strides);
+ const TensorShape dst_shape = compute_strided_slice_output_shape(src_shape, starts, ends, strides, begin_mask, end_mask, shrink_axis_mask);
// Create destination tensor
SimpleTensor<T> dst{ dst_shape, src.data_type(), 1 };
+ // Get coordinates
+ Coordinates starts_abs, ends_abs, final_strides;
+ std::tie(starts_abs, ends_abs, final_strides) = calculate_strided_slice_coords(src_shape,
+ starts, ends, strides,
+ begin_mask, end_mask, shrink_axis_mask);
+
// Perform strided slice
- Window win;
- win.use_tensor_dimensions(dst_shape);
+ unsigned int idx = 0;
+ Window win;
+ win.use_tensor_dimensions(compute_strided_slice_output_shape(src_shape,
+ starts, ends, strides,
+ begin_mask, end_mask, shrink_axis_mask, true));
execute_window_loop(win, [&](const Coordinates & id)
{
Coordinates offset;
@@ -119,7 +121,7 @@ SimpleTensor<T> strided_slice(const SimpleTensor<T> &src,
{
offset.set(i, starts_abs[i] + id[i] * final_strides[i]);
}
- *reinterpret_cast<T *>(dst(id)) = *reinterpret_cast<const T *>(src(offset));
+ dst.data()[idx++] = *reinterpret_cast<const T *>(src(offset));
});
return dst;