aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm.cl
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-03-02 11:18:12 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commitd2fab7315bac3a586f2f1b1c8d64f2441f89ca64 (patch)
tree33572f0fea29d24546850f3835703f9869726122 /src/core/CL/cl_kernels/gemm.cl
parent27c08abe6947b1ee5b266799f2bb2bf0a05d0def (diff)
downloadComputeLibrary-d2fab7315bac3a586f2f1b1c8d64f2441f89ca64.tar.gz
COMPMID-935 - Implementing Convolution with Winograd on OpenCL (part 4)
Implemented Winograd Output Transform (2x2,3x3) on OpenCL Implemented CLWinogradConvolutionLayer on OpenCL Change-Id: I6a113fc5f052ca07f878d2b800d2ab003f84af65 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125148 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl127
1 files changed, 111 insertions, 16 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index cba5eea437..a5b0acbe9c 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -162,6 +162,8 @@ __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
* @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
* @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
* @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -199,8 +201,18 @@ __kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0)
// src_addr_a = address of matrix A
// src_addr_b = address of matrix B
- __global float *src_addr_a = (__global float *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
- __global float *src_addr_b = (__global float *)(src1_ptr + z * src1_stride_z + x * src1_stride_y + src1_offset_first_element_in_bytes);
+ int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
+ int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
+
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
+ src1_addr_in_bytes += z * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
+
+ __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
+ __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
// Compute end row address for matrix B
__global float *src_end_addr_b = src_addr_b + COLS_B;
@@ -277,6 +289,9 @@ __kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0)
* @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
* @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
* @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -314,8 +329,18 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
// src_addr_a = address of matrix A
// src_addr_b = address of matrix B
- __global float *src_addr_a = (__global float *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
- __global float *src_addr_b = (__global float *)(src1_ptr + z * src1_stride_z + x * src1_stride_y + src1_offset_first_element_in_bytes);
+ int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
+ int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
+
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
+ src1_addr_in_bytes += z * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
+
+ __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
+ __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
// Compute end row address for matrix B
__global float *src_end_addr_b = src_addr_b + COLS_B;
@@ -510,6 +535,8 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0)
* @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
* @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
* @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -547,8 +574,18 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
// src_addr_a = address of matrix A
// src_addr_b = address of matrix B
- __global half *src_addr_a = (__global half *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
- __global half *src_addr_b = (__global half *)(src1_ptr + z * src1_stride_z + x * src1_stride_y + src1_offset_first_element_in_bytes);
+ int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
+ int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
+
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
+ src1_addr_in_bytes += z * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
+
+ __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
+ __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
// Compute end row address for matrix B
__global half *src_end_addr_b = src_addr_b + COLS_B;
@@ -627,8 +664,9 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
* @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
* @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
* @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- *
- * @note: ALPHA must be passed in 8 bit fixed point format
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note:ALPHA must be passed in 8 bit fixed point format
*
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -666,8 +704,18 @@ __kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
// src_addr_a = address of matrix A
// src_addr_b = address of matrix B
- __global char *src_addr_a = src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
- __global char *src_addr_b = src1_ptr + z * src1_stride_z + x * src1_stride_y + src1_offset_first_element_in_bytes;
+ int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
+ int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
+
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
+ src1_addr_in_bytes += z * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
+
+ __global char *src_addr_a = (__global char *)(src0_ptr + src0_addr_in_bytes);
+ __global char *src_addr_b = (__global char *)(src1_ptr + src1_addr_in_bytes);
// Compute end row address for matrix B
__global char *src_end_addr_b = src_addr_b + COLS_B;
@@ -738,8 +786,9 @@ __kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
* @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
* @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
* @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
- *
- * @note: ALPHA must be passed in 16 bit fixed point format
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ * @note:ALPHA must be passed in 16 bit fixed point format
*
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -777,8 +826,18 @@ __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
// src_addr_a = address of matrix A
// src_addr_b = address of matrix B
- __global short *src_addr_a = (__global short *)(src0_ptr + z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes);
- __global short *src_addr_b = (__global short *)(src1_ptr + z * src1_stride_z + x * src1_stride_y + src1_offset_first_element_in_bytes);
+ int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
+ int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
+
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
+ src1_addr_in_bytes += z * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
+
+ __global short *src_addr_a = (__global short *)(src0_ptr + src0_addr_in_bytes);
+ __global short *src_addr_b = (__global short *)(src1_ptr + src1_addr_in_bytes);
// Compute end row address for matrix B
__global short *src_end_addr_b = src_addr_b + COLS_B;
@@ -845,6 +904,8 @@ __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
* @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
* @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
* @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -885,7 +946,13 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
src_addr.s1 += get_global_id(2) * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
@@ -1013,6 +1080,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
* This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
* @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
* @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -1054,8 +1123,12 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
- // For convolution layer we do not want to slide the matrix B along Z
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
src_addr.s1 += get_global_id(2) * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
// Address boundary for matrix A
int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
@@ -1251,6 +1324,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
* This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
* @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
* @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -1293,8 +1368,12 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
- // For convolution layer we do not want to slide the matrix B along Z
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
src_addr.s1 += get_global_id(2) * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
// Address boundary for the matrix A
int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
@@ -1460,6 +1539,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
* @note The number matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
* @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
* @note The optional alpha value must be passed in 8 bit fixed point format using -DALPHA
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -1500,7 +1581,13 @@ __kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
src_addr.s1 += get_global_id(2) * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
@@ -1636,6 +1723,8 @@ __kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
* @note The number of matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
* @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
* @note The optional alpha value must be passed in 16 bit fixed point format using -DALPHA
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
*
* @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
* @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
@@ -1676,7 +1765,13 @@ __kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
// Add offset for batched GEMM
src_addr.s0 += get_global_id(2) * src0_stride_z;
+
+#if defined(MATRIX_B_DEPTH)
+ // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+ src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
+#else // defined(MATRIX_B_DEPTH)
src_addr.s1 += get_global_id(2) * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));