From bacfec5ecc3bd737c3d4eb2b0c165e0e55cc61f0 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 11 Jan 2019 11:30:55 +0000 Subject: COMPMID-1687: Optimize CLGEMMMatrixMultiplyKernel (part 1) Extended CLGEMMMatrixMultiplyReshapedKernel to support more parameters Change-Id: I4a27f986e3fe2dd071a4ccba5cfa0565f3db39ad Reviewed-on: https://review.mlplatform.org/495 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio --- src/core/CL/cl_kernels/gemm.cl | 218 +++++++++++++++++++++++++++++++---------- 1 file changed, 164 insertions(+), 54 deletions(-) (limited to 'src/core/CL/cl_kernels/gemm.cl') diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index 9dd072bd6e..3a76b74b2f 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -34,7 +34,7 @@ * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2) * @note Only the following values for M0, K0 and V0 are supported: * M0: 2,3,4,5,6,7,8 - * K0: 2,4,8,16 + * K0: 2,3,4,8,16 * V0: greater than 0 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D @@ -100,7 +100,8 @@ __kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src), __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)); - REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0); //uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0; + // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0; + REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0); #if defined(REINTERPRET_INPUT_AS_3D) // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we @@ -323,7 +324,7 @@ __kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src), * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2) * @note Only the following values for M0, K0 and V0 are supported: * M0: 2,3,4,5,6,7,8 - * K0: 2,4,8,16 + * K0: 2,3,4,8,16 * V0: greater than 0 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D @@ -389,14 +390,8 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src), __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)); - uint zin0 = 0; - uint zin1 = 0; - uint zin2 = 0; - uint zin3 = 0; - uint zin4 = 0; - uint zin5 = 0; - uint zin6 = 0; - uint zin7 = 0; + // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0; + REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0); #if defined(REINTERPRET_INPUT_AS_3D) // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we @@ -509,8 +504,10 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src), TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1); #if K0 > 2 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2); - TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3); #endif // K0 > 2 +#if K0 > 3 + TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3); +#endif // K0 > 3 #if K0 > 4 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4); TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5); @@ -544,8 +541,8 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src), * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2) * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time. * @note Only the following values for K0, N0 and H0 are supported: - * N0: 2,4,8,16 - * K0: 1,2,4,8,16 + * N0: 2,3,4,8,16 + * K0: 1,2,3,4,8,16 * H0: greater than 0 * * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32 @@ -618,11 +615,13 @@ __kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src), { a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y)); } +#endif // K0 > 2 +#if K0 > 3 if(y * (uint)K0 + 3 < SRC_HEIGHT) { a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y)); } -#endif // K0 > 2 +#endif // K0 > 3 #if K0 > 4 if(y * (uint)K0 + 4 < SRC_HEIGHT) { @@ -686,9 +685,11 @@ __kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src), #if K0 > 2 VSTORE(N0) (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); +#endif // K0 > 2 +#if K0 > 3 VSTORE(N0) (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // K0 > 2 +#endif // K0 > 3 #if K0 > 4 VSTORE(N0) (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); @@ -734,8 +735,8 @@ __kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src), * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time. * @note The option -DTRANSPOSE must passed at compile time. * @note Only the following values for K0, N0 and H0 are supported: - * N0: 2,4,8,16 - * K0: 4,8,16 + * N0: 2,3,4,8,16 + * K0: 2,3,4,8,16 * H0: greater than 0 * * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32 @@ -798,14 +799,18 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), { a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y)); } +#if K0 > 2 if(y * (uint)K0 + 2 < SRC_HEIGHT) { a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y)); } +#endif // K0 > 2 +#if K0 > 3 if(y * (uint)K0 + 3 < SRC_HEIGHT) { a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y)); } +#endif // K0 > 3 #if K0 > 4 if(y * (uint)K0 + 4 < SRC_HEIGHT) { @@ -862,7 +867,69 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), // ---------------------------Transpose the block ------------------------------ REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0; -#if K0 == 4 +#if K0 == 2 + // This part computes the following transpositions: + // 2x2 -> 2x2 + // 2x4 -> 4x2 + // 2x8 -> 8x2 + // 2x16 -> 16x2 + res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0); + res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1); +#if N0 > 2 + res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2); +#endif // N0 > 2 +#if N0 > 3 + res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3); +#endif // N0 > 3 +#if N0 > 4 + res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4); + res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5); + res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6); + res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7); +#endif // N0 > 4 +#if N0 > 8 + res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8); + res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9); + resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA); + resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB); + resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC); + resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD); + resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE); + resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF); +#endif // N0 > 8 + +#elif K0 == 3 // K0 == 2 + // This part computes the following transpositions: + // 3x2 -> 2x3 + // 3x4 -> 4x3 + // 3x8 -> 8x3 + // 3x16 -> 16x3 + res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0); + res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1); +#if N0 > 2 + res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2); +#endif // N0 > 2 +#if N0 > 3 + res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3); +#endif // N0 > 3 +#if N0 > 4 + res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4); + res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5); + res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6); + res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7); +#endif // N0 > 4 +#if N0 > 8 + res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8); + res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9); + resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA); + resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB); + resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC); + resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD); + resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE); + resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF); +#endif // N0 > 8 + +#elif K0 == 4 // K0 == 4 // This part computes the following transpositions: // 4x2 -> 2x4 // 4x4 -> 4x4 @@ -872,8 +939,10 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1); #if N0 > 2 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2); - res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3); #endif // N0 > 2 +#if N0 > 3 + res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3); +#endif // N0 > 3 #if N0 > 4 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4); res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5); @@ -891,7 +960,7 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF); #endif // N0 > 8 -#elif K0 == 8 // N0 == 3 +#elif K0 == 8 // K0 == 8 // This part computes the following transpositions: // 8x2 -> 2x8 // 8x4 -> 4x8 @@ -901,8 +970,10 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1); #if N0 > 2 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2); - res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3); #endif // N0 > 2 +#if N0 > 3 + res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3); +#endif // N0 > 3 #if N0 > 4 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4); res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5); @@ -920,7 +991,7 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF); #endif // N0 > 8 -#elif K0 == 16 // N0 == 16 +#elif K0 == 16 // K0 == 16 // This part computes the following transpositions: // 16x2 -> 2x16 @@ -934,9 +1005,11 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), #if N0 > 2 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2, a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2); +#endif // N0 > 2 +#if N0 > 3 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3, a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3); -#endif // N0 > 2 +#endif // N0 > 3 #if N0 > 4 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4, a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4); @@ -979,9 +1052,11 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), #if N0 > 2 VSTORE(K0) (res2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); +#endif // N0 > 2 +#if N0 > 3 VSTORE(K0) (res3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 2 +#endif // N0 > 3 #if N0 > 4 VSTORE(K0) (res4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); @@ -1018,34 +1093,60 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), #endif // defined(TRANSPOSE) #endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT) -#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE) +#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) -#define ARM_DOT(x, y, val) \ - ({ \ - val = fma(x.s0, y.s0, val); \ - val = fma(x.s1, y.s1, val); \ - val = fma(x.s2, y.s2, val); \ - val = fma(x.s3, y.s3, val); \ +#if K0 == 2 +#define ARM_DOT_K0(a, b, c) \ + ({ \ + c = fma(a.s0, b.s0, c); \ + c = fma(a.s1, b.s1, c); \ }) - -#if K0 == 4 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - ARM_DOT(a, b, c); \ +#elif K0 == 3 // K0 == 3 +#define ARM_DOT_K0(a, b, c) \ + ({ \ + c = fma(a.s0, b.s0, c); \ + c = fma(a.s1, b.s1, c); \ + c = fma(a.s2, b.s2, c); \ + }) +#elif K0 == 4 // K0 == 4 +#define ARM_DOT_K0(a, b, c) \ + ({ \ + c = fma(a.s0, b.s0, c); \ + c = fma(a.s1, b.s1, c); \ + c = fma(a.s2, b.s2, c); \ + c = fma(a.s3, b.s3, c); \ }) #elif K0 == 8 // K0 == 8 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - ARM_DOT((a).s0123, (b).s0123, c); \ - ARM_DOT((a).s4567, (b).s4567, c); \ +#define ARM_DOT_K0(a, b, c) \ + ({ \ + c = fma(a.s0, b.s0, c); \ + c = fma(a.s1, b.s1, c); \ + c = fma(a.s2, b.s2, c); \ + c = fma(a.s3, b.s3, c); \ + c = fma(a.s4, b.s4, c); \ + c = fma(a.s5, b.s5, c); \ + c = fma(a.s6, b.s6, c); \ + c = fma(a.s7, b.s7, c); \ }) #elif K0 == 16 // K0 == 16 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - ARM_DOT((a).s0123, (b).s0123, c); \ - ARM_DOT((a).s4567, (b).s4567, c); \ - ARM_DOT((a).s89AB, (b).s89AB, c); \ - ARM_DOT((a).sCDEF, (b).sCDEF, c); \ +#define ARM_DOT_K0(a, b, c) \ + ({ \ + c = fma(a.s0, b.s0, c); \ + c = fma(a.s1, b.s1, c); \ + c = fma(a.s2, b.s2, c); \ + c = fma(a.s3, b.s3, c); \ + c = fma(a.s4, b.s4, c); \ + c = fma(a.s5, b.s5, c); \ + c = fma(a.s6, b.s6, c); \ + c = fma(a.s7, b.s7, c); \ + c = fma(a.s8, b.s8, c); \ + c = fma(a.s9, b.s9, c); \ + c = fma(a.sA, b.sA, c); \ + c = fma(a.sB, b.sB, c); \ + c = fma(a.sC, b.sC, c); \ + c = fma(a.sD, b.sD, c); \ + c = fma(a.sE, b.sE, c); \ + c = fma(a.sF, b.sF, c); \ }) #else // K0 not supported #error "K0 value not supported" @@ -1057,6 +1158,13 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), ARM_DOT_K0((a), (b##0), (c.s0)); \ ARM_DOT_K0((a), (b##1), (c.s1)); \ }) +#elif N0 == 3 // N0 == 3 +#define ARM_DOT_K0XN0(a, b, c) \ + ({ \ + ARM_DOT_K0((a), (b##0), (c.s0)); \ + ARM_DOT_K0((a), (b##1), (c.s1)); \ + ARM_DOT_K0((a), (b##2), (c.s2)); \ + }) #elif N0 == 4 // N0 == 4 #define ARM_DOT_K0XN0(a, b, c) \ ({ \ @@ -1105,7 +1213,6 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed * - * @note The number of columns in the RHS matrix NOT reshaped needs to be passed at compile time using -DK (i.e. -Dk=128). * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4). * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2) * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2) @@ -1113,8 +1220,8 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time. * @note Only the following configurations of M0, N0 and K0 are currently supported: * - M0 = 2, 3, 4, 5, 6, 7, 8 - * - N0 = 2, 4, 8, 16 - * - K0 = 4, 8, 16 + * - N0 = 2, 3, 4, 8, 16 + * - K0 = 2, 3, 4, 8, 16 * * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D @@ -1140,6 +1247,7 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix + * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped. * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes) * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) @@ -1148,6 +1256,7 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), IMAGE_DECLARATION(rhs), IMAGE_DECLARATION(dst), + uint k, uint lhs_stride_z, uint rhs_stride_z, uint dst_stride_z @@ -1201,7 +1310,7 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), // Initialize the accumulators REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; - for(int i = 0; i < K; i += K0) + for(int i = 0; i < k; i += K0) { // Supported cases (M0, K0): // 2,4 - 2,8 - 2,16 @@ -1249,9 +1358,11 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), #if N0 > 2 VEC_DATA_TYPE(DATA_TYPE, K0) b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 2 * RHS_STEP_X * sizeof(DATA_TYPE))); +#endif // N0 > 2 +#if N0 > 3 VEC_DATA_TYPE(DATA_TYPE, K0) b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 3 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 2 +#endif // N0 > 3 #if N0 > 4 VEC_DATA_TYPE(DATA_TYPE, K0) b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 4 * RHS_STEP_X * sizeof(DATA_TYPE))); @@ -1363,7 +1474,7 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6); zout6 *= (dst_cross_plane_pad * dst_stride_y); #endif // M0 > 6 -#if M0 > 6 +#if M0 > 7 zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7); zout7 *= (dst_cross_plane_pad * dst_stride_y); @@ -1438,7 +1549,6 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7)); #endif // M0 > 7 - #undef LHS_BLOCK_SIZE #undef LHS_OFFSET_X #undef LHS_STEP_X -- cgit v1.2.1