aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/col2im.cl
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-09-13 17:20:04 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commite55b40a4d0cc5a82b8f0fd9ffec203ded9f3c63d (patch)
treee7736258428837e3889108909d58592937fe71fd /src/core/CL/cl_kernels/col2im.cl
parent64f1a908841913049ccc0eb941b5b213290d7bf7 (diff)
downloadComputeLibrary-e55b40a4d0cc5a82b8f0fd9ffec203ded9f3c63d.tar.gz
COMPMID-1581: Collapse windows
Change-Id: Iec56c9a96d9736a63f13b65efa33311950f20661 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/148572 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: bsgcomp <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/col2im.cl')
-rw-r--r--src/core/CL/cl_kernels/col2im.cl45
1 files changed, 22 insertions, 23 deletions
diff --git a/src/core/CL/cl_kernels/col2im.cl b/src/core/CL/cl_kernels/col2im.cl
index 5e52127f27..b02d07b332 100644
--- a/src/core/CL/cl_kernels/col2im.cl
+++ b/src/core/CL/cl_kernels/col2im.cl
@@ -23,7 +23,7 @@
*/
#include "helpers.h"
-#if defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT)
+#if defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT) && defined(NUM_GROUPS)
#if ELEMENT_SIZE == 1
#define COND_DATA_TYPE char
@@ -41,7 +41,7 @@
* @note The width of the input tensor must be passed at compile time using -DWIDTH_INPUT: e.g. -DWIDTH_INPUT=320
* @note The width of the output tensor must be passed at compile time using -DWIDTH_OUTPUT: e.g. -DWIDTH_OUTPUT=600
* @note The element size must be passed at compile time using -DELEMENT_SIZE: e.g. -DELEMENT_SIZE=4
- * @note In case of grouping the GROUPING flag must be passed at compile time using -DGROUPING
+ * @note The number of groups must be passed at compile time using -DNUM_GROUPS: e.g. -DNUM_GROUPS=4
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8/F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
@@ -58,15 +58,16 @@
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
* @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
+ * @param[in] dst_step_w dst_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
*/
__kernel void col2im(
TENSOR3D_DECLARATION(src),
- TENSOR3D_DECLARATION(dst),
- uint dst_stride_w)
+ TENSOR4D_DECLARATION(dst))
{
Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+ Tensor4D dst = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(dst, 0);
const uint xd = get_global_id(1) % WIDTH_OUTPUT; // x coordinate of the destination tensor
const uint yd = get_global_id(1) / WIDTH_OUTPUT; // y coordinate of the destination tensor
@@ -86,27 +87,25 @@ __kernel void col2im(
// If out-of-bound, overwrite with the first element
data = select((VEC_DATA_TYPE(DATA_TYPE, 8))data.s0, data, cond0);
- __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes;
-
-#if defined(GROUPING)
- // Compute output offset (batches on 4th dimension, no need to compute manually)
- int idx = yd * dst_stride_y + xd * dst_stride_x;
+#if NUM_GROUPS > 1
+ // Compute output offset (batches on 4th dimension)
+ int idx = yd * dst_stride_y + xd * dst_stride_x + (get_global_id(2) / NUM_GROUPS) * dst.stride_w;
- const uint group = get_global_id(2); // group ID
+ const uint group = get_global_id(2) % NUM_GROUPS; // group ID
x_clamped += group * WIDTH_INPUT;
-#else /* defined(GROUPING) */
+#else /* defined(NUM_GROUPS > 1 ) */
// Compute output offset (batches on 3rd dimension)
- int idx = yd * dst_stride_y + xd * dst_stride_x + get_global_id(2) * dst_stride_w;
-#endif /* GROUPING */
+ int idx = yd * dst.stride_y + xd * dst.stride_x + get_global_id(2) * dst.stride_w;
+#endif /* NUM_GROUPS > 1 */
// Store value
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s0 * dst_stride_z)) = data.s0;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s1 * dst_stride_z)) = data.s1;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s2 * dst_stride_z)) = data.s2;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s3 * dst_stride_z)) = data.s3;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s4 * dst_stride_z)) = data.s4;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s5 * dst_stride_z)) = data.s5;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s6 * dst_stride_z)) = data.s6;
- *((__global DATA_TYPE *)(output_ptr + idx + x_clamped.s7 * dst_stride_z)) = data.s7;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s0 * dst.stride_z)) = data.s0;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s1 * dst.stride_z)) = data.s1;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s2 * dst.stride_z)) = data.s2;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s3 * dst.stride_z)) = data.s3;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s4 * dst.stride_z)) = data.s4;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s5 * dst.stride_z)) = data.s5;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s6 * dst.stride_z)) = data.s6;
+ *((__global DATA_TYPE *)(dst.ptr + idx + x_clamped.s7 * dst.stride_z)) = data.s7;
}
-#endif // defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT)
+#endif // defined(DATA_TYPE) && defined(WIDTH_OUTPUT) && defined(ELEMENT_SIZE) && defined(WIDTH_INPUT) && defined(NUM_GROUPS)