diff options
Diffstat (limited to 'src/core/CL/cl_kernels/flatten.cl')
-rw-r--r-- | src/core/CL/cl_kernels/flatten.cl | 16 |
1 files changed, 7 insertions, 9 deletions
diff --git a/src/core/CL/cl_kernels/flatten.cl b/src/core/CL/cl_kernels/flatten.cl index 02694f709e..6418edc517 100644 --- a/src/core/CL/cl_kernels/flatten.cl +++ b/src/core/CL/cl_kernels/flatten.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -31,7 +31,7 @@ * @note The width, height and depth of the input tensor must be passed at compile time using -DSRC_WIDTH, -DSRC_HEIGHT and -DSRC_DEPTH. e.g. -DSRC_WIDTH=24, -DSRC_HEIGHT=24, -DSRC_DEPTH=16 * @note If the output has 3 dimensions, the 2nd dimension of the output tensor must be passed at compile time using -DDST_DIM1. e.g -DDST_DIM1=3 * - * @param[in] src_ptr Pointer to the source tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32 + * @param[in] src_ptr Pointer to the source tensor. Supported data types: All * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) @@ -62,14 +62,12 @@ __kernel void flatten( #if defined(DST_DIM1) uint b_tmp = b0; - b0 = b_tmp % DST_DIM1; // batch id0 - b1 = b_tmp / DST_DIM1; // batch id1 -#endif // defined(DST_DIM1) + b0 = b_tmp % DST_DIM1; // batch id0 + b1 = b_tmp / DST_DIM1; // batch id1 +#endif // defined(DST_DIM1) - __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + - (get_global_id(0) + get_global_id(1) * (uint)SRC_WIDTH + c * (uint)(SRC_WIDTH * SRC_HEIGHT)) * sizeof(DATA_TYPE) + - b0 * dst_stride_y + - b1 * dst_stride_z; + __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) + get_global_id(1) * (uint)SRC_WIDTH + c * (uint)(SRC_WIDTH * SRC_HEIGHT)) * sizeof( + DATA_TYPE) + b0 * dst_stride_y + b1 * dst_stride_z; *((__global DATA_TYPE *)output_ptr) = *((__global DATA_TYPE *)src.ptr); } |