aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/CL/cl_kernels/gemm.cl3
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp9
2 files changed, 6 insertions, 6 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index a5b0acbe9c..00b130f5a9 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -1851,6 +1851,9 @@ __kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
// Compute dst address
__global uchar *dst_addr = offset(&dst, 0, 0);
+ // Add offset for batched GEMM
+ dst_addr += get_global_id(2) * dst_stride_z;
+
// Multiply by the weight of matrix product and store the result
short8 acc_qs16;
acc_qs16 = convert_short8_sat(acc0);
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 7b785bb8da..dc9c59d2d0 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -162,12 +162,9 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
// Collapse along the Z direction
// This collapse needs to be here in order to tune the Z dimension of LWS
- Window collapsed = win;
- if(input1->num_dimensions() > 1)
- {
- const unsigned int dimension_to_collapse = std::min(static_cast<unsigned int>(input1->num_dimensions() - 1), 2u);
- collapsed = win.collapse(win, dimension_to_collapse);
- }
+ Window collapsed = win;
+ const unsigned int dimension_to_collapse = std::min(static_cast<unsigned int>(output->num_dimensions()), 2u);
+ collapsed = win.collapse(win, dimension_to_collapse);
Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
return std::make_pair(err, collapsed);