From 2ca5b0801d49169f464a5e501e2691f0d346b93b Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolaev Date: Thu, 18 Mar 2021 14:03:48 +0000 Subject: New variant of OpenCL Winograd (4x4,5x5) input transformation Resolves: COMPMID-4141 Signed-off-by: Aleksandr Nikolaev Change-Id: I1437680029ff25a3a5d4f6f258f30960545056a9 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5299 Tested-by: Arm Jenkins Reviewed-by: Giorgio Arena Reviewed-by: Gian Marco Iodice --- src/core/CL/cl_kernels/winograd_input_transform.cl | 176 ++++++++++----------- 1 file changed, 84 insertions(+), 92 deletions(-) diff --git a/src/core/CL/cl_kernels/winograd_input_transform.cl b/src/core/CL/cl_kernels/winograd_input_transform.cl index 8a27a7ecad..94f3772495 100644 --- a/src/core/CL/cl_kernels/winograd_input_transform.cl +++ b/src/core/CL/cl_kernels/winograd_input_transform.cl @@ -67,24 +67,45 @@ basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s7))); \ }) -#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \ +// out = B^T * in, B^T is defined as for F(4x4,5x5) input transformation +#define BT_MULTIPLY_4x4_5x5(out, in, comm_fact0, comm_fact1, DATA_TYPE) \ + ({ \ + comm_fact0 = in##2 + in##6 - (DATA_TYPE)4.25f * in##4; \ + comm_fact1 = in##1 + in##5 - (DATA_TYPE)4.25f * in##3; \ + out##0 += (DATA_TYPE)5.25f * (in##4 - in##2) - in##6; \ + out##7 += (DATA_TYPE)5.25f * (in##3 - in##5) - in##1; \ + out##1 = comm_fact0 + comm_fact1; \ + out##2 = comm_fact0 - comm_fact1; \ + \ + comm_fact0 = (DATA_TYPE)0.25f * in##2 - (DATA_TYPE)1.25f * in##4 + in##6; \ + comm_fact1 = (DATA_TYPE)0.5f * in##1 - (DATA_TYPE)2.5f * in##3 + (DATA_TYPE)2.f * in##5; \ + out##3 = comm_fact0 + comm_fact1; \ + out##4 = comm_fact0 - comm_fact1; \ + \ + comm_fact0 = (DATA_TYPE)4.f * in##2 - (DATA_TYPE)5.f * in##4 + in##6; \ + comm_fact1 = (DATA_TYPE)2.f * in##1 - (DATA_TYPE)2.5f * in##3 + (DATA_TYPE)0.5f * in##5; \ + out##5 = comm_fact0 + comm_fact1; \ + out##6 = comm_fact0 - comm_fact1; \ + }) + +#define OUTPUT_ROW_4x4_5x5(out, comm_fact) \ ({ \ - comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \ - comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \ - comm_fact.s2 = 2.5f * tmp.s3; \ - comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \ - comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \ - comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \ - comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \ + comm_fact.s2 = 2.5f * out.s3; \ + comm_fact.s1 = out.s1 - 4.25f * out.s3 + out.s5; \ + comm_fact.s0 = out.s2 - 4.25f * out.s4 + out.s6; \ + comm_fact.s4 = 0.25f * out.s2 - 1.25f * out.s4 + out.s6; \ + comm_fact.s5 = 4.f * out.s2 - 5.f * out.s4 + out.s6; \ + comm_fact.s3 = 0.5f * out.s1 + 2.f * out.s5 - comm_fact.s2; \ + comm_fact.s6 = 2.f * out.s1 + 0.5f * out.s5 - comm_fact.s2; \ \ - out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \ + out.s0 += 5.25f * (out.s4 - out.s2) - out.s6; \ + out.s7 += 5.25f * (out.s3 - out.s5) - out.s1; \ out.s1 = comm_fact.s0 + comm_fact.s1; \ out.s2 = comm_fact.s0 - comm_fact.s1; \ out.s3 = comm_fact.s3 + comm_fact.s4; \ out.s4 = comm_fact.s4 - comm_fact.s3; \ out.s5 = comm_fact.s5 + comm_fact.s6; \ out.s6 = comm_fact.s5 - comm_fact.s6; \ - out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \ }) #define OUTPUT_ROW_2x2_7x7(out, tmp, comm_fact) \ @@ -826,53 +847,49 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nchw( // Calculate common factors for intermediate tensor VEC_DATA_TYPE(DATA_TYPE, 8) - tmp0 = in_row0; + out0 = in_row0; VEC_DATA_TYPE(DATA_TYPE, 8) comm_fact0 = 0.0f; #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) + VEC_DATA_TYPE(DATA_TYPE, 8) + out1, out2, out3, out4, out5, out6, out7; comm_fact0 += in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4; - tmp0 += -in_row6 + (DATA_TYPE)5.25f * in_row4 - (DATA_TYPE)5.25f * in_row2; + out0 += -in_row6 + (DATA_TYPE)5.25f * (in_row4 - in_row2); VEC_DATA_TYPE(DATA_TYPE, 8) comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25f * in_row3; VEC_DATA_TYPE(DATA_TYPE, 8) comm_fact2 = (DATA_TYPE)0.25f * in_row2 - (DATA_TYPE)1.25f * in_row4 + in_row6; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1; + out1 = comm_fact0 + comm_fact1; + out2 = comm_fact0 - comm_fact1; comm_fact0 = (DATA_TYPE)2.5f * in_row3; comm_fact1 = (DATA_TYPE)0.5f * in_row1 - comm_fact0 + (DATA_TYPE)2.0f * in_row5; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1; + out3 = comm_fact1 + comm_fact2; + out4 = comm_fact2 - comm_fact1; comm_fact1 = (DATA_TYPE)2.0f * in_row1 - comm_fact0 + (DATA_TYPE)0.5f * in_row5; comm_fact2 = (DATA_TYPE)4.0f * in_row2 - (DATA_TYPE)5.0f * in_row4 + in_row6; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * in_row3 - (DATA_TYPE)5.25f * in_row5; + out5 = comm_fact1 + comm_fact2; + out6 = comm_fact2 - comm_fact1; + out7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * (in_row3 - in_row5); #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // Calculate output rows (reuse comm_fact0 vector) - VEC_DATA_TYPE(DATA_TYPE, 8) - out0; - - OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0); + OUTPUT_ROW_4x4_5x5(out0, comm_fact0); #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) - VEC_DATA_TYPE(DATA_TYPE, 8) - out1, out2, out3, out4, out5, out6, out7; - - OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0); - OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0); - OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0); - OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0); - OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0); - OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0); - OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0); + OUTPUT_ROW_4x4_5x5(out1, comm_fact0); + OUTPUT_ROW_4x4_5x5(out2, comm_fact0); + OUTPUT_ROW_4x4_5x5(out3, comm_fact0); + OUTPUT_ROW_4x4_5x5(out4, comm_fact0); + OUTPUT_ROW_4x4_5x5(out5, comm_fact0); + OUTPUT_ROW_4x4_5x5(out6, comm_fact0); + OUTPUT_ROW_4x4_5x5(out7, comm_fact0); #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // Store values across the channels @@ -1330,7 +1347,6 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc( int8 z_cond0 = z_coord_valid0 == z_coord0; #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) - // Load the input tile VEC_DATA_TYPE(DATA_TYPE, 8) in_row0; @@ -1348,13 +1364,11 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc( // Calculate common factors for intermediate tensor VEC_DATA_TYPE(DATA_TYPE, 8) comm_fact0 = 0.0f; - VEC_DATA_TYPE(DATA_TYPE, 8) - tmp0 = in_row0; VEC_DATA_TYPE(DATA_TYPE, 8) - out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f; + out0 = in_row0; - OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0); + OUTPUT_ROW_4x4_5x5(out0, comm_fact0); #elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) @@ -1375,28 +1389,26 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc( // Calculate common factors for intermediate tensor VEC_DATA_TYPE(DATA_TYPE, 8) comm_fact0 = 0.0f; - VEC_DATA_TYPE(DATA_TYPE, 8) - tmp0 = in_row0; VEC_DATA_TYPE(DATA_TYPE, 8) - out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f; + out0 = in_row0; - OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0); + OUTPUT_ROW_4x4_5x5(out0, comm_fact0); #else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) VEC_DATA_TYPE(DATA_TYPE, 8) - in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7; + out0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, out7; // Row0 - in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + out0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + out0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + out0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + out0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + out0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + out0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + out0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); + out0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z); - FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0); + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, out0.s, y_cond, z_cond0.s0); // Row1 in_row1.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z); @@ -1471,53 +1483,33 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc( FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row6.s, y_cond, z_cond0.s6); // Row7 - in_row7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); - in_row7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); - in_row7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); - in_row7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); - in_row7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); - in_row7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); - in_row7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); - in_row7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + out7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + out7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + out7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + out7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + out7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + out7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + out7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); + out7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z); - FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row7.s, y_cond, z_cond0.s7); + FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, out7.s, y_cond, z_cond0.s7); VEC_DATA_TYPE(DATA_TYPE, 8) - comm_fact0 = in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4; - VEC_DATA_TYPE(DATA_TYPE, 8) - comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25f * in_row3; + out1, out2, out3, out4, out5, out6; VEC_DATA_TYPE(DATA_TYPE, 8) - comm_fact2 = (DATA_TYPE)0.25f * in_row2 - (DATA_TYPE)1.25f * in_row4 + in_row6; - - // Calculate intermediate tensor and reuse common factor vectors - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp0 = in_row0 - in_row6 + (DATA_TYPE)5.25f * in_row4 - (DATA_TYPE)5.25f * in_row2; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1; + comm_fact0, comm_fact1; - comm_fact0 = (DATA_TYPE)2.5f * in_row3; - comm_fact1 = (DATA_TYPE)0.5f * in_row1 - comm_fact0 + (DATA_TYPE)2.f * in_row5; - - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1; - - comm_fact1 = (DATA_TYPE)2.f * in_row1 - comm_fact0 + (DATA_TYPE)0.5f * in_row5; - comm_fact2 = (DATA_TYPE)4.f * in_row2 - (DATA_TYPE)5.f * in_row4 + in_row6; - - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1; - const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * in_row3 - (DATA_TYPE)5.25f * in_row5; + BT_MULTIPLY_4x4_5x5(out, in_row, comm_fact0, comm_fact1, DATA_TYPE); // Calculate output rows (reuse comm_fact0 vector) - VEC_DATA_TYPE(DATA_TYPE, 8) - out0, out1, out2, out3, out4, out5, out6, out7; - OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0); - OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0); - OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0); - OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0); - OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0); - OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0); - OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0); - OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0); + OUTPUT_ROW_4x4_5x5(out0, comm_fact0); + OUTPUT_ROW_4x4_5x5(out1, comm_fact0); + OUTPUT_ROW_4x4_5x5(out2, comm_fact0); + OUTPUT_ROW_4x4_5x5(out3, comm_fact0); + OUTPUT_ROW_4x4_5x5(out4, comm_fact0); + OUTPUT_ROW_4x4_5x5(out5, comm_fact0); + OUTPUT_ROW_4x4_5x5(out6, comm_fact0); + OUTPUT_ROW_4x4_5x5(out7, comm_fact0); #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // Store values across the channels -- cgit v1.2.1