aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm.cl
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-12-14 17:13:34 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2018-12-17 11:22:04 +0000
commit49b1015f3505b920f9e4495017c99f91eed68965 (patch)
tree1678add8293a4077f3f49f86074be5365054c28e /src/core/CL/cl_kernels/gemm.cl
parent555f1c2241d6fa8c84926a72a0c54e4158817df4 (diff)
downloadComputeLibrary-49b1015f3505b920f9e4495017c99f91eed68965.tar.gz
COMPMID-1710: Fixing gemm_mm_reshaped_lhs_nt_rhs_t with REINTERPRET_OUTPUT_AS_3D
Change-Id: I9af1f7263c6e71e38af97f3112d35044cf60ddf0 Reviewed-on: https://review.mlplatform.org/403 Reviewed-by: Anthony Barbier <Anthony.barbier@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl20
1 files changed, 10 insertions, 10 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index d37dd2d2d6..44b50b3caa 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -942,13 +942,13 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
* @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
* @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
- * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p src0_ptr
+ * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
* @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
* @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
* @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
@@ -1182,41 +1182,41 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
// The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
- zout0 *= (dst_cross_plane_pad * dst_stride_z);
+ zout0 *= (dst_cross_plane_pad * dst_stride_y);
#if M0 > 1
zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
- zout1 *= (dst_cross_plane_pad * dst_stride_z);
+ zout1 *= (dst_cross_plane_pad * dst_stride_y);
#endif // M0 > 1
#if M0 > 2
zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
- zout2 *= (dst_cross_plane_pad * dst_stride_z);
+ zout2 *= (dst_cross_plane_pad * dst_stride_y);
#endif // M0 > 2
#if M0 > 3
zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
- zout3 *= (dst_cross_plane_pad * dst_stride_z);
+ zout3 *= (dst_cross_plane_pad * dst_stride_y);
#endif // M0 > 3
#if M0 > 4
zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
- zout4 *= (dst_cross_plane_pad * dst_stride_z);
+ zout4 *= (dst_cross_plane_pad * dst_stride_y);
#endif // M0 > 4
#if M0 > 5
zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
- zout5 *= (dst_cross_plane_pad * dst_stride_z);
+ zout5 *= (dst_cross_plane_pad * dst_stride_y);
#endif // M0 > 5
#if M0 > 6
zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
- zout6 *= (dst_cross_plane_pad * dst_stride_z);
+ zout6 *= (dst_cross_plane_pad * dst_stride_y);
#endif // M0 > 6
#if M0 > 6
zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
- zout7 *= (dst_cross_plane_pad * dst_stride_z);
+ zout7 *= (dst_cross_plane_pad * dst_stride_y);
#endif // M0 > 7
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we