From 49b1015f3505b920f9e4495017c99f91eed68965 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 14 Dec 2018 17:13:34 +0000 Subject: COMPMID-1710: Fixing gemm_mm_reshaped_lhs_nt_rhs_t with REINTERPRET_OUTPUT_AS_3D Change-Id: I9af1f7263c6e71e38af97f3112d35044cf60ddf0 Reviewed-on: https://review.mlplatform.org/403 Reviewed-by: Anthony Barbier Tested-by: Arm Jenkins --- src/core/CL/cl_kernels/gemm.cl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index d37dd2d2d6..44b50b3caa 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -942,13 +942,13 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes) * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix - * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p src0_ptr + * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes) * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes) * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix - * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32 + * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) @@ -1182,41 +1182,41 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0); - zout0 *= (dst_cross_plane_pad * dst_stride_z); + zout0 *= (dst_cross_plane_pad * dst_stride_y); #if M0 > 1 zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1); - zout1 *= (dst_cross_plane_pad * dst_stride_z); + zout1 *= (dst_cross_plane_pad * dst_stride_y); #endif // M0 > 1 #if M0 > 2 zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2); - zout2 *= (dst_cross_plane_pad * dst_stride_z); + zout2 *= (dst_cross_plane_pad * dst_stride_y); #endif // M0 > 2 #if M0 > 3 zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3); - zout3 *= (dst_cross_plane_pad * dst_stride_z); + zout3 *= (dst_cross_plane_pad * dst_stride_y); #endif // M0 > 3 #if M0 > 4 zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4); - zout4 *= (dst_cross_plane_pad * dst_stride_z); + zout4 *= (dst_cross_plane_pad * dst_stride_y); #endif // M0 > 4 #if M0 > 5 zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5); - zout5 *= (dst_cross_plane_pad * dst_stride_z); + zout5 *= (dst_cross_plane_pad * dst_stride_y); #endif // M0 > 5 #if M0 > 6 zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6); - zout6 *= (dst_cross_plane_pad * dst_stride_z); + zout6 *= (dst_cross_plane_pad * dst_stride_y); #endif // M0 > 6 #if M0 > 6 zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7); - zout7 *= (dst_cross_plane_pad * dst_stride_z); + zout7 *= (dst_cross_plane_pad * dst_stride_y); #endif // M0 > 7 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we -- cgit v1.2.1