aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-06-08 17:50:38 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:53:20 +0000
commitbe39f1281ee4ad4e83b92ad5a09f6bdc40b5718f (patch)
tree303b1250013301f915a43fe7cb86a5deedcaefe1
parenta8fdbbb072ed3a36da590d39168d676d6f5de26a (diff)
downloadComputeLibrary-be39f1281ee4ad4e83b92ad5a09f6bdc40b5718f.tar.gz
COMPMID-1204 Add NHWC data format support to Winograd input transform 4x4_5x5
Change-Id: I3dffdd1772b78db27a4374f074a24a15a9552189 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/134859 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CL/kernels/CLWinogradInputTransformKernel.h4
-rw-r--r--src/core/CL/CLKernelLibrary.cpp1
-rw-r--r--src/core/CL/cl_kernels/winograd.cl271
-rw-r--r--src/core/CL/kernels/CLWinogradInputTransformKernel.cpp2
-rw-r--r--tests/validation/CL/Winograd.cpp4
5 files changed, 277 insertions, 5 deletions
diff --git a/arm_compute/core/CL/kernels/CLWinogradInputTransformKernel.h b/arm_compute/core/CL/kernels/CLWinogradInputTransformKernel.h
index 58e8291161..ddf07200d8 100644
--- a/arm_compute/core/CL/kernels/CLWinogradInputTransformKernel.h
+++ b/arm_compute/core/CL/kernels/CLWinogradInputTransformKernel.h
@@ -49,7 +49,7 @@ public:
* @note Winograd input transform supports the following configurations:
* F(output tile, kernel size):F(2x2, 3x3), F(4x4, 3x3), F(4x4, 5x5)
* Strides: only unit strides
- * Data Layout: NCHW for all configurations, NHWC for F(4x4, 3x3)
+ * Data Layout: NCHW for all configurations, NHWC for F(4x4, 3x3), F(4x4, 5x5)
*
* @param[in] input The input tensor to transform. Data types supported: F32
* @param[in] output The output tensor. The shape for this tensor can be calculated using the utility function @p compute_winograd_input_transform_shape. Data types supported: Same as @p input
@@ -61,7 +61,7 @@ public:
* @note Winograd input transform supports the following configurations:
* F(output tile, kernel size):F(2x2, 3x3), F(4x4, 3x3), F(4x4, 5x5)
* Strides: only unit strides
- * Data Layout: NCHW for all configurations, NHWC for F(4x4, 3x3)
+ * Data Layout: NCHW for all configurations, NHWC for F(4x4, 3x3), F(4x4, 5x5)
*
* @param[in] input The input tensor to transform. Data types supported: F32
* @param[in] output The output tensor. The shape for this tensor can be calculated using the utility function @p compute_winograd_input_transform_shape. Data types supported: Same as @p input
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index 4a37b8ae03..97e9e1057b 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -379,6 +379,7 @@ const std::map<std::string, std::string> CLKernelLibrary::_kernel_program_map =
{ "winograd_input_transform_2x2_3x3_stepz2_nchw", "winograd.cl" },
{ "winograd_input_transform_4x4_3x3_stepz1_nchw", "winograd.cl" },
{ "winograd_input_transform_4x4_3x3_stepz1_nhwc", "winograd.cl" },
+ { "winograd_input_transform_4x4_5x5_stepz1_nhwc", "winograd.cl" },
{ "winograd_output_transform_2x2_3x3_nchw", "winograd.cl" },
{ "winograd_output_transform_4x4_3x3_nchw", "winograd.cl" },
{ "winograd_output_transform_4x4_5x5_nchw", "winograd.cl" },
diff --git a/src/core/CL/cl_kernels/winograd.cl b/src/core/CL/cl_kernels/winograd.cl
index ea499a83f0..93e038fff9 100644
--- a/src/core/CL/cl_kernels/winograd.cl
+++ b/src/core/CL/cl_kernels/winograd.cl
@@ -1374,6 +1374,8 @@ __kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
*
* @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
* @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
*
* @param[in] src_ptr Pointer to the source image. Supported data types: F32
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
@@ -1879,6 +1881,275 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
*((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
*((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
}
+
+#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
+
+/** This OpenCL kernel computes the input transform when the kernel size is 5x5, the output tile is 4x4 and data layout is NHWC
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
+ * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ int x = get_global_id(0);
+ int y = get_global_id(1);
+ int z = get_global_id(2);
+
+ // Compute input address
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float);
+
+ // Clamp coordinates. This clamp is valid for all rows
+ int8 y_coord = (int8)(y * 4) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
+ y_coord = clamp(y_coord, -1, SRC_DIM_1);
+
+ // Load 8x8 input tile
+ float8 in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
+
+ // Row0
+ int z_coord = (z * 4) - PAD_TOP + 0;
+ int8 valid_y = select(y_coord, -1, (int8)z_coord < 0); // If z < 0, set y to -1
+ valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
+ z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); // Clamp z coordinate
+
+ in_row0.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row0.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row1
+ z_coord = (z * 4) - PAD_TOP + 1;
+ valid_y = select(y_coord, -1, (int8)z_coord < 0);
+ valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
+
+ in_row1.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row1.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row2
+ z_coord = (z * 4) - PAD_TOP + 2;
+ valid_y = select(y_coord, -1, (int8)z_coord < 0);
+ valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
+
+ in_row2.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row2.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row3
+ z_coord = (z * 4) - PAD_TOP + 3;
+ valid_y = select(y_coord, -1, (int8)z_coord < 0);
+ valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
+
+ in_row3.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row3.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row4
+ z_coord = (z * 4) - PAD_TOP + 4;
+ valid_y = select(y_coord, -1, (int8)z_coord < 0);
+ valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
+
+ in_row4.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row4.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row5
+ z_coord = (z * 4) - PAD_TOP + 5;
+ valid_y = select(y_coord, -1, (int8)z_coord < 0);
+ valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
+
+ in_row5.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row5.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row6
+ z_coord = (z * 4) - PAD_TOP + 6;
+ valid_y = select(y_coord, -1, (int8)z_coord < 0);
+ valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
+
+ in_row6.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row6.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Row7
+ z_coord = (z * 4) - PAD_TOP + 7;
+ valid_y = select(y_coord, -1, (int8)z_coord < 0);
+ valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
+ z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
+
+ in_row7.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
+ in_row7.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
+
+ // Calculate common factors for intermediate tensor
+ float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
+ float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
+ float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
+
+ // Calculate intermediate tensor and reuse common factor vectors
+ const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
+ const float8 tmp1 = comm_fact0 + comm_fact1;
+ const float8 tmp2 = comm_fact0 - comm_fact1;
+
+ comm_fact0 = 2.5f * in_row3;
+ comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
+
+ const float8 tmp3 = comm_fact1 + comm_fact2;
+ const float8 tmp4 = comm_fact2 - comm_fact1;
+
+ comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
+ comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
+
+ const float8 tmp5 = comm_fact1 + comm_fact2;
+ const float8 tmp6 = comm_fact2 - comm_fact1;
+ const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
+
+ // Calculate output rows (reuse comm_fact0 vector)
+ float8 out0, out1, out2, out3, out4, out5, out6, out7;
+
+ OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
+
+ // Store values across the 64 channels
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(float) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
+
+ *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
+ *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
+ *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
+ *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
+ *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
+ *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
+ *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
+ *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
+ *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
+ *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
+ *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
+ *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
+ *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
+ *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
+ *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
+ *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
+ *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
+ *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
+ *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
+ *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
+ *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
+ *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
+ *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
+ *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
+ *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
+ *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
+ *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
+ *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
+ *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
+ *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
+ *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
+ *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
+ *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
+ *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
+ *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
+ *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
+ *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
+ *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
+ *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
+ *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
+ *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
+ *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
+ *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
+ *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
+ *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
+ *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
+ *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
+ *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
+ *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
+ *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
+ *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
+ *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
+ *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
+ *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
+ *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
+ *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
+ *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
+ *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
+ *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
+ *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
+ *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
+ *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
+ *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
+ *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
+}
+#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP)
#if defined(NUM_TILES_X)
diff --git a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
index e73ac7df76..274c9e7c3d 100644
--- a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
@@ -46,7 +46,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
const Size2D kernel_size = winograd_info.kernel_size;
ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1, "Winograd input transform only supports unit strides");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Winograd input transform only supports 3x3 and 5x5 kernels");
- ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::NHWC && (output_tile_size != Size2D(4U, 4U) || kernel_size != Size2D(3U, 3U)));
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() == DataLayout::NHWC && output_tile_size != Size2D(4U, 4U), "Winograd input transform only supports 4x4 output tile for NHWC data layout");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U)
&& output_tile_size != Size2D(4U, 4U),
"Winograd input transform only supports 2x2 or 4x4 output tile for 3x3 kernels");
diff --git a/tests/validation/CL/Winograd.cpp b/tests/validation/CL/Winograd.cpp
index 7f866cd11a..b869f4c314 100644
--- a/tests/validation/CL/Winograd.cpp
+++ b/tests/validation/CL/Winograd.cpp
@@ -129,7 +129,7 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(frame
FIXTURE_DATA_TEST_CASE(RunSmall, CLWinogradInputTransformFixture, framework::DatasetMode::PRECOMMIT, combine(framework::dataset::concat(combine(SmallWinogradInputTransformDataset,
framework::dataset::make("DataLayout", { DataLayout::NCHW })),
- combine(datasets::SmallWinogradInputTransformDataset4x4_3x3(),
+ combine(framework::dataset::concat(datasets::SmallWinogradInputTransformDataset4x4_3x3(), datasets::SmallWinogradInputTransformDataset4x4_5x5()),
framework::dataset::make("DataLayout", { DataLayout::NHWC }))),
framework::dataset::make("DataType", { DataType::F32 })))
{
@@ -138,7 +138,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLWinogradInputTransformFixture, framework::Dat
FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradInputTransformFixture, framework::DatasetMode::NIGHTLY, combine(framework::dataset::concat(combine(LargeWinogradInputTransformDataset,
framework::dataset::make("DataLayout", { DataLayout::NCHW })),
- combine(datasets::LargeWinogradInputTransformDataset4x4_3x3(),
+ combine(framework::dataset::concat(datasets::LargeWinogradInputTransformDataset4x4_3x3(), datasets::LargeWinogradInputTransformDataset4x4_5x5()),
framework::dataset::make("DataLayout", { DataLayout::NHWC }))),
framework::dataset::make("DataType", { DataType::F32 })))
{