aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm_helpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/gemm_helpers.h')
-rw-r--r--src/core/CL/cl_kernels/gemm_helpers.h123
1 files changed, 75 insertions, 48 deletions
diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h
index 5ada788d49..5b6ad17ce0 100644
--- a/src/core/CL/cl_kernels/gemm_helpers.h
+++ b/src/core/CL/cl_kernels/gemm_helpers.h
@@ -791,29 +791,28 @@
* @param[in] Z The offset in z-axis direction
* @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
* @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
- * @param[in] M Total number of rows. Used to detect if current block is at the boundary in y.
* @param[in] N Total number of columns. Used to detect if current block is at the boundary in x.
* @param[in] y Global id of current block in y. Used to detect if current block is at the boundary in y.
* @param[in] x Global id of current block in x. Used to detect if current block is at the boundary in x.
*/
-#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
- bool at_y_boundary = (y + 1) * M0 >= M; \
- bool at_x_boundary = (x + 1) * N0 >= N; \
- if(!at_y_boundary && !at_x_boundary) \
- { \
- STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
- } \
- else if(at_y_boundary && !at_x_boundary) \
- { \
- STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
- } \
- else if(!at_y_boundary && at_x_boundary) \
- { \
- STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
- } \
- else \
- { \
- STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
+ bool at_y_boundary = y == 0; \
+ bool at_x_boundary = (x + 1) * N0 >= N; \
+ if(!at_y_boundary && !at_x_boundary) \
+ { \
+ STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+ } \
+ else if(at_y_boundary && !at_x_boundary) \
+ { \
+ STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+ } \
+ else if(!at_y_boundary && at_x_boundary) \
+ { \
+ STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+ } \
+ else \
+ { \
+ STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
}
/** Store a block that can only be partial in x but not y.
*
@@ -862,18 +861,17 @@
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
- * @param[in] M Total number of rows. Used to detect if current block is at the boundary in y.
* @param[in] y Global id of current block in y. Used to detect if current block is at the boundary in y.
*/
-#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, M, y) \
- bool at_y_boundary = (y + 1) * M0 >= M; \
- if(!at_y_boundary) \
- { \
- STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
- } \
- else \
- { \
- STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, y) \
+ bool at_y_boundary = y == 0; \
+ if(!at_y_boundary) \
+ { \
+ STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+ } \
+ else \
+ { \
+ STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
}
/** @} */ // end of group STORE_BLOCK_PARTIAL
@@ -1484,10 +1482,16 @@
#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
-/** Store a block in a boundary-aware way that does not require any padding
- * Store a block of the shape M0xN0 in a boundary-aware way that doesn't require any padding for partial blocks
+/** Boundary-aware GEMM block store
* @name STORE_BLOCK_BOUNDARY_AWARE
- *
+ * This macro assumes the following schemes to achieve boundary-awareness:
+ * - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
+ * - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
+ * - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
+ * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
+ *
+ * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
+ * blocks **at the end**.
* Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
* "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
*
@@ -1495,20 +1499,19 @@
* | |<------------------------------N-------------------------->|
* y |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
* | -------------#############################################################
- * * | | | |...........................|
- * y == 0 | M0 | Non-boundary block |....Boundary block in x....|
- * | | | |...........................|
- * M --#############################################################
- * | | |...............................|...........................|
- * y == 1 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
+ * * | | |...............................|...........................|
+ * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
* | | |...............................|...........................|
+ * M --#############################################################
+ * | | | |...........................|
+ * y == 1 | M0 | Non-boundary block |....Boundary block in x....|
+ * | | | |...........................|
* |------------#############################################################
*
* Then @p PARTIAL_STORE_M0 = M % M0 and @p PARTIAL_STORE_N0 = N % N0
*
* @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
*
- * This method ensures that in the end the dst tensor is stored without requirements for paddings.
* It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
* and select corresponding store methods such that the boundary detection logic is only added when needed.
*
@@ -1526,7 +1529,6 @@
* @param[in] Z The offset in z-axis direction
* @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
* @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
- * @param[in] M Total number of rows. Used to detect if current block is at the boundary in y.
* @param[in] N Total number of columns. Used to detect if current block is at the boundary in x.
* @param[in] y Global id of current block in y. Used to detect if current block is at the boundary in y.
* @param[in] x Global id of current block in x. Used to detect if current block is at the boundary in x.
@@ -1534,30 +1536,55 @@
*/
#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
// Case1: No partial blocks in either x or y
-#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
+#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
// Case2: Partial blocks in y
-#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
- STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, M, y)
+#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
+ STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, y)
#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
// Case3: Partial blocks in x
-#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
+#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, N, x)
#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
// Case4: Partial blocks in both x and y
-#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
- STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x)
+#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
+ STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x)
#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
#else // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
-#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
+#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
#endif // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
-/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE \ No newline at end of file
+/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
+
+#if defined(PARTIAL_STORE_M0)
+/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
+ * @name COMPUTE_M0_START_ROW
+ * If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
+ * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
+ * blocks in the y dimension to avoid any padding.
+ * EG: M0=4, PARTIAL_STORE_M0=1:
+ * | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
+ * block 0 (partial)| start row = 0 | start row = 0
+ * block 1 (full) | start row = 4 | start row = 1
+ * block 2 (full) | start row = 8 | start row = 5
+ *
+ * @param[in] y Global id of current block in y.
+ * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16
+ * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
+ * @{
+ */
+#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
+ ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
+#else // defined(PARTIAL_STORE_M0)
+#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
+ ((uint)(y * M0))
+#endif // defined(PARTIAL_STORE_M0)
+/** @} */ // end of group COMPUTE_M0_START_ROW \ No newline at end of file