aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-05-21 13:32:43 +0100
committerGiuseppe Rossini <giuseppe.rossini@arm.com>2019-06-04 15:58:08 +0000
commitb0f342ec315397e4b87d3a9cc3d12f3645c153bc (patch)
tree3bfd95d4196f6c45feb368b0a020f3bb304e79cd /src/core/CL/cl_kernels
parentbbac660f1959ed2ab58b31a8d5db524883da1754 (diff)
downloadComputeLibrary-b0f342ec315397e4b87d3a9cc3d12f3645c153bc.tar.gz
COMPMID-2171: Fuse bias addition with CLGEMMMatrixMultiplyReshapedOnlyRHSKernel
Change-Id: I1d1e1f28fe7022309d72900893e8368820ca0f89 Signed-off-by: giuros01 <giuseppe.rossini@arm.com> Reviewed-on: https://review.mlplatform.org/c/1259 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl219
-rw-r--r--src/core/CL/cl_kernels/gemm_helpers.h174
2 files changed, 310 insertions, 83 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 41e5c338b3..2ac2eb7c32 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -731,29 +731,29 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
// 3x4 -> 4x3
// 3x8 -> 8x3
// 3x16 -> 16x3
- res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
- res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
+ res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
+ res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
#if N0 > 2
- res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
+ res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
#endif // N0 > 2
#if N0 > 3
- res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
+ res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
#endif // N0 > 3
#if N0 > 4
- res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
- res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
- res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
- res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
+ res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
+ res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
+ res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
+ res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
#endif // N0 > 4
#if N0 > 8
- res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
- res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
- resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
- resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
- resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
- resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
- resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
- resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
+ res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
+ res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
+ resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
+ resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
+ resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
+ resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
+ resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
+ resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
#endif // N0 > 8
#elif K0 == 4 // K0 == 4
@@ -1029,35 +1029,48 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
* (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
*
- * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
- * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
- * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
- * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
- * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
- * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
- * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
- * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
- * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
+ * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
+ * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
+ * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
+ * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
+ * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
+ * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
+ * @param[in] bias_ptr (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] bias_stride_x (Optional)Stride of the bias reshaped matrix in X dimension (in bytes)
+ * @param[in] bias_step_x (Optional)bias_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] bias_stride_y (Optional)Stride of the bias reshaped matrix in Y dimension (in bytes)
+ * @param[in] bias_step_y (Optional)bias_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] bias_offset_first_element_in_bytes (Optional)The offset of the first element in the bias reshaped matrix
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
+ * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
+ * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
+ * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
IMAGE_DECLARATION(rhs),
+#if defined(BETA)
+ IMAGE_DECLARATION(bias),
+#endif // defined(BETA)
IMAGE_DECLARATION(dst),
uint lhs_stride_z,
uint rhs_stride_z,
+#if defined(BETA)
+ uint bias_stride_z,
+#endif //defined(BETA)
uint dst_stride_z
#if defined(REINTERPRET_INPUT_AS_3D)
,
@@ -1108,7 +1121,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
#endif // defined(MATRIX_B_DEPTH)
REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
- REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
+ REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
#if defined(REINTERPRET_INPUT_AS_3D)
// The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
@@ -1144,7 +1157,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
// Load values from RHS matrix
- LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
+ LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
// Accumulate
ARM_DOT_K0XN0(K0, a0, b, c0);
@@ -1181,7 +1194,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
// Load values from RHS matrix
- LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
+ LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
// Accumulate
ARM_DOT_K0XN0(1, a0, b, c0);
@@ -1236,6 +1249,36 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
#endif // defined(ALPHA)
+ // Add beta*bias
+#if defined(BETA)
+#if defined(BROADCAST_BIAS)
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
+
+ LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+ SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+ // c = c + bias[broadcasted]
+ ADD_BLOCK_BROADCAST(M0, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
+ 2) * bias_stride_z;
+
+ LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+ SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+ // c = c + bias
+ ADD_BLOCK(M0, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
// Store output block
STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
@@ -1360,35 +1403,48 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
* (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
*
- * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
- * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
- * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
- * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
- * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
- * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
- * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
- * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
- * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
+ * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
+ * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
+ * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
+ * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
+ * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
+ * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
+ * @param[in] bias_ptr (Optional) Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] bias_stride_x (Optional) Stride of the bias reshaped matrix in X dimension (in bytes)
+ * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] bias_stride_y (Optional) Stride of the bias reshaped matrix in Y dimension (in bytes)
+ * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
+ * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
+ * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
+ * @param[in] bias_stride_z (Optional)Stride of the bias reshaped matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
IMAGE_DECLARATION(rhs),
+#if defined(BETA)
+ IMAGE_DECLARATION(bias),
+#endif // defined(BETA)
IMAGE_DECLARATION(dst),
uint lhs_stride_z,
uint rhs_stride_z,
+#if defined(BETA)
+ uint bias_stride_z,
+#endif //defined(BETA)
uint dst_stride_z
#if defined(REINTERPRET_INPUT_AS_3D)
,
@@ -1438,7 +1494,8 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
rhs_offset += z * rhs_stride_z;
#endif // defined(MATRIX_B_DEPTH)
- REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
+ REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
+ REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
#if defined(REINTERPRET_INPUT_AS_3D)
@@ -1568,6 +1625,36 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
#endif // defined(ALPHA)
+ // Add beta*bias
+#if defined(BETA)
+#if defined(BROADCAST_BIAS)
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
+
+ LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+ SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+ // c = c + bias[broadcasted]
+ ADD_BLOCK_BROADCAST(M0, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
+ 2) * bias_stride_z;
+
+ LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+ SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+ // c = c + bias
+ ADD_BLOCK(M0, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
// Store output block
STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h
index 2c76992b31..cd2d39b433 100644
--- a/src/core/CL/cl_kernels/gemm_helpers.h
+++ b/src/core/CL/cl_kernels/gemm_helpers.h
@@ -360,69 +360,69 @@
#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
#define SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##0 = BASENAME##0 * (DATA_TYPE)SCALE;
+ BASENAME##0 *= (DATA_TYPE)SCALE;
#define SCALE_ROW_2(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##1 = BASENAME##1 * (DATA_TYPE)SCALE;
+ BASENAME##1 *= (DATA_TYPE)SCALE;
#define SCALE_ROW_3(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_2(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##2 = BASENAME##2 * (DATA_TYPE)SCALE;
+ BASENAME##2 *= (DATA_TYPE)SCALE;
#define SCALE_ROW_4(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_3(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##3 = BASENAME##3 * (DATA_TYPE)SCALE;
+ BASENAME##3 *= (DATA_TYPE)SCALE;
#define SCALE_ROW_5(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_4(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##4 = BASENAME##4 * (DATA_TYPE)SCALE;
+ BASENAME##4 *= (DATA_TYPE)SCALE;
#define SCALE_ROW_6(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_5(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##5 = BASENAME##5 * (DATA_TYPE)SCALE;
+ BASENAME##5 *= (DATA_TYPE)SCALE;
#define SCALE_ROW_7(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_6(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##6 = BASENAME##6 * (DATA_TYPE)SCALE;
+ BASENAME##6 *= (DATA_TYPE)SCALE;
#define SCALE_ROW_8(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_7(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##7 = BASENAME##7 * (DATA_TYPE)SCALE;
+ BASENAME##7 *= (DATA_TYPE)SCALE;
#define SCALE_ROW_9(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_8(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##8 = BASENAME##8 * (DATA_TYPE)SCALE;
+ BASENAME##8 *= (DATA_TYPE)SCALE;
#define SCALE_ROW_10(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_9(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##9 = BASENAME##9 * (DATA_TYPE)SCALE;
+ BASENAME##9 *= (DATA_TYPE)SCALE;
#define SCALE_ROW_11(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_10(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##A = BASENAME##A * (DATA_TYPE)SCALE;
+ BASENAME##A *= (DATA_TYPE)SCALE;
#define SCALE_ROW_12(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_11(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##B = BASENAME##B * (DATA_TYPE)SCALE;
+ BASENAME##B *= (DATA_TYPE)SCALE;
#define SCALE_ROW_13(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_12(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##C = BASENAME##C * (DATA_TYPE)SCALE;
+ BASENAME##C *= (DATA_TYPE)SCALE;
#define SCALE_ROW_14(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_13(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##D = BASENAME##D * (DATA_TYPE)SCALE;
+ BASENAME##D *= (DATA_TYPE)SCALE;
#define SCALE_ROW_15(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_14(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##E = BASENAME##E * (DATA_TYPE)SCALE;
+ BASENAME##E *= (DATA_TYPE)SCALE;
#define SCALE_ROW_16(DATA_TYPE, BASENAME, SCALE) \
SCALE_ROW_15(DATA_TYPE, BASENAME, SCALE) \
- BASENAME##F = BASENAME##F * (DATA_TYPE)SCALE;
+ BASENAME##F *= (DATA_TYPE)SCALE;
-// SCALE_ROW_n scales the variables BASENAME##0 to BASENAME##(n-1) by SCALE
+// SCALE_BLOCK_n scales the variables BASENAME##0 to BASENAME##(n-1) by SCALE
#define SCALE_BLOCK_STR(N, DATA_TYPE, BASENAME, SCALE) SCALE_ROW_##N(DATA_TYPE, BASENAME, SCALE)
/** Scale elements stored in variables BASENAME##0 to BASENAME##(N-1) by SCALE
* Supported cases N=1,2,3..16, for variables BASENAME[0..N]
@@ -479,3 +479,143 @@
#define TRANSPOSE_K0XN0(K0, N0, BASENAME, B) \
CONCAT(TRANSPOSE_K0X, N0) \
(K0, BASENAME, B);
+
+#define ADD_ROW_1(BASENAME, BIAS) \
+ BASENAME##0 += BIAS##0;
+
+#define ADD_ROW_2(BASENAME, BIAS) \
+ ADD_ROW_1(BASENAME, BIAS) \
+ BASENAME##1 += BIAS##1;
+
+#define ADD_ROW_3(BASENAME, BIAS) \
+ ADD_ROW_2(BASENAME, BIAS) \
+ BASENAME##2 += BIAS##2;
+
+#define ADD_ROW_4(BASENAME, BIAS) \
+ ADD_ROW_3(BASENAME, BIAS) \
+ BASENAME##3 += BIAS##3;
+
+#define ADD_ROW_5(BASENAME, BIAS) \
+ ADD_ROW_4(BASENAME, BIAS) \
+ BASENAME##4 += BIAS##4;
+
+#define ADD_ROW_6(BASENAME, BIAS) \
+ ADD_ROW_5(BASENAME, BIAS) \
+ BASENAME##5 += BIAS##5;
+
+#define ADD_ROW_7(BASENAME, BIAS) \
+ ADD_ROW_6(BASENAME, BIAS) \
+ BASENAME##6 += BIAS##6;
+
+#define ADD_ROW_8(BASENAME, BIAS) \
+ ADD_ROW_7(BASENAME, BIAS) \
+ BASENAME##7 += BIAS##7;
+
+#define ADD_ROW_9(BASENAME, BIAS) \
+ ADD_ROW_8(BASENAME, BIAS) \
+ BASENAME##8 += BIAS##8;
+
+#define ADD_ROW_10(BASENAME, BIAS) \
+ ADD_ROW_9(BASENAME, BIAS) \
+ BASENAME##9 += BIAS##9;
+
+#define ADD_ROW_11(BASENAME, BIAS) \
+ ADD_ROW_10(BASENAME, BIAS) \
+ BASENAME##A += BIAS##A;
+
+#define ADD_ROW_12(BASENAME, BIAS) \
+ ADD_ROW_11(BASENAME, BIAS) \
+ BASENAME##B += BIAS##B;
+
+#define ADD_ROW_13(BASENAME, BIAS) \
+ ADD_ROW_12(BASENAME, BIAS) \
+ BASENAME##C += BIAS##C;
+
+#define ADD_ROW_14(BASENAME, BIAS) \
+ ADD_ROW_13(BASENAME, BIAS) \
+ BASENAME##D += BIAS##D;
+
+#define ADD_ROW_15(BASENAME, BIAS) \
+ ADD_ROW_14(BASENAME, BIAS) \
+ BASENAME##E += BIAS##E;
+
+#define ADD_ROW_16(BASENAME, BIAS) \
+ ADD_ROW_15(BASENAME, BIAS) \
+ BASENAME##F += BIAS##F;
+
+// ADD_ROW_n add the variables BIAS##0... BIAS##(n-1) to BASENAME##0 to BASENAME##(n-1)
+#define ADD_BLOCK_STR(N, BASENAME, BIAS) ADD_ROW_##N(BASENAME, BIAS)
+/** Add BIAS to BASENAME##0 ... BASENAME##(N-1)
+ * Supported cases N=1,2,3..16, for variables BASENAME[0..N]
+ */
+#define ADD_BLOCK(N, BASENAME, BIAS) ADD_BLOCK_STR(N, BASENAME, BIAS)
+
+#define ADD_ROW_BROADCAST_1(BASENAME, BIAS) \
+ BASENAME##0 += BIAS;
+
+#define ADD_ROW_BROADCAST_2(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_1(BASENAME, BIAS) \
+ BASENAME##1 += BIAS;
+
+#define ADD_ROW_BROADCAST_3(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_2(BASENAME, BIAS) \
+ BASENAME##2 += BIAS;
+
+#define ADD_ROW_BROADCAST_4(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_3(BASENAME, BIAS) \
+ BASENAME##3 += BIAS;
+
+#define ADD_ROW_BROADCAST_5(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_4(BASENAME, BIAS) \
+ BASENAME##4 += BIAS;
+
+#define ADD_ROW_BROADCAST_6(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_5(BASENAME, BIAS) \
+ BASENAME##5 += BIAS;
+
+#define ADD_ROW_BROADCAST_7(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_6(BASENAME, BIAS) \
+ BASENAME##6 += BIAS;
+
+#define ADD_ROW_BROADCAST_8(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_7(BASENAME, BIAS) \
+ BASENAME##7 += BIAS;
+
+#define ADD_ROW_BROADCAST_9(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_8(BASENAME, BIAS) \
+ BASENAME##8 += BIAS;
+
+#define ADD_ROW_BROADCAST_10(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_9(BASENAME, BIAS) \
+ BASENAME##9 += BIAS;
+
+#define ADD_ROW_BROADCAST_11(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_10(BASENAME, BIAS) \
+ BASENAME##A += BIAS;
+
+#define ADD_ROW_BROADCAST_12(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_11(BASENAME, BIAS) \
+ BASENAME##B += BIAS;
+
+#define ADD_ROW_BROADCAST_13(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_12(BASENAME, BIAS) \
+ BASENAME##C += BIAS;
+
+#define ADD_ROW_BROADCAST_14(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_13(BASENAME, BIAS) \
+ BASENAME##D += BIAS;
+
+#define ADD_ROW_BROADCAST_15(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_14(BASENAME, BIAS) \
+ BASENAME##E += BIAS;
+
+#define ADD_ROW_BROADCAST_16(BASENAME, BIAS) \
+ ADD_ROW_BROADCAST_15(BASENAME, BIAS) \
+ BASENAME##F += BIAS;
+
+// ADD_ROW_n add the variables BIAS to BASENAME##0 to BASENAME##(n-1)
+#define ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS) ADD_ROW_BROADCAST_##N(BASENAME, BIAS)
+/** Add elements stored in variables BIAS##0 ... BIAS##(N-1) to BASENAME##0 ... BASENAME##(N-1)
+ * Supported cases N=1,2,3..16, for variables BASENAME[0..N]
+ */
+#define ADD_BLOCK_BROADCAST(N, BASENAME, BIAS) ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS)