From fd683111bba15288dc88b7f53486f935ebeccde0 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 17 Apr 2018 09:52:44 +0100 Subject: COMPMID-922 - CLGEMM FP16 optimizations - part1 This patch improves of ~20% GEMM fp16. The results has been reported at the following confluence page: https://confluence.arm.com/display/MLENG/GEMM+FP32+performance%3A+ACL+18.05 I am aware with few cases we have a bit of degradation. However this cases are memory bound anyway (Fully connected layer cases) Change-Id: I183cbb7fba55a0b5eb86532c4dca5efe096096b0 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128044 Reviewed-by: Georgios Pinitas Tested-by: Jenkins Reviewed-by: Anthony Barbier --- src/core/CL/cl_kernels/gemm.cl | 205 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) (limited to 'src/core/CL/cl_kernels/gemm.cl') diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index 00b130f5a9..4b1672ce7b 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -1531,6 +1531,211 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 } +/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped + * + * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units. + * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y. + * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4. + * @note The number of matrix A columns must be passed at compile time using -DCOLS_A. + * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * + * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 + * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) + * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) + * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix + * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr + * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) + * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) + * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix + * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr + * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) + * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) + * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix + */ +__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), + IMAGE_DECLARATION(src1), + IMAGE_DECLARATION(dst), + uint src0_stride_z, + uint src1_stride_z, + uint dst_stride_z) +{ + int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X; + + // Compute starting address for matrix A and Matrix B + int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); + + // Update address for the matrix A + src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y; + + // Update address for the matrix B + src_addr.s1 += idx * sizeof(half); + + // Add offset for batched GEMM + src_addr.s0 += get_global_id(2) * src0_stride_z; + +#if defined(MATRIX_B_DEPTH) + // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 + src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; +#else // defined(MATRIX_B_DEPTH) + src_addr.s1 += get_global_id(2) * src1_stride_z; +#endif // defined(MATRIX_B_DEPTH) + + half8 acc0 = 0.0h; +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + half8 acc1 = 0.0h; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + half8 acc2 = 0.0h; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + half8 acc3 = 0.0h; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + int i = 0; + for(; i <= ((int)COLS_A - 4); i += 4) + { + // Load values from matrix A + half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + // Load values from matrix B + half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + + // Accumulate + acc0 = fma(b0, (half8)a0.s0, acc0); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc1 = fma(b0, (half8)a1.s0, acc1); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc2 = fma(b0, (half8)a2.s0, acc2); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc3 = fma(b0, (half8)a3.s0, acc3); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + acc0 = fma(b0, (half8)a0.s1, acc0); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc1 = fma(b0, (half8)a1.s1, acc1); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc2 = fma(b0, (half8)a2.s1, acc2); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc3 = fma(b0, (half8)a3.s1, acc3); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + acc0 = fma(b0, (half8)a0.s2, acc0); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc1 = fma(b0, (half8)a1.s2, acc1); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc2 = fma(b0, (half8)a2.s2, acc2); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc3 = fma(b0, (half8)a3.s2, acc3); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); + src_addr.s1 += src1_stride_y; + acc0 = fma(b0, (half8)a0.s3, acc0); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc1 = fma(b0, (half8)a1.s3, acc1); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc2 = fma(b0, (half8)a2.s3, acc2); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc3 = fma(b0, (half8)a3.s3, acc3); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + + src_addr.s0 += 4 * sizeof(half); + } + + for(; i < (int)COLS_A; ++i) + { + // Load values from matrix A + half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + // Load values from matrix B + half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); + + src_addr += (int2)(sizeof(half), src1_stride_y); + + // Accumulate + acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0; +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3; +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + } + + // Compute destination address + Image dst = CONVERT_TO_IMAGE_STRUCT(dst); + + // Compute dst address + __global uchar *dst_addr = offset(&dst, 0, 0); + + // Add offset for batched GEMM + dst_addr += get_global_id(2) * dst_stride_z; + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + acc0 = acc0 * (half8)ALPHA; +#endif // defined(ALPHA) + vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y)); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if defined(ALPHA) + acc1 = acc1 * (half8)ALPHA; +#endif // defined(ALPHA) + vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if defined(ALPHA) + acc2 = acc2 * (half8)ALPHA; +#endif // defined(ALPHA) + vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +#if defined(ALPHA) + acc3 = acc3 * (half8)ALPHA; +#endif // defined(ALPHA) + vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y)); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 +} + #if defined(FIXED_POINT_POSITION) /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped * -- cgit v1.2.1