aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemmlowp.cl
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-01-15 14:39:13 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:52:54 +0000
commite03342e3ba78ecf5b9128339dd47c30e00cb8565 (patch)
tree49a5456b056086385585149704725f4fdf516f32 /src/core/CL/cl_kernels/gemmlowp.cl
parent513fe2e80512091c22af3204053dbd53f8ccf12b (diff)
downloadComputeLibrary-e03342e3ba78ecf5b9128339dd47c30e00cb8565.tar.gz
COMPMID-799 - Use new OpenCL 8-bit dot product instruction
Change-Id: I03d6c6db13bcb565f117725bdab2b68c89a49e21 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122185 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/gemmlowp.cl')
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl98
1 files changed, 98 insertions, 0 deletions
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index 5e144d73af..da915778e7 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -190,6 +190,63 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0)
#if MULT_INTERLEAVE4X4_HEIGHT == 1
for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP))
{
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+ // Load values from matrix A (interleaved) and matrix B (transposed)
+ uchar16 a0 = vload16(0, src_addr_a);
+ uchar4 b0 = vload4(0, src_addr_b);
+ uchar4 b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP);
+ uchar4 b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP);
+ uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP);
+
+ // Accumulate
+ c00 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+ c01 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+ c02 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+ c03 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+ c10 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+ c11 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+ c12 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+ c13 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+ c20 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+ c21 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+ c22 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+ c23 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+ c30 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+ c31 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+ c32 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+ c33 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+ // Load values from matrix A (interleaved) and matrix B (transposed)
+ a0 = vload16(0, src_addr_a + 16);
+ b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP);
+ b1 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP);
+ b2 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP);
+ b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP);
+
+ // Accumulate
+ c00 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+ c01 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+ c02 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+ c03 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+ c10 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+ c11 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+ c12 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+ c13 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+ c20 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+ c21 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+ c22 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+ c23 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+
+ c30 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
+ c31 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
+ c32 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
+ c33 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
+#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Load values from matrix A (interleaved) and matrix B (transposed)
uchar16 a0 = vload16(0, src_addr_a);
uchar4 b0 = vload4(0, src_addr_b);
@@ -375,6 +432,7 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0)
c31 += (ushort)a0.sF * b0.s1;
c32 += (ushort)a0.sF * b0.s2;
c33 += (ushort)a0.sF * b0.s3;
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
@@ -666,6 +724,13 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
{
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+ // Accumulate
+ acc00 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a0);
+ acc01 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a0);
+ acc02 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a0);
+ acc03 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a0);
+#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Accumulate
ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0;
ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0;
@@ -691,9 +756,17 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
{
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+ // Accumulate
+ acc10 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a1);
+ acc11 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a1);
+ acc12 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a1);
+ acc13 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a1);
+#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Accumulate
ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0;
ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0;
@@ -719,10 +792,18 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
{
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+ // Accumulate
+ acc20 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a2);
+ acc21 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a2);
+ acc22 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a2);
+ acc23 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a2);
+#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Accumulate
ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0;
ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0;
@@ -748,10 +829,18 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
{
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+ // Accumulate
+ acc30 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a3);
+ acc31 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a3);
+ acc32 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a3);
+ acc33 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a3);
+#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Accumulate
ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0;
ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0;
@@ -777,10 +866,18 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
{
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+ // Accumulate
+ acc40 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a4);
+ acc41 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a4);
+ acc42 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a4);
+ acc43 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a4);
+#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Accumulate
ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0;
ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0;
@@ -806,6 +903,7 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
}