aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemmlowp.cl
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-07-06 17:06:36 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit6200fa405b16b4145b926a96de197718ad31bf93 (patch)
tree4ecaa3a29d79371c6439acf5bb580bc7ba99af09 /src/core/CL/cl_kernels/gemmlowp.cl
parentea55f91e5dd4e5bc766fabbac6df6ce3ab984d0e (diff)
downloadComputeLibrary-6200fa405b16b4145b926a96de197718ad31bf93.tar.gz
COMPMID-1288 Optimizing CLGEMMLowp using 8 bit dot product instruction
Change-Id: I536174b9381660a94578d6aa1892a6289a820391 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139109 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/gemmlowp.cl')
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl513
1 files changed, 415 insertions, 98 deletions
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index da915778e7..0ee7c27350 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -190,63 +190,6 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0)
#if MULT_INTERLEAVE4X4_HEIGHT == 1
for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP))
{
-#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
- // Load values from matrix A (interleaved) and matrix B (transposed)
- uchar16 a0 = vload16(0, src_addr_a);
- uchar4 b0 = vload4(0, src_addr_b);
- uchar4 b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP);
- uchar4 b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP);
- uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP);
-
- // Accumulate
- c00 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
- c01 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
- c02 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
- c03 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
-
- c10 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
- c11 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
- c12 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
- c13 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
-
- c20 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
- c21 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
- c22 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
- c23 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
-
- c30 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
- c31 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
- c32 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
- c33 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
-
- // Load values from matrix A (interleaved) and matrix B (transposed)
- a0 = vload16(0, src_addr_a + 16);
- b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP);
- b1 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP);
- b2 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP);
- b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP);
-
- // Accumulate
- c00 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
- c01 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
- c02 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
- c03 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
-
- c10 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
- c11 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
- c12 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
- c13 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
-
- c20 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
- c21 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
- c22 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
- c23 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
-
- c30 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0));
- c31 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1));
- c32 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2));
- c33 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3));
-#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Load values from matrix A (interleaved) and matrix B (transposed)
uchar16 a0 = vload16(0, src_addr_a);
uchar4 b0 = vload4(0, src_addr_b);
@@ -432,7 +375,6 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0)
c31 += (ushort)a0.sF * b0.s1;
c32 += (ushort)a0.sF * b0.s2;
c33 += (ushort)a0.sF * b0.s3;
-#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
@@ -472,6 +414,173 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0)
vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(offset(&dst, 0, 2)));
vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(offset(&dst, 0, 3)));
}
+
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+/** This OpenCL kernel is optimized for Bifrost and computes the matrix multiplication between matrix A (src0) and matrix B (src1)
+ * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication
+ *
+ * @attention The number of matrix B columns needs to be passed at compile time using -DCOLS_B
+ * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2)
+ * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ *
+ * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
+ * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
+ * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
+ * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
+ * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ */
+__kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION(src0),
+ IMAGE_DECLARATION(src1),
+ IMAGE_DECLARATION(dst))
+{
+ int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP;
+ int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
+
+ // Offset
+ const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
+ const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4;
+
+ // src_addr_a = address of matrix A
+ // src_addr_b = address of matrix B
+ __global uchar *src_addr_a = (__global uchar *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
+ __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
+
+ // Compute end row address for matrix B
+ __global uchar *src_end_addr_b = src_addr_b + COLS_B;
+
+ src_addr_a += offset_row_a;
+ src_addr_b += offset_row_b;
+
+ // Reset accumulators
+ uint c00 = 0;
+ uint c01 = 0;
+ uint c02 = 0;
+ uint c03 = 0;
+ uint c10 = 0;
+ uint c11 = 0;
+ uint c12 = 0;
+ uint c13 = 0;
+ uint c20 = 0;
+ uint c21 = 0;
+ uint c22 = 0;
+ uint c23 = 0;
+ uint c30 = 0;
+ uint c31 = 0;
+ uint c32 = 0;
+ uint c33 = 0;
+
+#if MULT_INTERLEAVE4X4_HEIGHT == 1
+ for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP))
+ {
+ // Load values from matrix A (interleaved) and matrix B (transposed)
+ uchar16 a0 = vload16(0, src_addr_a);
+ uchar4 b0 = vload4(0, src_addr_b);
+ uchar4 b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP);
+ uchar4 b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP);
+ uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP);
+
+ // Accumulate
+ c00 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00);
+ c01 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01);
+ c02 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02);
+ c03 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03);
+
+ c10 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10);
+ c11 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11);
+ c12 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12);
+ c13 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13);
+
+ c20 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20);
+ c21 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21);
+ c22 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22);
+ c23 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23);
+
+ c30 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30);
+ c31 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31);
+ c32 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32);
+ c33 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33);
+
+ // Load values from matrix A (interleaved) and matrix B (transposed)
+ a0 = vload16(0, src_addr_a + 16);
+ b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP);
+ b1 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP);
+ b2 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP);
+ b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP);
+
+ // Accumulate
+ c00 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00);
+ c01 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01);
+ c02 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02);
+ c03 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03);
+
+ c10 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10);
+ c11 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11);
+ c12 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12);
+ c13 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13);
+
+ c20 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20);
+ c21 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21);
+ c22 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22);
+ c23 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23);
+
+ c30 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30);
+ c31 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31);
+ c32 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32);
+ c33 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33);
+ }
+#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
+
+ for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP))
+ {
+ // Load values from matrix A (interleaved) and matrix B (transposed)
+ uchar4 a0 = vload4(0, src_addr_a);
+ uchar4 b0 = vload4(0, src_addr_b);
+
+ c00 += (ushort)a0.s0 * b0.s0;
+ c01 += (ushort)a0.s0 * b0.s1;
+ c02 += (ushort)a0.s0 * b0.s2;
+ c03 += (ushort)a0.s0 * b0.s3;
+
+ c10 += (ushort)a0.s1 * b0.s0;
+ c11 += (ushort)a0.s1 * b0.s1;
+ c12 += (ushort)a0.s1 * b0.s2;
+ c13 += (ushort)a0.s1 * b0.s3;
+
+ c20 += (ushort)a0.s2 * b0.s0;
+ c21 += (ushort)a0.s2 * b0.s1;
+ c22 += (ushort)a0.s2 * b0.s2;
+ c23 += (ushort)a0.s2 * b0.s3;
+
+ c30 += (ushort)a0.s3 * b0.s0;
+ c31 += (ushort)a0.s3 * b0.s1;
+ c32 += (ushort)a0.s3 * b0.s2;
+ c33 += (ushort)a0.s3 * b0.s3;
+ }
+
+ // Compute destination address
+ Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+
+ // Store 4x4 block
+ vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(offset(&dst, 0, 0)));
+ vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(offset(&dst, 0, 1)));
+ vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(offset(&dst, 0, 2)));
+ vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(offset(&dst, 0, 3)));
+}
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
+
#endif // defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
#if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
@@ -724,13 +833,6 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
{
-#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
- // Accumulate
- acc00 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a0);
- acc01 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a0);
- acc02 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a0);
- acc03 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a0);
-#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Accumulate
ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0;
ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0;
@@ -756,17 +858,9 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
-#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
{
-#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
- // Accumulate
- acc10 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a1);
- acc11 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a1);
- acc12 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a1);
- acc13 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a1);
-#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Accumulate
ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0;
ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0;
@@ -792,18 +886,10 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
-#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
{
-#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
- // Accumulate
- acc20 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a2);
- acc21 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a2);
- acc22 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a2);
- acc23 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a2);
-#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Accumulate
ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0;
ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0;
@@ -829,18 +915,10 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
-#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
{
-#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
- // Accumulate
- acc30 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a3);
- acc31 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a3);
- acc32 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a3);
- acc33 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a3);
-#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Accumulate
ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0;
ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0;
@@ -866,18 +944,10 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
-#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
{
-#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
- // Accumulate
- acc40 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a4);
- acc41 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a4);
- acc42 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a4);
- acc43 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a4);
-#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED
// Accumulate
ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0;
ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0;
@@ -903,7 +973,6 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
-#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
}
@@ -1016,6 +1085,254 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(offset(&dst, 0, 4)));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
}
+
+#if ARM_COMPUTE_OPENCL_DOT8_ENABLED
+/** OpenCL kernel optimized to use dot product that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
+ *
+ * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
+ *
+ * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
+ * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
+ * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
+ * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
+ * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ */
+__kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0),
+ IMAGE_DECLARATION(src1),
+ IMAGE_DECLARATION(dst))
+{
+ int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
+
+ // Compute starting address for matrix A and Matrix B
+ int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
+
+ // Update address for the matrix A
+ src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
+
+ // Update address for the matrix B
+ src_addr.s1 += idx;
+
+ int end_row_vec_a = src_addr.s0 + COLS_A;
+
+ uint acc00 = 0;
+ uint acc01 = 0;
+ uint acc02 = 0;
+ uint acc03 = 0;
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ uint acc10 = 0;
+ uint acc11 = 0;
+ uint acc12 = 0;
+ uint acc13 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ uint acc20 = 0;
+ uint acc21 = 0;
+ uint acc22 = 0;
+ uint acc23 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ uint acc30 = 0;
+ uint acc31 = 0;
+ uint acc32 = 0;
+ uint acc33 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uint acc40 = 0;
+ uint acc41 = 0;
+ uint acc42 = 0;
+ uint acc43 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+
+ for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y))
+ {
+ // Load values from matrix A
+ uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ // Load values from matrix B
+ uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
+ uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
+ uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
+ uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
+
+ {
+ // Accumulate
+ acc00 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a0, acc00);
+ acc01 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a0, acc01);
+ acc02 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a0, acc02);
+ acc03 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a0, acc03);
+ }
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ {
+ // Accumulate
+ acc10 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a1, acc10);
+ acc11 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a1, acc11);
+ acc12 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a1, acc12);
+ acc13 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a1, acc13);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ {
+ // Accumulate
+ acc20 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a2, acc20);
+ acc21 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a2, acc21);
+ acc22 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a2, acc22);
+ acc23 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a2, acc23);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ {
+ // Accumulate
+ acc30 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a3, acc30);
+ acc31 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a3, acc31);
+ acc32 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a3, acc32);
+ acc33 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a3, acc33);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ {
+ // Accumulate
+ acc40 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a4, acc40);
+ acc41 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a4, acc41);
+ acc42 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a4, acc42);
+ acc43 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a4, acc43);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ }
+
+ for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
+ {
+ // Load values from matrix A
+ uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ // Load values from matrix B
+ uchar4 b0 = vload4(0, src1_ptr + src_addr.s1);
+
+ // Accumulate
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a0;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a0;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a0;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a0;
+
+ acc00 += ((uint)tmp0);
+ acc01 += ((uint)tmp1);
+ acc02 += ((uint)tmp2);
+ acc03 += ((uint)tmp3);
+ }
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a1;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a1;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a1;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a1;
+
+ acc10 += ((uint)tmp0);
+ acc11 += ((uint)tmp1);
+ acc12 += ((uint)tmp2);
+ acc13 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a2;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a2;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a2;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a2;
+
+ acc20 += ((uint)tmp0);
+ acc21 += ((uint)tmp1);
+ acc22 += ((uint)tmp2);
+ acc23 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a3;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a3;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a3;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a3;
+
+ acc30 += ((uint)tmp0);
+ acc31 += ((uint)tmp1);
+ acc32 += ((uint)tmp2);
+ acc33 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a4;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a4;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a4;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a4;
+
+ acc40 += ((uint)tmp0);
+ acc41 += ((uint)tmp1);
+ acc42 += ((uint)tmp2);
+ acc43 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ }
+
+ // Compute destination address
+ Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+
+ // Store the result
+ vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(offset(&dst, 0, 0)));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(offset(&dst, 0, 1)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(offset(&dst, 0, 2)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(offset(&dst, 0, 3)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(offset(&dst, 0, 4)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+}
+#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED
+
#endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
#if defined(COLS_A)