diff options
Diffstat (limited to 'src/core/CL/cl_kernels')
-rw-r--r-- | src/core/CL/cl_kernels/channel_shuffle.cl | 147 |
1 files changed, 94 insertions, 53 deletions
diff --git a/src/core/CL/cl_kernels/channel_shuffle.cl b/src/core/CL/cl_kernels/channel_shuffle.cl index 9a87eb4af3..b7272da33a 100644 --- a/src/core/CL/cl_kernels/channel_shuffle.cl +++ b/src/core/CL/cl_kernels/channel_shuffle.cl @@ -1,5 +1,5 @@ /* -* Copyright (c) 2018-2020 Arm Limited. +* Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -22,15 +22,14 @@ * SOFTWARE. */ #include "helpers.h" +#include "tile_helpers.h" #if defined(DATA_TYPE) && defined(VEC_SIZE) && defined(NUM_GROUPS) && defined(K) && defined(SRC_DIM_Z) // Check valid VEC_SIZES -#if VEC_SIZE != 4 && VEC_SIZE != 8 && VEC_SIZE != 16 -#error "Only vector sizes 4, 8 and 16 are supported" -#endif // VEC_SIZE != 4 && VEC_SIZE != 8 && VEC_SIZE != 16 - -#define TYPE VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) +#if VEC_SIZE != 1 && VEC_SIZE != 2 && VEC_SIZE != 3 && VEC_SIZE != 4 && VEC_SIZE != 8 && VEC_SIZE != 16 +#error "Only vector sizes 1, 2, 3, 4, 8 and 16 are supported" +#endif // VEC_SIZE != 1 && VEC_SIZE != 2 && VEC_SIZE != 3 && VEC_SIZE != 4 && VEC_SIZE != 8 && VEC_SIZE != 16 #define DIV_MOD_UINT(x, y, div_res, mod_res) \ ({ \ @@ -88,8 +87,10 @@ __kernel void channel_shuffle_nchw(TENSOR4D_DECLARATION(src), // Load the Nx2 block const __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + y * src_stride_y + curr_channel * src_stride_z + batch_id * src_stride_w; - TYPE u0 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y)); - TYPE u1 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y)); + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) + u0 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y)); + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) + u1 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y)); // Store blocks __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z + batch_id * dst_stride_w; @@ -99,16 +100,17 @@ __kernel void channel_shuffle_nchw(TENSOR4D_DECLARATION(src), (u1, 0, (__global DATA_TYPE *)(output_ptr + 1 * dst_stride_y)); } -#if VEC_SIZE == 4 && defined(LAST_ACCESSED) +#if defined(VEC_SIZE) && defined(VEC_SIZE_LEFTOVER) && defined(SRC_DIM_X) + /** Performs channel shuffle when the data layout is NHWC. See https://arxiv.org/pdf/1707.01083.pdf for details. * - * @note This implementation is only defined for VEC_SIZE = 4 - * @note This last element accessed along the first dimension must be given as a preprocessor argument using -DLAST_ACCESSED=num. e.g. -DLAST_ACCESSED=64 in order to prevent out-of-bound writes. * @note The vector size must be given as a preprocessor argument using -DVEC_SIZE=num. e.g. -DVEC_SIZE=4 - * @note The height of the tensor must be given as a preprocessor argument using -DSRC_DIM_Z=num. e.g. -DSRC_DIM_Z=64 + * @note The third dimension of the tensor must be given as a preprocessor argument using -DSRC_DIM_Z=num. e.g. -DSRC_DIM_Z=64 + * @note The first dimension of the tensor must be given as a preprocessor argument using -DSRC_DIM_X=num. e.g. -DSRC_DIM_X=64 * @note The number of groups must be given as a preprocessor argument using -DNUM_GROUPS=num_groups. e.g. -DNUM_GROUPS=2 * @note The number of channels in each group must be given as a preprocessor argument using -DK=num. e.g. -DK=1 * K is equal to num_channels / num_groups. + * @note The leftover size in the X dimension shoud be given as preprocessor argument using -DVEC_SIZE_LEFTOVER is; x_dimension % VEC_SIZE. e.g. -DVEC_SIZE_LEFTOVER=1 * * @param[in] src_ptr Pointer to the source matrix. Supported data types: All * @param[in] src_stride_x Stride of the first source tensor in X dimension (in bytes) @@ -134,48 +136,87 @@ __kernel void channel_shuffle_nchw(TENSOR4D_DECLARATION(src), __kernel void channel_shuffle_nhwc(TENSOR4D_DECLARATION(src), TENSOR4D_DECLARATION(dst)) { - const uint curr_channel = min((uint)(get_global_id(0) * VEC_SIZE), (uint)LAST_ACCESSED); // input feature map - uint channel_id0 = 0; - uint channel_id1 = 0; - uint channel_id2 = 0; - uint channel_id3 = 0; - uint group_id0 = 0; - uint group_id1 = 0; - uint group_id2 = 0; - uint group_id3 = 0; - uint y = 0; - uint batch_id = 0; + // Offset computation + const uint curr_out_channel = GET_SPATIAL_IDX(0, VEC_SIZE, VEC_SIZE_LEFTOVER); // output feature map + uint z = 0; + uint batch_id = 0; // Compute curr_channel and batch_id - DIV_MOD_UINT(get_global_id(2), (uint)SRC_DIM_Z, batch_id, y); - - // Compute group_id and channel_id - DIV_MOD_UINT(curr_channel + (uint)0, K, group_id0, channel_id0); - DIV_MOD_UINT(curr_channel + (uint)1, K, group_id1, channel_id1); - DIV_MOD_UINT(curr_channel + (uint)2, K, group_id2, channel_id2); - DIV_MOD_UINT(curr_channel + (uint)3, K, group_id3, channel_id3); - - const uint x = get_global_id(1) * 2; - const uint z0 = channel_id0 * (uint)NUM_GROUPS + group_id0; - const uint z1 = channel_id1 * (uint)NUM_GROUPS + group_id1; - const uint z2 = channel_id2 * (uint)NUM_GROUPS + group_id2; - const uint z3 = channel_id3 * (uint)NUM_GROUPS + group_id3; - - // Load the Nx2 block - const __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + curr_channel * sizeof(DATA_TYPE) + x * src_stride_y + y * src_stride_z + batch_id * src_stride_w; - TYPE u0 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y)); - TYPE u1 = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y)); - - // Store blocks - __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + x * dst_stride_y + y * dst_stride_z + batch_id * dst_stride_w; - *((__global DATA_TYPE *)(output_ptr + (uint)0 * dst_stride_y + z0 * sizeof(DATA_TYPE))) = u0.s0; - *((__global DATA_TYPE *)(output_ptr + (uint)0 * dst_stride_y + z1 * sizeof(DATA_TYPE))) = u0.s1; - *((__global DATA_TYPE *)(output_ptr + (uint)0 * dst_stride_y + z2 * sizeof(DATA_TYPE))) = u0.s2; - *((__global DATA_TYPE *)(output_ptr + (uint)0 * dst_stride_y + z3 * sizeof(DATA_TYPE))) = u0.s3; - *((__global DATA_TYPE *)(output_ptr + (uint)1 * dst_stride_y + z0 * sizeof(DATA_TYPE))) = u1.s0; - *((__global DATA_TYPE *)(output_ptr + (uint)1 * dst_stride_y + z1 * sizeof(DATA_TYPE))) = u1.s1; - *((__global DATA_TYPE *)(output_ptr + (uint)1 * dst_stride_y + z2 * sizeof(DATA_TYPE))) = u1.s2; - *((__global DATA_TYPE *)(output_ptr + (uint)1 * dst_stride_y + z3 * sizeof(DATA_TYPE))) = u1.s3; + DIV_MOD_UINT(get_global_id(2), (uint)SRC_DIM_Z, batch_id, z); + + VEC_DATA_TYPE(uint, VEC_SIZE) + curr_out_channels = (VEC_DATA_TYPE(uint, VEC_SIZE))(curr_out_channel) + VEC_OFFS(uint, VEC_SIZE); + + VEC_DATA_TYPE(uint, VEC_SIZE) + in_channels = (curr_out_channels * (VEC_DATA_TYPE(uint, VEC_SIZE))(K)) % (VEC_DATA_TYPE(uint, VEC_SIZE))(SRC_DIM_X) + (curr_out_channels / (VEC_DATA_TYPE(uint, VEC_SIZE))(NUM_GROUPS)); + + // Load the values + const __global DATA_TYPE *input_ptr = (const __global DATA_TYPE *)(src_ptr + src_offset_first_element_in_bytes + get_global_id(1) * src_stride_y + z * src_stride_z + batch_id * src_stride_w); + +#if VEC_SIZE == 1 + DATA_TYPE out0 = *((const __global * DATA_TYPE)(input_ptr) + in_channels); +#elif VEC_SIZE == 2 + VEC_DATA_TYPE(DATA_TYPE, 2) + out0 = + { + *(input_ptr + in_channels.s0), + *(input_ptr + in_channels.s1) + }; +#elif VEC_SIZE == 3 + VEC_DATA_TYPE(DATA_TYPE, 3) + out0 = + { + *(input_ptr + in_channels.s0), + *(input_ptr + in_channels.s1), + *(input_ptr + in_channels.s2) + }; +#elif VEC_SIZE == 4 + VEC_DATA_TYPE(DATA_TYPE, 4) + out0 = + { + *(input_ptr + in_channels.s0), + *(input_ptr + in_channels.s1), + *(input_ptr + in_channels.s2), + *(input_ptr + in_channels.s3) + }; +#elif VEC_SIZE == 8 + VEC_DATA_TYPE(DATA_TYPE, 8) + out0 = + { + *(input_ptr + in_channels.s0), + *(input_ptr + in_channels.s1), + *(input_ptr + in_channels.s2), + *(input_ptr + in_channels.s3), + *(input_ptr + in_channels.s4), + *(input_ptr + in_channels.s5), + *(input_ptr + in_channels.s6), + *(input_ptr + in_channels.s7) + }; +#elif VEC_SIZE == 16 + VEC_DATA_TYPE(DATA_TYPE, 8) + out0 = + { + *(input_ptr + in_channels.s0), + *(input_ptr + in_channels.s1), + *(input_ptr + in_channels.s2), + *(input_ptr + in_channels.s3), + *(input_ptr + in_channels.s4), + *(input_ptr + in_channels.s5), + *(input_ptr + in_channels.s6), + *(input_ptr + in_channels.s7), + *(input_ptr + in_channels.s8), + *(input_ptr + in_channels.s9), + *(input_ptr + in_channels.sa), + *(input_ptr + in_channels.sb), + *(input_ptr + in_channels.sc), + *(input_ptr + in_channels.sd), + *(input_ptr + in_channels.se), + *(input_ptr + in_channels.sf) + }; +#endif // VEC_SIZE == 1 + + __global uchar *output_ptr = dst_ptr + curr_out_channel * sizeof(DATA_TYPE) + dst_offset_first_element_in_bytes + get_global_id(1) * dst_stride_y + z * dst_stride_z + batch_id * dst_stride_w; + STORE_VECTOR_SELECT(out, DATA_TYPE, output_ptr, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0); } -#endif // VEC_SIZE == 4 && defined(LAST_ACCESSED) +#endif // defined(VEC_SIZE) && defined(VEC_SIZE_LEFTOVER) && defined(SRC_DIM_X) #endif // defined(DATA_TYPE) && defined(VEC_SIZE) && defined(NUM_GROUPS) && defined(K) && defined(SRC_DIM_Z) |