aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/common/reverse.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/common/reverse.cl')
-rw-r--r--src/core/CL/cl_kernels/common/reverse.cl28
1 files changed, 17 insertions, 11 deletions
diff --git a/src/core/CL/cl_kernels/common/reverse.cl b/src/core/CL/cl_kernels/common/reverse.cl
index 6b0afb9c2c..f94bfb6640 100644
--- a/src/core/CL/cl_kernels/common/reverse.cl
+++ b/src/core/CL/cl_kernels/common/reverse.cl
@@ -1,5 +1,5 @@
/*
-* Copyright (c) 2018-2021 Arm Limited.
+* Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -33,6 +33,8 @@
*
* @note The data type must be given as a preprocessor argument using -DDATA_TYPE=num. e.g. -DDATA_TYPE=uint
* @note The number of dimensions to reverse must be given as a preprocessor argument using -DNUM_REVERSE_DIMS=num, e.g. -DNUM_REVERSE_DIMS=3
+ * @note The number of dimensions of the source tensor must be given as a preprocessor argument using -DRANK=num, e.g. -DRANK=3
+ * @note The values in axis_tensor must be within [-rank, rank-1].
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: All
* @param[in] src_stride_x Stride of the first source tensor in X dimension (in bytes)
@@ -78,20 +80,24 @@ __kernel void reverse(TENSOR4D_DECLARATION(src),
const uint4 dims = (uint4)(0, 1, 2, 3);
int4 to_reverse = (int4)(0, 0, 0, 0);
+
+ VEC_DATA_TYPE(int, NUM_REVERSE_DIMS) indices = VLOAD(NUM_REVERSE_DIMS)(0,(__global int *)axis.ptr);
+#if defined(USE_INVERTED_AXIS)
+ indices = select((VEC_DATA_TYPE(int, NUM_REVERSE_DIMS)) RANK - 1, -1, indices < 0) - indices;
+#else /* defined(USE_INVERTED_AXIS) */
+ indices = select(indices, indices + RANK, indices < 0);
+#endif /* defined(USE_INVERTED_AXIS) */
+
#if NUM_REVERSE_DIMS == 1
- const uint index = *((__global uint *)axis.ptr);
- to_reverse = (uint4)index == dims;
+ to_reverse = ((uint4)indices == dims);
#elif NUM_REVERSE_DIMS == 2
- const uint2 indices = vload2(0, (__global uint *)axis.ptr);
- to_reverse = ((uint4)indices.s0 == dims) || ((uint4)indices.s1 == dims);
+ to_reverse = ((uint4)indices.s0 == dims) || ((uint4)indices.s1 == dims);
#elif NUM_REVERSE_DIMS == 3
- const uint2 indices01 = vload2(0, (__global uint *)axis.ptr);
- const uint index2 = *((__global uint *)axis.ptr + 2);
- to_reverse = ((uint4)indices01.s0 == dims) || ((uint4)indices01.s1 == dims) || ((uint4)index2 == dims);
-#else /* NUM_REVERSE_DIMS == 3 */
- const uint4 indices = vload4(0, (__global uint *)axis.ptr);
- to_reverse = ((uint4)indices.s0 == dims) || ((uint4)indices.s1 == dims) || ((uint4)indices.s2 == dims) || ((uint4)indices.s3 == dims);
+ to_reverse = ((uint4)indices.s0 == dims) || ((uint4)indices.s1 == dims) || ((uint4)indices.s2 == dims);
+#else /* NUM_REVERSE_DIMS == 1 */
+ to_reverse = ((uint4)indices.s0 == dims) || ((uint4)indices.s1 == dims) || ((uint4)indices.s2 == dims) || ((uint4)indices.s3 == dims);
#endif /* NUM_REVERSE_DIMS == 1 */
+
const uint x_out = to_reverse.s0 ? width - x_in - 1 : x_in;
const uint y_out = to_reverse.s1 ? height - y_in - 1 : y_in;
const uint z_out = to_reverse.s2 ? depth - z_in - 1 : z_in;