aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm_helpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/gemm_helpers.h')
-rw-r--r--src/core/CL/cl_kernels/gemm_helpers.h197
1 files changed, 188 insertions, 9 deletions
diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h
index 54d38655a4..be72efa3b4 100644
--- a/src/core/CL/cl_kernels/gemm_helpers.h
+++ b/src/core/CL/cl_kernels/gemm_helpers.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -497,6 +497,185 @@
#define LOAD_TEXTURE2D(M0, N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) LOAD_TEXTURE2D_STR(M0, N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)
/** @} */ // end of group LOAD_TEXTURE2D
+/** Loads the rows from 0 to n-1 in the given variables (BASENAME0 to BASENAMEn-1) passing the Y index for each row to be loaded.
+ * @name LOAD_ROW_INDIRECT_n
+ *
+ * @param[in] N0 The number of columns to load
+ * @param[in] DATA_TYPE The data type of variables
+ * @param[in] BASENAME The basename of the destination variables for the loaded rows
+ * @param[in] PTR The base pointer
+ * @param[in] OFFSET The offset within a row
+ * @param[in] STRIDE_Y The stride value in y-axis direction
+ * @param[in] Y The y-axis offset vector
+ * @param[in] Y_MASK The y-axis mask vector. If 0, forces BASENAMEn to 0
+ * @{
+ */
+#define LOAD_ROW_INDIRECT_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##0; \
+ if(Y_MASK##0 != 0) \
+ BASENAME##0 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##0 * STRIDE_Y)); \
+ else \
+ BASENAME##0 = 0;
+
+#define LOAD_ROW_INDIRECT_2(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##1; \
+ if(Y_MASK##1 != 0) \
+ BASENAME##1 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##1 * STRIDE_Y)); \
+ else \
+ BASENAME##1 = 0;
+
+#define LOAD_ROW_INDIRECT_3(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_2(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##2; \
+ if(Y_MASK##2 != 0) \
+ BASENAME##2 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##2 * STRIDE_Y)); \
+ else \
+ BASENAME##2 = 0;
+
+#define LOAD_ROW_INDIRECT_4(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_3(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##3; \
+ if(Y_MASK##3 != 0) \
+ BASENAME##3 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##3 * STRIDE_Y)); \
+ else \
+ BASENAME##3 = 0;
+
+#define LOAD_ROW_INDIRECT_5(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_4(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##4; \
+ if(Y_MASK##4 != 0) \
+ BASENAME##4 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##4 * STRIDE_Y)); \
+ else \
+ BASENAME##4 = 0;
+
+#define LOAD_ROW_INDIRECT_6(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_5(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##5; \
+ if(Y_MASK##5 != 0) \
+ BASENAME##5 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##5 * STRIDE_Y)); \
+ else \
+ BASENAME##5 = 0;
+
+#define LOAD_ROW_INDIRECT_7(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_6(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##6; \
+ if(Y_MASK##6 != 0) \
+ BASENAME##6 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##6 * STRIDE_Y)); \
+ else \
+ BASENAME##6 = 0;
+
+#define LOAD_ROW_INDIRECT_8(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_7(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##7; \
+ if(Y_MASK##7 != 0) \
+ BASENAME##7 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##7 * STRIDE_Y)); \
+ else \
+ BASENAME##7 = 0;
+
+#define LOAD_ROW_INDIRECT_9(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_8(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##8; \
+ if(Y_MASK##8 != 0) \
+ BASENAME##8 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##8 * STRIDE_Y)); \
+ else \
+ BASENAME##8 = 0;
+
+#define LOAD_ROW_INDIRECT_10(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_9(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##9; \
+ if(Y_MASK##9 != 0) \
+ BASENAME##9 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##9 * STRIDE_Y)); \
+ else \
+ BASENAME##9 = 0;
+
+#define LOAD_ROW_INDIRECT_11(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_10(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##A; \
+ if(Y_MASK##A != 0) \
+ BASENAME##A = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##A * STRIDE_Y)); \
+ else \
+ BASENAME##A = 0;
+
+#define LOAD_ROW_INDIRECT_12(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_11(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##B; \
+ if(Y_MASK##B != 0) \
+ BASENAME##B = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##B * STRIDE_Y)); \
+ else \
+ BASENAME##B = 0;
+
+#define LOAD_ROW_INDIRECT_13(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_12(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##C; \
+ if(Y_MASK##C != 0) \
+ BASENAME##C = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##C * STRIDE_Y)); \
+ else \
+ BASENAME##C = 0;
+
+#define LOAD_ROW_INDIRECT_14(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_13(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##D; \
+ if(Y_MASK##D != 0) \
+ BASENAME##D = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##D * STRIDE_Y)); \
+ else \
+ BASENAME##D = 0;
+
+#define LOAD_ROW_INDIRECT_15(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_14(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##E; \
+ if(Y_MASK##E != 0) \
+ BASENAME##E = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##E * STRIDE_Y)); \
+ else \
+ BASENAME##E = 0;
+
+#define LOAD_ROW_INDIRECT_16(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ LOAD_ROW_INDIRECT_15(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) \
+ VEC_DATA_TYPE(DATA_TYPE, N0) \
+ BASENAME##F; \
+ if(Y_MASK##F != 0) \
+ BASENAME##F = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + Y##F * STRIDE_Y)); \
+ else \
+ BASENAME##F = 0;
+
+/** Load blocks (consecutive rows and columns) with Y offset.
+ * @name LOAD_BLOCK_INDIRECT
+ *
+ * Supported cases are M0=1,2,3,...,16 and N0=1,2,3,4,8,16
+ * The data to load is expected to have consecutive names for each row.
+ * E.g., for M0=3, and BASENAME=c, the expected data is c0, c1 and c2.
+ * The Z offset is expected to have consecutive names.
+ * E.g., for M0=3, and Z=zin, the expected Z offsets are zin0, zin1 and zin2.
+ *
+ * @param[in] M0 The number of consecutive rows
+ * @param[in] N0 The number of consecutive columns
+ * @param[in] DATA_TYPE The data type of the target
+ * @param[in] BASENAME The basename of the result variables
+ * @param[in] PTR The base pointer for the data
+ * @param[in] OFFSET The offset within a row
+ * @param[in] STRIDE_Y The stride in y-axis direction
+ * @param[in] Y The y-axis offset vector
+ * @param[in] Y_MASK The y-axis mask vector. If 0, forces BASENAMEn to 0
+ * @{
+ */
+#define LOAD_BLOCK_INDIRECT_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) LOAD_ROW_INDIRECT_##M0(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK)
+#define LOAD_BLOCK_INDIRECT(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK) LOAD_BLOCK_INDIRECT_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Y, Y_MASK)
+
/** Loads the elements from 0 to n-1 in the given variables (BASENAME0 to BASENAMEn-1).
* @name LOAD_ELEMENT_n
*
@@ -624,49 +803,49 @@
* @{
*/
#define CALCULATE_Z_OFFSET_1(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##0 = (0 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##0 = (0 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##0 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##0); \
Z##0 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_2(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_1(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##1 = (1 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##1 = (1 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##1 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##1); \
Z##1 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_3(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_2(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##2 = (2 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##2 = (2 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##2 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##2); \
Z##2 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_4(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_3(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##3 = (3 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##3 = (3 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##3 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##3); \
Z##3 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_5(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_4(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##4 = (4 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##4 = (4 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##4 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##4); \
Z##4 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_6(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_5(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##5 = (5 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##5 = (5 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##5 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##5); \
Z##5 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_7(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_6(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##6 = (6 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##6 = (6 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##6 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##6); \
Z##6 *= (CROSS_PLANE_PAD * STRIDE_Y);
#define CALCULATE_Z_OFFSET_8(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
CALCULATE_Z_OFFSET_7(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
- Z##7 = (7 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
+ Z##7 = (7 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \
Z##7 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##7); \
Z##7 *= (CROSS_PLANE_PAD * STRIDE_Y);