diff options
author | Usama Arif <usama.arif@arm.com> | 2019-04-25 14:28:07 +0100 |
---|---|---|
committer | Usama Arif <usama.arif@arm.com> | 2019-05-16 10:13:33 +0000 |
commit | 0681e3bf3b2abf9a0704c3243859a60204d3565c (patch) | |
tree | b4f7abc3094acb00a8c2021071b7d670244dc37a /src/core/CL/cl_kernels/gemm.cl | |
parent | 52c54f61b97bcedab309bfa761e193939e12e739 (diff) | |
download | ComputeLibrary-0681e3bf3b2abf9a0704c3243859a60204d3565c.tar.gz |
COMPMID-2041: Create GEMM helper file for OpenCL.
Change-Id: I7203d7e4d5540536b5e6638c81b26a955aa70f5c
Signed-off-by: Usama Arif <usama.arif@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1144
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r-- | src/core/CL/cl_kernels/gemm.cl | 1038 |
1 files changed, 48 insertions, 990 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index da940082ae..c3107a20f2 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#include "helpers.h" +#include "gemm_helpers.h" #include "repeat.h" #if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH) @@ -125,63 +125,10 @@ __kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply src_stride_z by DEPTH_GEMM3D - // Note for the REINTERPRET_INPUT_AS_3D case - // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D; // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D - zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0); - zin0 *= (cross_plane_pad * src_stride_y); -#if M0 > 1 - zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1); - zin1 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 1 -#if M0 > 2 - zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2); - zin2 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 2 -#if M0 > 3 - zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3); - zin3 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 3 -#if M0 > 4 - zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4); - zin4 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 4 -#if M0 > 5 - zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5); - zin5 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 5 -#if M0 > 6 - zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6); - zin6 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 6 -#if M0 > 7 - zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7); - zin7 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 7 + CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y); #else // defined(REINTERPRET_INPUT_AS_3D) @@ -193,79 +140,33 @@ __kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src), output_ptr += z * (uint)dst_stride_z; // ---------------------------Load input values -------------------------------- - // Load values from the LHS matrix - VEC_DATA_TYPE(DATA_TYPE, K0) - a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0)); + LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); BOUNDARY_CONDITION_X(x, a0); #if M0 > 1 - VEC_DATA_TYPE(DATA_TYPE, K0) - a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1)); BOUNDARY_CONDITION_X(x, a1); #endif // M0 > 1 #if M0 > 2 - VEC_DATA_TYPE(DATA_TYPE, K0) - a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2)); BOUNDARY_CONDITION_X(x, a2); #endif // M0 > 2 #if M0 > 3 - VEC_DATA_TYPE(DATA_TYPE, K0) - a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3)); BOUNDARY_CONDITION_X(x, a3); #endif // M0 > 3 #if M0 > 4 - VEC_DATA_TYPE(DATA_TYPE, K0) - a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4)); BOUNDARY_CONDITION_X(x, a4); #endif // M0 > 4 #if M0 > 5 - VEC_DATA_TYPE(DATA_TYPE, K0) - a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5)); BOUNDARY_CONDITION_X(x, a5); #endif // M0 > 5 #if M0 > 6 - VEC_DATA_TYPE(DATA_TYPE, K0) - a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6)); BOUNDARY_CONDITION_X(x, a6); #endif // M0 > 6 #if M0 > 7 - VEC_DATA_TYPE(DATA_TYPE, K0) - a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7)); BOUNDARY_CONDITION_X(x, a7); #endif // M0 > 7 - // ---------------------------Store output values ------------------------------ - - VSTORE(K0) - (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#if M0 > 1 - VSTORE(K0) - (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 1 -#if M0 > 2 - VSTORE(K0) - (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 2 -#if M0 > 3 - VSTORE(K0) - (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 3 -#if M0 > 4 - VSTORE(K0) - (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 4 -#if M0 > 5 - VSTORE(K0) - (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 5 -#if M0 > 6 - VSTORE(K0) - (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 6 -#if M0 > 7 - VSTORE(K0) - (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 7 + REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0); + STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout); #undef BLOCK_SIZE #undef OUTPUT_OFFSET_X @@ -424,63 +325,10 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply src_stride_z by DEPTH_GEMM3D - // Note for the REINTERPRET_INPUT_AS_3D case - // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D; // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D - zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0); - zin0 *= (cross_plane_pad * src_stride_y); -#if M0 > 1 - zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1); - zin1 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 1 -#if M0 > 2 - zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2); - zin2 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 2 -#if M0 > 3 - zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3); - zin3 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 3 -#if M0 > 4 - zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4); - zin4 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 4 -#if M0 > 5 - zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5); - zin5 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 5 -#if M0 > 6 - zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6); - zin6 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 6 -#if M0 > 7 - zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D; - zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7); - zin7 *= (cross_plane_pad * src_stride_y); -#endif // M0 > 7 + CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y); #else // defined(REINTERPRET_INPUT_AS_3D) @@ -494,45 +342,29 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src), // ---------------------------Load input values -------------------------------- // Load values from the LHS matrix - VEC_DATA_TYPE(DATA_TYPE, K0) - a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0)); + LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); BOUNDARY_CONDITION_X(x, a0); #if M0 > 1 - VEC_DATA_TYPE(DATA_TYPE, K0) - a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1)); BOUNDARY_CONDITION_X(x, a1); #endif // M0 > 1 #if M0 > 2 - VEC_DATA_TYPE(DATA_TYPE, K0) - a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2)); BOUNDARY_CONDITION_X(x, a2); #endif // M0 > 2 #if M0 > 3 - VEC_DATA_TYPE(DATA_TYPE, K0) - a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3)); BOUNDARY_CONDITION_X(x, a3); #endif // M0 > 3 #if M0 > 4 - VEC_DATA_TYPE(DATA_TYPE, K0) - a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4)); BOUNDARY_CONDITION_X(x, a4); #endif // M0 > 4 #if M0 > 5 - VEC_DATA_TYPE(DATA_TYPE, K0) - a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5)); BOUNDARY_CONDITION_X(x, a5); #endif // M0 > 5 #if M0 > 6 - VEC_DATA_TYPE(DATA_TYPE, K0) - a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6)); BOUNDARY_CONDITION_X(x, a6); #endif // M0 > 6 #if M0 > 7 - VEC_DATA_TYPE(DATA_TYPE, K0) - a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7)); BOUNDARY_CONDITION_X(x, a7); #endif // M0 > 7 - // ---------------------------Transpose and store block ----------------------- TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0); @@ -711,48 +543,8 @@ __kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src), #endif // K0 > 8 // ---------------------------Store output values ------------------------------ - VSTORE(N0) - (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#if K0 > 1 - VSTORE(N0) - (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // K0 > 1 -#if K0 > 2 - VSTORE(N0) - (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // K0 > 2 -#if K0 > 3 - VSTORE(N0) - (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // K0 > 3 -#if K0 > 4 - VSTORE(N0) - (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(N0) - (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(N0) - (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(N0) - (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 4 -#if K0 > 8 - VSTORE(N0) - (a8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(N0) - (a9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(N0) - (aA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(N0) - (aB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(N0) - (aC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(N0) - (aD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(N0) - (aE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(N0) - (aF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 8 + REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0); + STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout); #undef BLOCK_SIZE #undef OUTPUT_OFFSET_X @@ -1079,47 +871,8 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), #endif // N0 > 2 // ---------------------------Store the output values ------------------------------ - - VSTORE(K0) - (res0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (res1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#if N0 > 2 - VSTORE(K0) - (res2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 2 -#if N0 > 3 - VSTORE(K0) - (res3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 3 -#if N0 > 4 - VSTORE(K0) - (res4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (res5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (res6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (res7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 4 -#if N0 > 8 - VSTORE(K0) - (res8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (res9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (resA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (resB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (resC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (resD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (resE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); - VSTORE(K0) - (resF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 8 + REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0); + STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout); #undef BLOCK_SIZE #undef OUTPUT_OFFSET_X @@ -1354,63 +1107,12 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), rhs_offset += z * rhs_stride_z; #endif // defined(MATRIX_B_DEPTH) - REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; + REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0; + REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0); #if defined(REINTERPRET_INPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - - // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D - zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0); - zin0 *= (lhs_cross_plane_pad * lhs_stride_y); -#if M0 > 1 - zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1); - zin1 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 1 -#if M0 > 2 - zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2); - zin2 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 2 -#if M0 > 3 - zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3); - zin3 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 3 -#if M0 > 4 - zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4); - zin4 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 4 -#if M0 > 5 - zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5); - zin5 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 5 -#if M0 > 6 - zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6); - zin6 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 6 -#if M0 > 7 - zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7); - zin7 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 7 + // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D + CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply lhs_stride_z by DEPTH_GEMM3D @@ -1439,78 +1141,10 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), // 7,2 - 7,3 - 7,4 - 7,8 - 7,16 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16 // Load values from LHS matrix - VEC_DATA_TYPE(DATA_TYPE, K0) - a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0)); -#if M0 > 1 - VEC_DATA_TYPE(DATA_TYPE, K0) - a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1)); -#endif // M0 > 1 -#if M0 > 2 - VEC_DATA_TYPE(DATA_TYPE, K0) - a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2)); -#endif // M0 > 2 -#if M0 > 3 - VEC_DATA_TYPE(DATA_TYPE, K0) - a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3)); -#endif // M0 > 3 -#if M0 > 4 - VEC_DATA_TYPE(DATA_TYPE, K0) - a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4)); -#endif // M0 > 4 -#if M0 > 5 - VEC_DATA_TYPE(DATA_TYPE, K0) - a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5)); -#endif // M0 > 5 -#if M0 > 6 - VEC_DATA_TYPE(DATA_TYPE, K0) - a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6)); -#endif // M0 > 6 -#if M0 > 7 - VEC_DATA_TYPE(DATA_TYPE, K0) - a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7)); -#endif // M0 > 7 + LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); // Load values from RHS matrix - VEC_DATA_TYPE(DATA_TYPE, K0) - b0 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - b1 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE))); -#if N0 > 2 - VEC_DATA_TYPE(DATA_TYPE, K0) - b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 2 -#if N0 > 3 - VEC_DATA_TYPE(DATA_TYPE, K0) - b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 3 -#if N0 > 4 - VEC_DATA_TYPE(DATA_TYPE, K0) - b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - b5 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - b6 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - b7 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 4 -#if N0 > 8 - VEC_DATA_TYPE(DATA_TYPE, K0) - b8 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - b9 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bA = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bB = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bC = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bD = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bE = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bF = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 8 + LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs); // Accumulate ARM_DOT_K0XN0(K0, a0, b, c0); @@ -1544,54 +1178,10 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), for(; i < K; ++i) { // Load values from LHS matrix - DATA_TYPE a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0)); -#if M0 > 1 - DATA_TYPE a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1)); -#endif // M0 > 1 -#if M0 > 2 - DATA_TYPE a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2)); -#endif // M0 > 2 -#if M0 > 3 - DATA_TYPE a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3)); -#endif // M0 > 3 -#if M0 > 4 - DATA_TYPE a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4)); -#endif // M0 > 4 -#if M0 > 5 - DATA_TYPE a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5)); -#endif // M0 > 5 -#if M0 > 6 - DATA_TYPE a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6)); -#endif // M0 > 6 -#if M0 > 7 - DATA_TYPE a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7)); -#endif // M0 > 7 + LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); // Load values from RHS matrix - DATA_TYPE b0 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE b1 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE))); -#if N0 > 2 - DATA_TYPE b2 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 2 -#if N0 > 3 - DATA_TYPE b3 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 3 -#if N0 > 4 - DATA_TYPE b4 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE b5 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE b6 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE b7 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 4 -#if N0 > 8 - DATA_TYPE b8 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE b9 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE bA = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE bB = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE bC = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE bD = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE bE = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE))); - DATA_TYPE bF = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 8 + LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs); // Accumulate ARM_DOT_K0XN0(1, a0, b, c0); @@ -1626,60 +1216,9 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; #if defined(REINTERPRET_OUTPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D - zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0); - zout0 *= (dst_cross_plane_pad * dst_stride_y); -#if M0 > 1 - zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1); - zout1 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 1 -#if M0 > 2 - zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2); - zout2 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 2 -#if M0 > 3 - zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3); - zout3 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 3 -#if M0 > 4 - zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4); - zout4 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 4 -#if M0 > 5 - zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5); - zout5 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 5 -#if M0 > 6 - zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6); - zout6 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 6 -#if M0 > 7 - zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7); - zout7 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 7 + CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D @@ -1694,61 +1233,11 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), // Multiply by the weight of matrix-matrix product and store the result #if defined(ALPHA) - c0 = c0 * (DATA_TYPE)ALPHA; -#if M0 > 1 - c1 = c1 * (DATA_TYPE)ALPHA; -#endif // M0 > 1 -#if M0 > 2 - c2 = c2 * (DATA_TYPE)ALPHA; -#endif // M0 > 2 -#if M0 > 3 - c3 = c3 * (DATA_TYPE)ALPHA; -#endif // M0 > 3 -#if M0 > 4 - c4 = c4 * (DATA_TYPE)ALPHA; -#endif // M0 > 4 -#if M0 > 5 - c5 = c5 * (DATA_TYPE)ALPHA; -#endif // M0 > 5 -#if M0 > 6 - c6 = c6 * (DATA_TYPE)ALPHA; -#endif // M0 > 5 -#if M0 > 7 - c7 = c7 * (DATA_TYPE)ALPHA; -#endif // M0 > 7 + SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA); #endif // defined(ALPHA) // Store output block - VSTORE(N0) - (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0)); -#if M0 > 1 - VSTORE(N0) - (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1)); -#endif // M0 > 1 -#if M0 > 2 - VSTORE(N0) - (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2)); -#endif // M0 > 2 -#if M0 > 3 - VSTORE(N0) - (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3)); -#endif // M0 > 3 -#if M0 > 4 - VSTORE(N0) - (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4)); -#endif // M0 > 4 -#if M0 > 5 - VSTORE(N0) - (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5)); -#endif // M0 > 5 -#if M0 > 6 - VSTORE(N0) - (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6)); -#endif // M0 > 6 -#if M0 > 7 - VSTORE(N0) - (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7)); -#endif // M0 > 7 + STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout); #undef RHS_BLOCK_SIZE #undef RHS_OFFSET_X @@ -1952,60 +1441,9 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; #if defined(REINTERPRET_INPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D - zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0); - zin0 *= (lhs_cross_plane_pad * lhs_stride_y); -#if M0 > 1 - zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1); - zin1 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 1 -#if M0 > 2 - zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2); - zin2 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 2 -#if M0 > 3 - zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3); - zin3 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 3 -#if M0 > 4 - zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4); - zin4 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 4 -#if M0 > 5 - zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5); - zin5 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 5 -#if M0 > 6 - zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6); - zin6 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 6 -#if M0 > 7 - zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7); - zin7 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 7 + CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply lhs_stride_z by DEPTH_GEMM3D @@ -2034,36 +1472,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), // 7,2 - 7,3 - 7,4 - 7,8 - 7,16 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16 // Load values from LHS matrix - VEC_DATA_TYPE(DATA_TYPE, K0) - a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0)); -#if M0 > 1 - VEC_DATA_TYPE(DATA_TYPE, K0) - a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1)); -#endif // M0 > 1 -#if M0 > 2 - VEC_DATA_TYPE(DATA_TYPE, K0) - a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2)); -#endif // M0 > 2 -#if M0 > 3 - VEC_DATA_TYPE(DATA_TYPE, K0) - a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3)); -#endif // M0 > 3 -#if M0 > 4 - VEC_DATA_TYPE(DATA_TYPE, K0) - a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4)); -#endif // M0 > 4 -#if M0 > 5 - VEC_DATA_TYPE(DATA_TYPE, K0) - a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5)); -#endif // M0 > 5 -#if M0 > 6 - VEC_DATA_TYPE(DATA_TYPE, K0) - a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6)); -#endif // M0 > 6 -#if M0 > 7 - VEC_DATA_TYPE(DATA_TYPE, K0) - a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7)); -#endif // M0 > 7 + LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin); LD_RHS_VFMA_M0xN0(0, a, c); LD_RHS_VFMA_M0xN0(1, a, c); @@ -2140,60 +1549,8 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; #if defined(REINTERPRET_OUTPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D - zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0); - zout0 *= (dst_cross_plane_pad * dst_stride_y); -#if M0 > 1 - zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1); - zout1 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 1 -#if M0 > 2 - zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2); - zout2 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 2 -#if M0 > 3 - zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3); - zout3 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 3 -#if M0 > 4 - zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4); - zout4 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 4 -#if M0 > 5 - zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5); - zout5 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 5 -#if M0 > 6 - zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6); - zout6 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 6 -#if M0 > 7 - zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7); - zout7 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 7 + CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D @@ -2208,61 +1565,11 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), // Multiply by the weight of matrix-matrix product and store the result #if defined(ALPHA) - c0 = c0 * (DATA_TYPE)ALPHA; -#if M0 > 1 - c1 = c1 * (DATA_TYPE)ALPHA; -#endif // M0 > 1 -#if M0 > 2 - c2 = c2 * (DATA_TYPE)ALPHA; -#endif // M0 > 2 -#if M0 > 3 - c3 = c3 * (DATA_TYPE)ALPHA; -#endif // M0 > 3 -#if M0 > 4 - c4 = c4 * (DATA_TYPE)ALPHA; -#endif // M0 > 4 -#if M0 > 5 - c5 = c5 * (DATA_TYPE)ALPHA; -#endif // M0 > 5 -#if M0 > 6 - c6 = c6 * (DATA_TYPE)ALPHA; -#endif // M0 > 5 -#if M0 > 7 - c7 = c7 * (DATA_TYPE)ALPHA; -#endif // M0 > 7 + SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA); #endif // defined(ALPHA) // Store output block - VSTORE(N0) - (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0)); -#if M0 > 1 - VSTORE(N0) - (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1)); -#endif // M0 > 1 -#if M0 > 2 - VSTORE(N0) - (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2)); -#endif // M0 > 2 -#if M0 > 3 - VSTORE(N0) - (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3)); -#endif // M0 > 3 -#if M0 > 4 - VSTORE(N0) - (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4)); -#endif // M0 > 4 -#if M0 > 5 - VSTORE(N0) - (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5)); -#endif // M0 > 5 -#if M0 > 6 - VSTORE(N0) - (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6)); -#endif // M0 > 6 -#if M0 > 7 - VSTORE(N0) - (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7)); -#endif // M0 > 7 + STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout); #undef RHS_BLOCK_SIZE #undef RHS_OFFSET_X @@ -2498,6 +1805,9 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), // Initialize the accumulators REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; + REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0; + REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0); + for(int i = 0; i < k; i += K0) { // Supported cases (M0, K0): @@ -2510,78 +1820,10 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), // 7,2 - 7,3 - 7,4 - 7,8 - 7,16 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16 // Load values from LHS matrix - VEC_DATA_TYPE(DATA_TYPE, K0) - a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 0 * LHS_STEP_X * sizeof(DATA_TYPE))); -#if M0 > 1 - VEC_DATA_TYPE(DATA_TYPE, K0) - a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 1 * LHS_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 1 -#if M0 > 2 - VEC_DATA_TYPE(DATA_TYPE, K0) - a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 2 * LHS_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 2 -#if M0 > 3 - VEC_DATA_TYPE(DATA_TYPE, K0) - a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 3 * LHS_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 3 -#if M0 > 4 - VEC_DATA_TYPE(DATA_TYPE, K0) - a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 4 * LHS_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 4 -#if M0 > 5 - VEC_DATA_TYPE(DATA_TYPE, K0) - a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 5 * LHS_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 5 -#if M0 > 6 - VEC_DATA_TYPE(DATA_TYPE, K0) - a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 6 * LHS_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 6 -#if M0 > 7 - VEC_DATA_TYPE(DATA_TYPE, K0) - a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 7 * LHS_STEP_X * sizeof(DATA_TYPE))); -#endif // M0 > 7 + LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs); // Load values from RHS matrix - VEC_DATA_TYPE(DATA_TYPE, K0) - b0 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 0 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - b1 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 1 * RHS_STEP_X * sizeof(DATA_TYPE))); -#if N0 > 2 - VEC_DATA_TYPE(DATA_TYPE, K0) - b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 2 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 2 -#if N0 > 3 - VEC_DATA_TYPE(DATA_TYPE, K0) - b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 3 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 3 -#if N0 > 4 - VEC_DATA_TYPE(DATA_TYPE, K0) - b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 4 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - b5 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 5 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - b6 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 6 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - b7 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 7 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 4 -#if N0 > 8 - VEC_DATA_TYPE(DATA_TYPE, K0) - b8 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 8 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - b9 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 9 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bA = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 10 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bB = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 11 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bC = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 12 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bD = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 13 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bE = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 14 * RHS_STEP_X * sizeof(DATA_TYPE))); - VEC_DATA_TYPE(DATA_TYPE, K0) - bF = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 15 * RHS_STEP_X * sizeof(DATA_TYPE))); -#endif // N0 > 8 + LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zrhs); // Accumulate ARM_DOT_K0XN0(a0, b, c0); @@ -2616,61 +1858,9 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; #if defined(REINTERPRET_OUTPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D - zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0); - zout0 *= (dst_cross_plane_pad * dst_stride_y); -#if M0 > 1 - zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1); - zout1 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 1 -#if M0 > 2 - zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2); - zout2 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 2 -#if M0 > 3 - zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3); - zout3 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 3 -#if M0 > 4 - zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4); - zout4 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 4 -#if M0 > 5 - zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5); - zout5 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 5 -#if M0 > 6 - zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6); - zout6 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 6 -#if M0 > 7 - zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7); - zout7 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 7 - + CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D; @@ -2684,62 +1874,11 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), // Multiply by the weight of matrix-matrix product and store the result #if defined(ALPHA) - c0 = c0 * (DATA_TYPE)ALPHA; -#if M0 > 1 - c1 = c1 * (DATA_TYPE)ALPHA; -#endif // M0 > 1 -#if M0 > 2 - c2 = c2 * (DATA_TYPE)ALPHA; -#endif // M0 > 2 -#if M0 > 3 - c3 = c3 * (DATA_TYPE)ALPHA; -#endif // M0 > 3 -#if M0 > 4 - c4 = c4 * (DATA_TYPE)ALPHA; -#endif // M0 > 4 -#if M0 > 5 - c5 = c5 * (DATA_TYPE)ALPHA; -#endif // M0 > 5 -#if M0 > 6 - c6 = c6 * (DATA_TYPE)ALPHA; -#endif // M0 > 5 -#if M0 > 7 - c7 = c7 * (DATA_TYPE)ALPHA; -#endif // M0 > 7 + SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA); #endif // defined(ALPHA) // Store output block - VSTORE(N0) - (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0)); -#if M0 > 1 - VSTORE(N0) - (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1)); -#endif // M0 > 1 -#if M0 > 2 - VSTORE(N0) - (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2)); -#endif // M0 > 2 -#if M0 > 3 - VSTORE(N0) - (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3)); -#endif // M0 > 3 -#if M0 > 4 - VSTORE(N0) - (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4)); -#endif // M0 > 4 -#if M0 > 5 - VSTORE(N0) - (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5)); -#endif // M0 > 5 -#if M0 > 6 - VSTORE(N0) - (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6)); -#endif // M0 > 6 -#if M0 > 7 - VSTORE(N0) - (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7)); -#endif // M0 > 7 - + STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout); #undef LHS_BLOCK_SIZE #undef LHS_OFFSET_X #undef LHS_STEP_X @@ -2892,14 +2031,8 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src), input_ptr += z * src_stride_z * DEPTH_GEMM3D; // Load values from Matrix A - VEC_DATA_TYPE(DATA_TYPE, 4) - a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0)); - VEC_DATA_TYPE(DATA_TYPE, 4) - a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1)); - VEC_DATA_TYPE(DATA_TYPE, 4) - a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2)); - VEC_DATA_TYPE(DATA_TYPE, 4) - a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3)); + LOAD_BLOCK(4, 4, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin.s); + #else // defined(REINTERPRET_INPUT_AS_3D) __global uchar *input_ptr = src.ptr; @@ -4313,21 +3446,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), { #if defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A - VEC_DATA_TYPE(DATA_TYPE, 2) - a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - VEC_DATA_TYPE(DATA_TYPE, 2) - a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - VEC_DATA_TYPE(DATA_TYPE, 2) - a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - VEC_DATA_TYPE(DATA_TYPE, 2) - a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#else // defined(REINTERPRET_INPUT_AS_3D) + LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s); +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A VEC_DATA_TYPE(DATA_TYPE, 2) a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); @@ -4480,21 +3600,7 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), dst_addr += z * dst_stride_z * DEPTH_GEMM3D; // Store output block - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - + STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s); #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; @@ -4671,17 +3777,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), { #if defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A and matrix B - float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#else // defined(REINTERPRET_INPUT_AS_3D) + LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s); +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A and matrix B float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -5566,17 +4663,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), { #if defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A - half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#else // defined(REINTERPRET_INPUT_AS_3D) + LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s); +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -5778,19 +4866,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - // Store the output block - vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - + STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, hacc, dst_addr, dst_stride_y, zout.s); #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; @@ -5945,17 +5022,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), { #if defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A - half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#else // defined(REINTERPRET_INPUT_AS_3D) + LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s); +#else // defined(REINTERPRET_INPUT_AS_3D) // Load values from matrix A half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 @@ -6143,17 +5211,7 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), dst_addr += z * dst_stride_z * DEPTH_GEMM3D; // Store the output block - vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - + STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s); #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; |