From 06be6f8d2a316a307fa623150f8adf8f9c3416c5 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 24 Jun 2019 17:47:51 +0100 Subject: COMPMID-2096: Refactor the CLGEMMLowp function selection (heuristic) Change-Id: I15a8b39e0354d3b6686ed4cc8c361782c0512037 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1410 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: VidhyaSudhan Loganathan --- src/core/CL/cl_kernels/gemmlowp.cl | 1040 ------------------------------------ 1 file changed, 1040 deletions(-) (limited to 'src/core/CL/cl_kernels/gemmlowp.cl') diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl index d6494fe380..fc90dbd16c 100644 --- a/src/core/CL/cl_kernels/gemmlowp.cl +++ b/src/core/CL/cl_kernels/gemmlowp.cl @@ -193,168 +193,6 @@ (n0, k0, a, b, c); \ }) -#if defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP) -/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) - * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel before running the matrix multiplication - * - * @note The number of matrix B columns needs to be passed at compile time using -DCOLS_B: e.g. -DCOLS_B=1024 - * @note The transposition width step (mult_transpose1xW_width * 4) must be passed at compile time using -DTRANSPOSE1XW_WIDTH_STEP (i.e. -DTRANSPOSE1XW_WIDTH_STEP=2) - * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2) - * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: - * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D - * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. - * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor - * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped - * - * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8 - * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) - * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) - * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr - * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) - * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) - * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32 - * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) - * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) - * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix - * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) - * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) - * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) - */ -__kernel void gemmlowp_mm_interleaved_transposed_midgard(IMAGE_DECLARATION(src0), - IMAGE_DECLARATION(src1), - IMAGE_DECLARATION(dst), - uint src0_stride_z, - uint src1_stride_z, - uint dst_stride_z -#if defined(REINTERPRET_OUTPUT_AS_3D) - , - uint cross_plane_pad -#endif // REINTERPRET_OUTPUT_AS_3D - ) -{ - const int x = get_global_id(0) / TRANSPOSE1XW_WIDTH_STEP; - const int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT; - const int z = get_global_id(2); - - // Offset - const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4; - const int offset_row_b = (get_global_id(0) % TRANSPOSE1XW_WIDTH_STEP) * 4; - - // src_addr_a = address of matrix A - // src_addr_b = address of matrix B - __global uchar *src_addr_a = (__global uchar *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes); - __global uchar *src_addr_b = (__global uchar *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes); - -#if defined(MATRIX_B_DEPTH) - // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 - src_addr_b += (z % MATRIX_B_DEPTH) * src1_stride_z; -#else // defined(MATRIX_B_DEPTH) - src_addr_b += z * src1_stride_z; -#endif // defined(MATRIX_B_DEPTH) - - // Compute end row address for matrix B - __global uchar *src_end_addr_b = src_addr_b + COLS_B; - - src_addr_a += offset_row_a; - src_addr_b += offset_row_b; - - // Reset accumulators - int4 c00 = 0; - int4 c10 = 0; - int4 c20 = 0; - int4 c30 = 0; - - for(; src_addr_b <= (src_end_addr_b - (int)(8 * TRANSPOSE1XW_WIDTH_STEP)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * TRANSPOSE1XW_WIDTH_STEP) - { - // Load values from matrix A (interleaved) and matrix B (transposed) - int4 a0 = convert_int4(vload4(0, src_addr_a)); - int4 b0 = convert_int4(vload4(0, src_addr_b)); - - c00 += (int4)a0.s0 * b0; - c10 += (int4)a0.s1 * b0; - c20 += (int4)a0.s2 * b0; - c30 += (int4)a0.s3 * b0; - - a0 = convert_int4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT)); - b0 = convert_int4(vload4(0, src_addr_b + 4 * TRANSPOSE1XW_WIDTH_STEP)); - - c00 += (int4)a0.s0 * b0; - c10 += (int4)a0.s1 * b0; - c20 += (int4)a0.s2 * b0; - c30 += (int4)a0.s3 * b0; - } - - for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * TRANSPOSE1XW_WIDTH_STEP)) - { - // Load values from matrix A (interleaved) and matrix B (transposed) - int4 a0 = convert_int4(vload4(0, src_addr_a)); - int4 b0 = convert_int4(vload4(0, src_addr_b)); - - c00 += (int4)a0.s0 * b0; - c10 += (int4)a0.s1 * b0; - c20 += (int4)a0.s2 * b0; - c30 += (int4)a0.s3 * b0; - } - - // Compute destination address - Image dst = CONVERT_TO_IMAGE_STRUCT(dst); - -#if defined(REINTERPRET_OUTPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - - // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); - - // Add offset due to the cross plane paddings - zout *= (cross_plane_pad * dst_stride_y); - - // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we - // multiply dst_stride_z by DEPTH_GEMM3D - dst.ptr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store 4x4 block - vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0)); - vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1)); - vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2)); - vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3)); - -#else // defined(REINTERPRET_OUTPUT_AS_3D) - // Add offset for batched GEMM - dst.ptr += z * dst_stride_z; - - // Store 4x4 block - vstore4(c00, 0, (__global int *)(dst.ptr + 0 * dst_stride_y)); - vstore4(c10, 0, (__global int *)(dst.ptr + 1 * dst_stride_y)); - vstore4(c20, 0, (__global int *)(dst.ptr + 2 * dst_stride_y)); - vstore4(c30, 0, (__global int *)(dst.ptr + 3 * dst_stride_y)); -#endif // defined(REINTERPRET_OUTPUT_AS_3D) -} -#endif // defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP) - #if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A) #define VECTOR_UCHAR VEC_DATA_TYPE(uchar, NUM_ELEMS_PROCESSED_PER_THREAD_X) #define VECTOR_UINT VEC_DATA_TYPE(uint, NUM_ELEMS_PROCESSED_PER_THREAD_X) @@ -631,884 +469,6 @@ __kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 #endif // defined(REINTERPRET_OUTPUT_AS_3D) } - -/** OpenCL kernel optimized for Bifrost architectures that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped - * - * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A - * - * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: - * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D - * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D - * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. - * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor - * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped - * - * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8 - * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) - * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) - * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr - * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) - * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) - * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32 - * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) - * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) - * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix - * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) - * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) - * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) - * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D) - */ -__kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0), - IMAGE_DECLARATION(src1), - IMAGE_DECLARATION(dst), - uint src0_stride_z, - uint src1_stride_z, - uint dst_stride_z -#if defined(REINTERPRET_INPUT_AS_3D) - , - uint src_cross_plane_pad -#endif // REINTERPRET_INPUT_AS_3D -#if defined(REINTERPRET_OUTPUT_AS_3D) - , - uint dst_cross_plane_pad -#endif // REINTERPRET_OUTPUT_AS_3D - ) -{ - int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X; - - // Compute starting address for matrix A and Matrix B - int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); - - // Update address for the matrix A - src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y; - - // Update address for the matrix B - src_addr.s1 += idx; - -#if defined(REINTERPRET_INPUT_AS_3D) - // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - - // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; - zin = min(DEPTH_GEMM3D - 1, zin); - - // Add offset due to the cross plane paddings - zin *= (src_cross_plane_pad * src0_stride_y); - - // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we - // multiply src0_stride_z by DEPTH_GEMM3D - src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; - -#else // defined(REINTERPRET_INPUT_AS_3D) - - // Add offset for batched GEMM - src_addr.s0 += get_global_id(2) * src0_stride_z; - -#endif // defined(REINTERPRET_INPUT_AS_3D) - -#if defined(MATRIX_B_DEPTH) - // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 - src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; -#else // defined(MATRIX_B_DEPTH) - src_addr.s1 += get_global_id(2) * src1_stride_z; -#endif // defined(MATRIX_B_DEPTH) - - int end_row_vec_a = src_addr.s0 + COLS_A; - - uint acc00 = 0; - uint acc01 = 0; - uint acc02 = 0; - uint acc03 = 0; -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uint acc10 = 0; - uint acc11 = 0; - uint acc12 = 0; - uint acc13 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uint acc20 = 0; - uint acc21 = 0; - uint acc22 = 0; - uint acc23 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uint acc30 = 0; - uint acc31 = 0; - uint acc32 = 0; - uint acc33 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - uint acc40 = 0; - uint acc41 = 0; - uint acc42 = 0; - uint acc43 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - - for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y)) - { - // Load values from matrix A - uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - // Load values from matrix B - uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y); - uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y); - uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y); - uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); - - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0; - ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0; - ushort tmp2 = (ushort)b0.s2 * (ushort)a0.s0; - ushort tmp3 = (ushort)b0.s3 * (ushort)a0.s0; - - ushort tmp4 = (ushort)b1.s0 * (ushort)a0.s1; - ushort tmp5 = (ushort)b1.s1 * (ushort)a0.s1; - ushort tmp6 = (ushort)b1.s2 * (ushort)a0.s1; - ushort tmp7 = (ushort)b1.s3 * (ushort)a0.s1; - - ushort tmp8 = (ushort)b2.s0 * (ushort)a0.s2; - ushort tmp9 = (ushort)b2.s1 * (ushort)a0.s2; - ushort tmpA = (ushort)b2.s2 * (ushort)a0.s2; - ushort tmpB = (ushort)b2.s3 * (ushort)a0.s2; - - ushort tmpC = (ushort)b3.s0 * (ushort)a0.s3; - ushort tmpD = (ushort)b3.s1 * (ushort)a0.s3; - ushort tmpE = (ushort)b3.s2 * (ushort)a0.s3; - ushort tmpF = (ushort)b3.s3 * (ushort)a0.s3; - - acc00 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC); - acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); - acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); - acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); - } -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0; - ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0; - ushort tmp2 = (ushort)b0.s2 * (ushort)a1.s0; - ushort tmp3 = (ushort)b0.s3 * (ushort)a1.s0; - - ushort tmp4 = (ushort)b1.s0 * (ushort)a1.s1; - ushort tmp5 = (ushort)b1.s1 * (ushort)a1.s1; - ushort tmp6 = (ushort)b1.s2 * (ushort)a1.s1; - ushort tmp7 = (ushort)b1.s3 * (ushort)a1.s1; - - ushort tmp8 = (ushort)b2.s0 * (ushort)a1.s2; - ushort tmp9 = (ushort)b2.s1 * (ushort)a1.s2; - ushort tmpA = (ushort)b2.s2 * (ushort)a1.s2; - ushort tmpB = (ushort)b2.s3 * (ushort)a1.s2; - - ushort tmpC = (ushort)b3.s0 * (ushort)a1.s3; - ushort tmpD = (ushort)b3.s1 * (ushort)a1.s3; - ushort tmpE = (ushort)b3.s2 * (ushort)a1.s3; - ushort tmpF = (ushort)b3.s3 * (ushort)a1.s3; - - acc10 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC); - acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); - acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); - acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); - } -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0; - ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0; - ushort tmp2 = (ushort)b0.s2 * (ushort)a2.s0; - ushort tmp3 = (ushort)b0.s3 * (ushort)a2.s0; - - ushort tmp4 = (ushort)b1.s0 * (ushort)a2.s1; - ushort tmp5 = (ushort)b1.s1 * (ushort)a2.s1; - ushort tmp6 = (ushort)b1.s2 * (ushort)a2.s1; - ushort tmp7 = (ushort)b1.s3 * (ushort)a2.s1; - - ushort tmp8 = (ushort)b2.s0 * (ushort)a2.s2; - ushort tmp9 = (ushort)b2.s1 * (ushort)a2.s2; - ushort tmpA = (ushort)b2.s2 * (ushort)a2.s2; - ushort tmpB = (ushort)b2.s3 * (ushort)a2.s2; - - ushort tmpC = (ushort)b3.s0 * (ushort)a2.s3; - ushort tmpD = (ushort)b3.s1 * (ushort)a2.s3; - ushort tmpE = (ushort)b3.s2 * (ushort)a2.s3; - ushort tmpF = (ushort)b3.s3 * (ushort)a2.s3; - - acc20 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC); - acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); - acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); - acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); - } -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0; - ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0; - ushort tmp2 = (ushort)b0.s2 * (ushort)a3.s0; - ushort tmp3 = (ushort)b0.s3 * (ushort)a3.s0; - - ushort tmp4 = (ushort)b1.s0 * (ushort)a3.s1; - ushort tmp5 = (ushort)b1.s1 * (ushort)a3.s1; - ushort tmp6 = (ushort)b1.s2 * (ushort)a3.s1; - ushort tmp7 = (ushort)b1.s3 * (ushort)a3.s1; - - ushort tmp8 = (ushort)b2.s0 * (ushort)a3.s2; - ushort tmp9 = (ushort)b2.s1 * (ushort)a3.s2; - ushort tmpA = (ushort)b2.s2 * (ushort)a3.s2; - ushort tmpB = (ushort)b2.s3 * (ushort)a3.s2; - - ushort tmpC = (ushort)b3.s0 * (ushort)a3.s3; - ushort tmpD = (ushort)b3.s1 * (ushort)a3.s3; - ushort tmpE = (ushort)b3.s2 * (ushort)a3.s3; - ushort tmpF = (ushort)b3.s3 * (ushort)a3.s3; - - acc30 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC); - acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); - acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); - acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); - } -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0; - ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0; - ushort tmp2 = (ushort)b0.s2 * (ushort)a4.s0; - ushort tmp3 = (ushort)b0.s3 * (ushort)a4.s0; - - ushort tmp4 = (ushort)b1.s0 * (ushort)a4.s1; - ushort tmp5 = (ushort)b1.s1 * (ushort)a4.s1; - ushort tmp6 = (ushort)b1.s2 * (ushort)a4.s1; - ushort tmp7 = (ushort)b1.s3 * (ushort)a4.s1; - - ushort tmp8 = (ushort)b2.s0 * (ushort)a4.s2; - ushort tmp9 = (ushort)b2.s1 * (ushort)a4.s2; - ushort tmpA = (ushort)b2.s2 * (ushort)a4.s2; - ushort tmpB = (ushort)b2.s3 * (ushort)a4.s2; - - ushort tmpC = (ushort)b3.s0 * (ushort)a4.s3; - ushort tmpD = (ushort)b3.s1 * (ushort)a4.s3; - ushort tmpE = (ushort)b3.s2 * (ushort)a4.s3; - ushort tmpF = (ushort)b3.s3 * (ushort)a4.s3; - - acc40 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC); - acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD); - acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE); - acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF); - } -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - } - - for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y)) - { - // Load values from matrix A - uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - // Load values from matrix B - uchar4 b0 = vload4(0, src1_ptr + src_addr.s1); - - // Accumulate - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a0; - ushort tmp1 = (ushort)b0.s1 * (ushort)a0; - ushort tmp2 = (ushort)b0.s2 * (ushort)a0; - ushort tmp3 = (ushort)b0.s3 * (ushort)a0; - - acc00 += ((uint)tmp0); - acc01 += ((uint)tmp1); - acc02 += ((uint)tmp2); - acc03 += ((uint)tmp3); - } -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a1; - ushort tmp1 = (ushort)b0.s1 * (ushort)a1; - ushort tmp2 = (ushort)b0.s2 * (ushort)a1; - ushort tmp3 = (ushort)b0.s3 * (ushort)a1; - - acc10 += ((uint)tmp0); - acc11 += ((uint)tmp1); - acc12 += ((uint)tmp2); - acc13 += ((uint)tmp3); - } -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a2; - ushort tmp1 = (ushort)b0.s1 * (ushort)a2; - ushort tmp2 = (ushort)b0.s2 * (ushort)a2; - ushort tmp3 = (ushort)b0.s3 * (ushort)a2; - - acc20 += ((uint)tmp0); - acc21 += ((uint)tmp1); - acc22 += ((uint)tmp2); - acc23 += ((uint)tmp3); - } -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a3; - ushort tmp1 = (ushort)b0.s1 * (ushort)a3; - ushort tmp2 = (ushort)b0.s2 * (ushort)a3; - ushort tmp3 = (ushort)b0.s3 * (ushort)a3; - - acc30 += ((uint)tmp0); - acc31 += ((uint)tmp1); - acc32 += ((uint)tmp2); - acc33 += ((uint)tmp3); - } -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - { - // Accumulate - ushort tmp0 = (ushort)b0.s0 * (ushort)a4; - ushort tmp1 = (ushort)b0.s1 * (ushort)a4; - ushort tmp2 = (ushort)b0.s2 * (ushort)a4; - ushort tmp3 = (ushort)b0.s3 * (ushort)a4; - - acc40 += ((uint)tmp0); - acc41 += ((uint)tmp1); - acc42 += ((uint)tmp2); - acc43 += ((uint)tmp3); - } -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - } - - const int z = get_global_id(2); - - // Compute destination address - Image dst = CONVERT_TO_IMAGE_STRUCT(dst); - -#if defined(REINTERPRET_OUTPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - - // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); - - // Add offset due to the cross plane paddings - zout *= (dst_cross_plane_pad * dst_stride_y); - - // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we - // multiply dst_stride_z by DEPTH_GEMM3D - dst.ptr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store the result - vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - -#else // defined(REINTERPRET_OUTPUT_AS_3D) - // Add offset for batched GEMM - dst.ptr += z * dst_stride_z; - - // Store the result - vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst.ptr + 0 * dst_stride_y)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst.ptr + 1 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst.ptr + 2 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst.ptr + 3 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(dst.ptr + 4 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 -#endif // defined(REINTERPRET_OUTPUT_AS_3D) -} - -#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) -/** OpenCL kernel optimized to use dot product that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped - * - * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A - * - * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: - * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D - * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D - * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. - * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor - * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped - * - * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8 - * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) - * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) - * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr - * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) - * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) - * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32 - * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) - * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) - * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix - * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) - * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) - * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) - * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D) - */ -__kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0), - IMAGE_DECLARATION(src1), - IMAGE_DECLARATION(dst), - uint src0_stride_z, - uint src1_stride_z, - uint dst_stride_z -#if defined(REINTERPRET_INPUT_AS_3D) - , - uint src_cross_plane_pad -#endif // REINTERPRET_INPUT_AS_3D -#if defined(REINTERPRET_OUTPUT_AS_3D) - , - uint dst_cross_plane_pad -#endif // REINTERPRET_OUTPUT_AS_3D) - ) -{ - int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X; - - // Compute starting address for matrix A and Matrix B - int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); - - // Update address for the matrix A - src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y; - - // Update address for the matrix B - src_addr.s1 += idx; - -#if defined(REINTERPRET_INPUT_AS_3D) - // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - - // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; - zin = min(DEPTH_GEMM3D - 1, zin); - - // Add offset due to the cross plane paddings - zin *= (src_cross_plane_pad * src0_stride_y); - - zin += ((uint4)(0, 1, 2, 3)) * src0_stride_y; - - // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we - // multiply src0_stride_z by DEPTH_GEMM3D - src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; - -#else // defined(REINTERPRET_INPUT_AS_3D) - - // Add offset for batched GEMM - src_addr.s0 += get_global_id(2) * src0_stride_z; - -#endif // defined(REINTERPRET_INPUT_AS_3D) - -#if defined(MATRIX_B_DEPTH) - // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 - src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; -#else // defined(MATRIX_B_DEPTH) - src_addr.s1 += get_global_id(2) * src1_stride_z; -#endif // defined(MATRIX_B_DEPTH) - - uint acc00 = 0; - uint acc01 = 0; - uint acc02 = 0; - uint acc03 = 0; - uint acc04 = 0; - uint acc05 = 0; - uint acc06 = 0; - uint acc07 = 0; -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uint acc10 = 0; - uint acc11 = 0; - uint acc12 = 0; - uint acc13 = 0; - uint acc14 = 0; - uint acc15 = 0; - uint acc16 = 0; - uint acc17 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uint acc20 = 0; - uint acc21 = 0; - uint acc22 = 0; - uint acc23 = 0; - uint acc24 = 0; - uint acc25 = 0; - uint acc26 = 0; - uint acc27 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uint acc30 = 0; - uint acc31 = 0; - uint acc32 = 0; - uint acc33 = 0; - uint acc34 = 0; - uint acc35 = 0; - uint acc36 = 0; - uint acc37 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - - // A and B src indices get incremented at the same time. - int i = 0; - for(; i <= ((int)COLS_A - 8); i += 8) - { -#if defined(REINTERPRET_INPUT_AS_3D) - // Load values from matrix A and matrix B - uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + zin.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#else // defined(REINTERPRET_INPUT_AS_3D) - // Load values from matrix A and matrix B - uchar8 a0 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uchar8 a1 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uchar8 a2 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uchar8 a3 = vload8(0, (__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif // defined(REINTERPRET_INPUT_AS_3D) - - uchar8 b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y); - uchar8 b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y); - uchar8 b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y); - uchar8 b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); - src_addr.s1 += 4 * src1_stride_y; - - ARM_DOT(a0.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00); - ARM_DOT(a0.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01); - ARM_DOT(a0.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02); - ARM_DOT(a0.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03); - ARM_DOT(a0.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04); - ARM_DOT(a0.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05); - ARM_DOT(a0.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06); - ARM_DOT(a0.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07); - -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - ARM_DOT(a1.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10); - ARM_DOT(a1.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11); - ARM_DOT(a1.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12); - ARM_DOT(a1.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13); - ARM_DOT(a1.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14); - ARM_DOT(a1.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15); - ARM_DOT(a1.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16); - ARM_DOT(a1.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - ARM_DOT(a2.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20); - ARM_DOT(a2.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21); - ARM_DOT(a2.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22); - ARM_DOT(a2.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23); - ARM_DOT(a2.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24); - ARM_DOT(a2.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25); - ARM_DOT(a2.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26); - ARM_DOT(a2.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - ARM_DOT(a3.s0123, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30); - ARM_DOT(a3.s0123, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31); - ARM_DOT(a3.s0123, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32); - ARM_DOT(a3.s0123, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33); - ARM_DOT(a3.s0123, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34); - ARM_DOT(a3.s0123, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35); - ARM_DOT(a3.s0123, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36); - ARM_DOT(a3.s0123, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - - b0 = vload8(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y); - b1 = vload8(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y); - b2 = vload8(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y); - b3 = vload8(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y); - src_addr.s1 += 4 * src1_stride_y; - - ARM_DOT(a0.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc00); - ARM_DOT(a0.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc01); - ARM_DOT(a0.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc02); - ARM_DOT(a0.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc03); - ARM_DOT(a0.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc04); - ARM_DOT(a0.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc05); - ARM_DOT(a0.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc06); - ARM_DOT(a0.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc07); - -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - ARM_DOT(a1.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc10); - ARM_DOT(a1.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc11); - ARM_DOT(a1.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc12); - ARM_DOT(a1.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc13); - ARM_DOT(a1.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc14); - ARM_DOT(a1.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc15); - ARM_DOT(a1.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc16); - ARM_DOT(a1.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc17); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - ARM_DOT(a2.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc20); - ARM_DOT(a2.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc21); - ARM_DOT(a2.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc22); - ARM_DOT(a2.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc23); - ARM_DOT(a2.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc24); - ARM_DOT(a2.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc25); - ARM_DOT(a2.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc26); - ARM_DOT(a2.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc27); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - ARM_DOT(a3.s4567, (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), acc30); - ARM_DOT(a3.s4567, (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), acc31); - ARM_DOT(a3.s4567, (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), acc32); - ARM_DOT(a3.s4567, (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), acc33); - ARM_DOT(a3.s4567, (uchar4)(b0.s4, b1.s4, b2.s4, b3.s4), acc34); - ARM_DOT(a3.s4567, (uchar4)(b0.s5, b1.s5, b2.s5, b3.s5), acc35); - ARM_DOT(a3.s4567, (uchar4)(b0.s6, b1.s6, b2.s6, b3.s6), acc36); - ARM_DOT(a3.s4567, (uchar4)(b0.s7, b1.s7, b2.s7, b3.s7), acc37); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - - src_addr.s0 += 8; - } - - for(; i < (int)COLS_A; ++i) - { -#if defined(REINTERPRET_INPUT_AS_3D) - // Load values from matrix A - uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + zin.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#else // defined(REINTERPRET_INPUT_AS_3D) - // Load values from matrix A - uchar a0 = *((__global uchar *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uchar a1 = *((__global uchar *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uchar a2 = *((__global uchar *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uchar a3 = *((__global uchar *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif // defined(REINTERPRET_INPUT_AS_3D) - - // Load values from matrix B - uchar8 b0 = vload8(0, src1_ptr + src_addr.s1); - src_addr.s1 += src1_stride_y; - - acc00 += (uint)a0 * b0.s0; - acc01 += (uint)a0 * b0.s1; - acc02 += (uint)a0 * b0.s2; - acc03 += (uint)a0 * b0.s3; - acc04 += (uint)a0 * b0.s4; - acc05 += (uint)a0 * b0.s5; - acc06 += (uint)a0 * b0.s6; - acc07 += (uint)a0 * b0.s7; - -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc10 += (uint)a1 * b0.s0; - acc11 += (uint)a1 * b0.s1; - acc12 += (uint)a1 * b0.s2; - acc13 += (uint)a1 * b0.s3; - acc14 += (uint)a1 * b0.s4; - acc15 += (uint)a1 * b0.s5; - acc16 += (uint)a1 * b0.s6; - acc17 += (uint)a1 * b0.s7; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc20 += (uint)a2 * b0.s0; - acc21 += (uint)a2 * b0.s1; - acc22 += (uint)a2 * b0.s2; - acc23 += (uint)a2 * b0.s3; - acc24 += (uint)a2 * b0.s4; - acc25 += (uint)a2 * b0.s5; - acc26 += (uint)a2 * b0.s6; - acc27 += (uint)a2 * b0.s7; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 += (uint)a3 * b0.s0; - acc31 += (uint)a3 * b0.s1; - acc32 += (uint)a3 * b0.s2; - acc33 += (uint)a3 * b0.s3; - acc34 += (uint)a3 * b0.s4; - acc35 += (uint)a3 * b0.s5; - acc36 += (uint)a3 * b0.s6; - acc37 += (uint)a3 * b0.s7; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - - src_addr.s0 += 1; - } - - int z = get_global_id(2); - - // Compute destination address - Image dst = CONVERT_TO_IMAGE_STRUCT(dst); - - // Compute dst address - __global uchar *dst_addr = dst.ptr; - -#if defined(REINTERPRET_OUTPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - - // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); - - // Add offset due to the cross plane paddings - zout *= (dst_cross_plane_pad * dst_stride_y); - - // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we - // multiply dst_stride_z by DEPTH_GEMM3D - dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store the result - vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0)); - vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y + zout.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1)); - vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y + zout.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2)); - vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y + zout.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3)); - vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - -#else // defined(REINTERPRET_OUTPUT_AS_3D) - // Add offset for batched GEMM - dst_addr += z * dst_stride_z; - - // Store the result - vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(dst_addr + 0 * dst_stride_y)); - vstore4((int4)(acc04, acc05, acc06, acc07), 1, (__global int *)(dst_addr + 0 * dst_stride_y)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(dst_addr + 1 * dst_stride_y)); - vstore4((int4)(acc14, acc15, acc16, acc17), 1, (__global int *)(dst_addr + 1 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(dst_addr + 2 * dst_stride_y)); - vstore4((int4)(acc24, acc25, acc26, acc27), 1, (__global int *)(dst_addr + 2 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(dst_addr + 3 * dst_stride_y)); - vstore4((int4)(acc34, acc35, acc36, acc37), 0, (__global int *)(dst_addr + 3 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif // defined(REINTERPRET_OUTPUT_AS_3D) -} -#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) #endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A) #if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(M) && defined(N) -- cgit v1.2.1