aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2020-10-22 16:37:12 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2020-10-26 14:46:44 +0000
commit9ae06d4986bc3055f7786c1097b465bd321cf8eb (patch)
treeadb50e965f860893fe83e3937026056bf1f054c9 /src/core/CL/cl_kernels
parent5f91041aef3eb7373d5d2cebcbe60f279da85904 (diff)
downloadComputeLibrary-9ae06d4986bc3055f7786c1097b465bd321cf8eb.tar.gz
COMPMID-3925: Dispatch CLGEMM with no padding y requirement
- Add has_pad_y flag in GEMMKernelInfo - Skip reinterpret as 3D in CLGEMMMatrixMultiplyReshapedOnlyRHSKernel if has_pad_y = false - Add test to validate CLGEMMMatrixMultiplyReshapedOnlyRHSkernel with had_pad_y = false/true - Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel to run with has_pad_y = false/true in CLGEMM - Check if the lhs/dst tensors have pad y. If not, run CLGEMMMatrixMultiplyReshapedOnlyRHSKernel without padding requirement Change-Id: I68bb43389789736d676b899ac7c77fd9138babaf Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4248 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl20
-rw-r--r--src/core/CL/cl_kernels/gemm_helpers.h16
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl14
3 files changed, 25 insertions, 25 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 653aa5591c..fa93760847 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -1121,7 +1121,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_INPUT_AS_3D)
// The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply lhs_stride_z by DEPTH_GEMM3D
@@ -1227,7 +1227,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_OUTPUT_AS_3D)
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -1418,7 +1418,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t_texture(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_INPUT_AS_3D)
// The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply lhs_stride_z by DEPTH_GEMM3D
@@ -1573,7 +1573,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t_texture(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_OUTPUT_AS_3D)
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -1839,7 +1839,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_INPUT_AS_3D)
// The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zin, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply lhs_stride_z by DEPTH_GEMM3D
@@ -1969,7 +1969,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_OUTPUT_AS_3D)
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -2157,7 +2157,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_texture(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_INPUT_AS_3D)
// The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zin, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply lhs_stride_z by DEPTH_GEMM3D
@@ -2278,7 +2278,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_texture(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_OUTPUT_AS_3D)
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -4120,7 +4120,7 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_INPUT_AS_3D)
// The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply lhs_stride_z by DEPTH_GEMM3D
@@ -4232,7 +4232,7 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_OUTPUT_AS_3D)
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h
index 2534204f2f..54d38655a4 100644
--- a/src/core/CL/cl_kernels/gemm_helpers.h
+++ b/src/core/CL/cl_kernels/gemm_helpers.h
@@ -624,49 +624,49 @@
* @{
*/
#define CALCULATE_Z_OFFSET_1(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##0 = (0 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##0 = (0 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##0 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##0); \
Z##0 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_2(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_1(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##1 = (1 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##1 = (1 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##1 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##1); \
Z##1 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_3(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_2(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##2 = (2 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##2 = (2 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##2 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##2); \
Z##2 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_4(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_3(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##3 = (3 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##3 = (3 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##3 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##3); \
Z##3 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_5(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_4(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##4 = (4 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##4 = (4 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##4 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##4); \
Z##4 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_6(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_5(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##5 = (5 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##5 = (5 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##5 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##5); \
Z##5 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_7(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_6(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##6 = (6 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##6 = (6 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##6 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##6); \
Z##6 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_8(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_7(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##7 = (7 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##7 = (7 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##7 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##7); \
Z##7 *= (CROSS_PLANE_PAD * STRIDE_Y);
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index 950faeca0b..4a05635669 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -433,7 +433,7 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_OUTPUT_AS_3D)
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zout, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -567,7 +567,7 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_INPUT_AS_3D)
// The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zlhs, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply lhs_stride_z by DEPTH_GEMM3D
@@ -604,7 +604,7 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_OUTPUT_AS_3D)
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zout, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -781,7 +781,7 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t_fused_output_stage_fixedpoint(IMAG
#if defined(REINTERPRET_INPUT_AS_3D)
// The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zlhs, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply lhs_stride_z by DEPTH_GEMM3D
@@ -819,7 +819,7 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t_fused_output_stage_fixedpoint(IMAG
#if defined(REINTERPRET_OUTPUT_AS_3D)
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zout, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -1009,7 +1009,7 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_INPUT_AS_3D)
// The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply lhs_stride_z by DEPTH_GEMM3D
@@ -1080,7 +1080,7 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs),
#if defined(REINTERPRET_OUTPUT_AS_3D)
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
+ CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D