aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm.cl
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-06-14 16:11:10 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-20 16:02:39 +0000
commite16c8906a2aedf00e910754a01fca8bc4189cfc7 (patch)
treede9b88917bb00a76a9df68c9e92f05e38c5de817 /src/core/CL/cl_kernels/gemm.cl
parent0cbfda629dd8f684e625173341bab972f004222c (diff)
downloadComputeLibrary-e16c8906a2aedf00e910754a01fca8bc4189cfc7.tar.gz
COMPMID-2053: Fuse bias addition with CLGEMMMatrixMultiplyReshapedKernel
Change-Id: I5bfd38c94a6fd18a1cba2104f7e1b04e7bef6ec2 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1359 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl119
1 files changed, 82 insertions, 37 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 2ac2eb7c32..7ada14c774 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -1042,11 +1042,12 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
* @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
* @param[in] bias_ptr (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
- * @param[in] bias_stride_x (Optional)Stride of the bias reshaped matrix in X dimension (in bytes)
- * @param[in] bias_step_x (Optional)bias_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] bias_stride_y (Optional)Stride of the bias reshaped matrix in Y dimension (in bytes)
- * @param[in] bias_step_y (Optional)bias_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] bias_offset_first_element_in_bytes (Optional)The offset of the first element in the bias reshaped matrix
+ * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1055,7 +1056,7 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
* @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
* @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
- * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
+ * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
* @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
@@ -1415,10 +1416,10 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
* @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
* @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[in] bias_ptr (Optional) Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
- * @param[in] bias_stride_x (Optional) Stride of the bias reshaped matrix in X dimension (in bytes)
+ * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
* @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] bias_stride_y (Optional) Stride of the bias reshaped matrix in Y dimension (in bytes)
+ * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
* @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
* @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
@@ -1429,7 +1430,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
* @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
* @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
- * @param[in] bias_stride_z (Optional)Stride of the bias reshaped matrix in Z dimension (in bytes)
+ * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
* @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
* @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
* @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
@@ -1804,36 +1805,49 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
* (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
*
- * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
- * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
- * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
- * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
- * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
- * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
- * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
- * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
- * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
- * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
+ * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
+ * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
+ * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
+ * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
+ * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
+ * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
+ * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
+ * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
+ * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
+ * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
+ * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
IMAGE_DECLARATION(rhs),
+#if defined(BETA)
+ IMAGE_DECLARATION(bias),
+#endif // defined(BETA)
IMAGE_DECLARATION(dst),
uint k,
uint lhs_stride_z,
uint rhs_stride_z,
+#if defined(BETA)
+ uint bias_stride_z,
+#endif //defined(BETA)
uint dst_stride_z
#if defined(REINTERPRET_OUTPUT_AS_3D)
,
@@ -1892,8 +1906,8 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
// Initialize the accumulators
REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
- REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
- REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
+ REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
+ REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
for(int i = 0; i < k; i += K0)
{
@@ -1910,7 +1924,7 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
// Load values from RHS matrix
- LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
+ LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zero);
// Accumulate
ARM_DOT_K0XN0(a0, b, c0);
@@ -1942,7 +1956,7 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
- REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
+ REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
#if defined(REINTERPRET_OUTPUT_AS_3D)
@@ -1964,8 +1978,39 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
#endif // defined(ALPHA)
+ // Add beta*bias
+#if defined(BETA)
+#if defined(BROADCAST_BIAS)
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
+
+ LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+ SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+ // c = c + bias[broadcasted]
+ ADD_BLOCK_BROADCAST(M0, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
+ 2) * bias_stride_z;
+
+ LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+ SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+ // c = c + bias
+ ADD_BLOCK(M0, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
// Store output block
STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
+
#undef LHS_BLOCK_SIZE
#undef LHS_OFFSET_X
#undef LHS_STEP_X