aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-06-08 16:30:00 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:53:09 +0000
commit80d65d8f27f3ade2b461517c4fa29938c37590ed (patch)
tree851826cbe6470c52034cbc327d2160a0ccb685a6
parent0fc25454b6ce499b7f89792f91b81a61a42d3182 (diff)
downloadComputeLibrary-80d65d8f27f3ade2b461517c4fa29938c37590ed.tar.gz
COMPMID-1204 Add NHWC data format support to Winograd filter transform 4x4_5x5
Change-Id: I09adb8493fd2c438871c3d734cadf4b950c24d25 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/134822 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CL/kernels/CLWinogradFilterTransformKernel.h4
-rw-r--r--src/core/CL/CLKernelLibrary.cpp1
-rw-r--r--src/core/CL/cl_kernels/winograd.cl253
-rw-r--r--src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp2
-rw-r--r--tests/validation/CL/Winograd.cpp4
5 files changed, 259 insertions, 5 deletions
diff --git a/arm_compute/core/CL/kernels/CLWinogradFilterTransformKernel.h b/arm_compute/core/CL/kernels/CLWinogradFilterTransformKernel.h
index 7360646019..5e3d815d8c 100644
--- a/arm_compute/core/CL/kernels/CLWinogradFilterTransformKernel.h
+++ b/arm_compute/core/CL/kernels/CLWinogradFilterTransformKernel.h
@@ -51,7 +51,7 @@ public:
* @note Winograd filter transform supports the following configurations:
* F(output tile, kernel size):F(2x2, 3x3), F(4x4, 3x3), F(4x4, 5x5)
* Strides: only unit strides
- * Data Layout: NCHW for all configurations, NHWC for F(4x4, 3x3)
+ * Data Layout: NCHW for all configurations, NHWC for F(4x4, 3x3) and F(4x4, 5x5)
*
* @param[in] input Source tensor. The input is a 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM] (NCHW data layout) or [IFM, kernel_x, kernel_y, OFM] (NHWC data layout). Data types supported: F32.
* @param[out] output The output tensor. The shape for this tensor can be calculated using the utility function @p compute_winograd_filter_transform_shape. Data types supported: Same as @p input
@@ -63,7 +63,7 @@ public:
* @note Winograd filter transform supports the following configurations:
* F(output tile, kernel size):F(2x2, 3x3), F(4x4, 3x3), F(4x4, 5x5)
* Strides: only unit strides
- * Data Layout: NCHW for all configurations, NHWC for F(4x4, 3x3)
+ * Data Layout: NCHW for all configurations, NHWC for F(4x4, 3x3) and F(4x4, 5x5)
*
* @param[in] input Source tensor. The input is a 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM] (NCHW data layout) or [IFM, kernel_x, kernel_y, OFM] (NHWC data layout). Data types supported: F32.
* @param[out] output The output tensor. The shape for this tensor can be calculated using the utility function @p compute_winograd_filter_transform_shape. Data types supported: Same as @p input
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index 9139048142..4a37b8ae03 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -373,6 +373,7 @@ const std::map<std::string, std::string> CLKernelLibrary::_kernel_program_map =
{ "winograd_filter_transform_4x4_3x3_nchw", "winograd.cl" },
{ "winograd_filter_transform_4x4_5x5_nchw", "winograd.cl" },
{ "winograd_filter_transform_4x4_3x3_nhwc", "winograd.cl" },
+ { "winograd_filter_transform_4x4_5x5_nhwc", "winograd.cl" },
{ "winograd_input_transform_4x4_5x5_stepz1_nchw", "winograd.cl" },
{ "winograd_input_transform_2x2_3x3_stepz1_nchw", "winograd.cl" },
{ "winograd_input_transform_2x2_3x3_stepz2_nchw", "winograd.cl" },
diff --git a/src/core/CL/cl_kernels/winograd.cl b/src/core/CL/cl_kernels/winograd.cl
index 485b0a5411..ea499a83f0 100644
--- a/src/core/CL/cl_kernels/winograd.cl
+++ b/src/core/CL/cl_kernels/winograd.cl
@@ -673,6 +673,259 @@ __kernel void winograd_filter_transform_4x4_5x5_nchw(
*(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
*(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
}
+
+/** This OpenCL kernel performs Winograd filter transform 5x5 when the data layout is NHWC and the output tile is 4x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
+ * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x4_5x5_nhwc(
+ TENSOR4D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
+
+ const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * sizeof(float) + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
+
+ // Load the values from the input tensor
+ float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
+ float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
+ float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
+ float w03 = *((__global float *)(src_addr + 0 * src_stride_z + 3 * src_stride_y));
+ float w04 = *((__global float *)(src_addr + 0 * src_stride_z + 4 * src_stride_y));
+ float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
+ float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
+ float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
+ float w13 = *((__global float *)(src_addr + 1 * src_stride_z + 3 * src_stride_y));
+ float w14 = *((__global float *)(src_addr + 1 * src_stride_z + 4 * src_stride_y));
+ float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
+ float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
+ float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
+ float w23 = *((__global float *)(src_addr + 2 * src_stride_z + 3 * src_stride_y));
+ float w24 = *((__global float *)(src_addr + 2 * src_stride_z + 4 * src_stride_y));
+ float w30 = *((__global float *)(src_addr + 3 * src_stride_z + 0 * src_stride_y));
+ float w31 = *((__global float *)(src_addr + 3 * src_stride_z + 1 * src_stride_y));
+ float w32 = *((__global float *)(src_addr + 3 * src_stride_z + 2 * src_stride_y));
+ float w33 = *((__global float *)(src_addr + 3 * src_stride_z + 3 * src_stride_y));
+ float w34 = *((__global float *)(src_addr + 3 * src_stride_z + 4 * src_stride_y));
+ float w40 = *((__global float *)(src_addr + 4 * src_stride_z + 0 * src_stride_y));
+ float w41 = *((__global float *)(src_addr + 4 * src_stride_z + 1 * src_stride_y));
+ float w42 = *((__global float *)(src_addr + 4 * src_stride_z + 2 * src_stride_y));
+ float w43 = *((__global float *)(src_addr + 4 * src_stride_z + 3 * src_stride_y));
+ float w44 = *((__global float *)(src_addr + 4 * src_stride_z + 4 * src_stride_y));
+
+ // Transform the 3x3 tile in a 8x8 tile
+ float8 out0 = 0.0f;
+ float8 out1 = 0.0f;
+ float8 out2 = 0.0f;
+ float8 out3 = 0.0f;
+ float8 out4 = 0.0f;
+ float8 out5 = 0.0f;
+ float8 out6 = 0.0f;
+ float8 out7 = 0.0f;
+
+ // Row 0
+ out0.s0 = w00;
+ out0.s1 = -2.f * (w00 + w01 + w02 + w03 + w04) / 9.f;
+ out0.s2 = -2.f * (w00 - w01 + w02 - w03 + w04) / 9.f;
+ out0.s3 = (w00 + 2.f * w01 + 4.f * w02 + 8.f * w03 + 16.f * w04) / 90.f;
+ out0.s4 = (w00 - 2.f * w01 + 4.f * w02 - 8.f * w03 + 16.f * w04) / 90.f;
+ out0.s5 = (16.f * w00 + 8.f * w01 + 4.f * w02 + 2.f * w03 + w04) / 180.f;
+ out0.s6 = (16.f * w00 - 8.f * w01 + 4.f * w02 - 2.f * w03 + w04) / 180.f;
+ out0.s7 = w04;
+
+ // Row 1
+ out1.s0 = -2.f * (w00 + w10 + w20 + w30 + w40) / 9.f;
+ out1.s1 = 4.f * ((w00 + w10 + w20 + w30 + w40) + (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) + (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
+ out1.s2 = 4.f * ((w00 + w10 + w20 + w30 + w40) - (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) - (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
+ out1.s3 = -((w00 + w10 + w20 + w30 + w40) + 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
+ (w04 + w14 + w24 + w34 + w44)) / 405.f;
+ out1.s4 = -((w00 + w10 + w20 + w30 + w40) - 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
+ (w04 + w14 + w24 + w34 + w44)) / 405.f;
+ out1.s5 = -(16.f * (w00 + w10 + w20 + w30 + w40) + 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 2.f * (w03 + w13 + w23 + w33 + w43) +
+ (w04 + w14 + w24 + w34 + w44)) / 810.f;
+ out1.s6 = -(16.f * (w00 + w10 + w20 + w30 + w40) - 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 2.f * (w03 + w13 + w23 + w33 + w43) +
+ (w04 + w14 + w24 + w34 + w44)) / 810.f;
+ out1.s7 = -2.f * (w04 + w14 + w24 + w34 + w44) / 9.f;
+
+ // Row 2
+ out2.s0 = -2.f * (w00 - w10 + w20 - w30 + w40) / 9.f;
+ out2.s1 = 4.f * ((w00 - w10 + w20 - w30 + w40) + (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) + (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
+ out2.s2 = 4.f * ((w00 - w10 + w20 - w30 + w40) - (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) - (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
+ out2.s3 = -((w00 - w10 + w20 - w30 + w40) + 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
+ (w04 - w14 + w24 - w34 + w44)) / 405.f;
+ out2.s4 = -((w00 - w10 + w20 - w30 + w40) - 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
+ (w04 - w14 + w24 - w34 + w44)) / 405.f;
+ out2.s5 = -(16.f * (w00 - w10 + w20 - w30 + w40) + 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 2.f * (w03 - w13 + w23 - w33 + w43) +
+ (w04 - w14 + w24 - w34 + w44)) / 810.f;
+ out2.s6 = -(16.f * (w00 - w10 + w20 - w30 + w40) - 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 2.f * (w03 - w13 + w23 - w33 + w43) +
+ (w04 - w14 + w24 - w34 + w44)) / 810.f;
+ out2.s7 = -2.f * (w04 - w14 + w24 - w34 + w44) / 9.f;
+
+ // Row 3
+ out3.s0 = (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) / 90.f;
+ out3.s1 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) +
+ (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
+ out3.s2 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) -
+ (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
+ out3.s3 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 8.f
+ * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
+ out3.s4 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 8.f
+ * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
+ out3.s5 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
+ out3.s6 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
+ out3.s7 = (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44) / 90.f;
+
+ // Row 4
+ out4.s0 = (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) / 90.f;
+ out4.s1 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) +
+ (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
+ out4.s2 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) -
+ (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
+ out4.s3 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 8.f
+ * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
+ out4.s4 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 8.f
+ * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
+ out4.s5 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
+ out4.s6 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
+ (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
+ out4.s7 = (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44) / 90.f;
+
+ // Row 5
+ out5.s0 = (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) / 180.f;
+ out5.s1 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) +
+ (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
+ out5.s2 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) -
+ (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
+ out5.s3 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 8.f
+ * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
+ out5.s4 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 8.f
+ * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
+ out5.s5 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
+ (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
+ out5.s6 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
+ (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
+ out5.s7 = (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44) / 180.f;
+
+ // Row 6
+ out6.s0 = (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) / 180.f;
+ out6.s1 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) +
+ (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
+ out6.s2 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) -
+ (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
+ out6.s3 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 8.f
+ * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
+ out6.s4 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 8.f
+ * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
+ out6.s5 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
+ (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
+ out6.s6 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
+ (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
+ out6.s7 = (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44) / 180.f;
+
+ // Row 7
+ out7.s0 = w40;
+ out7.s1 = -2.f * (w40 + w41 + w42 + w43 + w44) / 9.f;
+ out7.s2 = -2.f * (w40 - w41 + w42 - w43 + w44) / 9.f;
+ out7.s3 = (w40 + 2.f * w41 + 4.f * w42 + 8.f * w43 + 16.f * w44) / 90.f;
+ out7.s4 = (w40 - 2.f * w41 + 4.f * w42 - 8.f * w43 + 16.f * w44) / 90.f;
+ out7.s5 = (16.f * w40 + 8.f * w41 + 4.f * w42 + 2.f * w43 + w44) / 180.f;
+ out7.s6 = (16.f * w40 - 8.f * w41 + 4.f * w42 - 2.f * w43 + w44) / 180.f;
+ out7.s7 = w44;
+
+ int x0 = get_global_id(2); // idx filter
+ int y0 = get_global_id(0); // idx channel
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
+
+ // Store the 64 values across the 64 channels
+ *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
+ *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
+ *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
+ *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
+ *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
+ *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
+ *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
+ *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
+ *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
+ *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
+ *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
+ *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
+ *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
+ *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
+ *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
+ *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
+ *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
+ *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
+ *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
+ *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
+ *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
+ *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
+ *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
+ *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
+ *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
+ *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
+ *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
+ *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
+ *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
+ *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
+ *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
+ *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
+ *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
+ *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
+ *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
+ *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
+ *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
+ *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
+ *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
+ *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
+ *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
+ *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
+ *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
+ *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
+ *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
+ *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
+ *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
+ *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
+ *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
+ *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
+ *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
+ *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
+ *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
+ *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
+ *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
+ *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
+ *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
+ *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
+ *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
+ *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
+ *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
+ *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
+ *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
+ *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
+}
#endif // defined(SRC_DIM_Z)
#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP)
diff --git a/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp b/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
index cf4d73fbc1..779df637f6 100644
--- a/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
@@ -55,7 +55,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Winograd filter transform only supports 3x3 and 5x5 kernels");
- ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::NHWC && (output_tile_size != Size2D(4U, 4U) || kernel_size != Size2D(3U, 3U)));
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() == DataLayout::NHWC && output_tile_size != Size2D(4U, 4U), "Winograd filter transform only supports 4x4 output tile for NHWC data layout");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U)
&& output_tile_size != Size2D(4U, 4U),
"Winograd filter transform only supports 2x2 or 4x4 output tile for 3x3 kernels");
diff --git a/tests/validation/CL/Winograd.cpp b/tests/validation/CL/Winograd.cpp
index e0b4b5f795..7f866cd11a 100644
--- a/tests/validation/CL/Winograd.cpp
+++ b/tests/validation/CL/Winograd.cpp
@@ -216,7 +216,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLWinogradFilterTransformFixture, framework::Da
framework::dataset::make("OutputTile", Size2D(4U, 4U)))),
combine(datasets::Small5x5Shapes(), framework::dataset::make("OutputTile", Size2D(4U, 4U)))),
framework::dataset::make("DataLayout", { DataLayout::NCHW })),
- combine(combine(datasets::Small3x3Shapes(), framework::dataset::make("OutputTile", Size2D(4U, 4U))), framework::dataset::make("DataLayout", { DataLayout::NHWC }))),
+ combine(combine(framework::dataset::concat(datasets::Small3x3Shapes(), datasets::Small5x5Shapes()), framework::dataset::make("OutputTile", Size2D(4U, 4U))), framework::dataset::make("DataLayout", { DataLayout::NHWC }))),
framework::dataset::make("DataType", { DataType::F32 })))
{
// Validate output
@@ -229,7 +229,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradFilterTransformFixture, framework::Da
framework::dataset::make("OutputTile", Size2D(4U, 4U)))),
combine(datasets::Large5x5Shapes(), framework::dataset::make("OutputTile", Size2D(4U, 4U)))),
framework::dataset::make("DataLayout", { DataLayout::NCHW })),
- combine(combine(datasets::Large3x3Shapes(), framework::dataset::make("OutputTile", Size2D(4U, 4U))), framework::dataset::make("DataLayout", { DataLayout::NHWC }))),
+ combine(combine(framework::dataset::concat(datasets::Large3x3Shapes(), datasets::Large5x5Shapes()), framework::dataset::make("OutputTile", Size2D(4U, 4U))), framework::dataset::make("DataLayout", { DataLayout::NHWC }))),
framework::dataset::make("DataType", { DataType::F32 })))
{
// Validate output