From 0de45d0a8009e19331c4e29d617fa183167c513a Mon Sep 17 00:00:00 2001 From: Sheri Zhang Date: Fri, 17 Apr 2020 14:59:13 +0100 Subject: COMPMID-3394: Replace get_cl_type_from_data_type in All Signed-off-by: Sheri Zhang Change-Id: I978050182817c964779c775cdefd88d2c7df0ca5 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3069 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Georgios Pinitas --- src/core/CL/cl_kernels/flatten.cl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) (limited to 'src/core/CL/cl_kernels/flatten.cl') diff --git a/src/core/CL/cl_kernels/flatten.cl b/src/core/CL/cl_kernels/flatten.cl index 02694f709e..6418edc517 100644 --- a/src/core/CL/cl_kernels/flatten.cl +++ b/src/core/CL/cl_kernels/flatten.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -31,7 +31,7 @@ * @note The width, height and depth of the input tensor must be passed at compile time using -DSRC_WIDTH, -DSRC_HEIGHT and -DSRC_DEPTH. e.g. -DSRC_WIDTH=24, -DSRC_HEIGHT=24, -DSRC_DEPTH=16 * @note If the output has 3 dimensions, the 2nd dimension of the output tensor must be passed at compile time using -DDST_DIM1. e.g -DDST_DIM1=3 * - * @param[in] src_ptr Pointer to the source tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32 + * @param[in] src_ptr Pointer to the source tensor. Supported data types: All * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) @@ -62,14 +62,12 @@ __kernel void flatten( #if defined(DST_DIM1) uint b_tmp = b0; - b0 = b_tmp % DST_DIM1; // batch id0 - b1 = b_tmp / DST_DIM1; // batch id1 -#endif // defined(DST_DIM1) + b0 = b_tmp % DST_DIM1; // batch id0 + b1 = b_tmp / DST_DIM1; // batch id1 +#endif // defined(DST_DIM1) - __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + - (get_global_id(0) + get_global_id(1) * (uint)SRC_WIDTH + c * (uint)(SRC_WIDTH * SRC_HEIGHT)) * sizeof(DATA_TYPE) + - b0 * dst_stride_y + - b1 * dst_stride_z; + __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) + get_global_id(1) * (uint)SRC_WIDTH + c * (uint)(SRC_WIDTH * SRC_HEIGHT)) * sizeof( + DATA_TYPE) + b0 * dst_stride_y + b1 * dst_stride_z; *((__global DATA_TYPE *)output_ptr) = *((__global DATA_TYPE *)src.ptr); } -- cgit v1.2.1