aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-12-10 18:45:35 +0000
committerPablo Marquez <pablo.tello@arm.com>2018-12-14 15:27:18 +0000
commitb4af2c6738614850aaca3754904f0e8e3b17f0b2 (patch)
treea2d234a99d0599c325311c73a4e4f2df019eb3ee
parentbf9731edfa0439cad4d70efc3065e71e199c62b8 (diff)
downloadComputeLibrary-b4af2c6738614850aaca3754904f0e8e3b17f0b2.tar.gz
COMPMID-1710: Fixes in StrideSlice calculations.
Change-Id: I66eb922f1ff15142de278bf4439a61c979f98ba7 Reviewed-on: https://review.mlplatform.org/382 Reviewed-by: Matthew Bentham <matthew.bentham@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
-rw-r--r--arm_compute/core/utils/helpers/bit_ops.h52
-rw-r--r--arm_compute/core/utils/helpers/tensor_transform.h92
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h15
-rw-r--r--src/core/CL/cl_kernels/slice_ops.cl36
-rw-r--r--src/core/CL/kernels/CLStridedSliceKernel.cpp14
-rw-r--r--src/core/utils/helpers/tensor_transform.cpp192
-rw-r--r--src/graph/nodes/SliceLayerNode.cpp14
-rw-r--r--src/runtime/CL/functions/CLSlice.cpp8
-rw-r--r--tests/datasets/SliceOperationsDataset.h6
-rw-r--r--tests/validation/reference/SliceOperations.cpp28
10 files changed, 289 insertions, 168 deletions
diff --git a/arm_compute/core/utils/helpers/bit_ops.h b/arm_compute/core/utils/helpers/bit_ops.h
new file mode 100644
index 0000000000..fd27014a46
--- /dev/null
+++ b/arm_compute/core/utils/helpers/bit_ops.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_UTILS_HELPERS_BIT_OPS_H__
+#define __ARM_COMPUTE_UTILS_HELPERS_BIT_OPS_H__
+
+#include "arm_compute/core/utils/misc/Requires.h"
+
+#include <type_traits>
+
+namespace arm_compute
+{
+namespace helpers
+{
+namespace bit_ops
+{
+/** Checks if the idx-th bit is set in an integral type
+ *
+ * @param[in] v Integral input
+ * @param[in] idx Index of the bit to check
+ *
+ * @return True if the idx-th bit is set else false
+ */
+template <typename T, REQUIRES_TA(std::is_integral<T>::value)>
+bool is_bit_set(T v, unsigned int idx)
+{
+ return (v & 1 << idx) != 0;
+}
+} // namespace bit_ops
+} // namespace helpers
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_UTILS_HELPERS_BIT_OPS_H__ */
diff --git a/arm_compute/core/utils/helpers/tensor_transform.h b/arm_compute/core/utils/helpers/tensor_transform.h
index 966c1f1fdf..aa359ad119 100644
--- a/arm_compute/core/utils/helpers/tensor_transform.h
+++ b/arm_compute/core/utils/helpers/tensor_transform.h
@@ -32,45 +32,33 @@ namespace helpers
{
namespace tensor_transform
{
-/** Returns the absolute ends coordinates of slice
+/** Computes stride of a given index
*
- * @param[in] input_shape Input tensor shape
- * @param[in] ends End coordinates
+ * @param[in] index Index of tensor to calculate absolute start position
+ * @param[in] strides Slice strides
*
- * @return Absolute end coordinate
+ * @return Stride at a given index
*/
-Coordinates slice_absolute_end_coords(TensorShape input_shape, Coordinates ends);
+int calculate_stride_on_index(int index, Coordinates strides);
-/** Computes output shape of slice
- *
- * @warning Ends must be non-negative
- *
- * @param[in] input_shape Input tensor shape
- * @param[in] starts Start coordinates
- * @param[in] ends_abs Absolute end coordinates
- *
- * @return The output tensor shape
- */
-TensorShape compute_slice_output_shape(TensorShape input_shape, Coordinates starts, Coordinates ends_abs);
-
-/** Returns the absolute start coordinates of strided slice
+/** Computes absolute start position of a given index for a strided slice operation
*
* @param[in] input_shape Input tensor shape
+ * @param[in] index Index of tensor to calculate absolute start position
* @param[in] starts Start coordinates
* @param[in] strides Slice strides
* @param[in] begin_mask (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and
* the fullest possible range in that dimension is used instead.
*
- * @return Absolute start coordinates
+ * @return Absolute start position of a given index
*/
-Coordinates strided_slice_absolute_start_coords(TensorShape input_shape, Coordinates starts, Coordinates strides, int32_t begin_mask = 0);
+int calculate_start_on_index(TensorShape input_shape, int index, Coordinates starts, Coordinates strides, int32_t begin_mask);
-/** Returns the absolute ends coordinates of strided slice
- *
- * @warning Starts must be non-negative
+/** Returns the absolute end position of a given index for a strided slice operation
*
* @param[in] input_shape Input tensor shape
- * @param[in] starts_abs Absolute start coordinates
+ * @param[in] index Index of tensor to calculate absolute start position
+ * @param[in] start_on_index Absolute start coordinate for given index
* @param[in] ends End coordinates
* @param[in] strides Slice strides
* @param[in] end_mask (Optional) If the ith bit of end_mask is set, end[i] is ignored and
@@ -78,32 +66,62 @@ Coordinates strided_slice_absolute_start_coords(TensorShape input_shape, Coordin
* @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1.
* A slice of size 1 starting from starts[i] in the dimension must be preserved.
*
- * @return Absolute end coordinates
+ * @return Absolute end position of a given index
*/
-Coordinates strided_slice_absolute_end_coords(TensorShape input_shape, Coordinates starts_abs, Coordinates ends, Coordinates strides,
- int32_t end_mask = 0, int32_t shrink_axis_mask = 0);
-/** Returns the final strides of strided slice
+int calculate_end_on_index(TensorShape input_shape, int index, int start_on_index, Coordinates ends, Coordinates strides,
+ int32_t end_mask = 0, int32_t shrink_axis_mask = 0);
+
+/** Calculate start, end and stride coordinates for a strided slice
*
- * @param[in] input_shape Input tensor shape
- * @param[in] strides Slice strides
+ * @param[in] input_shape Input tensor shape
+ * @param[in] starts Start coordinates
+ * @param[in] ends End coordinates
+ * @param[in] strides Slice strides
+ * @param[in] begin_mask (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and
+ * the fullest possible range in that dimension is used instead.
+ * @param[in] end_mask (Optional) If the ith bit of end_mask is set, end[i] is ignored and
+ * the fullest possible range in that dimension is used instead.
+ * @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1.
+ * A slice of size 1 starting from starts[i] in the dimension must be preserved.
*
- * @return The final strides need by strided slice
+ * @return A tuple with <Start,End,Strides>
*/
-Coordinates strided_slice_strides(TensorShape input_shape, Coordinates strides);
+std::tuple<Coordinates, Coordinates, Coordinates> calculate_strided_slice_coords(TensorShape input_shape,
+ Coordinates starts, Coordinates ends, Coordinates strides,
+ int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0);
/** Computes output shape of strided slice
*
* @warning Starts and ends must be non-negative
* @warning Starts, ends and final strides should have the same dimensions as the input shape
*
- * @param[in] input_shape Input tensor shape
- * @param[in] starts_abs Absolute start coordinates
- * @param[in] ends_abs Absolute end coordinates
- * @param[in] final_strides Slice strides
+ * @param[in] input_shape Input tensor shape
+ * @param[in] starts Absolute start coordinates
+ * @param[in] ends Absolute end coordinates
+ * @param[in] strides Slice strides
+ * @param[in] begin_mask (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and
+ * the fullest possible range in that dimension is used instead.
+ * @param[in] end_mask (Optional) If the ith bit of end_mask is set, end[i] is ignored and
+ * the fullest possible range in that dimension is used instead.
+ * @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1.
+ * A slice of size 1 starting from starts[i] in the dimension must be preserved.
+ * @param[in] return_unshrinked (Optional) Returns un-shrinked shape
*
* @return The output tensor shape
*/
-TensorShape compute_strided_slice_output_shape(TensorShape input_shape, Coordinates starts_abs, Coordinates ends_abs, Coordinates final_strides);
+TensorShape compute_strided_slice_output_shape(TensorShape input_shape, Coordinates starts, Coordinates ends, Coordinates strides,
+ int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0,
+ bool return_unshrinked = false);
+
+/** Constructs end mask in case we want to perform a slice operation using the strided slice interface
+ *
+ * @note Ends are inclusive in slice operations that is why construction an end mask is needed
+ *
+ * @param[in] ends End coordinates
+ *
+ * @return End mask
+ */
+int32_t construct_slice_end_mask(Coordinates ends);
} // namespace tensor_tranform
} // namespace helpers
} // namespace arm_compute
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index f41d00f54d..adf5309ea5 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -668,15 +668,16 @@ inline TensorShape compute_strided_slice_shape(const ITensorInfo &input,
int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask)
{
using namespace arm_compute::helpers::tensor_transform;
+ return compute_strided_slice_output_shape(input.tensor_shape(), starts, ends, strides, begin_mask, end_mask, shrink_axis_mask);
+}
- const TensorShape &input_shape = input.tensor_shape();
-
- // Get actual start, end coordinates and strides
- const Coordinates final_strides = strided_slice_strides(input_shape, strides);
- const Coordinates starts_abs = strided_slice_absolute_start_coords(input_shape, starts, final_strides, begin_mask);
- const Coordinates ends_abs = strided_slice_absolute_end_coords(input_shape, starts_abs, ends, final_strides, end_mask, shrink_axis_mask);
+inline TensorShape compute_slice_shape(const TensorShape &input_shape, const Coordinates &starts, const Coordinates &ends)
+{
+ using namespace arm_compute::helpers::tensor_transform;
- return compute_strided_slice_output_shape(input_shape, starts_abs, ends_abs, final_strides);
+ return compute_strided_slice_output_shape(input_shape,
+ starts, ends, BiStrides(),
+ 0, construct_slice_end_mask(ends), 0);
}
inline TensorShape compute_batch_to_space_shape(const ITensorInfo *input, const int block_x, const int block_y)
diff --git a/src/core/CL/cl_kernels/slice_ops.cl b/src/core/CL/cl_kernels/slice_ops.cl
index bc3df47345..97decee6fc 100644
--- a/src/core/CL/cl_kernels/slice_ops.cl
+++ b/src/core/CL/cl_kernels/slice_ops.cl
@@ -64,7 +64,9 @@ __kernel void strided_slice(
int offset = 0;
// Offset X
-#if defined(START_0) && defined(STRIDE_0) && defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
+#if defined(SHRINK_0)
+ input.ptr += (int)START_0 * input_stride_x;
+#elif defined(START_0) && defined(STRIDE_0) && defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
// Check if access on width gets out of bounds
// If it does shift access vector to access elements within bounds
const int xi = (int)(get_global_id(0) * VEC_SIZE);
@@ -77,20 +79,46 @@ __kernel void strided_slice(
#endif // defined(START_0) && defined(STRIDE_0)
// Offset Y
-#if defined(START_1) && defined(STRIDE_1)
+#if defined(SHRINK_1)
+ input.ptr += (int)START_1 * input_stride_y;
+#elif defined(START_1) && defined(STRIDE_1)
+#if defined(SHRINK_0)
+ offset = (int)START_1 + (int)get_global_id(0) * (int)STRIDE_1;
+#else // defined(SHRINK_0)
offset = (int)START_1 + (int)get_global_id(1) * (int)STRIDE_1;
+#endif // defined(SHRINK_0)
input.ptr += offset * input_stride_y;
#endif // defined(START_1) && defined(STRIDE_1)
// Offset Z
-#if defined(START_2) && defined(STRIDE_2)
+#if defined(SHRINK_2)
+ input.ptr += (int)START_2 * input_stride_z;
+#elif defined(START_2) && defined(STRIDE_2)
+
+#if defined(SHRINK_1) && defined(SHRINK_0)
+ offset = (int)START_2 + (int)get_global_id(0) * (int)STRIDE_2;
+#elif defined(SHRINK_1) || defined(SHRINK_0)
+ offset = (int)START_2 + (int)get_global_id(1) * (int)STRIDE_2;
+#else // defined(SHRINK_1) && defined(SHRINK_0)
offset = (int)START_2 + ((int)get_global_id(2) % (int)DST_DEPTH) * (int)STRIDE_2;
+#endif // defined(SHRINK_1) && defined(SHRINK_0)
+
input.ptr += offset * input_stride_z;
#endif // defined(START_2) && defined(STRIDE_2)
// Offset depth
-#if defined(START_3) && defined(STRIDE_3)
+#if defined(SHRINK_3)
+ input.ptr += (int)START_3 * input_stride_w;
+#elif defined(START_3) && defined(STRIDE_3)
+#if defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0)
+ offset = (int)START_3 + (int)get_global_id(0) * (int)STRIDE_3;
+#elif !defined(SHRINK_2) && !defined(SHRINK_1) && !defined(SHRINK_0)
offset = (int)START_3 + ((int)get_global_id(2) / (int)DST_DEPTH) * (int)STRIDE_3;
+#elif(defined(SHRINK_0) && defined(SHRINK_1)) || (defined(SHRINK_1) && defined(SHRINK_2)) || (defined(SHRINK_0) && defined(SHRINK_2))
+ offset = (int)START_3 + (int)get_global_id(1) * (int)STRIDE_3;
+#else // defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0)
+ offset = (int)START_3 + ((int)get_global_id(2) % (int)DST_DEPTH) * (int)STRIDE_3;
+#endif // defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0)
input.ptr += offset * input_stride_w;
#endif // defined(START_3) && defined(STRIDE_3)
diff --git a/src/core/CL/kernels/CLStridedSliceKernel.cpp b/src/core/CL/kernels/CLStridedSliceKernel.cpp
index 3828a48d02..c40f3c9f0b 100644
--- a/src/core/CL/kernels/CLStridedSliceKernel.cpp
+++ b/src/core/CL/kernels/CLStridedSliceKernel.cpp
@@ -32,6 +32,7 @@
#include "arm_compute/core/Window.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/helpers/bit_ops.h"
#include "arm_compute/core/utils/helpers/tensor_transform.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
@@ -114,9 +115,11 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output,
const TensorShape &input_shape = input->info()->tensor_shape();
- const Coordinates final_strides = arm_compute::helpers::tensor_transform::strided_slice_strides(input_shape, strides);
- const Coordinates starts_abs = arm_compute::helpers::tensor_transform::strided_slice_absolute_start_coords(input_shape, starts, final_strides, begin_mask);
- const Coordinates ends_abs = arm_compute::helpers::tensor_transform::strided_slice_absolute_end_coords(input_shape, starts_abs, ends, final_strides, end_mask, shrink_axis_mask);
+ Coordinates starts_abs, ends_abs, final_strides;
+ std::tie(starts_abs, ends_abs, final_strides) = arm_compute::helpers::tensor_transform::calculate_strided_slice_coords(
+ input_shape,
+ starts, ends, strides,
+ begin_mask, end_mask, shrink_axis_mask);
// Configure kernel window
auto win_config = validate_and_configure_window(input->info(), output->info(), starts, ends, strides, begin_mask, end_mask, shrink_axis_mask);
@@ -125,7 +128,8 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output,
// Enable multiple elements processing along x if stride_x is 1 and output width greater than the access vector size
const int vec_size_x = 16 / input->info()->element_size();
const int output_width_x = output->info()->tensor_shape().x();
- const bool multi_access_x = (final_strides.x() == 1) && (output_width_x / vec_size_x > 0);
+ const bool is_shrink_on_x = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, 0);
+ const bool multi_access_x = !is_shrink_on_x && (final_strides.x() == 1) && (output_width_x / vec_size_x > 0);
// Update window if needed
if(multi_access_x)
@@ -141,8 +145,10 @@ void CLStridedSliceKernel::configure(const ICLTensor *input, ICLTensor *output,
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
{
+ const bool is_shrink = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, i);
build_opts.add_option("-DSTART_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(starts_abs[i]));
build_opts.add_option("-DSTRIDE_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(final_strides[i]));
+ build_opts.add_option_if(is_shrink, "-DSHRINK_" + support::cpp11::to_string(i));
}
build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(output_width_x - vec_size_x, 0)));
build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x));
diff --git a/src/core/utils/helpers/tensor_transform.cpp b/src/core/utils/helpers/tensor_transform.cpp
index a4bce5da5a..08803c7fb0 100644
--- a/src/core/utils/helpers/tensor_transform.cpp
+++ b/src/core/utils/helpers/tensor_transform.cpp
@@ -23,143 +23,155 @@
*/
#include "arm_compute/core/utils/helpers/tensor_transform.h"
+#include "arm_compute/core/utils/helpers/bit_ops.h"
+
namespace arm_compute
{
namespace helpers
{
namespace tensor_transform
{
-Coordinates slice_absolute_end_coords(TensorShape input_shape, Coordinates ends)
-{
- // Create end mask
- int32_t end_mask = 0;
- for(unsigned int i = 0; i < ends.num_dimensions(); ++i)
- {
- if(ends[i] < 0)
- {
- end_mask |= 1 << i;
- }
- }
- // Get unit strides
- const BiStrides unit_strides = strided_slice_strides(input_shape, BiStrides());
-
- return strided_slice_absolute_end_coords(input_shape, Coordinates(), ends, unit_strides, end_mask);
-}
-
-TensorShape compute_slice_output_shape(TensorShape input_shape, Coordinates starts, Coordinates ends_abs)
+int calculate_stride_on_index(int index, Coordinates strides)
{
- // Get unit strides
- const BiStrides unit_strides = strided_slice_strides(input_shape, BiStrides());
- return compute_strided_slice_output_shape(input_shape, starts, ends_abs, unit_strides);
+ return index >= static_cast<int>(strides.num_dimensions()) ? 1 : strides[index];
}
-Coordinates strided_slice_absolute_start_coords(TensorShape input_shape, Coordinates starts, Coordinates strides, int32_t begin_mask)
+int calculate_start_on_index(TensorShape input_shape, int index, Coordinates starts, Coordinates strides, int32_t begin_mask)
{
- Coordinates starts_abs;
- for(unsigned int i = 0; i < starts.num_dimensions(); ++i)
+ // Early exit
+ if(index >= static_cast<int>(starts.num_dimensions()))
{
- // Get start index
- int start_i = starts[i];
+ return 0;
+ }
- // Reset in case of begin mask present
- if((begin_mask & 1 << i) != 0)
- {
- start_i = strides[i] > 0 ? std::numeric_limits<int>::lowest() : std::numeric_limits<int>::max();
- }
+ // Get stride
+ const int stride = calculate_stride_on_index(index, strides);
- // Account negative start points
- const int dim_size = input_shape[i];
- if(start_i < 0)
- {
- start_i += dim_size;
- }
+ // Calculate start
+ int start = starts[index];
- // Final clamp
- start_i = utility::clamp(start_i, 0, dim_size - 1);
- starts_abs.set(i, start_i);
+ // Reset in case of begin mask present
+ if(arm_compute::helpers::bit_ops::is_bit_set(begin_mask, index))
+ {
+ start = stride > 0 ? std::numeric_limits<int>::lowest() : std::numeric_limits<int>::max();
}
- // Fill remaining
- for(unsigned int i = starts_abs.num_dimensions(); i < input_shape.num_dimensions(); ++i)
+ // Account negative start points
+ const int dim_size = input_shape[index];
+ if(start < 0)
{
- starts_abs.set(i, 0);
+ start += dim_size;
}
- return starts_abs;
+ // Final clamp
+ start = utility::clamp(start, 0, dim_size - 1);
+
+ return start;
}
-Coordinates strided_slice_absolute_end_coords(TensorShape input_shape, Coordinates starts_abs, Coordinates ends, Coordinates strides,
- int32_t end_mask, int32_t shrink_axis_mask)
+int calculate_end_on_index(TensorShape input_shape, int index, int start_on_index,
+ Coordinates ends, Coordinates strides,
+ int32_t end_mask, int32_t shrink_axis_mask)
{
- Coordinates ends_abs;
- for(unsigned int i = 0; i < ends.num_dimensions(); ++i)
+ // Early exit
+ if(index >= static_cast<int>(ends.num_dimensions()))
{
- // Get end index
- int stop_i = ends[i];
+ return input_shape[index];
+ }
- // Shrink dimension
- if((shrink_axis_mask & (1 << i)) != 0)
- {
- stop_i = starts_abs[i] + 1;
- }
+ const int stride = calculate_stride_on_index(index, strides);
+ const bool shrink_axis = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, index);
- // Reset in case of begin mask present
- if((end_mask & 1 << i) != 0)
- {
- stop_i = (strides[i] > 0) ? std::numeric_limits<int>::max() : std::numeric_limits<int>::lowest();
- }
+ // Calculate start
+ int stop = ends[index];
- // Account negative end points
- const int dim_size = input_shape[i];
- if(stop_i < 0)
- {
- stop_i += dim_size;
- }
+ // Shrink dimension
+ if(shrink_axis)
+ {
+ stop = start_on_index + 1;
+ }
- // Final clamp
- stop_i = (strides[i] > 0) ? utility::clamp(stop_i, 0, dim_size) : utility::clamp(stop_i, -1, dim_size - 1);
- ends_abs.set(i, stop_i);
+ // Reset in case of begin mask present
+ if(arm_compute::helpers::bit_ops::is_bit_set(end_mask, index) && !shrink_axis)
+ {
+ stop = (stride > 0) ? std::numeric_limits<int>::max() : std::numeric_limits<int>::lowest();
}
- // Fill remaining ends
- for(unsigned int i = ends_abs.num_dimensions(); i < input_shape.num_dimensions(); ++i)
+ // Account negative end points
+ const int dim_size = input_shape[index];
+ if(stop < 0)
{
- ends_abs.set(i, input_shape[i]);
+ stop += dim_size;
}
- return ends_abs;
+ // Final clamp
+ stop = (stride > 0) ? utility::clamp(stop, 0, dim_size) : utility::clamp(stop, -1, dim_size - 1);
+
+ return stop;
}
-Coordinates strided_slice_strides(TensorShape input_shape, Coordinates strides)
+std::tuple<Coordinates, Coordinates, Coordinates> calculate_strided_slice_coords(TensorShape input_shape,
+ Coordinates starts, Coordinates ends, Coordinates strides,
+ int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask)
{
- for(unsigned int i = strides.num_dimensions(); i < input_shape.num_dimensions(); ++i)
+ Coordinates starts_abs, ends_abs, final_strides;
+ for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
{
- strides.set(i, 1);
+ const int start_i = calculate_start_on_index(input_shape, i, starts, strides, begin_mask);
+ starts_abs.set(i, start_i);
+ ends_abs.set(i, calculate_end_on_index(input_shape, i, start_i, ends, strides, end_mask, shrink_axis_mask));
+ final_strides.set(i, calculate_stride_on_index(i, strides));
}
- return strides;
+
+ return std::make_tuple(starts_abs, ends_abs, final_strides);
}
-TensorShape compute_strided_slice_output_shape(TensorShape input_shape, Coordinates starts_abs, Coordinates ends_abs, Coordinates final_strides)
+TensorShape compute_strided_slice_output_shape(TensorShape input_shape, Coordinates starts, Coordinates ends, Coordinates strides,
+ int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask, bool return_unshrinked)
{
- TensorShape output_shape = input_shape;
+ unsigned int index = 0;
+
+ TensorShape output_shape;
for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
{
- const int stride_i = final_strides[i];
- const int range = ends_abs[i] - starts_abs[i];
- if((range == 0) || // Zero range
- (range < 0 && stride_i >= 0) || // Negative range with positive stride
- (range > 0 && stride_i <= 0)) // Positive range with negative stride
+ const int stride = calculate_stride_on_index(index, strides);
+ const int start = calculate_start_on_index(input_shape, i, starts, strides, begin_mask);
+ const int end = calculate_end_on_index(input_shape, i, start, ends, strides, end_mask, shrink_axis_mask);
+ const int range = end - start;
+
+ const bool is_shrink = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, i);
+ if(return_unshrinked || !is_shrink)
{
- output_shape.set(i, 0);
- return output_shape;
+ if((range == 0) || // Zero range
+ (range < 0 && stride >= 0) || // Negative range with positive stride
+ (range > 0 && stride <= 0)) // Positive range with negative stride
+ {
+ output_shape.set(index, 0);
+ return output_shape;
+ }
+ else
+ {
+ int dim = range / stride + (range % stride != 0 ? 1 : 0);
+ output_shape.set(index++, dim);
+ }
}
- else
+ }
+ return output_shape;
+}
+
+int32_t construct_slice_end_mask(Coordinates ends)
+{
+ // Create end mask
+ int32_t end_mask = 0;
+ for(unsigned int i = 0; i < ends.num_dimensions(); ++i)
+ {
+ if(ends[i] < 0)
{
- int dim = range / stride_i + (range % stride_i != 0 ? 1 : 0);
- output_shape.set(i, dim);
+ end_mask |= 1 << i;
}
}
- return output_shape;
+
+ return end_mask;
}
} // namespace tensor_transform
} // namespace helpers
diff --git a/src/graph/nodes/SliceLayerNode.cpp b/src/graph/nodes/SliceLayerNode.cpp
index 3a29e4c9ad..bfc009d3eb 100644
--- a/src/graph/nodes/SliceLayerNode.cpp
+++ b/src/graph/nodes/SliceLayerNode.cpp
@@ -24,7 +24,7 @@
#include "arm_compute/graph/nodes/SliceLayerNode.h"
#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/utils/helpers/tensor_transform.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/INodeVisitor.h"
@@ -52,16 +52,12 @@ Coordinates SliceLayerNode::ends() const
TensorDescriptor SliceLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
const Coordinates &starts, const Coordinates &ends)
{
- // Get absolute end coordinates
- const Coordinates ends_abs = arm_compute::helpers::tensor_transform::slice_absolute_end_coords(input_descriptor.shape, ends);
+ using namespace arm_compute::helpers::tensor_transform;
- TensorDescriptor output_descriptor = input_descriptor;
- for(unsigned int i = 0; i < starts.num_dimensions(); ++i)
- {
- output_descriptor.shape.set(i, ends_abs[i] - starts[i]);
- }
+ TensorDescriptor output_desc = input_descriptor;
+ output_desc.shape = arm_compute::misc::shape_calculator::compute_slice_shape(input_descriptor.shape, starts, ends);
- return output_descriptor;
+ return output_desc;
}
bool SliceLayerNode::forward_descriptors()
diff --git a/src/runtime/CL/functions/CLSlice.cpp b/src/runtime/CL/functions/CLSlice.cpp
index bef7eca71c..f630853fe3 100644
--- a/src/runtime/CL/functions/CLSlice.cpp
+++ b/src/runtime/CL/functions/CLSlice.cpp
@@ -36,10 +36,10 @@ void CLSlice::configure(const ICLTensor *input, ICLTensor *output, const Coordin
ARM_COMPUTE_ERROR_ON_NULLPTR(input);
// Get absolute end coordinates
- const Coordinates ends_abs = arm_compute::helpers::tensor_transform::slice_absolute_end_coords(input->info()->tensor_shape(), ends);
+ const int32_t slice_end_mask = arm_compute::helpers::tensor_transform::construct_slice_end_mask(ends);
auto k = arm_compute::support::cpp14::make_unique<CLStridedSliceKernel>();
- k->configure(input, output, starts, ends_abs, BiStrides(), 0, 0, 0);
+ k->configure(input, output, starts, ends, BiStrides(), 0, slice_end_mask, 0);
_kernel = std::move(k);
}
@@ -54,8 +54,8 @@ Status CLSlice::validate(const ITensorInfo *input, const ITensorInfo *output, co
}));
// Get absolute end coordinates
- const Coordinates ends_abs = arm_compute::helpers::tensor_transform::slice_absolute_end_coords(input->tensor_shape(), ends);
+ const int32_t slice_end_mask = arm_compute::helpers::tensor_transform::construct_slice_end_mask(ends);
- return CLStridedSliceKernel::validate(input, output, starts, ends_abs, BiStrides(), 0, 0, 0);
+ return CLStridedSliceKernel::validate(input, output, starts, ends, BiStrides(), 0, slice_end_mask, 0);
}
} // namespace arm_compute
diff --git a/tests/datasets/SliceOperationsDataset.h b/tests/datasets/SliceOperationsDataset.h
index b6df4040fd..e891419e9b 100644
--- a/tests/datasets/SliceOperationsDataset.h
+++ b/tests/datasets/SliceOperationsDataset.h
@@ -262,6 +262,12 @@ public:
add_config(TensorShape(15U, 16U, 4U), Coordinates(0, 1, 2), Coordinates(5, -1, 4), BiStrides(2, 1, 2), 0, 1);
// 4D
add_config(TensorShape(15U, 16U, 4U, 12U), Coordinates(0, 1, 2, 2), Coordinates(5, -1, 4, 5), BiStrides(2, 1, 2, 3));
+
+ // Shrink axis
+ add_config(TensorShape(1U, 3U, 2U, 3U), Coordinates(0, 1, 0, 0), Coordinates(1, 1, 1, 1), BiStrides(1, 1, 1, 1), 0, 15, 6);
+ add_config(TensorShape(3U, 2U), Coordinates(0, 0), Coordinates(3U, 1U), BiStrides(1, 1), 0, 0, 2);
+ add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 0, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 6, 1);
+ add_config(TensorShape(4U, 7U, 7U), Coordinates(0, 1, 0), Coordinates(1U, 1U, 1U), BiStrides(1, 1, 1), 0, 5, 3);
}
};
diff --git a/tests/validation/reference/SliceOperations.cpp b/tests/validation/reference/SliceOperations.cpp
index 04b5b98453..40ca9de927 100644
--- a/tests/validation/reference/SliceOperations.cpp
+++ b/tests/validation/reference/SliceOperations.cpp
@@ -24,6 +24,7 @@
#include "SliceOperations.h"
#include "arm_compute/core/utils/helpers/tensor_transform.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
namespace arm_compute
{
@@ -50,11 +51,8 @@ SimpleTensor<T> slice(const SimpleTensor<T> &src, Coordinates starts, Coordinate
// Get source shape
const TensorShape &src_shape = src.shape();
- // Get actual end
- Coordinates ends_abs = slice_absolute_end_coords(src_shape, ends);
-
// Get destination shape
- TensorShape dst_shape = compute_slice_output_shape(src_shape, starts, ends_abs);
+ TensorShape dst_shape = arm_compute::misc::shape_calculator::compute_slice_shape(src_shape, starts, ends);
// Create destination tensor
SimpleTensor<T> dst{ dst_shape, src.data_type(), 1 };
@@ -98,20 +96,24 @@ SimpleTensor<T> strided_slice(const SimpleTensor<T> &src,
// Get source shape
const TensorShape &src_shape = src.shape();
- // Get actual start, end coordinates and strides
- const Coordinates final_strides = strided_slice_strides(src_shape, strides);
- const Coordinates starts_abs = strided_slice_absolute_start_coords(src_shape, starts, final_strides, begin_mask);
- const Coordinates ends_abs = strided_slice_absolute_end_coords(src_shape, starts_abs, ends, final_strides, end_mask, shrink_axis_mask);
-
// Get destination shape
- const TensorShape dst_shape = compute_strided_slice_output_shape(src_shape, starts_abs, ends_abs, final_strides);
+ const TensorShape dst_shape = compute_strided_slice_output_shape(src_shape, starts, ends, strides, begin_mask, end_mask, shrink_axis_mask);
// Create destination tensor
SimpleTensor<T> dst{ dst_shape, src.data_type(), 1 };
+ // Get coordinates
+ Coordinates starts_abs, ends_abs, final_strides;
+ std::tie(starts_abs, ends_abs, final_strides) = calculate_strided_slice_coords(src_shape,
+ starts, ends, strides,
+ begin_mask, end_mask, shrink_axis_mask);
+
// Perform strided slice
- Window win;
- win.use_tensor_dimensions(dst_shape);
+ unsigned int idx = 0;
+ Window win;
+ win.use_tensor_dimensions(compute_strided_slice_output_shape(src_shape,
+ starts, ends, strides,
+ begin_mask, end_mask, shrink_axis_mask, true));
execute_window_loop(win, [&](const Coordinates & id)
{
Coordinates offset;
@@ -119,7 +121,7 @@ SimpleTensor<T> strided_slice(const SimpleTensor<T> &src,
{
offset.set(i, starts_abs[i] + id[i] * final_strides[i]);
}
- *reinterpret_cast<T *>(dst(id)) = *reinterpret_cast<const T *>(src(offset));
+ dst.data()[idx++] = *reinterpret_cast<const T *>(src(offset));
});
return dst;