aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/winograd_input_transform.cl
diff options
context:
space:
mode:
authorAleksandr Nikolaev <aleksandr.nikolaev@arm.com>2021-03-18 14:03:48 +0000
committerGian Marco Iodice <gianmarco.iodice@arm.com>2021-03-29 12:35:59 +0000
commit2ca5b0801d49169f464a5e501e2691f0d346b93b (patch)
treefcae46b28aebd1c45d3e33ca591533a48ce87691 /src/core/CL/cl_kernels/winograd_input_transform.cl
parent1e3ab4264fb0455abe8a3903abab40c59b9be91e (diff)
downloadComputeLibrary-2ca5b0801d49169f464a5e501e2691f0d346b93b.tar.gz
New variant of OpenCL Winograd (4x4,5x5) input transformation
Resolves: COMPMID-4141 Signed-off-by: Aleksandr Nikolaev <aleksandr.nikolaev@arm.com> Change-Id: I1437680029ff25a3a5d4f6f258f30960545056a9 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5299 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/winograd_input_transform.cl')
-rw-r--r--src/core/CL/cl_kernels/winograd_input_transform.cl176
1 files changed, 84 insertions, 92 deletions
diff --git a/src/core/CL/cl_kernels/winograd_input_transform.cl b/src/core/CL/cl_kernels/winograd_input_transform.cl
index 8a27a7ecad..94f3772495 100644
--- a/src/core/CL/cl_kernels/winograd_input_transform.cl
+++ b/src/core/CL/cl_kernels/winograd_input_transform.cl
@@ -67,24 +67,45 @@
basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s7))); \
})
-#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
+// out = B^T * in, B^T is defined as for F(4x4,5x5) input transformation
+#define BT_MULTIPLY_4x4_5x5(out, in, comm_fact0, comm_fact1, DATA_TYPE) \
+ ({ \
+ comm_fact0 = in##2 + in##6 - (DATA_TYPE)4.25f * in##4; \
+ comm_fact1 = in##1 + in##5 - (DATA_TYPE)4.25f * in##3; \
+ out##0 += (DATA_TYPE)5.25f * (in##4 - in##2) - in##6; \
+ out##7 += (DATA_TYPE)5.25f * (in##3 - in##5) - in##1; \
+ out##1 = comm_fact0 + comm_fact1; \
+ out##2 = comm_fact0 - comm_fact1; \
+ \
+ comm_fact0 = (DATA_TYPE)0.25f * in##2 - (DATA_TYPE)1.25f * in##4 + in##6; \
+ comm_fact1 = (DATA_TYPE)0.5f * in##1 - (DATA_TYPE)2.5f * in##3 + (DATA_TYPE)2.f * in##5; \
+ out##3 = comm_fact0 + comm_fact1; \
+ out##4 = comm_fact0 - comm_fact1; \
+ \
+ comm_fact0 = (DATA_TYPE)4.f * in##2 - (DATA_TYPE)5.f * in##4 + in##6; \
+ comm_fact1 = (DATA_TYPE)2.f * in##1 - (DATA_TYPE)2.5f * in##3 + (DATA_TYPE)0.5f * in##5; \
+ out##5 = comm_fact0 + comm_fact1; \
+ out##6 = comm_fact0 - comm_fact1; \
+ })
+
+#define OUTPUT_ROW_4x4_5x5(out, comm_fact) \
({ \
- comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
- comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
- comm_fact.s2 = 2.5f * tmp.s3; \
- comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
- comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
- comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
- comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
+ comm_fact.s2 = 2.5f * out.s3; \
+ comm_fact.s1 = out.s1 - 4.25f * out.s3 + out.s5; \
+ comm_fact.s0 = out.s2 - 4.25f * out.s4 + out.s6; \
+ comm_fact.s4 = 0.25f * out.s2 - 1.25f * out.s4 + out.s6; \
+ comm_fact.s5 = 4.f * out.s2 - 5.f * out.s4 + out.s6; \
+ comm_fact.s3 = 0.5f * out.s1 + 2.f * out.s5 - comm_fact.s2; \
+ comm_fact.s6 = 2.f * out.s1 + 0.5f * out.s5 - comm_fact.s2; \
\
- out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
+ out.s0 += 5.25f * (out.s4 - out.s2) - out.s6; \
+ out.s7 += 5.25f * (out.s3 - out.s5) - out.s1; \
out.s1 = comm_fact.s0 + comm_fact.s1; \
out.s2 = comm_fact.s0 - comm_fact.s1; \
out.s3 = comm_fact.s3 + comm_fact.s4; \
out.s4 = comm_fact.s4 - comm_fact.s3; \
out.s5 = comm_fact.s5 + comm_fact.s6; \
out.s6 = comm_fact.s5 - comm_fact.s6; \
- out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
})
#define OUTPUT_ROW_2x2_7x7(out, tmp, comm_fact) \
@@ -826,53 +847,49 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
// Calculate common factors for intermediate tensor
VEC_DATA_TYPE(DATA_TYPE, 8)
- tmp0 = in_row0;
+ out0 = in_row0;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = 0.0f;
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
+ VEC_DATA_TYPE(DATA_TYPE, 8)
+ out1, out2, out3, out4, out5, out6, out7;
comm_fact0 += in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4;
- tmp0 += -in_row6 + (DATA_TYPE)5.25f * in_row4 - (DATA_TYPE)5.25f * in_row2;
+ out0 += -in_row6 + (DATA_TYPE)5.25f * (in_row4 - in_row2);
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25f * in_row3;
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact2 = (DATA_TYPE)0.25f * in_row2 - (DATA_TYPE)1.25f * in_row4 + in_row6;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1;
+ out1 = comm_fact0 + comm_fact1;
+ out2 = comm_fact0 - comm_fact1;
comm_fact0 = (DATA_TYPE)2.5f * in_row3;
comm_fact1 = (DATA_TYPE)0.5f * in_row1 - comm_fact0 + (DATA_TYPE)2.0f * in_row5;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1;
+ out3 = comm_fact1 + comm_fact2;
+ out4 = comm_fact2 - comm_fact1;
comm_fact1 = (DATA_TYPE)2.0f * in_row1 - comm_fact0 + (DATA_TYPE)0.5f * in_row5;
comm_fact2 = (DATA_TYPE)4.0f * in_row2 - (DATA_TYPE)5.0f * in_row4 + in_row6;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * in_row3 - (DATA_TYPE)5.25f * in_row5;
+ out5 = comm_fact1 + comm_fact2;
+ out6 = comm_fact2 - comm_fact1;
+ out7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * (in_row3 - in_row5);
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Calculate output rows (reuse comm_fact0 vector)
- VEC_DATA_TYPE(DATA_TYPE, 8)
- out0;
-
- OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out0, comm_fact0);
#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
- VEC_DATA_TYPE(DATA_TYPE, 8)
- out1, out2, out3, out4, out5, out6, out7;
-
- OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out1, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out2, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out3, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out4, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out5, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out6, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out7, comm_fact0);
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Store values across the channels
@@ -1330,7 +1347,6 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
int8 z_cond0 = z_coord_valid0 == z_coord0;
#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
-
// Load the input tile
VEC_DATA_TYPE(DATA_TYPE, 8)
in_row0;
@@ -1348,13 +1364,11 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
// Calculate common factors for intermediate tensor
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = 0.0f;
- VEC_DATA_TYPE(DATA_TYPE, 8)
- tmp0 = in_row0;
VEC_DATA_TYPE(DATA_TYPE, 8)
- out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
+ out0 = in_row0;
- OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out0, comm_fact0);
#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
@@ -1375,28 +1389,26 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
// Calculate common factors for intermediate tensor
VEC_DATA_TYPE(DATA_TYPE, 8)
comm_fact0 = 0.0f;
- VEC_DATA_TYPE(DATA_TYPE, 8)
- tmp0 = in_row0;
VEC_DATA_TYPE(DATA_TYPE, 8)
- out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
+ out0 = in_row0;
- OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out0, comm_fact0);
#else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
VEC_DATA_TYPE(DATA_TYPE, 8)
- in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
+ out0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, out7;
// Row0
- in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ out0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ out0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ out0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ out0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ out0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ out0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ out0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
+ out0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
- FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, out0.s, y_cond, z_cond0.s0);
// Row1
in_row1.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
@@ -1471,53 +1483,33 @@ __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row6.s, y_cond, z_cond0.s6);
// Row7
- in_row7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
- in_row7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
- in_row7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
- in_row7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
- in_row7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
- in_row7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
- in_row7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
- in_row7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ out7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ out7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ out7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ out7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ out7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ out7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ out7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
+ out7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
- FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row7.s, y_cond, z_cond0.s7);
+ FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, out7.s, y_cond, z_cond0.s7);
VEC_DATA_TYPE(DATA_TYPE, 8)
- comm_fact0 = in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4;
- VEC_DATA_TYPE(DATA_TYPE, 8)
- comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25f * in_row3;
+ out1, out2, out3, out4, out5, out6;
VEC_DATA_TYPE(DATA_TYPE, 8)
- comm_fact2 = (DATA_TYPE)0.25f * in_row2 - (DATA_TYPE)1.25f * in_row4 + in_row6;
-
- // Calculate intermediate tensor and reuse common factor vectors
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp0 = in_row0 - in_row6 + (DATA_TYPE)5.25f * in_row4 - (DATA_TYPE)5.25f * in_row2;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1;
+ comm_fact0, comm_fact1;
- comm_fact0 = (DATA_TYPE)2.5f * in_row3;
- comm_fact1 = (DATA_TYPE)0.5f * in_row1 - comm_fact0 + (DATA_TYPE)2.f * in_row5;
-
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1;
-
- comm_fact1 = (DATA_TYPE)2.f * in_row1 - comm_fact0 + (DATA_TYPE)0.5f * in_row5;
- comm_fact2 = (DATA_TYPE)4.f * in_row2 - (DATA_TYPE)5.f * in_row4 + in_row6;
-
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1;
- const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * in_row3 - (DATA_TYPE)5.25f * in_row5;
+ BT_MULTIPLY_4x4_5x5(out, in_row, comm_fact0, comm_fact1, DATA_TYPE);
// Calculate output rows (reuse comm_fact0 vector)
- VEC_DATA_TYPE(DATA_TYPE, 8)
- out0, out1, out2, out3, out4, out5, out6, out7;
- OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
- OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out0, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out1, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out2, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out3, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out4, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out5, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out6, comm_fact0);
+ OUTPUT_ROW_4x4_5x5(out7, comm_fact0);
#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
// Store values across the channels