diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/core/CL/cl_kernels/gemm.cl | 3 | ||||
-rw-r--r-- | src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp | 9 |
2 files changed, 6 insertions, 6 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index a5b0acbe9c..00b130f5a9 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -1851,6 +1851,9 @@ __kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0), // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); + // Add offset for batched GEMM + dst_addr += get_global_id(2) * dst_stride_z; + // Multiply by the weight of matrix product and store the result short8 acc_qs16; acc_qs16 = convert_short8_sat(acc0); diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp index 7b785bb8da..dc9c59d2d0 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp @@ -162,12 +162,9 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu // Collapse along the Z direction // This collapse needs to be here in order to tune the Z dimension of LWS - Window collapsed = win; - if(input1->num_dimensions() > 1) - { - const unsigned int dimension_to_collapse = std::min(static_cast<unsigned int>(input1->num_dimensions() - 1), 2u); - collapsed = win.collapse(win, dimension_to_collapse); - } + Window collapsed = win; + const unsigned int dimension_to_collapse = std::min(static_cast<unsigned int>(output->num_dimensions()), 2u); + collapsed = win.collapse(win, dimension_to_collapse); Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; return std::make_pair(err, collapsed); |