diff options
Diffstat (limited to 'src/core/CL/cl_kernels/common/reverse.cl')
-rw-r--r-- | src/core/CL/cl_kernels/common/reverse.cl | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/src/core/CL/cl_kernels/common/reverse.cl b/src/core/CL/cl_kernels/common/reverse.cl index 6b0afb9c2c..f94bfb6640 100644 --- a/src/core/CL/cl_kernels/common/reverse.cl +++ b/src/core/CL/cl_kernels/common/reverse.cl @@ -1,5 +1,5 @@ /* -* Copyright (c) 2018-2021 Arm Limited. +* Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -33,6 +33,8 @@ * * @note The data type must be given as a preprocessor argument using -DDATA_TYPE=num. e.g. -DDATA_TYPE=uint * @note The number of dimensions to reverse must be given as a preprocessor argument using -DNUM_REVERSE_DIMS=num, e.g. -DNUM_REVERSE_DIMS=3 + * @note The number of dimensions of the source tensor must be given as a preprocessor argument using -DRANK=num, e.g. -DRANK=3 + * @note The values in axis_tensor must be within [-rank, rank-1]. * * @param[in] src_ptr Pointer to the source tensor. Supported data types: All * @param[in] src_stride_x Stride of the first source tensor in X dimension (in bytes) @@ -78,20 +80,24 @@ __kernel void reverse(TENSOR4D_DECLARATION(src), const uint4 dims = (uint4)(0, 1, 2, 3); int4 to_reverse = (int4)(0, 0, 0, 0); + + VEC_DATA_TYPE(int, NUM_REVERSE_DIMS) indices = VLOAD(NUM_REVERSE_DIMS)(0,(__global int *)axis.ptr); +#if defined(USE_INVERTED_AXIS) + indices = select((VEC_DATA_TYPE(int, NUM_REVERSE_DIMS)) RANK - 1, -1, indices < 0) - indices; +#else /* defined(USE_INVERTED_AXIS) */ + indices = select(indices, indices + RANK, indices < 0); +#endif /* defined(USE_INVERTED_AXIS) */ + #if NUM_REVERSE_DIMS == 1 - const uint index = *((__global uint *)axis.ptr); - to_reverse = (uint4)index == dims; + to_reverse = ((uint4)indices == dims); #elif NUM_REVERSE_DIMS == 2 - const uint2 indices = vload2(0, (__global uint *)axis.ptr); - to_reverse = ((uint4)indices.s0 == dims) || ((uint4)indices.s1 == dims); + to_reverse = ((uint4)indices.s0 == dims) || ((uint4)indices.s1 == dims); #elif NUM_REVERSE_DIMS == 3 - const uint2 indices01 = vload2(0, (__global uint *)axis.ptr); - const uint index2 = *((__global uint *)axis.ptr + 2); - to_reverse = ((uint4)indices01.s0 == dims) || ((uint4)indices01.s1 == dims) || ((uint4)index2 == dims); -#else /* NUM_REVERSE_DIMS == 3 */ - const uint4 indices = vload4(0, (__global uint *)axis.ptr); - to_reverse = ((uint4)indices.s0 == dims) || ((uint4)indices.s1 == dims) || ((uint4)indices.s2 == dims) || ((uint4)indices.s3 == dims); + to_reverse = ((uint4)indices.s0 == dims) || ((uint4)indices.s1 == dims) || ((uint4)indices.s2 == dims); +#else /* NUM_REVERSE_DIMS == 1 */ + to_reverse = ((uint4)indices.s0 == dims) || ((uint4)indices.s1 == dims) || ((uint4)indices.s2 == dims) || ((uint4)indices.s3 == dims); #endif /* NUM_REVERSE_DIMS == 1 */ + const uint x_out = to_reverse.s0 ? width - x_in - 1 : x_in; const uint y_out = to_reverse.s1 ? height - y_in - 1 : y_in; const uint z_out = to_reverse.s2 ? depth - z_in - 1 : z_in; |