diff options
Diffstat (limited to 'src/core/CL/cl_kernels/winograd_input_transform.cl')
-rw-r--r-- | src/core/CL/cl_kernels/winograd_input_transform.cl | 196 |
1 files changed, 167 insertions, 29 deletions
diff --git a/src/core/CL/cl_kernels/winograd_input_transform.cl b/src/core/CL/cl_kernels/winograd_input_transform.cl index fe1c0b3c1d..01cbc84ff3 100644 --- a/src/core/CL/cl_kernels/winograd_input_transform.cl +++ b/src/core/CL/cl_kernels/winograd_input_transform.cl @@ -555,12 +555,16 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nchw( } #if defined(SRC_DIM_1) && defined(SRC_DIM_2) -/** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data layout is NHWC +/** This OpenCL kernel computes the input transform when the output tile is 4x4, 4x1 or 1x4, the filter size 3x3, 3x1 or 1x3 and the data layout is NHWC * * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112) * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112) + * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4 + * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4 + * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time + * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time * * @param[in] src_ptr Pointer to the source image. Supported data types: F32 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) @@ -587,20 +591,25 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( int y = get_global_id(1); int z = get_global_id(2); - __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * src_stride_x; + __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float); // Clamp coordinates. This clamp is valid for all rows - int4 y_coord0 = (int4)(y * 4) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT; - int2 y_coord1 = (int2)(y * 4) + (int2)(4, 5) - (int2)PAD_LEFT; + int4 y_coord0 = (int4)(y * OUTPUT_TILE_W) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT; + int2 y_coord1 = (int2)(y * OUTPUT_TILE_W) + (int2)(4, 5) - (int2)PAD_LEFT; y_coord0 = clamp(y_coord0, -1, SRC_DIM_1); y_coord1 = clamp(y_coord1, -1, SRC_DIM_1); + int z_coord; + int4 valid_y0; + int2 valid_y1; + +#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // Row4 - int z_coord = (z * 4) - PAD_TOP + 4; + z_coord = (z * 4) - PAD_TOP + 4; // If z < 0, set y to -1 - int4 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); - int2 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0); + valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); + valid_y1 = select(y_coord1, -1, (int2)z_coord < 0); // If z >= SRC_DIM_2, set y to SRC_DIM_2 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2); valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2); @@ -628,9 +637,11 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( k3 += -2.0f * d41 + 2.0f * d43 - d42; k4 += 2.0f * d41 - 2.0f * d43 - d42; k5 += 4.0f * d41 - 5.0f * d43 + d45; +#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) +#if !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // Row0 - z_coord = (z * 4) - PAD_TOP + 0; + z_coord = (z * OUTPUT_TILE_H) - PAD_TOP + 0; #if PAD_TOP != 0 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); @@ -649,9 +660,36 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( float d03 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); float d04 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); float d05 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); +#else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) + int4 z_coords0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP; + int2 z_coords1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP; + + valid_y0 = select((int4)y_coord0.s0, (int4) - 1, z_coords0 < (int4)0); + valid_y1 = select((int2)y_coord0.s0, (int2) - 1, z_coords1 < (int2)0); + valid_y0 = select(valid_y0, (int4)SRC_DIM_1, z_coords0 >= (int4)SRC_DIM_2); + valid_y1 = select(valid_y1, (int2)SRC_DIM_1, z_coords1 >= (int2)SRC_DIM_2); + + z_coords0 = clamp((int4)z_coords0, (int4)0, (int4)(SRC_DIM_2 - 1)); + z_coords1 = clamp((int2)z_coords1, (int2)0, (int2)(SRC_DIM_2 - 1)); + + float d00 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coords0.s0 * src_stride_z); + float d01 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coords0.s1 * src_stride_z); + float d02 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coords0.s2 * src_stride_z); + float d03 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coords0.s3 * src_stride_z); + float d04 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coords1.s0 * src_stride_z); + float d05 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coords1.s1 * src_stride_z); +#endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) + + float out0 = 16.0f * d00 - 20.0f * d02 + 4.0f * d04; + float out1 = -16.0f * d01 - 16.0f * d02 + 4.0f * d03 + 4.0f * d04; + float out2 = 16.0f * d01 - 16.0f * d02 - 4.0f * d03 + 4.0f * d04; + float out3 = -8.0f * d01 - 4.0f * d02 + 8.0f * d03 + 4.0f * d04; + float out4 = 8.0f * d01 - 4.0f * d02 - 8.0f * d03 + 4.0f * d04; + float out5 = 16.0f * d01 - 20.0f * d03 + 4.0f * d05; +#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // Row2 - z_coord = (z * 4) - PAD_TOP + 2; + z_coord = (z * OUTPUT_TILE_H) - PAD_TOP + 2; valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); valid_y1 = select(y_coord1, -1, (int2)z_coord < 0); valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2); @@ -665,17 +703,12 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( float d24 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); float d25 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); - // Compute destination address - __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + x * dst_stride_x + (y + z * (int)NUM_TILES_X) * dst_stride_y); - - uint dst_plane_stride = dst_stride_z / sizeof(float); - - float out0 = k0; - float out1 = k1; - float out2 = k2; - float out3 = k3; - float out4 = k4; - float out5 = k5; + out0 += k0; + out1 += k1; + out2 += k2; + out3 += k3; + out4 += k4; + out5 += k5; float out6 = k0; float out7 = k1; float out8 = k2; @@ -702,12 +735,17 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( float out29 = k5; // Channels [0, 5]: [out00, out01, out02, out03, out04, out05] - out0 += 16.0f * d00 - 20.0f * d02 - 20.0f * d20 + 25.0f * d22 + 4.0f * d04 - 5.0f * d24; - out1 += -16.0f * d01 - 16.0f * d02 + 4.0f * d03 + 20.0f * d21 + 20.0f * d22 - 5.0f * d23 + 4.0f * d04 - 5.0f * d24; - out2 += 16.0f * d01 - 16.0f * d02 - 4.0f * d03 - 20.0f * d21 + 20.0f * d22 + 5.0f * d23 + 4.0f * d04 - 5.0f * d24; - out3 += -8.0f * d01 - 4.0f * d02 + 8.0f * d03 + 10.0f * d21 + 5.0f * d22 - 10.0f * d23 + 4.0f * d04 - 5.0f * d24; - out4 += 8.0f * d01 - 4.0f * d02 - 8.0f * d03 - 10.0f * d21 + 5.0f * d22 + 10.0f * d23 + 4.0f * d04 - 5.0f * d24; - out5 += 16.0f * d01 - 20.0f * d03 - 20.0f * d21 + 4.0f * d05 + 25.0f * d23 - 5.0f * d25; + out0 += -20.0f * d20 + 25.0f * d22 - 5.0f * d24; + out1 += 20.0f * d21 + 20.0f * d22 - 5.0f * d23 - 5.0f * d24; + out2 += -20.0f * d21 + 20.0f * d22 + 5.0f * d23 - 5.0f * d24; + out3 += 10.0f * d21 + 5.0f * d22 - 10.0f * d23 - 5.0f * d24; + out4 += -10.0f * d21 + 5.0f * d22 + 10.0f * d23 - 5.0f * d24; + out5 += -20.0f * d21 + 25.0f * d23 - 5.0f * d25; +#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) + + // Compute destination address + __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(float) + (y + z * (int)NUM_TILES_X) * dst_stride_y); + uint dst_plane_stride = dst_stride_z / sizeof(float); *((__global float *)dst_addr) = out0; dst_addr += dst_plane_stride; @@ -722,8 +760,9 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( *((__global float *)dst_addr) = out5; dst_addr += dst_plane_stride; +#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // Row1 - z_coord = (z * 4) - PAD_TOP + 1; + z_coord = (z * OUTPUT_TILE_H) - PAD_TOP + 1; // Row1 can never be out of bounds valid_y0 = y_coord0; valid_y1 = y_coord1; @@ -736,7 +775,7 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( float d15 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); // Row3 - z_coord = (z * 4) - PAD_TOP + 3; + z_coord = (z * OUTPUT_TILE_H) - PAD_TOP + 3; valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); valid_y1 = select(y_coord1, -1, (int2)z_coord < 0); valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2); @@ -859,7 +898,7 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( dst_addr += dst_plane_stride; // Row5 - z_coord = (z * 4) - PAD_TOP + 5; + z_coord = (z * OUTPUT_TILE_H) - PAD_TOP + 5; valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); valid_y1 = select(y_coord1, -1, (int2)z_coord < 0); valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2); @@ -894,7 +933,106 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( dst_addr += dst_plane_stride; *((__global float *)dst_addr) = out5; dst_addr += dst_plane_stride; +#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) +} + +#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) +/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1 for data layout NHWC + * + * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). + * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). + * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4 + * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1 + * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time + * + * @param[in] src_ptr Pointer to the source image. Supported data types: F32 + * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) + * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) + * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image + * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr + * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + */ +__kernel void winograd_input_transform_4x1_3x1_stepz1_nhwc( + TENSOR3D_DECLARATION(src), + TENSOR3D_DECLARATION(dst)) +{ + winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr, + src_stride_x, + src_step_x, + src_stride_y, + src_step_y, + src_stride_z, + src_step_z, + src_offset_first_element_in_bytes, + dst_ptr, + dst_stride_x, + dst_step_x, + dst_stride_y, + dst_step_y, + dst_stride_z, + dst_step_z, + dst_offset_first_element_in_bytes); +} +#endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) + +#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) +/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4 for data layout NHWC + * + * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). + * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). + * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1 + * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4 + * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time + * + * @param[in] src_ptr Pointer to the source image. Supported data types: F32 + * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) + * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) + * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image + * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) + * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr + * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + */ +__kernel void winograd_input_transform_1x4_1x3_stepz1_nhwc( + TENSOR3D_DECLARATION(src), + TENSOR3D_DECLARATION(dst)) +{ + winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr, + src_stride_x, + src_step_x, + src_stride_y, + src_step_y, + src_stride_z, + src_step_z, + src_offset_first_element_in_bytes, + dst_ptr, + dst_stride_x, + dst_step_x, + dst_stride_y, + dst_step_y, + dst_stride_z, + dst_step_z, + dst_offset_first_element_in_bytes); } +#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) #endif // defined(SRC_DIM_1) && defined(SRC_DIM_2) |