aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-04-11 15:59:10 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:37 +0000
commite52a3000d2c13bc1b66ca66b3d12b6b836982394 (patch)
tree70e8ef5ba216762604f84228805aac9bd65747b6 /src
parentdd03870b63784abe499761da2b26b209b33f2db2 (diff)
downloadComputeLibrary-e52a3000d2c13bc1b66ca66b3d12b6b836982394.tar.gz
COMPMID-1026 - Add support for 4x4 output tile in CLWinogradConvolutionLayer
The performance achieved can be found at the following confluence page: https://confluence.arm.com/display/MLENG/GEMM-based+convolution+vs+Winograd-based+convolution+on+OpenCL Change-Id: I4b690cfdd4eb4ff0cd17b14fdd49ccaa1d1dc85c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127729 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/CL/CLKernelLibrary.cpp2
-rw-r--r--src/core/CL/cl_kernels/winograd.cl436
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp19
-rw-r--r--src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp8
-rw-r--r--src/core/CL/kernels/CLWinogradInputTransformKernel.cpp19
-rw-r--r--src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp4
-rw-r--r--src/runtime/CL/functions/CLConvolutionLayer.cpp37
-rw-r--r--src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp18
-rw-r--r--src/runtime/NEON/functions/NEConvolutionLayer.cpp8
9 files changed, 511 insertions, 40 deletions
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index 50f623fffb..59be956ad8 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -363,7 +363,9 @@ const std::map<std::string, std::string> CLKernelLibrary::_kernel_program_map =
{ "winograd_input_transform_4x4_5x5_stepz1_nchw", "winograd.cl" },
{ "winograd_input_transform_2x2_3x3_stepz1_nchw", "winograd.cl" },
{ "winograd_input_transform_2x2_3x3_stepz2_nchw", "winograd.cl" },
+ { "winograd_input_transform_4x4_3x3_stepz1_nchw", "winograd.cl" },
{ "winograd_output_transform_2x2_3x3_nchw", "winograd.cl" },
+ { "winograd_output_transform_4x4_3x3_nchw", "winograd.cl" },
{ "winograd_output_transform_4x4_5x5_nchw", "winograd.cl" },
{ "YUYV422_to_IYUV_bt709", "color_convert.cl" },
{ "YUYV422_to_NV12_bt709", "color_convert.cl" },
diff --git a/src/core/CL/cl_kernels/winograd.cl b/src/core/CL/cl_kernels/winograd.cl
index cda23b0155..f40a969ea0 100644
--- a/src/core/CL/cl_kernels/winograd.cl
+++ b/src/core/CL/cl_kernels/winograd.cl
@@ -708,6 +708,265 @@ __kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
vstore2(out33, 0, (__global float *)(dst_addr + 15 * dst_stride_z));
}
+/** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data format is NCHW
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ *
+ * @param[in] src_ptr Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst))
+{
+ int x = get_global_id(0);
+ int y = get_global_id(1);
+ int z = get_global_id(2);
+
+ // Compute input address
+ __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
+
+ src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
+
+ // Row4
+ float4 d40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
+ float2 d41 = vload2(2, (__global float *)(src_addr + 4 * src_stride_y));
+
+ float k0 = d41.s0;
+ float k1 = d41.s0;
+ float k2 = d41.s0;
+ float k3 = d41.s0;
+ float k4 = d41.s0;
+ float k5 = 0.0f;
+
+ k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
+ k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
+ k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
+ k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
+ k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
+ k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
+
+ // Row0
+ float4 d00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
+ float2 d01 = vload2(2, (__global float *)(src_addr + 0 * src_stride_y));
+
+ // Row2
+ float4 d20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
+ float2 d21 = vload2(2, (__global float *)(src_addr + 2 * src_stride_y));
+
+ // Compute destination address
+ __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y);
+
+ uint dst_plane_stride = dst_stride_z / sizeof(float);
+
+ float out0 = k0;
+ float out1 = k1;
+ float out2 = k2;
+ float out3 = k3;
+ float out4 = k4;
+ float out5 = k5;
+ float out6 = k0;
+ float out7 = k1;
+ float out8 = k2;
+ float out9 = k3;
+ float out10 = k4;
+ float out11 = k5;
+ float out12 = k0;
+ float out13 = k1;
+ float out14 = k2;
+ float out15 = k3;
+ float out16 = k4;
+ float out17 = k5;
+ float out18 = k0;
+ float out19 = k1;
+ float out20 = k2;
+ float out21 = k3;
+ float out22 = k4;
+ float out23 = k5;
+ float out24 = k0;
+ float out25 = k1;
+ float out26 = k2;
+ float out27 = k3;
+ float out28 = k4;
+ float out29 = k5;
+
+ // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
+ out0 += 16.0f * d00.s0 - 20.0f * d00.s2 - 20.0f * d20.s0 + 25.0f * d20.s2 + 4.0f * d01.s0 - 5.0f * d21.s0;
+ out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
+ out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 - 20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
+ out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
+ out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 - 10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
+ out5 += 16.0f * d00.s1 - 20.0f * d00.s3 - 20.0f * d20.s1 + 4.0f * d01.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
+
+ *(dst_addr) = out0;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out1;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out2;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out3;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out4;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out5;
+ dst_addr += dst_plane_stride;
+
+ // Row1
+ float4 d10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
+ float2 d11 = vload2(2, (__global float *)(src_addr + 1 * src_stride_y));
+
+ // Row3
+ float4 d30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
+ float2 d31 = vload2(2, (__global float *)(src_addr + 3 * src_stride_y));
+
+ // Compute common parts for the channels between [6, 29]
+ // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
+ // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
+ float part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
+ float part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
+ float part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
+ float part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
+ float part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
+ float part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
+ float part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
+ float part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
+ float part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
+ float part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
+ float part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
+ float part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
+
+ // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
+ // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
+ float part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
+ float part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
+ float part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
+ float part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
+ float part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
+ float part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
+ float part18 = part6 * 0.25f; // d20.s2 - d21.s0
+ float part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
+ float part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
+ float part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
+ float part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
+ float part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
+
+ out6 += part0 - part1;
+ out12 += part0 + part1;
+ out7 += part2 + part3 + part4 + part5;
+ out8 += part2 - part3 + part4 - part5;
+ out13 += part2 + part3 - part4 - part5;
+ out14 += part2 - part3 - part4 + part5;
+ out9 += part6 + part7 + part8 + part9;
+ out10 += part6 - part7 + part8 - part9;
+ out15 += part6 - part7 - part8 + part9;
+ out16 += part6 + part7 - part8 - part9;
+ out11 += part10 + part11;
+ out17 += part10 - part11;
+
+ out18 += part13 - part12;
+ out24 += part13 + part12;
+ out19 += part14 + part15 + part16 + part17;
+ out20 += part14 - part15 + part16 - part17;
+ out25 += part14 - part15 - part16 + part17;
+ out26 += part14 + part15 - part16 - part17;
+ out21 += part18 + part19 + part20 + part21;
+ out22 += part18 - part19 + part20 - part21;
+ out27 += part18 - part19 - part20 + part21;
+ out28 += part18 + part19 - part20 - part21;
+ out23 += part22 + part23;
+ out29 += part22 - part23;
+
+ *(dst_addr) = out6;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out7;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out8;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out9;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out10;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out11;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out12;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out13;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out14;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out15;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out16;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out17;
+ dst_addr += dst_plane_stride;
+
+ *(dst_addr) = out18;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out19;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out20;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out21;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out22;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out23;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out24;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out25;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out26;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out27;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out28;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out29;
+ dst_addr += dst_plane_stride;
+
+ // Row5
+ float4 d50 = vload4(0, (__global float *)(src_addr + 5 * src_stride_y));
+ float2 d51 = vload2(2, (__global float *)(src_addr + 5 * src_stride_y));
+
+ // Channels [30, 35]
+ out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+ out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+ out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+ out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+ out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+ out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
+
+ *(dst_addr) = out0;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out1;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out2;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out3;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out4;
+ dst_addr += dst_plane_stride;
+ *(dst_addr) = out5;
+ dst_addr += dst_plane_stride;
+}
+
#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
({ \
comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
@@ -981,6 +1240,183 @@ __kernel void winograd_output_transform_2x2_3x3_nchw(
vstore2((float2)(out10, out11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
}
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data format is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ *
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
+ * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
+ * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_4x4_3x3_nchw(
+ TENSOR3D_DECLARATION(src),
+ TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+ ,
+ VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+ // Each thread stores a 4x4 tile
+ Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+
+ const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
+
+ // Load the values across the 36 channels to compose the 6x6 tile
+ float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
+ float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
+ float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
+ float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
+ float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
+ float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
+
+ float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
+ float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
+ float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
+ float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
+ float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
+ float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
+
+ float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
+ float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
+ float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
+ float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
+ float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
+ float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
+
+ float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
+ float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
+ float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
+ float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
+ float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
+ float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
+
+ float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
+ float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
+ float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
+ float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
+ float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
+ float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
+
+ float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
+ float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
+ float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
+ float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
+ float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
+ float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
+
+ // Compute out00, out01, out02 and out03
+ float out00 = d01 + d21 + d41 + d11 + d31;
+ float out01 = d01 + d21 + d41 + d11 + d31;
+ float out02 = d01 + d21 + d41 + d11 + d31;
+ float out03 = d01 + d21 + d41 + d11 + d31;
+
+ float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
+ float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
+
+ out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
+ out01 += k1 - d02 - d12 - d22 - d32 - d42;
+ out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
+ out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
+
+ // Compute out10, out11, out12 and out13
+ float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+ float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+
+ k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
+ k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
+
+ out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
+ out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
+ out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
+ out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
+
+ // Compute out20, out21, out22 and out23
+ float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+ float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+
+ k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
+ k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
+
+ out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
+ out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
+ out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
+ out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
+
+ // Compute out30, out31, out32 and out33
+ float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+ float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+
+ k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
+ k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
+
+ out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
+ out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
+ out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
+ out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
+
+ int y_in = get_global_id(1);
+ int x_out = (y_in % NUM_TILES_X) * 4;
+ int y_out = (y_in / NUM_TILES_X) * 4;
+ int z_out = get_global_id(0);
+
+#if defined(HAS_BIAS)
+ // Add bias
+ Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+
+ float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
+
+ out00 += (float)b;
+ out01 += (float)b;
+ out02 += (float)b;
+ out03 += (float)b;
+
+ out10 += (float)b;
+ out11 += (float)b;
+ out12 += (float)b;
+ out13 += (float)b;
+
+ out20 += (float)b;
+ out21 += (float)b;
+ out22 += (float)b;
+ out23 += (float)b;
+
+ out30 += (float)b;
+ out31 += (float)b;
+ out32 += (float)b;
+ out33 += (float)b;
+
+#endif // defined(HAS_BIAS)
+
+ // Get output address
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
+
+ // Store the 4x4 output tile
+ vstore4((float4)(out00, out01, out02, out03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+ vstore4((float4)(out10, out11, out12, out13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+ vstore4((float4)(out20, out21, out22, out23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+ vstore4((float4)(out30, out31, out32, out33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+}
+
#define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \
({ \
comm_fact.s0 = d1 + d2; \
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 0a0de7ad4e..805a594af6 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -286,17 +286,23 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
else // The input tensors have not been reshaped
{
build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0)));
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
// Create kernels according to the architecture, data type and input size.
if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && is_data_type_float(data_type))
{
- kernel_name = "gemm_mm_floating_point_" + lower_string(string_from_data_type(data_type)) + "_bifrost";
- // The first kernel is optimized for the case of 1000 or less output elements (e.g. FC8 of AlexNet and VGG-16, and
- // FC1 of Inception v3). The second kernel is optimized for the case of greater than 1000 output elements (e.g.
- // FC6 and FC7 of AlexNet and VGG-16).
- if(input1->info()->dimension(0) <= 1000 && input0->info()->num_dimensions() == 1 && data_type == DataType::F32)
+ kernel_name = "gemm_mm_floating_point";
+
+ if(input0->info()->num_dimensions() != 1)
+ {
+ kernel_name += "_" + lower_string(string_from_data_type(data_type)) + "_bifrost";
+ }
+ else if(input1->info()->dimension(0) <= 1000 && data_type == DataType::F32)
{
- kernel_name += "_1000";
+ // The first kernel is optimized for the case of 1000 or less output elements (e.g. FC8 of AlexNet and VGG-16, and
+ // FC1 of Inception v3). The second kernel is optimized for the case of greater than 1000 output elements (e.g.
+ // FC6 and FC7 of AlexNet and VGG-16).
+ kernel_name += "_" + lower_string(string_from_data_type(data_type)) + "_bifrost_1000";
}
// The work-group size equal to the Bifrost quad size has been proved to be optimal for these kernels
@@ -309,7 +315,6 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
}
else // (MIDGARD and F32) or (F16)
{
- build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
kernel_name = "gemm_mm_floating_point";
}
build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elements_processed.y()));
diff --git a/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp b/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
index d3a33c01f9..41b3ac50b5 100644
--- a/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
@@ -55,9 +55,11 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
- ARM_COMPUTE_RETURN_ERROR_ON(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U));
- ARM_COMPUTE_RETURN_ERROR_ON(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U) && output_tile_size != Size2D(4U, 4U));
- ARM_COMPUTE_RETURN_ERROR_ON(kernel_size == Size2D(5U, 5U) && output_tile_size != Size2D(4U, 4U));
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Winograd filter transform only supports 3x3 and 5x5 kernels");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U)
+ && output_tile_size != Size2D(4U, 4U),
+ "Winograd filter transform only supports 2x2 or 4x4 output tile for 3x3 kernels");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(5U, 5U) && output_tile_size != Size2D(4U, 4U), "Winograd filter transform only supports 4x4 output tile for 5x5 kernels");
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(idx_w) != kernel_size.width || input->dimension(idx_h) != kernel_size.height);
ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
diff --git a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
index a47590d20f..febd22b04e 100644
--- a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
@@ -47,7 +47,9 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
const Size2D kernel_size = winograd_info.kernel_size;
ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1, "Winograd input transform only supports unit strides");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Winograd input transform only supports 3x3 and 5x5 kernels");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U), "Winograd input transform only supports 2x2 output tile for 3x3 kernels");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U)
+ && output_tile_size != Size2D(4U, 4U),
+ "Winograd input transform only supports 2x2 or 4x4 output tile for 3x3 kernels");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(5U, 5U) && output_tile_size != Size2D(4U, 4U), "Winograd input transform only supports 4x4 output tile for 5x5 kernels");
ARM_COMPUTE_UNUSED(conv_info);
ARM_COMPUTE_UNUSED(output_tile_size);
@@ -111,7 +113,6 @@ void CLWinogradInputTransformKernel::configure(const ICLTensor *input, ICLTensor
const int num_elements_y = input->info()->dimension(1) - (kernel_size.height - 1) + conv_info.pad_top() + conv_info.pad_bottom();
// Check if we need to extend the right or bottom border
- // FIXME: This actually is not needed. Added just for validating the result;
const unsigned int extra_border_right = ((num_elements_x % output_tile_size.width) == 0) ? 0u : static_cast<unsigned int>(output_tile_size.width - 1);
const unsigned int extra_border_bottom = ((num_elements_y % output_tile_size.height) == 0) ? 0u : static_cast<unsigned int>(output_tile_size.height - 1);
@@ -137,19 +138,13 @@ void CLWinogradInputTransformKernel::configure(const ICLTensor *input, ICLTensor
std::string kernel_name = "winograd_input_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string();
// Check optimized kernel if output_dims == 2x2
- if(output_tile_size.width == 2 && output_tile_size.height == 2)
+ if(output_tile_size == Size2D(2U, 2U))
{
- if((_input->info()->dimension(2) % 2) != 0)
- {
- _step_z = 1;
- }
- else
- {
- _step_z = 2;
- _lws_hint = cl::NDRange(1, 1, 8);
- }
+ _step_z = (_input->info()->dimension(2) % 2) != 0 ? 1 : 2;
}
+ _lws_hint = cl::NDRange(1, 1, 8);
+
// Append stepz and data layout
kernel_name += "_stepz";
kernel_name += support::cpp11::to_string(_step_z);
diff --git a/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
index 8ee1a82209..c5d2528aa2 100644
--- a/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
@@ -58,6 +58,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con
ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Only 3x3 and 5x5 kernels are supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size == Size2D(2U, 2U) && input->dimension(2) != 16, "Wrong number of batches");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size == Size2D(4U, 4U) && input->dimension(2) != 36, "Wrong number of batches");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(5U, 5U) && output_tile_size == Size2D(4U, 4U) && input->dimension(2) != 64, "Wrong number of batches");
// Compute number of elements to process in the X and Y direction
@@ -67,7 +68,6 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con
const int num_tiles_y = std::ceil(num_elements_y / static_cast<float>(output_tile_size.height));
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != static_cast<unsigned int>((num_tiles_x * num_tiles_y)));
- ARM_COMPUTE_UNUSED(output_tile_size);
if(bias != nullptr)
{
@@ -207,4 +207,4 @@ void CLWinogradOutputTransformKernel::run(const Window &window, cl::CommandQueue
enqueue(queue, *this, slice, _lws_hint);
}
while(window.slide_window_slice_3D(slice) && window.slide_window_slice_3D(slice_out));
-} \ No newline at end of file
+}
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp
index bcb5424aab..643e24d638 100644
--- a/src/runtime/CL/functions/CLConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp
@@ -48,9 +48,16 @@ void CLConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, c
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
ARM_COMPUTE_ERROR_THROW_ON(CLConvolutionLayer::validate(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info));
- switch(CLConvolutionLayer::get_convolution_method(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info,
+ switch(CLConvolutionLayer::get_convolution_method(input->info(), weights->info(), output->info(), conv_info,
weights_info, act_info, CLScheduler::get().target(), dilation))
{
+ case ConvolutionMethod::WINOGRAD:
+ {
+ auto f = arm_compute::support::cpp14::make_unique<CLWinogradConvolutionLayer>();
+ f->configure(input, weights, biases, output, conv_info);
+ _function = std::move(f);
+ break;
+ }
case ConvolutionMethod::DIRECT:
{
auto f = arm_compute::support::cpp14::make_unique<CLDirectConvolutionLayer>();
@@ -79,8 +86,14 @@ Status CLConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo
//Configure if the parameters match the direct convolution or the gemm-based
const GPUTarget gpu_target = CLScheduler::get().target();
- switch(CLConvolutionLayer::get_convolution_method(input, weights, biases, output, conv_info, weights_info, act_info, gpu_target, dilation))
+ switch(CLConvolutionLayer::get_convolution_method(input, weights, output, conv_info, weights_info, act_info, gpu_target, dilation))
{
+ case ConvolutionMethod::WINOGRAD:
+ {
+ //Validate Winograd
+ CLWinogradConvolutionLayer::validate(input, weights, biases, output, conv_info);
+ break;
+ }
case ConvolutionMethod::DIRECT:
{
// Validate direct convolution layer
@@ -101,19 +114,25 @@ Status CLConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo
return Status{};
}
-ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
+ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info,
const WeightsInfo &weights_info, const ActivationLayerInfo &act_info, const GPUTarget gpu_target, const Size2D &dilation)
{
- ARM_COMPUTE_UNUSED(input);
- ARM_COMPUTE_UNUSED(weights);
- ARM_COMPUTE_UNUSED(biases);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(weights);
ARM_COMPUTE_UNUSED(output);
- ARM_COMPUTE_UNUSED(conv_info);
ARM_COMPUTE_UNUSED(weights_info);
ARM_COMPUTE_UNUSED(gpu_target);
- ARM_COMPUTE_UNUSED(dilation);
- ARM_COMPUTE_UNUSED(act_info);
+ const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+ const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
+
+ if((input->data_type() == DataType::F32) && (input->data_layout() == DataLayout::NCHW) && (input->dimension(idx_c) > 3) && (weights->dimension(idx_w) == 3) && (weights->dimension(idx_h) == 3)
+ && (weights->num_dimensions() <= 4) && (conv_info.stride().first == 1) && (conv_info.stride().second == 1) && (dilation == Size2D(1U, 1U)) && (!act_info.enabled()))
+ {
+ return ConvolutionMethod::WINOGRAD;
+ }
return ConvolutionMethod::GEMM;
}
diff --git a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
index 0aa7f8d1b5..86ccddac88 100644
--- a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
@@ -44,13 +44,18 @@ void CLWinogradConvolutionLayer::configure(ICLTensor *input, const ICLTensor *we
const size_t idx_height = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT);
// Input shape
- const TensorShape input_shape = input->info()->tensor_shape();
+ const TensorShape input_shape = input->info()->tensor_shape();
+ const unsigned int input_w = input->info()->tensor_shape()[idx_width];
+ const unsigned int input_h = input->info()->tensor_shape()[idx_height];
// Kernel size
const unsigned int kernel_w = weights->info()->tensor_shape()[idx_width];
const unsigned int kernel_h = weights->info()->tensor_shape()[idx_height];
- const WinogradInfo winograd_info = WinogradInfo(Size2D(2, 2),
+ //Winograd output tile
+ const Size2D output_tile = (Size2D(kernel_w, kernel_h) == Size2D(3U, 3U) && input_w <= 4 && input_h <= 4) ? Size2D(2U, 2U) : Size2D(4U, 4U);
+
+ const WinogradInfo winograd_info = WinogradInfo(output_tile,
Size2D(kernel_w, kernel_h),
Size2D(input_shape[idx_width], input_shape[idx_height]),
conv_info,
@@ -95,13 +100,18 @@ Status CLWinogradConvolutionLayer::validate(const ITensorInfo *input, const ITen
const size_t idx_height = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
// Input shape
- const TensorShape input_shape = input->tensor_shape();
+ const TensorShape input_shape = input->tensor_shape();
+ const unsigned int input_w = input->tensor_shape()[idx_width];
+ const unsigned int input_h = input->tensor_shape()[idx_height];
// Kernel size
const unsigned int kernel_w = weights->tensor_shape()[idx_width];
const unsigned int kernel_h = weights->tensor_shape()[idx_height];
- const WinogradInfo winograd_info = WinogradInfo(Size2D(2, 2),
+ //Winograd output tile
+ const Size2D output_tile = (Size2D(kernel_w, kernel_h) == Size2D(3U, 3U) && input_w <= 4 && input_h <= 4) ? Size2D(2U, 2U) : Size2D(4U, 4U);
+
+ const WinogradInfo winograd_info = WinogradInfo(output_tile,
Size2D(kernel_w, kernel_h),
Size2D(input_shape[idx_width], input_shape[idx_height]),
conv_info,
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
index afc354533d..b0603e92d2 100644
--- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
@@ -109,10 +109,12 @@ ConvolutionMethod NEConvolutionLayer::get_convolution_method(const ITensorInfo *
ARM_COMPUTE_ERROR_ON_NULLPTR(weights);
ARM_COMPUTE_UNUSED(output);
ARM_COMPUTE_UNUSED(weights_info);
- ARM_COMPUTE_UNUSED(act_info);
- if((input->data_type() == DataType::F32) && (weights->dimension(0) == 3) && (weights->dimension(1) == 3) && (weights->num_dimensions() <= 4) && (conv_info.stride().first == 1)
- && (conv_info.stride().second == 1) && (dilation == Size2D(1U, 1U)))
+ const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+
+ if((input->data_type() == DataType::F32) && (input->data_layout() == DataLayout::NCHW) && (weights->dimension(idx_w) == 3) && (weights->dimension(idx_h) == 3) && (weights->num_dimensions() <= 4)
+ && (conv_info.stride().first == 1) && (conv_info.stride().second == 1) && (dilation == Size2D(1U, 1U)) && (!act_info.enabled()))
{
//FIXME Until COMPMID-1041 is implemented Winograd is slower than GEMM on A53.
if(Scheduler::get().cpu_info().get_cpu_model() != CPUModel::A53)