diff options
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp | 31 |
1 files changed, 30 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index 5d5f21507f..41fecc6bec 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -79,7 +79,36 @@ void run_hybrid_kernel<Nothing, false>::run( #endif UNUSED(kern_k); - strat.kernel(num_strings, string_ptr, A_arg, M, N, b_ptr, output_arg, bias_ptr, act, accumulate); + /* Indirect hybrid kernels read the full width of the bias. So we need to detect the case where we are writing + * a partial block and pad the bias for that block. */ + if (bias_ptr && !accumulate && (N % strategy::out_width() != 0)) { + /* Break N into "N_bulk" (a multiple of output width) and "N_remainder" */ + unsigned int N_remainder = N % strategy::out_width(); + unsigned int N_bulk = N - N_remainder; + + /* Output argument to be used for the tail */ + IndirectOutputArg<Tr> offset_output = output_arg; + + /* If there is a "bulk" to be processed, handle that and update "offset_output" appropriately. */ + if (N_bulk > 0) { + strat.kernel(num_strings, string_ptr, A_arg, M, N_bulk, b_ptr, output_arg, bias_ptr, act, accumulate); + + if (output_arg.is_indirect) { + offset_output = IndirectOutputArg<Tr>(output_arg.indirect.ptr, output_arg.indirect.offset + N_bulk); + } else { + offset_output = IndirectOutputArg<Tr>(output_arg.direct.base + N_bulk, output_arg.direct.stride); + } + } + + /* Pad the bias buffer for the remainder */ + Tr *bias_pad_buffer = reinterpret_cast<Tr *>(alloca(strategy::out_width() * sizeof(Tr))); + memcpy(bias_pad_buffer, bias_ptr + N_bulk, N_remainder * sizeof(Tr)); + + /* Process the remainder, offsetting the B pointer as needed. */ + strat.kernel(num_strings, string_ptr, A_arg, M, N_remainder, b_ptr + (N_bulk * kern_k), offset_output, bias_pad_buffer, act, accumulate); + } else { + strat.kernel(num_strings, string_ptr, A_arg, M, N, b_ptr, output_arg, bias_ptr, act, accumulate); + } } template<> |