aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl101
1 files changed, 72 insertions, 29 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 7ada14c774..854d0092d9 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -2122,35 +2122,49 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
* (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
*
- * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
- * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
- * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
- * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
- * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
- * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
- * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
- * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
- * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
- * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
- * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
- * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
- * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
- * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
- * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
- * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
+ * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
+ * @param[in] lhs_step_x lhs_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
+ * @param[in] lhs_step_y lhs_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
+ * @param[in] rhs_ptr Pointer to the RHS matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] rhs_stride_x Stride of the RHS matrix in X dimension (in bytes)
+ * @param[in] rhs_step_x rhs_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] rhs_stride_y Stride of the RHS matrix in Y dimension (in bytes)
+ * @param[in] rhs_step_y rhs_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS matrix
+ * @param[in] bias_ptr (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
+ * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
+ * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
+ * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
+ * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
+ * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
+ * @param[in] rhs_stride_z Stride of the RHS matrix in Z dimension (in bytes)
+ * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
+ * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
*/
__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
IMAGE_DECLARATION(rhs),
+#if defined(BETA)
+ IMAGE_DECLARATION(bias),
+#endif // defined(BETA)
IMAGE_DECLARATION(dst),
uint lhs_stride_z,
uint rhs_stride_z,
+#if defined(BETA)
+ uint bias_stride_z,
+#endif //defined(BETA)
uint dst_stride_z
#if defined(REINTERPRET_INPUT_AS_3D)
,
@@ -2192,8 +2206,8 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
rhs_offset += z * rhs_stride_z;
#endif // defined(MATRIX_B_DEPTH)
- REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
- REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
+ REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
+ REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
#if defined(REINTERPRET_INPUT_AS_3D)
// The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
@@ -2211,7 +2225,7 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
#endif // defined(REINTERPRET_INPUT_AS_3D)
// Initialize the accumulators
- REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
+ REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
int i = 0;
for(; i <= (K - K0); i += K0)
@@ -2229,7 +2243,7 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
// Load values from RHS matrix
- LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs);
+ LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zero);
RHS_VFMA_M0xN0(0, a, b0, c);
RHS_VFMA_M0xN0(1, a, b1, c);
@@ -2305,7 +2319,7 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
- REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
+ REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
#if defined(REINTERPRET_OUTPUT_AS_3D)
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
@@ -2323,11 +2337,40 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
#endif // defined(REINTERPRET_OUTPUT_AS_3D)
// Multiply by the weight of matrix-matrix product and store the result
- // Multiply by the weight of matrix-matrix product and store the result
#if defined(ALPHA)
SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
#endif // defined(ALPHA)
+ // Add beta*bias
+#if defined(BETA)
+#if defined(BROADCAST_BIAS)
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
+
+ LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+ SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+ // c = c + bias[broadcasted]
+ ADD_BLOCK_BROADCAST(M0, c, bias0);
+
+#else // defined(BROADCAST_BIAS)
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
+ 2) * bias_stride_z;
+
+ LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
+
+#ifndef UNIT_BETA
+ SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
+#endif // UNIT_BIAS
+
+ // c = c + bias
+ ADD_BLOCK(M0, c, bias);
+
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
// Store output block
STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);