aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2020-07-15 12:09:58 +0100
committerSiCong Li <sicong.li@arm.com>2020-07-21 09:41:49 +0000
commit406a13f0b414d5c0375a46beec8dd9363a1cca56 (patch)
tree5f6fb7cfa1c7683d44de32840ffb541f450c8961
parentf6f7876e9ee8b58a8a6b335b032d554412fa3983 (diff)
downloadComputeLibrary-406a13f0b414d5c0375a46beec8dd9363a1cca56.tar.gz
COMPMID-3331 Remove y load padding from CLGEMMMatrixMultiplyReshapedOnlyRHSKernel and CLGEMMMatrixMultiplyNativeKernel
Resolves: COMPMID-3333, COMPMID-3334 * Implement an "overlap load, but don't overlap store" strategy: - Change STORE_BLOCK_BOUNDARY_AWARE so that the partial block in y dimension is placed at the beginning instead of at the end. - Implement 3 auxiliary functions to calculate the lhs, bias and dst addresses, taking into account the potential partial block in y dimension. * Remove y load padding from Lhs and Bias tensors in CLGEMMMatrixMultiplyReshapedOnlyRHSKernel and CLGEMMMatrixMultiplyNativeKernel * Modify config tests to assert zero-padding in new dimensions Change-Id: I8f8585c7c0f543d720c2c91b885417c7dad35af4 Signed-off-by: SiCong Li <sicong.li@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3576 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/CL/cl_kernels/gemm.cl45
-rw-r--r--src/core/CL/cl_kernels/gemm_helpers.h123
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp11
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp15
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyNative.cpp3
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp3
6 files changed, 106 insertions, 94 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 2360561f8a..adb3a1c25d 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -1112,7 +1112,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
#endif // defined(DUMMY_WORK_ITEMS)
// Compute LHS matrix address
- uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
+ uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
// Compute RHS reshaped matrix address
uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
@@ -1228,7 +1228,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
rhs_offset += sizeof(DATA_TYPE);
}
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
@@ -1268,8 +1268,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
ADD_BLOCK_BROADCAST(M0, c, bias0);
#else // defined(BROADCAST_BIAS)
- __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
- 2) * bias_stride_z;
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
@@ -1288,7 +1287,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
#endif // defined(ACTIVATION_TYPE)
// Store output block
- STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x);
+ STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x);
#undef RHS_BLOCK_SIZE
#undef RHS_OFFSET_X
@@ -1406,7 +1405,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t_texture(IMAGE_DECLARATION(lhs),
#endif // defined(DUMMY_WORK_ITEMS)
// Compute LHS matrix address
- uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
+ uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
@@ -1572,7 +1571,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t_texture(IMAGE_DECLARATION(lhs),
#endif // LEFTOVER_K != 0
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
@@ -1612,8 +1611,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t_texture(IMAGE_DECLARATION(lhs),
ADD_BLOCK_BROADCAST(M0, c, bias0);
#else // defined(BROADCAST_BIAS)
- __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
- 2) * bias_stride_z;
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
@@ -1632,7 +1630,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t_texture(IMAGE_DECLARATION(lhs),
#endif // defined(ACTIVATION_TYPE)
// Store output block
- STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x);
+ STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x);
#undef RHS_BLOCK_SIZE
#undef RHS_OFFSET_X
@@ -1825,7 +1823,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
#endif // defined(DUMMY_WORK_ITEMS)
// Compute LHS matrix address
- uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
+ uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
// Compute RHS reshaped matrix address
uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
@@ -1967,7 +1965,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
}
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
@@ -2006,8 +2004,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
ADD_BLOCK_BROADCAST(M0, c, bias0);
#else // defined(BROADCAST_BIAS)
- __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
- 2) * bias_stride_z;
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
@@ -2026,7 +2023,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
#endif // defined(ACTIVATION_TYPE)
// Store output block
- STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x);
+ STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x);
#undef RHS_BLOCK_SIZE
#undef RHS_OFFSET_X
@@ -2140,7 +2137,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_texture(IMAGE_DECLARATION(lhs),
#endif // defined(DUMMY_WORK_ITEMS)
// Compute LHS matrix address
- uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
+ uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
#if defined(MATRIX_B_DEPTH)
// Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
@@ -2274,7 +2271,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_texture(IMAGE_DECLARATION(lhs),
x_rhs += RHS_STEP_X;
}
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
@@ -2313,8 +2310,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_texture(IMAGE_DECLARATION(lhs),
ADD_BLOCK_BROADCAST(M0, c, bias0);
#else // defined(BROADCAST_BIAS)
- __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
- 2) * bias_stride_z;
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
@@ -2333,7 +2329,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_texture(IMAGE_DECLARATION(lhs),
#endif // defined(ACTIVATION_TYPE)
// Store output block
- STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x);
+ STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x);
#undef RHS_BLOCK_SIZE
#undef RHS_OFFSET_X
@@ -4078,7 +4074,7 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
#endif // defined(DUMMY_WORK_ITEMS)
// Compute LHS matrix address
- uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
+ uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
// Compute RHS matrix address
uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
@@ -4201,7 +4197,7 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
rhs_offset += rhs_stride_y;
}
- __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
+ __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
@@ -4240,8 +4236,7 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
ADD_BLOCK_BROADCAST(M0, c, bias0);
#else // defined(BROADCAST_BIAS)
- __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
- 2) * bias_stride_z;
+ __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
@@ -4260,7 +4255,7 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
#endif // defined(ACTIVATION_TYPE)
// Store output block
- STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x);
+ STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x);
#undef RHS_BLOCK_SIZE
#undef RHS_OFFSET_X
diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h
index 5ada788d49..5b6ad17ce0 100644
--- a/src/core/CL/cl_kernels/gemm_helpers.h
+++ b/src/core/CL/cl_kernels/gemm_helpers.h
@@ -791,29 +791,28 @@
* @param[in] Z The offset in z-axis direction
* @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
* @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
- * @param[in] M Total number of rows. Used to detect if current block is at the boundary in y.
* @param[in] N Total number of columns. Used to detect if current block is at the boundary in x.
* @param[in] y Global id of current block in y. Used to detect if current block is at the boundary in y.
* @param[in] x Global id of current block in x. Used to detect if current block is at the boundary in x.
*/
-#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
- bool at_y_boundary = (y + 1) * M0 >= M; \
- bool at_x_boundary = (x + 1) * N0 >= N; \
- if(!at_y_boundary && !at_x_boundary) \
- { \
- STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
- } \
- else if(at_y_boundary && !at_x_boundary) \
- { \
- STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
- } \
- else if(!at_y_boundary && at_x_boundary) \
- { \
- STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
- } \
- else \
- { \
- STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
+ bool at_y_boundary = y == 0; \
+ bool at_x_boundary = (x + 1) * N0 >= N; \
+ if(!at_y_boundary && !at_x_boundary) \
+ { \
+ STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+ } \
+ else if(at_y_boundary && !at_x_boundary) \
+ { \
+ STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+ } \
+ else if(!at_y_boundary && at_x_boundary) \
+ { \
+ STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+ } \
+ else \
+ { \
+ STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
}
/** Store a block that can only be partial in x but not y.
*
@@ -862,18 +861,17 @@
* @param[in] STRIDE_Y The stride value in y-axis direction
* @param[in] Z The offset in z-axis direction
* @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
- * @param[in] M Total number of rows. Used to detect if current block is at the boundary in y.
* @param[in] y Global id of current block in y. Used to detect if current block is at the boundary in y.
*/
-#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, M, y) \
- bool at_y_boundary = (y + 1) * M0 >= M; \
- if(!at_y_boundary) \
- { \
- STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
- } \
- else \
- { \
- STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, y) \
+ bool at_y_boundary = y == 0; \
+ if(!at_y_boundary) \
+ { \
+ STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
+ } \
+ else \
+ { \
+ STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \
}
/** @} */ // end of group STORE_BLOCK_PARTIAL
@@ -1484,10 +1482,16 @@
#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
-/** Store a block in a boundary-aware way that does not require any padding
- * Store a block of the shape M0xN0 in a boundary-aware way that doesn't require any padding for partial blocks
+/** Boundary-aware GEMM block store
* @name STORE_BLOCK_BOUNDARY_AWARE
- *
+ * This macro assumes the following schemes to achieve boundary-awareness:
+ * - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
+ * - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
+ * - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
+ * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
+ *
+ * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
+ * blocks **at the end**.
* Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
* "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
*
@@ -1495,20 +1499,19 @@
* | |<------------------------------N-------------------------->|
* y |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
* | -------------#############################################################
- * * | | | |...........................|
- * y == 0 | M0 | Non-boundary block |....Boundary block in x....|
- * | | | |...........................|
- * M --#############################################################
- * | | |...............................|...........................|
- * y == 1 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
+ * * | | |...............................|...........................|
+ * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
* | | |...............................|...........................|
+ * M --#############################################################
+ * | | | |...........................|
+ * y == 1 | M0 | Non-boundary block |....Boundary block in x....|
+ * | | | |...........................|
* |------------#############################################################
*
* Then @p PARTIAL_STORE_M0 = M % M0 and @p PARTIAL_STORE_N0 = N % N0
*
* @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
*
- * This method ensures that in the end the dst tensor is stored without requirements for paddings.
* It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
* and select corresponding store methods such that the boundary detection logic is only added when needed.
*
@@ -1526,7 +1529,6 @@
* @param[in] Z The offset in z-axis direction
* @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
* @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
- * @param[in] M Total number of rows. Used to detect if current block is at the boundary in y.
* @param[in] N Total number of columns. Used to detect if current block is at the boundary in x.
* @param[in] y Global id of current block in y. Used to detect if current block is at the boundary in y.
* @param[in] x Global id of current block in x. Used to detect if current block is at the boundary in x.
@@ -1534,30 +1536,55 @@
*/
#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
// Case1: No partial blocks in either x or y
-#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
+#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
// Case2: Partial blocks in y
-#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
- STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, M, y)
+#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
+ STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, y)
#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
// Case3: Partial blocks in x
-#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
+#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, N, x)
#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
// Case4: Partial blocks in both x and y
-#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
- STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x)
+#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
+ STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x)
#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
#else // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
-#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, M, N, y, x) \
+#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, N, y, x) \
STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
#endif // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
-/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE \ No newline at end of file
+/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
+
+#if defined(PARTIAL_STORE_M0)
+/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
+ * @name COMPUTE_M0_START_ROW
+ * If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
+ * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
+ * blocks in the y dimension to avoid any padding.
+ * EG: M0=4, PARTIAL_STORE_M0=1:
+ * | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
+ * block 0 (partial)| start row = 0 | start row = 0
+ * block 1 (full) | start row = 4 | start row = 1
+ * block 2 (full) | start row = 8 | start row = 5
+ *
+ * @param[in] y Global id of current block in y.
+ * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16
+ * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
+ * @{
+ */
+#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
+ ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
+#else // defined(PARTIAL_STORE_M0)
+#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
+ ((uint)(y * M0))
+#endif // defined(PARTIAL_STORE_M0)
+/** @} */ // end of group COMPUTE_M0_START_ROW \ No newline at end of file
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
index 37fcd10511..c67d3601ad 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
@@ -155,17 +155,12 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
num_elems_processed_per_iteration_x = rhs_info.n0;
num_elems_processed_per_iteration_y = lhs_info.m0;
- // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
- // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
- const unsigned int m = reinterpret_output_as_3d ? gemm_info.m : output->dimension(1);
- const unsigned int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
-
win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
AccessWindowStatic input0_access(input0, 0, 0,
input0->dimension(0),
- input0->dimension(1) + bottom_pad);
+ input0->dimension(1));
AccessWindowStatic input1_access(input1, 0, 0,
ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x),
input1->dimension(1));
@@ -177,11 +172,9 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
{
const int bias_processed_per_iteration_x = num_elems_processed_per_iteration_x;
- const int bias_processed_per_iteration_y = gemm_info.broadcast_bias ? 1 : num_elems_processed_per_iteration_y;
-
AccessWindowStatic input2_access(input2, 0, 0,
ceil_to_multiple(input2->dimension(0), bias_processed_per_iteration_x),
- ceil_to_multiple(input2->dimension(1), bias_processed_per_iteration_y));
+ input2->dimension(1));
window_changed = update_window_and_padding(win, input0_access, input1_access, input2_access) || // window used by the execute_window_loop
update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
index 7d76ffd86c..27520c6072 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -155,19 +155,14 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
num_elems_processed_per_iteration_x = rhs_info.n0;
num_elems_processed_per_iteration_y = lhs_info.m0;
- // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
- // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
- const unsigned int m = reinterpret_output_as_3d ? gemm_info.m : output->dimension(1);
- const unsigned int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
-
win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
AccessWindowStatic input0_access(input0, 0, 0,
input0->dimension(0),
- input0->dimension(1) + bottom_pad);
+ input0->dimension(1));
AccessWindowStatic input1_access(input1, 0, 0,
- input1->dimension(0),
+ ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x),
input1->dimension(1));
AccessWindowStatic output_access(output, 0, 0,
output->dimension(0),
@@ -175,11 +170,11 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
if(input2 != nullptr)
{
- const int bias_processed_per_iteration_x = num_elems_processed_per_iteration_x;
- const int bias_processed_per_iteration_y = gemm_info.broadcast_bias ? 1 : num_elems_processed_per_iteration_y;
+ const int bias_processed_per_iteration_x = num_elems_processed_per_iteration_x;
+
AccessWindowStatic input2_access(input2, 0, 0,
ceil_to_multiple(input2->dimension(0), bias_processed_per_iteration_x),
- ceil_to_multiple(input2->dimension(1), bias_processed_per_iteration_y));
+ input2->dimension(1));
window_changed = update_window_and_padding(win, input0_access, input1_access, input2_access) || // window used by the execute_window_loop
update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor
diff --git a/tests/validation/CL/GEMMMatrixMultiplyNative.cpp b/tests/validation/CL/GEMMMatrixMultiplyNative.cpp
index bdf8248bb2..b474ea3262 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyNative.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyNative.cpp
@@ -210,7 +210,8 @@ bool validate_zero_padding(unsigned int m_value, unsigned int n_value, unsigned
CLGEMMMatrixMultiplyNative gemm;
gemm.configure(&lhs, &rhs, &bias, &dst, 1.0f, 1.0f, lhs_info, rhs_info, kernel_info);
- return dst.info()->padding().empty();
+ // Padding can be added along rhs and bias's X dimension
+ return dst.info()->padding().empty() && lhs.info()->padding().empty() && bias.info()->padding().bottom == 0 && bias.info()->padding().top == 0;
}
} // namespace
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
index 0456ca2017..c87d309ada 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
@@ -246,7 +246,8 @@ bool validate_zero_padding(unsigned int m_value, unsigned int n_value, unsigned
gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
- return dst.info()->padding().empty();
+ // Padding can be added along rhs and bias's X dimension
+ return dst.info()->padding().empty() && lhs.info()->padding().empty() && bias.info()->padding().bottom == 0 && bias.info()->padding().top == 0;
}
} // namespace