From 43a129e94df41f9ac8bc78b702da5a387ada0494 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 14 May 2019 10:14:08 +0100 Subject: COMPMID-2379: Use the macros available in gemm_helpers.h in GEMMLowp OpenCL kernels Change-Id: I09923a068bff36d42a3f2c1084ffa8bf218187b9 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1260 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- .../CLGEMMLowpMatrixMultiplyReshapedKernel.h | 2 +- ...CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h | 2 +- src/core/CL/CLKernelLibrary.cpp | 1 - src/core/CL/cl_kernels/gemm_helpers.h | 183 +++- src/core/CL/cl_kernels/gemmlowp.cl | 1033 ++++---------------- src/core/CL/cl_kernels/helpers.h | 2 + .../CLGEMMLowpMatrixMultiplyReshapedKernel.cpp | 3 +- ...GEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp | 2 + 8 files changed, 369 insertions(+), 859 deletions(-) diff --git a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.h b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.h index eaadaeff19..f0c8d5cdae 100644 --- a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.h @@ -58,7 +58,7 @@ public: * lhs_info.transpose: false * @param[in] rhs_info RHS matrix information used for reshaping the input1 tensor. Only the following values are supported: * rhs_info.n0: 2,3,4,8,16 - * rhs_info.k0: 2,3,4,8,16 + * rhs_info.k0: same as lhs_info.k0 * rhs_info.transpose: true * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices * diff --git a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h index 6f8f8fead5..5328ee44bc 100644 --- a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h @@ -72,7 +72,7 @@ public: * lhs_info.k0: 2,3,4,8,16 * @param[in] rhs_info RHS matrix information used for reshaping the input1 tensor. Only the following values are supported: * rhs_info.n0: 2,3,4,8,16 - * rhs_info.k0: 2,3,4,8,16 + * rhs_info.k0: same as lhs_info.k0 * rhs_info.transpose: true * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices * diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp index 9952ed2fff..e426db28c9 100644 --- a/src/core/CL/CLKernelLibrary.cpp +++ b/src/core/CL/CLKernelLibrary.cpp @@ -334,7 +334,6 @@ const std::map CLKernelLibrary::_kernel_program_map = { "gemmlowp_mm_midgard", "gemmlowp.cl" }, { "gemmlowp_mm_interleaved_transposed_midgard", "gemmlowp.cl" }, { "gemmlowp_mm_reshaped_lhs_nt_rhs_t", "gemmlowp.cl" }, - { "gemmlowp_mm_reshaped_lhs_nt_rhs_t_dot8", "gemmlowp.cl" }, { "gemmlowp_mm_reshaped_only_rhs_t", "gemmlowp.cl" }, { "gemmlowp_offset_contribution", "gemmlowp.cl" }, { "gemmlowp_offset_contribution_quantize_down", "gemmlowp.cl" }, diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h index c9e548afb8..2c76992b31 100644 --- a/src/core/CL/cl_kernels/gemm_helpers.h +++ b/src/core/CL/cl_kernels/gemm_helpers.h @@ -112,50 +112,50 @@ #define LOAD_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) LOAD_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) #define CALCULATE_Z_OFFSET_1(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ - Z##0 = (0 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ - Z##0 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##0); \ + Z##0 = (0 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ + Z##0 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##0); \ Z##0 *= (CROSS_PLANE_PAD * STRIDE_Y); #define CALCULATE_Z_OFFSET_2(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ CALCULATE_Z_OFFSET_1(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ - Z##1 = (1 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ - Z##1 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##1); \ + Z##1 = (1 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ + Z##1 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##1); \ Z##1 *= (CROSS_PLANE_PAD * STRIDE_Y); #define CALCULATE_Z_OFFSET_3(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ CALCULATE_Z_OFFSET_2(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ - Z##2 = (2 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ - Z##2 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##2); \ + Z##2 = (2 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ + Z##2 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##2); \ Z##2 *= (CROSS_PLANE_PAD * STRIDE_Y); #define CALCULATE_Z_OFFSET_4(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ CALCULATE_Z_OFFSET_3(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ - Z##3 = (3 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ - Z##3 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##3); \ + Z##3 = (3 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ + Z##3 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##3); \ Z##3 *= (CROSS_PLANE_PAD * STRIDE_Y); #define CALCULATE_Z_OFFSET_5(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ CALCULATE_Z_OFFSET_4(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ - Z##4 = (4 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ - Z##4 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##4); \ + Z##4 = (4 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ + Z##4 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##4); \ Z##4 *= (CROSS_PLANE_PAD * STRIDE_Y); #define CALCULATE_Z_OFFSET_6(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ CALCULATE_Z_OFFSET_5(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ - Z##5 = (5 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ - Z##5 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##5); \ + Z##5 = (5 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ + Z##5 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##5); \ Z##5 *= (CROSS_PLANE_PAD * STRIDE_Y); #define CALCULATE_Z_OFFSET_7(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ CALCULATE_Z_OFFSET_6(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ - Z##6 = (6 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ - Z##6 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##6); \ + Z##6 = (6 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ + Z##6 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##6); \ Z##6 *= (CROSS_PLANE_PAD * STRIDE_Y); #define CALCULATE_Z_OFFSET_8(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ CALCULATE_Z_OFFSET_7(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ - Z##7 = (7 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ - Z##7 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##7); \ + Z##7 = (7 + (DATA_TYPE)(Y * (DATA_TYPE)M0)) / (DATA_TYPE)HEIGHT_GEMM3D; \ + Z##7 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##7); \ Z##7 *= (CROSS_PLANE_PAD * STRIDE_Y); // CALCULATE_Z_OFFSET_n calculates Z for Z##0 to Z##(n-1) @@ -179,6 +179,7 @@ */ #define CALCULATE_Z_OFFSET(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) CALCULATE_Z_OFFSET_STR(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) +// STORE_ROW_n macros #define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ VSTORE(N0) \ (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); @@ -258,15 +259,106 @@ VSTORE(N0) \ (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); +// CONVERT_STORE_ROW_n macros +#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); + +#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); + +#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); + +#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); + +#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); + +#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); + +#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); + +#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); + +#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); + +#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); + +#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); + +#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); + +#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); + +#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); + +#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); + +#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ + VSTORE(N0) \ + (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); + // STORE_ROW_n stores the rows 0..n-1 from variables BASENAME##0 to BASENAME##(n-1) #define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) -/** Store Blocks of M0 consecutive rows and N0 consecutive columns when using Z offset as well -* Supported cases M0=1,2,3..16. N0=2,3,4,8,16, for variables BASENAME[0..M] - * The data to store is expected to have consecutive names for each row, For e.g. For M0=3, and basename=c, the expected data is c0, c1 and c2. - * The Z offset is expected to have consecutive names For e.g. For M0=3, and Z=zin, the expected z offsets are zin0, zin1 and zin2. + +// CONVERT_STORE_ROW_n converts and stores the rows 0..n-1 from variables BASENAME##0 to BASENAME##(n-1) +#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) + +/** Store a block of size M0 (rows) x NO (columns). + * Supported cases M0=1,2,3..16. N0=2,3,4,8,16, for variables BASENAME[0..M] + * The data to store is expected to have consecutive names for each row, For e.g. For M0=3, and basename=c, the expected data is c0, c1 and c2. + * The Z offset is expected to have consecutive names For e.g. For M0=3, and Z=zin, the expected z offsets are zin0, zin1 and zin2. */ #define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) +/** Convert and store a block of size M0 (rows) x NO (columns). + * Supported cases M0=1,2,3..16. N0=2,3,4,8,16, for variables BASENAME[0..M] + * The data to store is expected to have consecutive names for each row, For e.g. For M0=3, and basename=c, the expected data is c0, c1 and c2. + * The Z offset is expected to have consecutive names For e.g. For M0=3, and Z=zin, the expected z offsets are zin0, zin1 and zin2. + */ +#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) + #define SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE) \ BASENAME##0 = BASENAME##0 * (DATA_TYPE)SCALE; @@ -336,3 +428,54 @@ * Supported cases N=1,2,3..16, for variables BASENAME[0..N] */ #define SCALE_BLOCK(N, DATA_TYPE, BASENAME, SCALE) SCALE_BLOCK_STR(N, DATA_TYPE, BASENAME, SCALE) + +/** Given a set of vectors of size K0, these macros create a new vector to contain the values at index IDX_COL (with IDX_COL < N0) for all input vectors */ +#define COLUMN_VECTOR1(IDX_COL, BASENAME, B) \ + uchar BASENAME##IDX_COL = (uchar)((B##0).s##IDX_COL); +#define COLUMN_VECTOR2(IDX_COL, BASENAME, B) \ + uchar2 BASENAME##IDX_COL = (uchar2)((B##0).s##IDX_COL, (B##1).s##IDX_COL); +#define COLUMN_VECTOR3(IDX_COL, BASENAME, B) \ + uchar3 BASENAME##IDX_COL = (uchar3)((B##0).s##IDX_COL, (B##1).s##IDX_COL, (B##2).s##IDX_COL); +#define COLUMN_VECTOR4(IDX_COL, BASENAME, B) \ + uchar4 BASENAME##IDX_COL = (uchar4)((B##0).s##IDX_COL, (B##1).s##IDX_COL, (B##2).s##IDX_COL, (B##3).s##IDX_COL); +#define COLUMN_VECTOR8(IDX_COL, BASENAME, B) \ + uchar8 BASENAME##IDX_COL = (uchar8)((B##0).s##IDX_COL, (B##1).s##IDX_COL, (B##2).s##IDX_COL, (B##3).s##IDX_COL, (B##4).s##IDX_COL, (B##5).s##IDX_COL, (B##6).s##IDX_COL, (B##7).s##IDX_COL); +#define COLUMN_VECTOR16(IDX_COL, BASENAME, B) \ + uchar16 BASENAME##N0 = (uchar16)((B##0).s##IDX_COL, (B##1).s##IDX_COL, (B##2).s##IDX_COL, (B##3).s##IDX_COL, (B##4).s##IDX_COL, (B##5).s##IDX_COL, (B##6).s##IDX_COL, (B##7).s##IDX_COL, (B##8).s##IDX_COL, (B##9).s##IDX_COL, (B##A).s##IDX_COL, (B##B).s##IDX_COL, (B##C).s##IDX_COL, (B##D).s##IDX_COL, (B##E).s##IDX_COL, (B##F).s##IDX_COL); + +/** Given N0 vectors of size K0, these macros create K0 vectors of size N0 which are the result of a transposition */ +#define TRANSPOSE_K0X1(K0, BASENAME, B) \ + COLUMN_VECTOR(K0, 0, BASENAME, B); +#define TRANSPOSE_K0X2(K0, BASENAME, B) \ + TRANSPOSE_K0X1(K0, BASENAME, B); \ + COLUMN_VECTOR(K0, 1, BASENAME, B); +#define TRANSPOSE_K0X3(K0, BASENAME, B) \ + TRANSPOSE_K0X2(K0, BASENAME, B); \ + COLUMN_VECTOR(K0, 2, BASENAME, B); +#define TRANSPOSE_K0X4(K0, BASENAME, B) \ + TRANSPOSE_K0X3(K0, BASENAME, B); \ + COLUMN_VECTOR(K0, 3, BASENAME, B); +#define TRANSPOSE_K0X8(K0, BASENAME, B) \ + TRANSPOSE_K0X4(K0, BASENAME, B); \ + COLUMN_VECTOR(K0, 4, BASENAME, B); \ + COLUMN_VECTOR(K0, 5, BASENAME, B); \ + COLUMN_VECTOR(K0, 6, BASENAME, B); \ + COLUMN_VECTOR(K0, 7, BASENAME, B); +#define TRANSPOSE_K0X16(K0, BASENAME, B) \ + TRANSPOSE_K0X8(K0, BASENAME, B); \ + COLUMN_VECTOR(K0, 8, BASENAME, B); \ + COLUMN_VECTOR(K0, 9, BASENAME, B); \ + COLUMN_VECTOR(K0, A, BASENAME, B); \ + COLUMN_VECTOR(K0, B, BASENAME, B); \ + COLUMN_VECTOR(K0, C, BASENAME, B); \ + COLUMN_VECTOR(K0, D, BASENAME, B); \ + COLUMN_VECTOR(K0, E, BASENAME, B); \ + COLUMN_VECTOR(K0, F, BASENAME, B); + +#define COLUMN_VECTOR(K0, IDX_COL, BASENAME, B) \ + CONCAT(COLUMN_VECTOR, K0) \ + (IDX_COL, BASENAME, B); + +#define TRANSPOSE_K0XN0(K0, N0, BASENAME, B) \ + CONCAT(TRANSPOSE_K0X, N0) \ + (K0, BASENAME, B); diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl index b1ba8e0377..0080369705 100644 --- a/src/core/CL/cl_kernels/gemmlowp.cl +++ b/src/core/CL/cl_kernels/gemmlowp.cl @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#include "helpers.h" +#include "gemm_helpers.h" #include "helpers_asymm.h" #include "repeat.h" @@ -33,6 +33,166 @@ #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) +#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) + +/** Specialized macros to perform the dot product instruction between two vectors of size N [1,16]. These macros use the dot8 instruction */ +#define ARM_DOT1(a, b, c) \ + ({ \ + ARM_DOT((uchar4)(a, (uchar3)0), (uchar4)(b, (uchar3)0), c); \ + }) +#define ARM_DOT2(a, b, c) \ + ({ \ + ARM_DOT((uchar4)(a, (uchar2)0), (uchar4)(b, (uchar2)0), c); \ + }) +#define ARM_DOT3(a, b, c) \ + ({ \ + ARM_DOT((uchar4)(a, (uchar)0), (uchar4)(b, (uchar)0), c); \ + }) +#define ARM_DOT4(a, b, c) \ + ({ \ + ARM_DOT(a, b, c); \ + }) +#define ARM_DOT8(a, b, c) \ + ({ \ + ARM_DOT4((a.lo), (b.lo), c); \ + ARM_DOT4((a.hi), (b.hi), c); \ + }) +#define ARM_DOT16(a, b, c) \ + ({ \ + ARM_DOT8((a.lo), (b.lo), c); \ + ARM_DOT8((a.hi), (b.hi), c); \ + }) + +#else // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) + +/** Specialized macros to perform the dot product instruction between two vectors of size K0 [1,16] without using the dot8 instruction. */ +#define ARM_DOT1(a, b, c) \ + ({ \ + c += (uint)a.s0 * b.s0; \ + }) +#define ARM_DOT2(a, b, c) \ + ({ \ + ARM_DOT1(a, b, c); \ + c += (uint)a.s1 * b.s1; \ + }) +#define ARM_DOT3(a, b, c) \ + ({ \ + ARM_DOT2(a, b, c); \ + c += (uint)a.s2 * b.s2; \ + }) +#define ARM_DOT4(a, b, c) \ + ({ \ + ARM_DOT3(a, b, c); \ + c += (uint)a.s3 * b.s3; \ + }) +#define ARM_DOT8(a, b, c) \ + ({ \ + ARM_DOT4((a.lo), (b.lo), c); \ + ARM_DOT4((a.hi), (b.hi), c); \ + }) +#define ARM_DOT16(a, b, c) \ + ({ \ + ARM_DOT8((a.lo), (b.lo), c); \ + ARM_DOT8((a.hi), (b.hi), c); \ + }) +#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) + +/** Specialized macros to perform a broadcast dot product operation between one vector "a" and N0 vectors "b" of size K0 [1,16] */ +#define ARM_DOT_K0X2(k0, a, b, c) \ + ({ \ + ARM_DOT_K0(k0, (a), (b##0), (c.s0)); \ + ARM_DOT_K0(k0, (a), (b##1), (c.s1)); \ + }) +#define ARM_DOT_K0X3(k0, a, b, c) \ + ({ \ + ARM_DOT_K0X2(k0, a, b, c); \ + ARM_DOT_K0(k0, (a), (b##2), (c.s2)); \ + }) +#define ARM_DOT_K0X4(k0, a, b, c) \ + ({ \ + ARM_DOT_K0X3(k0, a, b, c); \ + ARM_DOT_K0(k0, (a), (b##3), (c.s3)); \ + }) +#define ARM_DOT_K0X8(k0, a, b, c) \ + ({ \ + ARM_DOT_K0X4(k0, a, b, c); \ + ARM_DOT_K0(k0, (a), (b##4), (c.s4)); \ + ARM_DOT_K0(k0, (a), (b##5), (c.s5)); \ + ARM_DOT_K0(k0, (a), (b##6), (c.s6)); \ + ARM_DOT_K0(k0, (a), (b##7), (c.s7)); \ + }) +#define ARM_DOT_K0X16(k0, a, b, c) \ + ({ \ + ARM_DOT_K0X8(k0, a, b, c); \ + ARM_DOT_K0(k0, (a), (b##8), (c.s8)); \ + ARM_DOT_K0(k0, (a), (b##9), (c.s9)); \ + ARM_DOT_K0(k0, (a), (b##A), (c.sA)); \ + ARM_DOT_K0(k0, (a), (b##B), (c.sB)); \ + ARM_DOT_K0(k0, (a), (b##C), (c.sC)); \ + ARM_DOT_K0(k0, (a), (b##D), (c.sD)); \ + ARM_DOT_K0(k0, (a), (b##E), (c.sE)); \ + ARM_DOT_K0(k0, (a), (b##F), (c.sF)); \ + }) + +/** Specialized macros to perform a a partial matrix multiplication with dimensions M0,N0,K0*/ +#define ARM_MM_K0XN0X1(n0, k0, a, b, c) \ + ({ \ + ARM_DOT_K0XN0(n0, k0, (a##0), b, (c##0)); \ + }) +#define ARM_MM_K0XN0X2(n0, k0, a, b, c) \ + ({ \ + ARM_MM_K0XN0X1(n0, k0, a, b, c); \ + ARM_DOT_K0XN0(n0, k0, (a##1), b, (c##1)); \ + }) +#define ARM_MM_K0XN0X3(n0, k0, a, b, c) \ + ({ \ + ARM_MM_K0XN0X2(n0, k0, a, b, c); \ + ARM_DOT_K0XN0(n0, k0, (a##2), b, (c##2)); \ + }) +#define ARM_MM_K0XN0X4(n0, k0, a, b, c) \ + ({ \ + ARM_MM_K0XN0X3(n0, k0, a, b, c); \ + ARM_DOT_K0XN0(n0, k0, (a##3), b, (c##3)); \ + }) +#define ARM_MM_K0XN0X5(n0, k0, a, b, c) \ + ({ \ + ARM_MM_K0XN0X4(n0, k0, a, b, c); \ + ARM_DOT_K0XN0(n0, k0, (a##4), b, (c##4)); \ + }) +#define ARM_MM_K0XN0X6(n0, k0, a, b, c) \ + ({ \ + ARM_MM_K0XN0X5(n0, k0, a, b, c); \ + ARM_DOT_K0XN0(n0, k0, (a##5), b, (c##5)); \ + }) +#define ARM_MM_K0XN0X7(n0, k0, a, b, c) \ + ({ \ + ARM_MM_K0XN0X6(n0, k0, a, b, c); \ + ARM_DOT_K0XN0(n0, k0, (a##6), b, (c##6)); \ + }) +#define ARM_MM_K0XN0X8(n0, k0, a, b, c) \ + ({ \ + ARM_MM_K0XN0X7(n0, k0, a, b, c); \ + ARM_DOT_K0XN0(n0, k0, (a##7), b, (c##7)); \ + }) + +#define ARM_DOT_K0(k0, a, b, c) \ + ({ \ + CONCAT(ARM_DOT, k0) \ + ((a), (b), (c)); \ + }) + +#define ARM_DOT_K0XN0(n0, k0, a, b, c) \ + ({ \ + CONCAT(ARM_DOT_K0X, n0) \ + (k0, (a), b, (c)); \ + }) + +#define ARM_MM_K0XN0XM0(m0, n0, k0, a, b, c) \ + ({ \ + CONCAT(ARM_MM_K0XN0X, m0) \ + (n0, k0, a, b, c); \ + }) + #if defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP) /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel before running the matrix multiplication @@ -1352,161 +1512,7 @@ __kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0), #endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A) #if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(M) && defined(N) - -#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) - -#if K0 == 2 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - ARM_DOT((uchar4)(a, (uchar2)0), (uchar4)(b, (uchar2)0), c); \ - }) -#elif K0 == 3 // K0 == 3 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - ARM_DOT((uchar4)(a, (uchar)0), (uchar4)(b, (uchar)0), c); \ - }) -#elif K0 == 4 // K0 == 4 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - ARM_DOT(a, b, c); \ - }) -#elif K0 == 8 // K0 == 8 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - ARM_DOT(a.s0123, b.s0123, c); \ - ARM_DOT(a.s4567, b.s4567, c); \ - }) -#elif K0 == 16 // K0 == 16 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - ARM_DOT(a.s0123, b.s0123, c); \ - ARM_DOT(a.s4567, b.s4567, c); \ - ARM_DOT(a.s89AB, b.s89AB, c); \ - ARM_DOT(a.sCDEF, b.sCDEF, c); \ - }) -#else // K0 not supported -#error "K0 value not supported" -#endif // K0 - -#else // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) - -#if K0 == 2 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - c += (uint)a.s0 * b.s0; \ - c += (uint)a.s1 * b.s1; \ - }) -#elif K0 == 3 // K0 == 3 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - c += (uint)a.s0 * b.s0; \ - c += (uint)a.s1 * b.s1; \ - c += (uint)a.s2 * b.s2; \ - }) -#elif K0 == 4 // K0 == 4 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - c += (uint)a.s0 * b.s0; \ - c += (uint)a.s1 * b.s1; \ - c += (uint)a.s2 * b.s2; \ - c += (uint)a.s3 * b.s3; \ - }) -#elif K0 == 8 // K0 == 8 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - c += (uint)a.s0 * b.s0; \ - c += (uint)a.s1 * b.s1; \ - c += (uint)a.s2 * b.s2; \ - c += (uint)a.s3 * b.s3; \ - c += (uint)a.s4 * b.s4; \ - c += (uint)a.s5 * b.s5; \ - c += (uint)a.s6 * b.s6; \ - c += (uint)a.s7 * b.s7; \ - }) -#elif K0 == 16 // K0 == 16 -#define ARM_DOT_K0(a, b, c) \ - ({ \ - c += (uint)a.s0 * b.s0; \ - c += (uint)a.s1 * b.s1; \ - c += (uint)a.s2 * b.s2; \ - c += (uint)a.s3 * b.s3; \ - c += (uint)a.s4 * b.s4; \ - c += (uint)a.s5 * b.s5; \ - c += (uint)a.s6 * b.s6; \ - c += (uint)a.s7 * b.s7; \ - c += (uint)a.s8 * b.s8; \ - c += (uint)a.s9 * b.s9; \ - c += (uint)a.sA * b.sA; \ - c += (uint)a.sB * b.sB; \ - c += (uint)a.sC * b.sC; \ - c += (uint)a.sD * b.sD; \ - c += (uint)a.sE * b.sE; \ - c += (uint)a.sF * b.sF; \ - }) -#else // K0 not supported -#error "K0 value not supported" -#endif // K0 - -#endif //defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) - -#if N0 == 2 -#define ARM_DOT_K0XN0(a, b, c) \ - ({ \ - ARM_DOT_K0((a), (b##0), (c.s0)); \ - ARM_DOT_K0((a), (b##1), (c.s1)); \ - }) -#elif N0 == 3 // N0 == 3 -#define ARM_DOT_K0XN0(a, b, c) \ - ({ \ - ARM_DOT_K0((a), (b##0), (c.s0)); \ - ARM_DOT_K0((a), (b##1), (c.s1)); \ - ARM_DOT_K0((a), (b##2), (c.s2)); \ - }) -#elif N0 == 4 // N0 == 4 -#define ARM_DOT_K0XN0(a, b, c) \ - ({ \ - ARM_DOT_K0((a), (b##0), (c.s0)); \ - ARM_DOT_K0((a), (b##1), (c.s1)); \ - ARM_DOT_K0((a), (b##2), (c.s2)); \ - ARM_DOT_K0((a), (b##3), (c.s3)); \ - }) -#elif N0 == 8 // N0 == 8 -#define ARM_DOT_K0XN0(a, b, c) \ - ({ \ - ARM_DOT_K0((a), (b##0), (c.s0)); \ - ARM_DOT_K0((a), (b##1), (c.s1)); \ - ARM_DOT_K0((a), (b##2), (c.s2)); \ - ARM_DOT_K0((a), (b##3), (c.s3)); \ - ARM_DOT_K0((a), (b##4), (c.s4)); \ - ARM_DOT_K0((a), (b##5), (c.s5)); \ - ARM_DOT_K0((a), (b##6), (c.s6)); \ - ARM_DOT_K0((a), (b##7), (c.s7)); \ - }) -#elif N0 == 16 // N0 == 16 -#define ARM_DOT_K0XN0(a, b, c) \ - ({ \ - ARM_DOT_K0((a), (b##0), (c.s0)); \ - ARM_DOT_K0((a), (b##1), (c.s1)); \ - ARM_DOT_K0((a), (b##2), (c.s2)); \ - ARM_DOT_K0((a), (b##3), (c.s3)); \ - ARM_DOT_K0((a), (b##4), (c.s4)); \ - ARM_DOT_K0((a), (b##5), (c.s5)); \ - ARM_DOT_K0((a), (b##6), (c.s6)); \ - ARM_DOT_K0((a), (b##7), (c.s7)); \ - ARM_DOT_K0((a), (b##8), (c.s8)); \ - ARM_DOT_K0((a), (b##9), (c.s9)); \ - ARM_DOT_K0((a), (b##A), (c.sA)); \ - ARM_DOT_K0((a), (b##B), (c.sB)); \ - ARM_DOT_K0((a), (b##C), (c.sC)); \ - ARM_DOT_K0((a), (b##D), (c.sD)); \ - ARM_DOT_K0((a), (b##E), (c.sE)); \ - ARM_DOT_K0((a), (b##F), (c.sF)); \ - }) -#else // N0 not supported -#error "N0 value not supported" -#endif // N0 conditions - -/** This OpenCL kernel computes the matrix multiplication between 2 matrices with QASYMM data type . +/** This OpenCL kernel computes the matrix multiplication between 2 matrices with QASYMM data type. * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed * @@ -1594,243 +1600,73 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), #define RHS_STEP_LOOP (H0) #endif // defined(RHS_INTERLEAVE) + uint x = get_global_id(0); + uint y = get_global_id(1); + uint z = get_global_id(2); + #if defined(DUMMY_WORK_ITEMS) - if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M)) + if((x * N0 >= N) || (y * M0 >= M)) { return; } #endif // defined(DUMMY_WORK_ITEMS) // Compute LHS matrix address - __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X + (get_global_id(1) / V0) * (uint)lhs_stride_y + (get_global_id( - 2) - * lhs_stride_z); + __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z); // Compute RHS matrix address - __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X + (get_global_id(0) / (uint)H0) * rhs_stride_y; + __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X + (x / (uint)H0) * rhs_stride_y; #if defined(MATRIX_B_DEPTH) // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 - rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z; + rhs_addr += (z % MATRIX_B_DEPTH) * rhs_stride_z; #else // defined(MATRIX_B_DEPTH) - rhs_addr += get_global_id(2) * rhs_stride_z; + rhs_addr += z * rhs_stride_z; #endif // defined(MATRIX_B_DEPTH) + REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; + REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0); + // Initialize the accumulators REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; for(int i = 0; i < k; i += K0) { - // Supported cases (M0, K0): - // 2,4 - 2,8 - 2,16 - // 3,4 - 3,8 - 3,16 - // 4,4 - 4,8 - 4,16 - // 5,4 - 5,8 - 5,16 - // 6,4 - 6,8 - 6,16 // Load values from LHS matrix - VEC_DATA_TYPE(uchar, K0) - a0 = VLOAD(K0)(0, lhs_addr + 0 * LHS_STEP_X); -#if M0 > 1 - VEC_DATA_TYPE(uchar, K0) - a1 = VLOAD(K0)(0, lhs_addr + 1 * LHS_STEP_X); -#endif // M0 > 1 -#if M0 > 2 - VEC_DATA_TYPE(uchar, K0) - a2 = VLOAD(K0)(0, lhs_addr + 2 * LHS_STEP_X); -#endif // M0 > 2 -#if M0 > 3 - VEC_DATA_TYPE(uchar, K0) - a3 = VLOAD(K0)(0, lhs_addr + 3 * LHS_STEP_X); -#endif // M0 > 3 -#if M0 > 4 - VEC_DATA_TYPE(uchar, K0) - a4 = VLOAD(K0)(0, lhs_addr + 4 * LHS_STEP_X); -#endif // M0 > 4 -#if M0 > 5 - VEC_DATA_TYPE(uchar, K0) - a5 = VLOAD(K0)(0, lhs_addr + 5 * LHS_STEP_X); -#endif // M0 > 5 -#if M0 > 6 - VEC_DATA_TYPE(uchar, K0) - a6 = VLOAD(K0)(0, lhs_addr + 6 * LHS_STEP_X); -#endif // M0 > 6 -#if M0 > 7 - VEC_DATA_TYPE(uchar, K0) - a7 = VLOAD(K0)(0, lhs_addr + 7 * LHS_STEP_X); -#endif // M0 > 7 + LOAD_BLOCK(M0, K0, uchar, a, lhs_addr, 0, LHS_STEP_X, zlhs); // Load values from RHS matrix - VEC_DATA_TYPE(uchar, K0) - b0 = VLOAD(K0)(0, rhs_addr + 0 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - b1 = VLOAD(K0)(0, rhs_addr + 1 * RHS_STEP_X); -#if N0 > 2 - VEC_DATA_TYPE(uchar, K0) - b2 = VLOAD(K0)(0, rhs_addr + 2 * RHS_STEP_X); -#endif // N0 > 2 -#if N0 > 3 - VEC_DATA_TYPE(uchar, K0) - b3 = VLOAD(K0)(0, rhs_addr + 3 * RHS_STEP_X); -#endif // N0 > 3 -#if N0 > 4 - VEC_DATA_TYPE(uchar, K0) - b4 = VLOAD(K0)(0, rhs_addr + 4 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - b5 = VLOAD(K0)(0, rhs_addr + 5 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - b6 = VLOAD(K0)(0, rhs_addr + 6 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - b7 = VLOAD(K0)(0, rhs_addr + 7 * RHS_STEP_X); -#endif // N0 > 4 -#if N0 > 8 - VEC_DATA_TYPE(uchar, K0) - b8 = VLOAD(K0)(0, rhs_addr + 8 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - b9 = VLOAD(K0)(0, rhs_addr + 9 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bA = VLOAD(K0)(0, rhs_addr + 10 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bB = VLOAD(K0)(0, rhs_addr + 11 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bC = VLOAD(K0)(0, rhs_addr + 12 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bD = VLOAD(K0)(0, rhs_addr + 13 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bE = VLOAD(K0)(0, rhs_addr + 14 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bF = VLOAD(K0)(0, rhs_addr + 15 * RHS_STEP_X); -#endif // N0 > 8 + LOAD_BLOCK(N0, K0, uchar, b, rhs_addr, 0, RHS_STEP_X, zrhs); - // Accumulate - ARM_DOT_K0XN0(a0, b, c0); -#if M0 > 1 - ARM_DOT_K0XN0(a1, b, c1); -#endif // M0 > 1 -#if M0 > 2 - ARM_DOT_K0XN0(a2, b, c2); -#endif // M0 > 2 -#if M0 > 3 - ARM_DOT_K0XN0(a3, b, c3); -#endif // M0 > 3 -#if M0 > 4 - ARM_DOT_K0XN0(a4, b, c4); -#endif // M0 > 4 -#if M0 > 5 - ARM_DOT_K0XN0(a5, b, c5); -#endif // M0 > 5 -#if M0 > 6 - ARM_DOT_K0XN0(a6, b, c6); -#endif // M0 > 6 -#if M0 > 7 - ARM_DOT_K0XN0(a7, b, c7); -#endif // M0 > 7 + // Partial matrix multiplication M0,N0,K0 + ARM_MM_K0XN0XM0(M0, N0, K0, a, b, c); + // Update address lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP); rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP); } - __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(int)) + (get_global_id(1) * (uint)M0 * dst_stride_y); + __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(int)) + (y * (uint)M0 * dst_stride_y); REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; #if defined(REINTERPRET_OUTPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - - // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D - zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0); - zout0 *= (dst_cross_plane_pad * dst_stride_y); -#if M0 > 1 - zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1); - zout1 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 1 -#if M0 > 2 - zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2); - zout2 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 2 -#if M0 > 3 - zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3); - zout3 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 3 -#if M0 > 4 - zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4); - zout4 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 4 -#if M0 > 5 - zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5); - zout5 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 5 -#if M0 > 6 - zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6); - zout6 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 6 -#if M0 > 7 - zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7); - zout7 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 7 + // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D + CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D - dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D; + dst_addr += z * dst_stride_z * DEPTH_GEMM3D; #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM - dst_addr += get_global_id(2) * dst_stride_z; + dst_addr += z * dst_stride_z; #endif // defined(REINTERPRET_OUTPUT_AS_3D) - // Store output block - VSTORE(N0) - (CONVERT_SAT(c0, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout0)); -#if M0 > 1 - VSTORE(N0) - (CONVERT_SAT(c1, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout1)); -#endif // M0 > 1 -#if M0 > 2 - VSTORE(N0) - (CONVERT_SAT(c2, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout2)); -#endif // M0 > 2 -#if M0 > 3 - VSTORE(N0) - (CONVERT_SAT(c3, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout3)); -#endif // M0 > 3 -#if M0 > 4 - VSTORE(N0) - (CONVERT_SAT(c4, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 4 * dst_stride_y + zout4)); -#endif // M0 > 4 -#if M0 > 5 - VSTORE(N0) - (CONVERT_SAT(c5, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 5 * dst_stride_y + zout5)); -#endif // M0 > 5 -#if M0 > 6 - VSTORE(N0) - (CONVERT_SAT(c6, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 6 * dst_stride_y + zout6)); -#endif // M0 > 6 -#if M0 > 7 - VSTORE(N0) - (CONVERT_SAT(c7, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 7 * dst_stride_y + zout7)); -#endif // M0 > 7 + // Convert and store output block + CONVERT_STORE_BLOCK(M0, N0, int, c, dst_addr, dst_stride_y, zout); #undef LHS_BLOCK_SIZE #undef LHS_OFFSET_X @@ -1839,256 +1675,13 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), #undef RHS_OFFSET_X #undef RHS_STEP_X } - -#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) -/** This OpenCL kernel computes the matrix multiplication between 2 matrices with QASYMM8 data type using the dot8 instruction. - * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed - * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed - * - * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4). - * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2) - * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2) - * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time. - * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time. - * @note Only the following configurations of M0, N0 and K0 are currently supported: - * - M0 = 2, 3, 4, 5, 6, 7, 8 - * - N0 = 2, 3, 4, 8, 16 - * - K0 = 2, 3, 4, 8, 16 - * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: - * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D - * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. - * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor - * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped - * - * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: QASYMM8 - * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes) - * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes) - * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix - * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr - * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes) - * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes) - * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix - * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr - * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) - * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) - * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix - * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped. - * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes) - * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes) - * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) - */ -__kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t_dot8(IMAGE_DECLARATION(lhs), - IMAGE_DECLARATION(rhs), - IMAGE_DECLARATION(dst), - uint k, - uint lhs_stride_z, - uint rhs_stride_z, - uint dst_stride_z -#if defined(REINTERPRET_OUTPUT_AS_3D) - , - uint dst_cross_plane_pad -#endif // REINTERPRET_OUTPUT_AS_3D - ) -{ - // Note: ARM_DOT_K0XN0 is generated with the dot8 instruction - gemmlowp_mm_reshaped_lhs_nt_rhs_t(lhs_ptr, - lhs_stride_x, - lhs_step_x, - lhs_stride_y, - lhs_step_y, - lhs_offset_first_element_in_bytes, - rhs_ptr, - rhs_stride_x, - rhs_step_x, - rhs_stride_y, - rhs_step_y, - rhs_offset_first_element_in_bytes, - dst_ptr, - dst_stride_x, - dst_step_x, - dst_stride_y, - dst_step_y, - dst_offset_first_element_in_bytes, - k, - lhs_stride_z, - rhs_stride_z, - dst_stride_z -#if defined(REINTERPRET_OUTPUT_AS_3D) - , - dst_cross_plane_pad -#endif // REINTERPRET_OUTPUT_AS_3D - ); -} -#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) #endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) #if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(K) -#define CONCAT(a, b) a##b - -#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) - -#define ARM_DOT1(a, b, c) \ - ({ \ - ARM_DOT((uchar4)(a, (uchar3)0), (uchar4)(b, (uchar3)0), c); \ - }) -#define ARM_DOT2(a, b, c) \ - ({ \ - ARM_DOT((uchar4)(a, (uchar2)0), (uchar4)(b, (uchar2)0), c); \ - }) -#define ARM_DOT3(a, b, c) \ - ({ \ - ARM_DOT((uchar4)(a, (uchar)0), (uchar4)(b, (uchar)0), c); \ - }) -#define ARM_DOT4(a, b, c) \ - ({ \ - ARM_DOT(a, b, c); \ - }) -#define ARM_DOT8(a, b, c) \ - ({ \ - ARM_DOT4((a.lo), (b.lo), c); \ - ARM_DOT4((a.hi), (b.hi), c); \ - }) -#define ARM_DOT16(a, b, c) \ - ({ \ - ARM_DOT8((a.lo), (b.lo), c); \ - ARM_DOT8((a.hi), (b.hi), c); \ - }) - -#else // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) - -#define ARM_DOT1(a, b, c) \ - ({ \ - c += (uint)a.s0 * b.s0; \ - }) -#define ARM_DOT2(a, b, c) \ - ({ \ - ARM_DOT1(a, b, c); \ - c += (uint)a.s1 * b.s1; \ - }) -#define ARM_DOT3(a, b, c) \ - ({ \ - ARM_DOT2(a, b, c); \ - c += (uint)a.s2 * b.s2; \ - }) -#define ARM_DOT4(a, b, c) \ - ({ \ - ARM_DOT3(a, b, c); \ - c += (uint)a.s3 * b.s3; \ - }) -#define ARM_DOT8(a, b, c) \ - ({ \ - ARM_DOT4((a.lo), (b.lo), c); \ - ARM_DOT4((a.hi), (b.hi), c); \ - }) -#define ARM_DOT16(a, b, c) \ - ({ \ - ARM_DOT8((a.lo), (b.lo), c); \ - ARM_DOT8((a.hi), (b.hi), c); \ - }) -#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) - -#if N0 == 2 -#define ARM_DOT_K0XN0(k0, a, b, c) \ - ({ \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##0), (c.s0)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##1), (c.s1)); \ - }) -#elif N0 == 3 // N0 == 3 -#define ARM_DOT_K0XN0(k0, a, b, c) \ - ({ \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##0), (c.s0)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##1), (c.s1)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##2), (c.s2)); \ - }) -#elif N0 == 4 // N0 == 4 -#define ARM_DOT_K0XN0(k0, a, b, c) \ - ({ \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##0), (c.s0)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##1), (c.s1)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##2), (c.s2)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##3), (c.s3)); \ - }) -#elif N0 == 8 // N0 == 8 -#define ARM_DOT_K0XN0(k0, a, b, c) \ - ({ \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##0), (c.s0)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##1), (c.s1)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##2), (c.s2)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##3), (c.s3)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##4), (c.s4)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##5), (c.s5)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##6), (c.s6)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##7), (c.s7)); \ - }) -#elif N0 == 16 // N0 == 16 -#define ARM_DOT_K0XN0(k0, a, b, c) \ - ({ \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##0), (c.s0)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##1), (c.s1)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##2), (c.s2)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##3), (c.s3)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##4), (c.s4)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##5), (c.s5)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##6), (c.s6)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##7), (c.s7)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##8), (c.s8)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##9), (c.s9)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##A), (c.sA)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##B), (c.sB)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##C), (c.sC)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##D), (c.sD)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##E), (c.sE)); \ - CONCAT(ARM_DOT, k0) \ - ((a), (b##F), (c.sF)); \ - }) -#else // N0 not supported -#error "N0 value not supported" -#endif // N0 conditions - /** This OpenCL kernel computes the matrix multiplication between 2 matrices. * The LHS matrix is NOT reshaped - * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed + * The RHS matrix is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed * * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64) * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4). @@ -2186,63 +1779,12 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), rhs_offset += z * rhs_stride_z; #endif // defined(MATRIX_B_DEPTH) - REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; + REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; + REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0); #if defined(REINTERPRET_INPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - - // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D - zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0); - zin0 *= (lhs_cross_plane_pad * lhs_stride_y); -#if M0 > 1 - zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1); - zin1 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 1 -#if M0 > 2 - zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2); - zin2 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 2 -#if M0 > 3 - zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3); - zin3 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 3 -#if M0 > 4 - zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4); - zin4 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 4 -#if M0 > 5 - zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5); - zin5 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 5 -#if M0 > 6 - zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6); - zin6 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 6 -#if M0 > 7 - zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7); - zin7 *= (lhs_cross_plane_pad * lhs_stride_y); -#endif // M0 > 7 + // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D + CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply lhs_stride_z by DEPTH_GEMM3D @@ -2260,112 +1802,14 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), for(int i = 0; i < K; i += K0) { - // Supported cases (M0, K0): - // 1,2 - 1,3 - 1,4 - 1,8 - 1,16 - // 2,2 - 2,3 - 2,4 - 2,8 - 2,16 - // 3,2 - 3,3 - 3,4 - 3,8 - 3,16 - // 4,2 - 4,3 - 4,4 - 4,8 - 4,16 - // 5,2 - 5,3 - 5,4 - 5,8 - 5,16 - // 6,2 - 6,3 - 6,4 - 6,8 - 6,16 - // 7,2 - 7,3 - 7,4 - 7,8 - 7,16 - // 8,2 - 8,3 - 8,4 - 8,8 - 8,16 // Load values from LHS matrix - VEC_DATA_TYPE(uchar, K0) - a0 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0); -#if M0 > 1 - VEC_DATA_TYPE(uchar, K0) - a1 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1); -#endif // M0 > 1 -#if M0 > 2 - VEC_DATA_TYPE(uchar, K0) - a2 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2); -#endif // M0 > 2 -#if M0 > 3 - VEC_DATA_TYPE(uchar, K0) - a3 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3); -#endif // M0 > 3 -#if M0 > 4 - VEC_DATA_TYPE(uchar, K0) - a4 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4); -#endif // M0 > 4 -#if M0 > 5 - VEC_DATA_TYPE(uchar, K0) - a5 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5); -#endif // M0 > 5 -#if M0 > 6 - VEC_DATA_TYPE(uchar, K0) - a6 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6); -#endif // M0 > 6 -#if M0 > 7 - VEC_DATA_TYPE(uchar, K0) - a7 = VLOAD(K0)(0, lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7); -#endif // M0 > 7 + LOAD_BLOCK(M0, K0, uchar, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); // Load values from RHS matrix - VEC_DATA_TYPE(uchar, K0) - b0 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 0 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - b1 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 1 * RHS_STEP_X); -#if N0 > 2 - VEC_DATA_TYPE(uchar, K0) - b2 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 2 * RHS_STEP_X); -#endif // N0 > 2 -#if N0 > 3 - VEC_DATA_TYPE(uchar, K0) - b3 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 3 * RHS_STEP_X); -#endif // N0 > 3 -#if N0 > 4 - VEC_DATA_TYPE(uchar, K0) - b4 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 4 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - b5 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 5 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - b6 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 6 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - b7 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 7 * RHS_STEP_X); -#endif // N0 > 4 -#if N0 > 8 - VEC_DATA_TYPE(uchar, K0) - b8 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 8 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - b9 = VLOAD(K0)(0, rhs_ptr + rhs_offset + 9 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bA = VLOAD(K0)(0, rhs_ptr + rhs_offset + 10 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bB = VLOAD(K0)(0, rhs_ptr + rhs_offset + 11 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bC = VLOAD(K0)(0, rhs_ptr + rhs_offset + 12 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bD = VLOAD(K0)(0, rhs_ptr + rhs_offset + 13 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bE = VLOAD(K0)(0, rhs_ptr + rhs_offset + 14 * RHS_STEP_X); - VEC_DATA_TYPE(uchar, K0) - bF = VLOAD(K0)(0, rhs_ptr + rhs_offset + 15 * RHS_STEP_X); -#endif // N0 > 8 + LOAD_BLOCK(N0, K0, uchar, b, rhs_ptr, rhs_offset, RHS_STEP_X, zrhs); - // Accumulate - ARM_DOT_K0XN0(K0, a0, b, c0); -#if M0 > 1 - ARM_DOT_K0XN0(K0, a1, b, c1); -#endif // M0 > 1 -#if M0 > 2 - ARM_DOT_K0XN0(K0, a2, b, c2); -#endif // M0 > 2 -#if M0 > 3 - ARM_DOT_K0XN0(K0, a3, b, c3); -#endif // M0 > 3 -#if M0 > 4 - ARM_DOT_K0XN0(K0, a4, b, c4); -#endif // M0 > 4 -#if M0 > 5 - ARM_DOT_K0XN0(K0, a5, b, c5); -#endif // M0 > 5 -#if M0 > 6 - ARM_DOT_K0XN0(K0, a6, b, c6); -#endif // M0 > 6 -#if M0 > 7 - ARM_DOT_K0XN0(K0, a7, b, c7); -#endif // M0 > 7 + // Partial matrix multiplication M0,N0,K0 + ARM_MM_K0XN0XM0(M0, N0, K0, a, b, c); lhs_offset += K0; rhs_offset += N0 * RHS_STEP_X * RHS_STEP_LOOP; @@ -2376,60 +1820,8 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; #if defined(REINTERPRET_OUTPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D - zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0); - zout0 *= (dst_cross_plane_pad * dst_stride_y); -#if M0 > 1 - zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1); - zout1 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 1 -#if M0 > 2 - zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2); - zout2 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 2 -#if M0 > 3 - zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3); - zout3 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 3 -#if M0 > 4 - zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4); - zout4 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 4 -#if M0 > 5 - zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5); - zout5 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 5 -#if M0 > 6 - zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6); - zout6 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 6 -#if M0 > 7 - zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D; - zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7); - zout7 *= (dst_cross_plane_pad * dst_stride_y); -#endif // M0 > 7 + CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y); // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D @@ -2442,37 +1834,8 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), #endif // defined(REINTERPRET_OUTPUT_AS_3D) - // Store output block - VSTORE(N0) - (CONVERT_SAT(c0, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 0 * dst_stride_y + zout0)); -#if M0 > 1 - VSTORE(N0) - (CONVERT_SAT(c1, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 1 * dst_stride_y + zout1)); -#endif // M0 > 1 -#if M0 > 2 - VSTORE(N0) - (CONVERT_SAT(c2, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 2 * dst_stride_y + zout2)); -#endif // M0 > 2 -#if M0 > 3 - VSTORE(N0) - (CONVERT_SAT(c3, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 3 * dst_stride_y + zout3)); -#endif // M0 > 3 -#if M0 > 4 - VSTORE(N0) - (CONVERT_SAT(c4, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 4 * dst_stride_y + zout4)); -#endif // M0 > 4 -#if M0 > 5 - VSTORE(N0) - (CONVERT_SAT(c5, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 5 * dst_stride_y + zout5)); -#endif // M0 > 5 -#if M0 > 6 - VSTORE(N0) - (CONVERT_SAT(c6, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 6 * dst_stride_y + zout6)); -#endif // M0 > 6 -#if M0 > 7 - VSTORE(N0) - (CONVERT_SAT(c7, VEC_DATA_TYPE(int, N0)), 0, (__global int *)(dst_addr + 7 * dst_stride_y + zout7)); -#endif // M0 > 7 + // Convert and store output block + CONVERT_STORE_BLOCK(M0, N0, int, c, dst_addr, dst_stride_y, zout); #undef RHS_BLOCK_SIZE #undef RHS_OFFSET_X diff --git a/src/core/CL/cl_kernels/helpers.h b/src/core/CL/cl_kernels/helpers.h index e39dbc3e94..756c906e66 100644 --- a/src/core/CL/cl_kernels/helpers.h +++ b/src/core/CL/cl_kernels/helpers.h @@ -43,6 +43,8 @@ #define GPU_ARCH_MIDGARD 0x100 #define GPU_ARCH_BIFROST 0x200 +#define CONCAT(a, b) a##b + #define EXPAND(x) x #define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val) diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp index a8c1704d91..050b792c4e 100644 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp +++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp @@ -214,7 +214,6 @@ void CLGEMMLowpMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, std::string kernel_name("gemmlowp_mm_reshaped_"); kernel_name += lhs_info.transpose ? "lhs_t_" : "lhs_nt_"; kernel_name += rhs_info.transpose ? "rhs_t" : "rhs_nt"; - kernel_name += dot8_supported(CLKernelLibrary::get().get_device()) ? "_dot8" : ""; // Create kernel _kernel = static_cast(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); @@ -222,6 +221,8 @@ void CLGEMMLowpMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, // Set config_id for enabling LWS tuning _config_id = kernel_name; _config_id += "_"; + _config_id += dot8_supported(CLKernelLibrary::get().get_device()) ? "_dot8" : ""; + _config_id += "_"; _config_id += (_reinterpret_output_as_3d ? "3do_" : ""); _config_id += support::cpp11::to_string(output->info()->dimension(1)); _config_id += "_"; diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp index 923b9529fa..3ddeeaee41 100644 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -229,6 +229,8 @@ void CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *i // Set config_id for enabling LWS tuning _config_id = kernel_name; _config_id += "_"; + _config_id += dot8_supported(CLKernelLibrary::get().get_device()) ? "_dot8" : ""; + _config_id += "_"; _config_id += (_reinterpret_input_as_3d ? "3di_" : ""); _config_id += (_reinterpret_output_as_3d ? "3do_" : ""); _config_id += support::cpp11::to_string(output->info()->dimension(1)); -- cgit v1.2.1