From 6200fa405b16b4145b926a96de197718ad31bf93 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Fri, 6 Jul 2018 17:06:36 +0100 Subject: COMPMID-1288 Optimizing CLGEMMLowp using 8 bit dot product instruction Change-Id: I536174b9381660a94578d6aa1892a6289a820391 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139109 Reviewed-by: Georgios Pinitas Tested-by: Jenkins --- src/core/CL/cl_kernels/gemmlowp.cl | 513 ++++++++++++++++++++++++++++++------- 1 file changed, 415 insertions(+), 98 deletions(-) (limited to 'src/core/CL/cl_kernels/gemmlowp.cl') diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl index da915778e7..0ee7c27350 100644 --- a/src/core/CL/cl_kernels/gemmlowp.cl +++ b/src/core/CL/cl_kernels/gemmlowp.cl @@ -190,63 +190,6 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0) #if MULT_INTERLEAVE4X4_HEIGHT == 1 for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP)) { -#if ARM_COMPUTE_OPENCL_DOT8_ENABLED - // Load values from matrix A (interleaved) and matrix B (transposed) - uchar16 a0 = vload16(0, src_addr_a); - uchar4 b0 = vload4(0, src_addr_b); - uchar4 b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP); - uchar4 b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP); - uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP); - - // Accumulate - c00 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); - c01 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); - c02 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); - c03 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); - - c10 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); - c11 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); - c12 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); - c13 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); - - c20 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); - c21 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); - c22 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); - c23 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); - - c30 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); - c31 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); - c32 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); - c33 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); - - // Load values from matrix A (interleaved) and matrix B (transposed) - a0 = vload16(0, src_addr_a + 16); - b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP); - b1 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP); - b2 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP); - b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP); - - // Accumulate - c00 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); - c01 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); - c02 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); - c03 += arm_dot((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); - - c10 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); - c11 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); - c12 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); - c13 += arm_dot((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); - - c20 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); - c21 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); - c22 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); - c23 += arm_dot((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); - - c30 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0)); - c31 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1)); - c32 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2)); - c33 += arm_dot((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3)); -#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Load values from matrix A (interleaved) and matrix B (transposed) uchar16 a0 = vload16(0, src_addr_a); uchar4 b0 = vload4(0, src_addr_b); @@ -432,7 +375,6 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0) c31 += (ushort)a0.sF * b0.s1; c32 += (ushort)a0.sF * b0.s2; c33 += (ushort)a0.sF * b0.s3; -#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #endif // MULT_INTERLEAVE4X4_HEIGHT == 1 @@ -472,6 +414,173 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost(IMAGE_DECLARATION(src0) vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(offset(&dst, 0, 2))); vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(offset(&dst, 0, 3))); } + +#if ARM_COMPUTE_OPENCL_DOT8_ENABLED +/** This OpenCL kernel is optimized for Bifrost and computes the matrix multiplication between matrix A (src0) and matrix B (src1) + * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication + * + * @attention The number of matrix B columns needs to be passed at compile time using -DCOLS_B + * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2) + * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2) + * + * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8 + * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) + * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) + * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix + * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr + * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) + * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) + * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix + * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32 + * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) + * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) + * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix + */ +__kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION(src0), + IMAGE_DECLARATION(src1), + IMAGE_DECLARATION(dst)) +{ + int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP; + int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT; + + // Offset + const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4; + const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4; + + // src_addr_a = address of matrix A + // src_addr_b = address of matrix B + __global uchar *src_addr_a = (__global uchar *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes); + __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes); + + // Compute end row address for matrix B + __global uchar *src_end_addr_b = src_addr_b + COLS_B; + + src_addr_a += offset_row_a; + src_addr_b += offset_row_b; + + // Reset accumulators + uint c00 = 0; + uint c01 = 0; + uint c02 = 0; + uint c03 = 0; + uint c10 = 0; + uint c11 = 0; + uint c12 = 0; + uint c13 = 0; + uint c20 = 0; + uint c21 = 0; + uint c22 = 0; + uint c23 = 0; + uint c30 = 0; + uint c31 = 0; + uint c32 = 0; + uint c33 = 0; + +#if MULT_INTERLEAVE4X4_HEIGHT == 1 + for(; src_addr_b <= (src_end_addr_b - (int)(32 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += (32 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (32 * TRANSPOSE1XW_WIDTH_STEP)) + { + // Load values from matrix A (interleaved) and matrix B (transposed) + uchar16 a0 = vload16(0, src_addr_a); + uchar4 b0 = vload4(0, src_addr_b); + uchar4 b1 = vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP); + uchar4 b2 = vload4(0, src_addr_b + 8 * TRANSPOSE1XW_WIDTH_STEP); + uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP); + + // Accumulate + c00 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00); + c01 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01); + c02 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02); + c03 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03); + + c10 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10); + c11 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11); + c12 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12); + c13 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13); + + c20 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20); + c21 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21); + c22 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22); + c23 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23); + + c30 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30); + c31 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31); + c32 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32); + c33 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33); + + // Load values from matrix A (interleaved) and matrix B (transposed) + a0 = vload16(0, src_addr_a + 16); + b0 = vload4(0, src_addr_b + 16 * TRANSPOSE1XW_WIDTH_STEP); + b1 = vload4(0, src_addr_b + 20 * TRANSPOSE1XW_WIDTH_STEP); + b2 = vload4(0, src_addr_b + 24 * TRANSPOSE1XW_WIDTH_STEP); + b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP); + + // Accumulate + c00 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00); + c01 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01); + c02 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02); + c03 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03); + + c10 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10); + c11 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11); + c12 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12); + c13 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13); + + c20 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20); + c21 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21); + c22 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22); + c23 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23); + + c30 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30); + c31 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31); + c32 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32); + c33 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33); + } +#endif // MULT_INTERLEAVE4X4_HEIGHT == 1 + + for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP)) + { + // Load values from matrix A (interleaved) and matrix B (transposed) + uchar4 a0 = vload4(0, src_addr_a); + uchar4 b0 = vload4(0, src_addr_b); + + c00 += (ushort)a0.s0 * b0.s0; + c01 += (ushort)a0.s0 * b0.s1; + c02 += (ushort)a0.s0 * b0.s2; + c03 += (ushort)a0.s0 * b0.s3; + + c10 += (ushort)a0.s1 * b0.s0; + c11 += (ushort)a0.s1 * b0.s1; + c12 += (ushort)a0.s1 * b0.s2; + c13 += (ushort)a0.s1 * b0.s3; + + c20 += (ushort)a0.s2 * b0.s0; + c21 += (ushort)a0.s2 * b0.s1; + c22 += (ushort)a0.s2 * b0.s2; + c23 += (ushort)a0.s2 * b0.s3; + + c30 += (ushort)a0.s3 * b0.s0; + c31 += (ushort)a0.s3 * b0.s1; + c32 += (ushort)a0.s3 * b0.s2; + c33 += (ushort)a0.s3 * b0.s3; + } + + // Compute destination address + Image dst = CONVERT_TO_IMAGE_STRUCT(dst); + + // Store 4x4 block + vstore4((int4)(c00, c01, c02, c03), 0, (__global int *)(offset(&dst, 0, 0))); + vstore4((int4)(c10, c11, c12, c13), 0, (__global int *)(offset(&dst, 0, 1))); + vstore4((int4)(c20, c21, c22, c23), 0, (__global int *)(offset(&dst, 0, 2))); + vstore4((int4)(c30, c31, c32, c33), 0, (__global int *)(offset(&dst, 0, 3))); +} +#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED + #endif // defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP) #if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A) @@ -724,13 +833,6 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); { -#if ARM_COMPUTE_OPENCL_DOT8_ENABLED - // Accumulate - acc00 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a0); - acc01 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a0); - acc02 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a0); - acc03 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a0); -#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Accumulate ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0; ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0; @@ -756,17 +858,9 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); -#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 { -#if ARM_COMPUTE_OPENCL_DOT8_ENABLED - // Accumulate - acc10 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a1); - acc11 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a1); - acc12 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a1); - acc13 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a1); -#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Accumulate ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0; ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0; @@ -792,18 +886,10 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); -#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 { -#if ARM_COMPUTE_OPENCL_DOT8_ENABLED - // Accumulate - acc20 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a2); - acc21 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a2); - acc22 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a2); - acc23 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a2); -#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Accumulate ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0; ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0; @@ -829,18 +915,10 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); -#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 { -#if ARM_COMPUTE_OPENCL_DOT8_ENABLED - // Accumulate - acc30 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a3); - acc31 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a3); - acc32 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a3); - acc33 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a3); -#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Accumulate ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0; ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0; @@ -866,18 +944,10 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); -#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 { -#if ARM_COMPUTE_OPENCL_DOT8_ENABLED - // Accumulate - acc40 += arm_dot((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a4); - acc41 += arm_dot((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a4); - acc42 += arm_dot((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a4); - acc43 += arm_dot((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a4); -#else // ARM_COMPUTE_OPENCL_DOT8_ENABLED // Accumulate ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0; ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0; @@ -903,7 +973,6 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); -#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 } @@ -1016,6 +1085,254 @@ __kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(offset(&dst, 0, 4))); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 } + +#if ARM_COMPUTE_OPENCL_DOT8_ENABLED +/** OpenCL kernel optimized to use dot product that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped + * + * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A + * + * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8 + * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) + * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) + * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix + * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr + * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) + * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) + * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix + * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32 + * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) + * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) + * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix + */ +__kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0), + IMAGE_DECLARATION(src1), + IMAGE_DECLARATION(dst)) +{ + int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X; + + // Compute starting address for matrix A and Matrix B + int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); + + // Update address for the matrix A + src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y; + + // Update address for the matrix B + src_addr.s1 += idx; + + int end_row_vec_a = src_addr.s0 + COLS_A; + + uint acc00 = 0; + uint acc01 = 0; + uint acc02 = 0; + uint acc03 = 0; +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + uint acc10 = 0; + uint acc11 = 0; + uint acc12 = 0; + uint acc13 = 0; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + uint acc20 = 0; + uint acc21 = 0; + uint acc22 = 0; + uint acc23 = 0; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + uint acc30 = 0; + uint acc31 = 0; + uint acc32 = 0; + uint acc33 = 0; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + uint acc40 = 0; + uint acc41 = 0; + uint acc42 = 0; + uint acc43 = 0; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + + for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y)) + { + // Load values from matrix A + uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + // Load values from matrix B + uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y); + uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y); + uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y); + uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); + + { + // Accumulate + acc00 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a0, acc00); + acc01 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a0, acc01); + acc02 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a0, acc02); + acc03 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a0, acc03); + } +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + { + // Accumulate + acc10 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a1, acc10); + acc11 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a1, acc11); + acc12 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a1, acc12); + acc13 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a1, acc13); + } +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + { + // Accumulate + acc20 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a2, acc20); + acc21 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a2, acc21); + acc22 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a2, acc22); + acc23 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a2, acc23); + } +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + { + // Accumulate + acc30 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a3, acc30); + acc31 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a3, acc31); + acc32 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a3, acc32); + acc33 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a3, acc33); + } +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + { + // Accumulate + acc40 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a4, acc40); + acc41 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a4, acc41); + acc42 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a4, acc42); + acc43 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a4, acc43); + } +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + } + + for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y)) + { + // Load values from matrix A + uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + // Load values from matrix B + uchar4 b0 = vload4(0, src1_ptr + src_addr.s1); + + // Accumulate + { + // Accumulate + ushort tmp0 = (ushort)b0.s0 * (ushort)a0; + ushort tmp1 = (ushort)b0.s1 * (ushort)a0; + ushort tmp2 = (ushort)b0.s2 * (ushort)a0; + ushort tmp3 = (ushort)b0.s3 * (ushort)a0; + + acc00 += ((uint)tmp0); + acc01 += ((uint)tmp1); + acc02 += ((uint)tmp2); + acc03 += ((uint)tmp3); + } +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + { + // Accumulate + ushort tmp0 = (ushort)b0.s0 * (ushort)a1; + ushort tmp1 = (ushort)b0.s1 * (ushort)a1; + ushort tmp2 = (ushort)b0.s2 * (ushort)a1; + ushort tmp3 = (ushort)b0.s3 * (ushort)a1; + + acc10 += ((uint)tmp0); + acc11 += ((uint)tmp1); + acc12 += ((uint)tmp2); + acc13 += ((uint)tmp3); + } +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + { + // Accumulate + ushort tmp0 = (ushort)b0.s0 * (ushort)a2; + ushort tmp1 = (ushort)b0.s1 * (ushort)a2; + ushort tmp2 = (ushort)b0.s2 * (ushort)a2; + ushort tmp3 = (ushort)b0.s3 * (ushort)a2; + + acc20 += ((uint)tmp0); + acc21 += ((uint)tmp1); + acc22 += ((uint)tmp2); + acc23 += ((uint)tmp3); + } +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + { + // Accumulate + ushort tmp0 = (ushort)b0.s0 * (ushort)a3; + ushort tmp1 = (ushort)b0.s1 * (ushort)a3; + ushort tmp2 = (ushort)b0.s2 * (ushort)a3; + ushort tmp3 = (ushort)b0.s3 * (ushort)a3; + + acc30 += ((uint)tmp0); + acc31 += ((uint)tmp1); + acc32 += ((uint)tmp2); + acc33 += ((uint)tmp3); + } +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + { + // Accumulate + ushort tmp0 = (ushort)b0.s0 * (ushort)a4; + ushort tmp1 = (ushort)b0.s1 * (ushort)a4; + ushort tmp2 = (ushort)b0.s2 * (ushort)a4; + ushort tmp3 = (ushort)b0.s3 * (ushort)a4; + + acc40 += ((uint)tmp0); + acc41 += ((uint)tmp1); + acc42 += ((uint)tmp2); + acc43 += ((uint)tmp3); + } +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + } + + // Compute destination address + Image dst = CONVERT_TO_IMAGE_STRUCT(dst); + + // Store the result + vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(offset(&dst, 0, 0))); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(offset(&dst, 0, 1))); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(offset(&dst, 0, 2))); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(offset(&dst, 0, 3))); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 + vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(offset(&dst, 0, 4))); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 +} +#endif // ARM_COMPUTE_OPENCL_DOT8_ENABLED + #endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A) #if defined(COLS_A) -- cgit v1.2.1