aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/winograd.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/winograd.cl')
-rw-r--r--src/core/CL/cl_kernels/winograd.cl177
1 files changed, 163 insertions, 14 deletions
diff --git a/src/core/CL/cl_kernels/winograd.cl b/src/core/CL/cl_kernels/winograd.cl
index 14bebb4b0b..6a570277ab 100644
--- a/src/core/CL/cl_kernels/winograd.cl
+++ b/src/core/CL/cl_kernels/winograd.cl
@@ -23,11 +23,11 @@
*/
#include "helpers.h"
-#if defined(NUM_CHANNELS)
+#if defined(SRC_DIM_Z)
/** This OpenCL kernel performs Winograd filter transform 3x3 when the data format is NCHW and the output tile is 2x2
*
- * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
@@ -52,7 +52,7 @@ __kernel void winograd_filter_transform_2x2_3x3_nchw(
TENSOR4D_DECLARATION(src),
TENSOR3D_DECLARATION(dst))
{
- Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
+ Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
@@ -92,8 +92,8 @@ __kernel void winograd_filter_transform_2x2_3x3_nchw(
out3.s3 = (w2.s2);
int z = get_global_id(2);
- int x0 = z / NUM_CHANNELS; // idx filter
- int y0 = z % NUM_CHANNELS; // idx channel
+ int x0 = z / SRC_DIM_Z; // idx filter
+ int y0 = z % SRC_DIM_Z; // idx channel
// Get output address
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
@@ -119,7 +119,7 @@ __kernel void winograd_filter_transform_2x2_3x3_nchw(
/** This OpenCL kernel performs Winograd filter transform 3x3 when the data format is NCHW and the output tile is 4x4
*
- * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
@@ -144,7 +144,7 @@ __kernel void winograd_filter_transform_4x4_3x3_nchw(
TENSOR4D_DECLARATION(src),
TENSOR3D_DECLARATION(dst))
{
- Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
+ Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
@@ -210,8 +210,8 @@ __kernel void winograd_filter_transform_4x4_3x3_nchw(
out5.s5 = (w2.s2);
int z = get_global_id(2);
- int x0 = z / NUM_CHANNELS; // idx filter
- int y0 = z % NUM_CHANNELS; // idx channel
+ int x0 = z / SRC_DIM_Z; // idx filter
+ int y0 = z % SRC_DIM_Z; // idx channel
// Get output address
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
@@ -255,9 +255,158 @@ __kernel void winograd_filter_transform_4x4_3x3_nchw(
*(__global float *)(dst_addr + 35 * dst_stride_z) = out5.s5;
}
+/** This OpenCL kernel performs Winograd filter transform 3x3 when the data format is NHWC and the output tile is 4x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x4_3x3_nhwc(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
+
+ const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
+
+ // Load the values from the input tensor
+ float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
+ float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
+ float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
+ float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
+ float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
+ float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
+ float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
+ float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
+ float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
+
+ // Transform the 3x3 tile in a 6x6 tile
+ float out00, out01, out02, out03, out04, out05;
+ float out10, out11, out12, out13, out14, out15;
+ float out20, out21, out22, out23, out24, out25;
+ float out30, out31, out32, out33, out34, out35;
+ float out40, out41, out42, out43, out44, out45;
+ float out50, out51, out52, out53, out54, out55;
+
+ out00 = out01 = out02 = out03 = out04 = out05 = 0.f;
+ out10 = out11 = out12 = out13 = out14 = out15 = 0.f;
+ out20 = out21 = out22 = out23 = out24 = out25 = 0.f;
+ out30 = out31 = out32 = out33 = out34 = out35 = 0.f;
+ out40 = out41 = out42 = out43 = out44 = out45 = 0.f;
+ out50 = out51 = out52 = out53 = out54 = out55 = 0.f;
+
+ // Row 0
+ out00 = (w00) / 16.f;
+ out01 = (-w00 - w01 - w02) / 24.f;
+ out02 = (-w00 + w01 - w02) / 24.f;
+ out03 = (w00 + 2.f * w01 + 4.f * w02) / 96.f;
+ out04 = (w00 - 2.f * w01 + 4.f * w02) / 96.f;
+ out05 = (w02) / 4.f;
+
+ // Row 1
+ out10 = (-w00 - w10 - w20) / 24.f;
+ out11 = (w00 + w10 + w20 + w01 + w11 + w21 + w02 + w12 + w22) / 36.f;
+ out12 = (w00 + w10 + w20 - w01 - w11 - w21 + w02 + w12 + w22) / 36.f;
+ out13 = (-w00 - w10 - w20 + 2.f * (-w01 - w11 - w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
+ out14 = (-w00 - w10 - w20 + 2.f * (w01 + w11 + w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
+ out15 = (-w02 - w12 - w22) / 6.f;
+
+ // Row 2
+ out20 = (-w00 + w10 - w20) / 24.f;
+ out21 = (w00 - w10 + w20 + w01 - w11 + w21 + w02 - w12 + w22) / 36.f;
+ out22 = (w00 - w10 + w20 - w01 + w11 - w21 + w02 - w12 + w22) / 36.f;
+ out23 = (-w00 + w10 - w20 + 2.f * (-w01 + w11 - w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
+ out24 = (-w00 + w10 - w20 + 2.f * (w01 - w11 + w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
+ out25 = (-w02 + w12 - w22) / 6.f;
+
+ // Row 3
+ out30 = (w00 + 2.f * w10 + 4.f * w20) / 96.f;
+ out31 = (-w00 - 2.f * w10 - 4.f * w20 - w01 - 2.f * w11 - 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
+ out32 = (-w00 - 2.f * w10 - 4.f * w20 + w01 + 2.f * w11 + 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
+ out33 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (w01 + 2.f * w11 + 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
+ out34 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (-w01 - 2.f * w11 - 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
+ out35 = (w02 + 2.f * w12 + 4.f * w22) / 24.f;
+
+ // Row 4
+ out40 = (w00 - 2.f * w10 + 4.f * w20) / 96.f;
+ out41 = (-w00 + 2.f * w10 - 4.f * w20 - w01 + 2.f * w11 - 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
+ out42 = (-w00 + 2.f * w10 - 4.f * w20 + w01 - 2.f * w11 + 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
+ out43 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (w01 - 2.f * w11 + 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
+ out44 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (-w01 + 2.f * w11 - 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
+ out45 = (w02 - 2.f * w12 + 4.f * w22) / 24.f;
+
+ // Row 5
+ out50 = (w20) / 4.f;
+ out51 = (-w20 - w21 - w22) / 6.f;
+ out52 = (-w20 + w21 - w22) / 6.f;
+ out53 = (w20 + 2.f * w21 + 4.f * w22) / 24.f;
+ out54 = (w20 - 2.f * w21 + 4.f * w22) / 24.f;
+ out55 = (w22);
+
+ int x0 = get_global_id(2); // idx filter
+ int y0 = get_global_id(0); // idx channel
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
+
+ // Store the values across the channels
+ *(__global float *)(dst_addr + 0 * dst_stride_z) = out00;
+ *(__global float *)(dst_addr + 1 * dst_stride_z) = out01;
+ *(__global float *)(dst_addr + 2 * dst_stride_z) = out02;
+ *(__global float *)(dst_addr + 3 * dst_stride_z) = out03;
+ *(__global float *)(dst_addr + 4 * dst_stride_z) = out04;
+ *(__global float *)(dst_addr + 5 * dst_stride_z) = out05;
+ *(__global float *)(dst_addr + 6 * dst_stride_z) = out10;
+ *(__global float *)(dst_addr + 7 * dst_stride_z) = out11;
+ *(__global float *)(dst_addr + 8 * dst_stride_z) = out12;
+ *(__global float *)(dst_addr + 9 * dst_stride_z) = out13;
+ *(__global float *)(dst_addr + 10 * dst_stride_z) = out14;
+ *(__global float *)(dst_addr + 11 * dst_stride_z) = out15;
+ *(__global float *)(dst_addr + 12 * dst_stride_z) = out20;
+ *(__global float *)(dst_addr + 13 * dst_stride_z) = out21;
+ *(__global float *)(dst_addr + 14 * dst_stride_z) = out22;
+ *(__global float *)(dst_addr + 15 * dst_stride_z) = out23;
+ *(__global float *)(dst_addr + 16 * dst_stride_z) = out24;
+ *(__global float *)(dst_addr + 17 * dst_stride_z) = out25;
+ *(__global float *)(dst_addr + 18 * dst_stride_z) = out30;
+ *(__global float *)(dst_addr + 19 * dst_stride_z) = out31;
+ *(__global float *)(dst_addr + 20 * dst_stride_z) = out32;
+ *(__global float *)(dst_addr + 21 * dst_stride_z) = out33;
+ *(__global float *)(dst_addr + 22 * dst_stride_z) = out34;
+ *(__global float *)(dst_addr + 23 * dst_stride_z) = out35;
+ *(__global float *)(dst_addr + 24 * dst_stride_z) = out40;
+ *(__global float *)(dst_addr + 25 * dst_stride_z) = out41;
+ *(__global float *)(dst_addr + 26 * dst_stride_z) = out42;
+ *(__global float *)(dst_addr + 27 * dst_stride_z) = out43;
+ *(__global float *)(dst_addr + 28 * dst_stride_z) = out44;
+ *(__global float *)(dst_addr + 29 * dst_stride_z) = out45;
+ *(__global float *)(dst_addr + 30 * dst_stride_z) = out50;
+ *(__global float *)(dst_addr + 31 * dst_stride_z) = out51;
+ *(__global float *)(dst_addr + 32 * dst_stride_z) = out52;
+ *(__global float *)(dst_addr + 33 * dst_stride_z) = out53;
+ *(__global float *)(dst_addr + 34 * dst_stride_z) = out54;
+ *(__global float *)(dst_addr + 35 * dst_stride_z) = out55;
+}
/** This OpenCL kernel performs Winograd filter transform 5x5 when the data format is NCHW and the output tile is 4x4
*
- * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
@@ -282,7 +431,7 @@ __kernel void winograd_filter_transform_4x4_5x5_nchw(
TENSOR4D_DECLARATION(src),
TENSOR3D_DECLARATION(dst))
{
- Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
+ Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
@@ -452,8 +601,8 @@ __kernel void winograd_filter_transform_4x4_5x5_nchw(
out7.s7 = w41;
int z = get_global_id(2);
- int x0 = z / NUM_CHANNELS; // idx filter
- int y0 = z % NUM_CHANNELS; // idx channel
+ int x0 = z / SRC_DIM_Z; // idx filter
+ int y0 = z % SRC_DIM_Z; // idx channel
// Get output address
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
@@ -524,7 +673,7 @@ __kernel void winograd_filter_transform_4x4_5x5_nchw(
*(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
*(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
}
-#endif // defined(NUM_CHANNELS)
+#endif // defined(SRC_DIM_Z)
#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP)
/** This OpenCL kernel computes the input transform when the kernel size is 3x3 and the output tile is 2x2