aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm.cl
diff options
context:
space:
mode:
authorUsama Arif <usama.arif@arm.com>2019-04-25 14:28:07 +0100
committerUsama Arif <usama.arif@arm.com>2019-05-16 10:13:33 +0000
commit0681e3bf3b2abf9a0704c3243859a60204d3565c (patch)
treeb4f7abc3094acb00a8c2021071b7d670244dc37a /src/core/CL/cl_kernels/gemm.cl
parent52c54f61b97bcedab309bfa761e193939e12e739 (diff)
downloadComputeLibrary-0681e3bf3b2abf9a0704c3243859a60204d3565c.tar.gz
COMPMID-2041: Create GEMM helper file for OpenCL.
Change-Id: I7203d7e4d5540536b5e6638c81b26a955aa70f5c Signed-off-by: Usama Arif <usama.arif@arm.com> Reviewed-on: https://review.mlplatform.org/c/1144 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl1038
1 files changed, 48 insertions, 990 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index da940082ae..c3107a20f2 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#include "helpers.h"
+#include "gemm_helpers.h"
#include "repeat.h"
#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
@@ -125,63 +125,10 @@ __kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply src_stride_z by DEPTH_GEMM3D
- // Note for the REINTERPRET_INPUT_AS_3D case
- // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
// The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
- zin0 *= (cross_plane_pad * src_stride_y);
-#if M0 > 1
- zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
- zin1 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 1
-#if M0 > 2
- zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
- zin2 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 2
-#if M0 > 3
- zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
- zin3 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 3
-#if M0 > 4
- zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
- zin4 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 4
-#if M0 > 5
- zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
- zin5 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 5
-#if M0 > 6
- zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
- zin6 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 6
-#if M0 > 7
- zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
- zin7 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 7
+ CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
#else // defined(REINTERPRET_INPUT_AS_3D)
@@ -193,79 +140,33 @@ __kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
output_ptr += z * (uint)dst_stride_z;
// ---------------------------Load input values --------------------------------
-
// Load values from the LHS matrix
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
+ LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
BOUNDARY_CONDITION_X(x, a0);
#if M0 > 1
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
BOUNDARY_CONDITION_X(x, a1);
#endif // M0 > 1
#if M0 > 2
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
BOUNDARY_CONDITION_X(x, a2);
#endif // M0 > 2
#if M0 > 3
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
BOUNDARY_CONDITION_X(x, a3);
#endif // M0 > 3
#if M0 > 4
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
BOUNDARY_CONDITION_X(x, a4);
#endif // M0 > 4
#if M0 > 5
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
BOUNDARY_CONDITION_X(x, a5);
#endif // M0 > 5
#if M0 > 6
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
BOUNDARY_CONDITION_X(x, a6);
#endif // M0 > 6
#if M0 > 7
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
BOUNDARY_CONDITION_X(x, a7);
#endif // M0 > 7
-
// ---------------------------Store output values ------------------------------
-
- VSTORE(K0)
- (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#if M0 > 1
- VSTORE(K0)
- (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 1
-#if M0 > 2
- VSTORE(K0)
- (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 2
-#if M0 > 3
- VSTORE(K0)
- (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 3
-#if M0 > 4
- VSTORE(K0)
- (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 4
-#if M0 > 5
- VSTORE(K0)
- (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 5
-#if M0 > 6
- VSTORE(K0)
- (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 6
-#if M0 > 7
- VSTORE(K0)
- (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 7
+ REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
+ STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
#undef BLOCK_SIZE
#undef OUTPUT_OFFSET_X
@@ -424,63 +325,10 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply src_stride_z by DEPTH_GEMM3D
- // Note for the REINTERPRET_INPUT_AS_3D case
- // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
// The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
- zin0 *= (cross_plane_pad * src_stride_y);
-#if M0 > 1
- zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
- zin1 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 1
-#if M0 > 2
- zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
- zin2 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 2
-#if M0 > 3
- zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
- zin3 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 3
-#if M0 > 4
- zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
- zin4 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 4
-#if M0 > 5
- zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
- zin5 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 5
-#if M0 > 6
- zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
- zin6 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 6
-#if M0 > 7
- zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
- zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
- zin7 *= (cross_plane_pad * src_stride_y);
-#endif // M0 > 7
+ CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
#else // defined(REINTERPRET_INPUT_AS_3D)
@@ -494,45 +342,29 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
// ---------------------------Load input values --------------------------------
// Load values from the LHS matrix
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
+ LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
BOUNDARY_CONDITION_X(x, a0);
#if M0 > 1
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
BOUNDARY_CONDITION_X(x, a1);
#endif // M0 > 1
#if M0 > 2
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
BOUNDARY_CONDITION_X(x, a2);
#endif // M0 > 2
#if M0 > 3
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
BOUNDARY_CONDITION_X(x, a3);
#endif // M0 > 3
#if M0 > 4
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
BOUNDARY_CONDITION_X(x, a4);
#endif // M0 > 4
#if M0 > 5
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
BOUNDARY_CONDITION_X(x, a5);
#endif // M0 > 5
#if M0 > 6
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
BOUNDARY_CONDITION_X(x, a6);
#endif // M0 > 6
#if M0 > 7
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
BOUNDARY_CONDITION_X(x, a7);
#endif // M0 > 7
-
// ---------------------------Transpose and store block -----------------------
TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
@@ -711,48 +543,8 @@ __kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
#endif // K0 > 8
// ---------------------------Store output values ------------------------------
- VSTORE(N0)
- (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#if K0 > 1
- VSTORE(N0)
- (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // K0 > 1
-#if K0 > 2
- VSTORE(N0)
- (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // K0 > 2
-#if K0 > 3
- VSTORE(N0)
- (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // K0 > 3
-#if K0 > 4
- VSTORE(N0)
- (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(N0)
- (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(N0)
- (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(N0)
- (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 4
-#if K0 > 8
- VSTORE(N0)
- (a8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(N0)
- (a9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(N0)
- (aA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(N0)
- (aB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(N0)
- (aC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(N0)
- (aD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(N0)
- (aE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(N0)
- (aF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 8
+ REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
+ STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
#undef BLOCK_SIZE
#undef OUTPUT_OFFSET_X
@@ -1079,47 +871,8 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
#endif // N0 > 2
// ---------------------------Store the output values ------------------------------
-
- VSTORE(K0)
- (res0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (res1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#if N0 > 2
- VSTORE(K0)
- (res2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 2
-#if N0 > 3
- VSTORE(K0)
- (res3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 3
-#if N0 > 4
- VSTORE(K0)
- (res4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (res5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (res6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (res7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 4
-#if N0 > 8
- VSTORE(K0)
- (res8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (res9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (resA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (resB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (resC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (resD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (resE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
- VSTORE(K0)
- (resF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 8
+ REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
+ STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
#undef BLOCK_SIZE
#undef OUTPUT_OFFSET_X
@@ -1354,63 +1107,12 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
rhs_offset += z * rhs_stride_z;
#endif // defined(MATRIX_B_DEPTH)
- REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
+ REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
+ REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
#if defined(REINTERPRET_INPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
- // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
- zin0 *= (lhs_cross_plane_pad * lhs_stride_y);
-#if M0 > 1
- zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
- zin1 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 1
-#if M0 > 2
- zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
- zin2 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 2
-#if M0 > 3
- zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
- zin3 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 3
-#if M0 > 4
- zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
- zin4 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 4
-#if M0 > 5
- zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
- zin5 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 5
-#if M0 > 6
- zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
- zin6 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 6
-#if M0 > 7
- zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
- zin7 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 7
+ // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
+ CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply lhs_stride_z by DEPTH_GEMM3D
@@ -1439,78 +1141,10 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
// 7,2 - 7,3 - 7,4 - 7,8 - 7,16
// 8,2 - 8,3 - 8,4 - 8,8 - 8,16
// Load values from LHS matrix
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
-#if M0 > 1
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
-#endif // M0 > 1
-#if M0 > 2
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
-#endif // M0 > 2
-#if M0 > 3
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
-#endif // M0 > 3
-#if M0 > 4
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
-#endif // M0 > 4
-#if M0 > 5
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
-#endif // M0 > 5
-#if M0 > 6
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
-#endif // M0 > 6
-#if M0 > 7
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
-#endif // M0 > 7
+ LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
// Load values from RHS matrix
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b0 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b1 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#if N0 > 2
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 2
-#if N0 > 3
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 3
-#if N0 > 4
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b5 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b6 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b7 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 4
-#if N0 > 8
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b8 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b9 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bA = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bB = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bC = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bD = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bE = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bF = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 8
+ LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
// Accumulate
ARM_DOT_K0XN0(K0, a0, b, c0);
@@ -1544,54 +1178,10 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
for(; i < K; ++i)
{
// Load values from LHS matrix
- DATA_TYPE a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
-#if M0 > 1
- DATA_TYPE a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
-#endif // M0 > 1
-#if M0 > 2
- DATA_TYPE a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
-#endif // M0 > 2
-#if M0 > 3
- DATA_TYPE a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
-#endif // M0 > 3
-#if M0 > 4
- DATA_TYPE a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
-#endif // M0 > 4
-#if M0 > 5
- DATA_TYPE a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
-#endif // M0 > 5
-#if M0 > 6
- DATA_TYPE a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
-#endif // M0 > 6
-#if M0 > 7
- DATA_TYPE a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
-#endif // M0 > 7
+ LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
// Load values from RHS matrix
- DATA_TYPE b0 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE b1 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#if N0 > 2
- DATA_TYPE b2 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 2
-#if N0 > 3
- DATA_TYPE b3 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 3
-#if N0 > 4
- DATA_TYPE b4 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE b5 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE b6 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE b7 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 4
-#if N0 > 8
- DATA_TYPE b8 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE b9 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE bA = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE bB = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE bC = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE bD = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE bE = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
- DATA_TYPE bF = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 8
+ LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
// Accumulate
ARM_DOT_K0XN0(1, a0, b, c0);
@@ -1626,60 +1216,9 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
#if defined(REINTERPRET_OUTPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
- zout0 *= (dst_cross_plane_pad * dst_stride_y);
-#if M0 > 1
- zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
- zout1 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 1
-#if M0 > 2
- zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
- zout2 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 2
-#if M0 > 3
- zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
- zout3 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 3
-#if M0 > 4
- zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
- zout4 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 4
-#if M0 > 5
- zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
- zout5 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 5
-#if M0 > 6
- zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
- zout6 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 6
-#if M0 > 7
- zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
- zout7 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 7
+ CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -1694,61 +1233,11 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
// Multiply by the weight of matrix-matrix product and store the result
#if defined(ALPHA)
- c0 = c0 * (DATA_TYPE)ALPHA;
-#if M0 > 1
- c1 = c1 * (DATA_TYPE)ALPHA;
-#endif // M0 > 1
-#if M0 > 2
- c2 = c2 * (DATA_TYPE)ALPHA;
-#endif // M0 > 2
-#if M0 > 3
- c3 = c3 * (DATA_TYPE)ALPHA;
-#endif // M0 > 3
-#if M0 > 4
- c4 = c4 * (DATA_TYPE)ALPHA;
-#endif // M0 > 4
-#if M0 > 5
- c5 = c5 * (DATA_TYPE)ALPHA;
-#endif // M0 > 5
-#if M0 > 6
- c6 = c6 * (DATA_TYPE)ALPHA;
-#endif // M0 > 5
-#if M0 > 7
- c7 = c7 * (DATA_TYPE)ALPHA;
-#endif // M0 > 7
+ SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
#endif // defined(ALPHA)
// Store output block
- VSTORE(N0)
- (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
-#if M0 > 1
- VSTORE(N0)
- (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
-#endif // M0 > 1
-#if M0 > 2
- VSTORE(N0)
- (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
-#endif // M0 > 2
-#if M0 > 3
- VSTORE(N0)
- (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
-#endif // M0 > 3
-#if M0 > 4
- VSTORE(N0)
- (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
-#endif // M0 > 4
-#if M0 > 5
- VSTORE(N0)
- (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
-#endif // M0 > 5
-#if M0 > 6
- VSTORE(N0)
- (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
-#endif // M0 > 6
-#if M0 > 7
- VSTORE(N0)
- (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
-#endif // M0 > 7
+ STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
#undef RHS_BLOCK_SIZE
#undef RHS_OFFSET_X
@@ -1952,60 +1441,9 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
#if defined(REINTERPRET_INPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
// The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
- zin0 *= (lhs_cross_plane_pad * lhs_stride_y);
-#if M0 > 1
- zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
- zin1 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 1
-#if M0 > 2
- zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
- zin2 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 2
-#if M0 > 3
- zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
- zin3 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 3
-#if M0 > 4
- zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
- zin4 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 4
-#if M0 > 5
- zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
- zin5 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 5
-#if M0 > 6
- zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
- zin6 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 6
-#if M0 > 7
- zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
- zin7 *= (lhs_cross_plane_pad * lhs_stride_y);
-#endif // M0 > 7
+ CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply lhs_stride_z by DEPTH_GEMM3D
@@ -2034,36 +1472,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
// 7,2 - 7,3 - 7,4 - 7,8 - 7,16
// 8,2 - 8,3 - 8,4 - 8,8 - 8,16
// Load values from LHS matrix
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
-#if M0 > 1
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
-#endif // M0 > 1
-#if M0 > 2
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
-#endif // M0 > 2
-#if M0 > 3
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
-#endif // M0 > 3
-#if M0 > 4
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
-#endif // M0 > 4
-#if M0 > 5
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
-#endif // M0 > 5
-#if M0 > 6
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
-#endif // M0 > 6
-#if M0 > 7
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
-#endif // M0 > 7
+ LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
LD_RHS_VFMA_M0xN0(0, a, c);
LD_RHS_VFMA_M0xN0(1, a, c);
@@ -2140,60 +1549,8 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
#if defined(REINTERPRET_OUTPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
-
// The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
- zout0 *= (dst_cross_plane_pad * dst_stride_y);
-#if M0 > 1
- zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
- zout1 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 1
-#if M0 > 2
- zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
- zout2 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 2
-#if M0 > 3
- zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
- zout3 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 3
-#if M0 > 4
- zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
- zout4 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 4
-#if M0 > 5
- zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
- zout5 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 5
-#if M0 > 6
- zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
- zout6 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 6
-#if M0 > 7
- zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
- zout7 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 7
+ CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
@@ -2208,61 +1565,11 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
// Multiply by the weight of matrix-matrix product and store the result
#if defined(ALPHA)
- c0 = c0 * (DATA_TYPE)ALPHA;
-#if M0 > 1
- c1 = c1 * (DATA_TYPE)ALPHA;
-#endif // M0 > 1
-#if M0 > 2
- c2 = c2 * (DATA_TYPE)ALPHA;
-#endif // M0 > 2
-#if M0 > 3
- c3 = c3 * (DATA_TYPE)ALPHA;
-#endif // M0 > 3
-#if M0 > 4
- c4 = c4 * (DATA_TYPE)ALPHA;
-#endif // M0 > 4
-#if M0 > 5
- c5 = c5 * (DATA_TYPE)ALPHA;
-#endif // M0 > 5
-#if M0 > 6
- c6 = c6 * (DATA_TYPE)ALPHA;
-#endif // M0 > 5
-#if M0 > 7
- c7 = c7 * (DATA_TYPE)ALPHA;
-#endif // M0 > 7
+ SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
#endif // defined(ALPHA)
// Store output block
- VSTORE(N0)
- (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
-#if M0 > 1
- VSTORE(N0)
- (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
-#endif // M0 > 1
-#if M0 > 2
- VSTORE(N0)
- (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
-#endif // M0 > 2
-#if M0 > 3
- VSTORE(N0)
- (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
-#endif // M0 > 3
-#if M0 > 4
- VSTORE(N0)
- (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
-#endif // M0 > 4
-#if M0 > 5
- VSTORE(N0)
- (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
-#endif // M0 > 5
-#if M0 > 6
- VSTORE(N0)
- (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
-#endif // M0 > 6
-#if M0 > 7
- VSTORE(N0)
- (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
-#endif // M0 > 7
+ STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
#undef RHS_BLOCK_SIZE
#undef RHS_OFFSET_X
@@ -2498,6 +1805,9 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
// Initialize the accumulators
REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
+ REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
+ REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
+
for(int i = 0; i < k; i += K0)
{
// Supported cases (M0, K0):
@@ -2510,78 +1820,10 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
// 7,2 - 7,3 - 7,4 - 7,8 - 7,16
// 8,2 - 8,3 - 8,4 - 8,8 - 8,16
// Load values from LHS matrix
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 0 * LHS_STEP_X * sizeof(DATA_TYPE)));
-#if M0 > 1
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 1 * LHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 1
-#if M0 > 2
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 2 * LHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 2
-#if M0 > 3
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 3 * LHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 3
-#if M0 > 4
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 4 * LHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 4
-#if M0 > 5
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 5 * LHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 5
-#if M0 > 6
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 6 * LHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 6
-#if M0 > 7
- VEC_DATA_TYPE(DATA_TYPE, K0)
- a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 7 * LHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // M0 > 7
+ LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
// Load values from RHS matrix
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b0 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b1 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#if N0 > 2
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 2
-#if N0 > 3
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 3
-#if N0 > 4
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b5 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b6 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b7 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 4
-#if N0 > 8
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b8 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- b9 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bA = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bB = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bC = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bD = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bE = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
- VEC_DATA_TYPE(DATA_TYPE, K0)
- bF = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
-#endif // N0 > 8
+ LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
// Accumulate
ARM_DOT_K0XN0(a0, b, c0);
@@ -2616,61 +1858,9 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
#if defined(REINTERPRET_OUTPUT_AS_3D)
- // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
- // in order to take into account the presence of possible cross plane paddings
- //
- // | |
- // | plane0 |
- // | |
- // |__________________|
- // |******************|
- // | cross_plane_pad |
- // |******************|
- // | |
- // | plane1 |
- // | |
- // |__________________|
// The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
- zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
- zout0 *= (dst_cross_plane_pad * dst_stride_y);
-#if M0 > 1
- zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
- zout1 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 1
-#if M0 > 2
- zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
- zout2 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 2
-#if M0 > 3
- zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
- zout3 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 3
-#if M0 > 4
- zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
- zout4 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 4
-#if M0 > 5
- zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
- zout5 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 5
-#if M0 > 6
- zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
- zout6 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 6
-#if M0 > 7
- zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
- zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
- zout7 *= (dst_cross_plane_pad * dst_stride_y);
-#endif // M0 > 7
-
+ CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
@@ -2684,62 +1874,11 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
// Multiply by the weight of matrix-matrix product and store the result
#if defined(ALPHA)
- c0 = c0 * (DATA_TYPE)ALPHA;
-#if M0 > 1
- c1 = c1 * (DATA_TYPE)ALPHA;
-#endif // M0 > 1
-#if M0 > 2
- c2 = c2 * (DATA_TYPE)ALPHA;
-#endif // M0 > 2
-#if M0 > 3
- c3 = c3 * (DATA_TYPE)ALPHA;
-#endif // M0 > 3
-#if M0 > 4
- c4 = c4 * (DATA_TYPE)ALPHA;
-#endif // M0 > 4
-#if M0 > 5
- c5 = c5 * (DATA_TYPE)ALPHA;
-#endif // M0 > 5
-#if M0 > 6
- c6 = c6 * (DATA_TYPE)ALPHA;
-#endif // M0 > 5
-#if M0 > 7
- c7 = c7 * (DATA_TYPE)ALPHA;
-#endif // M0 > 7
+ SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
#endif // defined(ALPHA)
// Store output block
- VSTORE(N0)
- (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
-#if M0 > 1
- VSTORE(N0)
- (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
-#endif // M0 > 1
-#if M0 > 2
- VSTORE(N0)
- (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
-#endif // M0 > 2
-#if M0 > 3
- VSTORE(N0)
- (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
-#endif // M0 > 3
-#if M0 > 4
- VSTORE(N0)
- (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
-#endif // M0 > 4
-#if M0 > 5
- VSTORE(N0)
- (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
-#endif // M0 > 5
-#if M0 > 6
- VSTORE(N0)
- (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
-#endif // M0 > 6
-#if M0 > 7
- VSTORE(N0)
- (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
-#endif // M0 > 7
-
+ STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
#undef LHS_BLOCK_SIZE
#undef LHS_OFFSET_X
#undef LHS_STEP_X
@@ -2892,14 +2031,8 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
input_ptr += z * src_stride_z * DEPTH_GEMM3D;
// Load values from Matrix A
- VEC_DATA_TYPE(DATA_TYPE, 4)
- a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
- VEC_DATA_TYPE(DATA_TYPE, 4)
- a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
- VEC_DATA_TYPE(DATA_TYPE, 4)
- a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
- VEC_DATA_TYPE(DATA_TYPE, 4)
- a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
+ LOAD_BLOCK(4, 4, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin.s);
+
#else // defined(REINTERPRET_INPUT_AS_3D)
__global uchar *input_ptr = src.ptr;
@@ -4313,21 +3446,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
{
#if defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
- VEC_DATA_TYPE(DATA_TYPE, 2)
- a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- VEC_DATA_TYPE(DATA_TYPE, 2)
- a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- VEC_DATA_TYPE(DATA_TYPE, 2)
- a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- VEC_DATA_TYPE(DATA_TYPE, 2)
- a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#else // defined(REINTERPRET_INPUT_AS_3D)
+ LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
VEC_DATA_TYPE(DATA_TYPE, 2)
a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
@@ -4480,21 +3600,7 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
// Store output block
- VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
- (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
- (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
- (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
- (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
+ STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -4671,17 +3777,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
{
#if defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A and matrix B
- float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#else // defined(REINTERPRET_INPUT_AS_3D)
+ LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A and matrix B
float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -5566,17 +4663,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
{
#if defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
- half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#else // defined(REINTERPRET_INPUT_AS_3D)
+ LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -5778,19 +4866,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
// Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
// multiply dst_stride_z by DEPTH_GEMM3D
dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
-
// Store the output block
- vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
+ STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, hacc, dst_addr, dst_stride_y, zout.s);
#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;
@@ -5945,17 +5022,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
{
#if defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
- half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#else // defined(REINTERPRET_INPUT_AS_3D)
+ LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
+#else // defined(REINTERPRET_INPUT_AS_3D)
// Load values from matrix A
half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -6143,17 +5211,7 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
// Store the output block
- vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-
+ STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
#else // defined(REINTERPRET_OUTPUT_AS_3D)
// Add offset for batched GEMM
dst_addr += z * dst_stride_z;