aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl218
1 files changed, 164 insertions, 54 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 9dd072bd6e..3a76b74b2f 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -34,7 +34,7 @@
* @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
* @note Only the following values for M0, K0 and V0 are supported:
* M0: 2,3,4,5,6,7,8
- * K0: 2,4,8,16
+ * K0: 2,3,4,8,16
* V0: greater than 0
* @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
* -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
@@ -100,7 +100,8 @@ __kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
__global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
(uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
- REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0); //uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
+ // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
+ REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
#if defined(REINTERPRET_INPUT_AS_3D)
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
@@ -323,7 +324,7 @@ __kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
* @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
* @note Only the following values for M0, K0 and V0 are supported:
* M0: 2,3,4,5,6,7,8
- * K0: 2,4,8,16
+ * K0: 2,3,4,8,16
* V0: greater than 0
* @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
* -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
@@ -389,14 +390,8 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
__global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
(uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
- uint zin0 = 0;
- uint zin1 = 0;
- uint zin2 = 0;
- uint zin3 = 0;
- uint zin4 = 0;
- uint zin5 = 0;
- uint zin6 = 0;
- uint zin7 = 0;
+ // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
+ REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
#if defined(REINTERPRET_INPUT_AS_3D)
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
@@ -509,8 +504,10 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
#if K0 > 2
TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
- TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
#endif // K0 > 2
+#if K0 > 3
+ TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
+#endif // K0 > 3
#if K0 > 4
TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
@@ -544,8 +541,8 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
* @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
* @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
* @note Only the following values for K0, N0 and H0 are supported:
- * N0: 2,4,8,16
- * K0: 1,2,4,8,16
+ * N0: 2,3,4,8,16
+ * K0: 1,2,3,4,8,16
* H0: greater than 0
*
* @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
@@ -618,11 +615,13 @@ __kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
{
a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
}
+#endif // K0 > 2
+#if K0 > 3
if(y * (uint)K0 + 3 < SRC_HEIGHT)
{
a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
}
-#endif // K0 > 2
+#endif // K0 > 3
#if K0 > 4
if(y * (uint)K0 + 4 < SRC_HEIGHT)
{
@@ -686,9 +685,11 @@ __kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
#if K0 > 2
VSTORE(N0)
(a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
+#endif // K0 > 2
+#if K0 > 3
VSTORE(N0)
(a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // K0 > 2
+#endif // K0 > 3
#if K0 > 4
VSTORE(N0)
(a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
@@ -734,8 +735,8 @@ __kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
* @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
* @note The option -DTRANSPOSE must passed at compile time.
* @note Only the following values for K0, N0 and H0 are supported:
- * N0: 2,4,8,16
- * K0: 4,8,16
+ * N0: 2,3,4,8,16
+ * K0: 2,3,4,8,16
* H0: greater than 0
*
* @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
@@ -798,14 +799,18 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
{
a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
}
+#if K0 > 2
if(y * (uint)K0 + 2 < SRC_HEIGHT)
{
a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
}
+#endif // K0 > 2
+#if K0 > 3
if(y * (uint)K0 + 3 < SRC_HEIGHT)
{
a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
}
+#endif // K0 > 3
#if K0 > 4
if(y * (uint)K0 + 4 < SRC_HEIGHT)
{
@@ -862,7 +867,69 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
// ---------------------------Transpose the block ------------------------------
REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
-#if K0 == 4
+#if K0 == 2
+ // This part computes the following transpositions:
+ // 2x2 -> 2x2
+ // 2x4 -> 4x2
+ // 2x8 -> 8x2
+ // 2x16 -> 16x2
+ res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
+ res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
+#if N0 > 2
+ res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
+#endif // N0 > 2
+#if N0 > 3
+ res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
+#endif // N0 > 3
+#if N0 > 4
+ res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
+ res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
+ res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
+ res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
+#endif // N0 > 4
+#if N0 > 8
+ res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
+ res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
+ resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
+ resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
+ resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
+ resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
+ resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
+ resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
+#endif // N0 > 8
+
+#elif K0 == 3 // K0 == 2
+ // This part computes the following transpositions:
+ // 3x2 -> 2x3
+ // 3x4 -> 4x3
+ // 3x8 -> 8x3
+ // 3x16 -> 16x3
+ res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
+ res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
+#if N0 > 2
+ res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
+#endif // N0 > 2
+#if N0 > 3
+ res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
+#endif // N0 > 3
+#if N0 > 4
+ res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
+ res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
+ res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
+ res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
+#endif // N0 > 4
+#if N0 > 8
+ res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
+ res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
+ resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
+ resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
+ resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
+ resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
+ resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
+ resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
+#endif // N0 > 8
+
+#elif K0 == 4 // K0 == 4
// This part computes the following transpositions:
// 4x2 -> 2x4
// 4x4 -> 4x4
@@ -872,8 +939,10 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
#if N0 > 2
res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
- res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
#endif // N0 > 2
+#if N0 > 3
+ res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
+#endif // N0 > 3
#if N0 > 4
res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
@@ -891,7 +960,7 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
#endif // N0 > 8
-#elif K0 == 8 // N0 == 3
+#elif K0 == 8 // K0 == 8
// This part computes the following transpositions:
// 8x2 -> 2x8
// 8x4 -> 4x8
@@ -901,8 +970,10 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
#if N0 > 2
res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
- res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
#endif // N0 > 2
+#if N0 > 3
+ res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
+#endif // N0 > 3
#if N0 > 4
res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
@@ -920,7 +991,7 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
#endif // N0 > 8
-#elif K0 == 16 // N0 == 16
+#elif K0 == 16 // K0 == 16
// This part computes the following transpositions:
// 16x2 -> 2x16
@@ -934,9 +1005,11 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
#if N0 > 2
res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
+#endif // N0 > 2
+#if N0 > 3
res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
-#endif // N0 > 2
+#endif // N0 > 3
#if N0 > 4
res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
@@ -979,9 +1052,11 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
#if N0 > 2
VSTORE(K0)
(res2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
+#endif // N0 > 2
+#if N0 > 3
VSTORE(K0)
(res3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 2
+#endif // N0 > 3
#if N0 > 4
VSTORE(K0)
(res4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
@@ -1018,34 +1093,60 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
#endif // defined(TRANSPOSE)
#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
-#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
+#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE)
-#define ARM_DOT(x, y, val) \
- ({ \
- val = fma(x.s0, y.s0, val); \
- val = fma(x.s1, y.s1, val); \
- val = fma(x.s2, y.s2, val); \
- val = fma(x.s3, y.s3, val); \
+#if K0 == 2
+#define ARM_DOT_K0(a, b, c) \
+ ({ \
+ c = fma(a.s0, b.s0, c); \
+ c = fma(a.s1, b.s1, c); \
})
-
-#if K0 == 4
-#define ARM_DOT_K0(a, b, c) \
- ({ \
- ARM_DOT(a, b, c); \
+#elif K0 == 3 // K0 == 3
+#define ARM_DOT_K0(a, b, c) \
+ ({ \
+ c = fma(a.s0, b.s0, c); \
+ c = fma(a.s1, b.s1, c); \
+ c = fma(a.s2, b.s2, c); \
+ })
+#elif K0 == 4 // K0 == 4
+#define ARM_DOT_K0(a, b, c) \
+ ({ \
+ c = fma(a.s0, b.s0, c); \
+ c = fma(a.s1, b.s1, c); \
+ c = fma(a.s2, b.s2, c); \
+ c = fma(a.s3, b.s3, c); \
})
#elif K0 == 8 // K0 == 8
-#define ARM_DOT_K0(a, b, c) \
- ({ \
- ARM_DOT((a).s0123, (b).s0123, c); \
- ARM_DOT((a).s4567, (b).s4567, c); \
+#define ARM_DOT_K0(a, b, c) \
+ ({ \
+ c = fma(a.s0, b.s0, c); \
+ c = fma(a.s1, b.s1, c); \
+ c = fma(a.s2, b.s2, c); \
+ c = fma(a.s3, b.s3, c); \
+ c = fma(a.s4, b.s4, c); \
+ c = fma(a.s5, b.s5, c); \
+ c = fma(a.s6, b.s6, c); \
+ c = fma(a.s7, b.s7, c); \
})
#elif K0 == 16 // K0 == 16
-#define ARM_DOT_K0(a, b, c) \
- ({ \
- ARM_DOT((a).s0123, (b).s0123, c); \
- ARM_DOT((a).s4567, (b).s4567, c); \
- ARM_DOT((a).s89AB, (b).s89AB, c); \
- ARM_DOT((a).sCDEF, (b).sCDEF, c); \
+#define ARM_DOT_K0(a, b, c) \
+ ({ \
+ c = fma(a.s0, b.s0, c); \
+ c = fma(a.s1, b.s1, c); \
+ c = fma(a.s2, b.s2, c); \
+ c = fma(a.s3, b.s3, c); \
+ c = fma(a.s4, b.s4, c); \
+ c = fma(a.s5, b.s5, c); \
+ c = fma(a.s6, b.s6, c); \
+ c = fma(a.s7, b.s7, c); \
+ c = fma(a.s8, b.s8, c); \
+ c = fma(a.s9, b.s9, c); \
+ c = fma(a.sA, b.sA, c); \
+ c = fma(a.sB, b.sB, c); \
+ c = fma(a.sC, b.sC, c); \
+ c = fma(a.sD, b.sD, c); \
+ c = fma(a.sE, b.sE, c); \
+ c = fma(a.sF, b.sF, c); \
})
#else // K0 not supported
#error "K0 value not supported"
@@ -1057,6 +1158,13 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
ARM_DOT_K0((a), (b##0), (c.s0)); \
ARM_DOT_K0((a), (b##1), (c.s1)); \
})
+#elif N0 == 3 // N0 == 3
+#define ARM_DOT_K0XN0(a, b, c) \
+ ({ \
+ ARM_DOT_K0((a), (b##0), (c.s0)); \
+ ARM_DOT_K0((a), (b##1), (c.s1)); \
+ ARM_DOT_K0((a), (b##2), (c.s2)); \
+ })
#elif N0 == 4 // N0 == 4
#define ARM_DOT_K0XN0(a, b, c) \
({ \
@@ -1105,7 +1213,6 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
* The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
* The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
*
- * @note The number of columns in the RHS matrix NOT reshaped needs to be passed at compile time using -DK (i.e. -Dk=128).
* @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
* @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
* @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
@@ -1113,8 +1220,8 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
* @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
* @note Only the following configurations of M0, N0 and K0 are currently supported:
* - M0 = 2, 3, 4, 5, 6, 7, 8
- * - N0 = 2, 4, 8, 16
- * - K0 = 4, 8, 16
+ * - N0 = 2, 3, 4, 8, 16
+ * - K0 = 2, 3, 4, 8, 16
*
* @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
* -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
@@ -1140,6 +1247,7 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
* @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
* @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
@@ -1148,6 +1256,7 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
IMAGE_DECLARATION(rhs),
IMAGE_DECLARATION(dst),
+ uint k,
uint lhs_stride_z,
uint rhs_stride_z,
uint dst_stride_z
@@ -1201,7 +1310,7 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
// Initialize the accumulators
REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
- for(int i = 0; i < K; i += K0)
+ for(int i = 0; i < k; i += K0)
{
// Supported cases (M0, K0):
// 2,4 - 2,8 - 2,16
@@ -1249,9 +1358,11 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
#if N0 > 2
VEC_DATA_TYPE(DATA_TYPE, K0)
b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
+#endif // N0 > 2
+#if N0 > 3
VEC_DATA_TYPE(DATA_TYPE, K0)
b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 2
+#endif // N0 > 3
#if N0 > 4
VEC_DATA_TYPE(DATA_TYPE, K0)
b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
@@ -1363,7 +1474,7 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
zout6 *= (dst_cross_plane_pad * dst_stride_y);
#endif // M0 > 6
-#if M0 > 6
+#if M0 > 7
zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
zout7 *= (dst_cross_plane_pad * dst_stride_y);
@@ -1438,7 +1549,6 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
(c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
#endif // M0 > 7
-
#undef LHS_BLOCK_SIZE
#undef LHS_OFFSET_X
#undef LHS_STEP_X