diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_batched.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemv_batched.hpp | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp index 4453ee8243..be2f5614be 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp @@ -36,9 +36,9 @@ private: UniqueGemmCommon<To, Tr> _subgemm = nullptr; public: - GemvBatched(const GemmArgs<Tr> &args) { + GemvBatched(const GemmArgs &args) { /* Just create a subgemm with batches->M */ - GemmArgs<Tr> newargs = args; + GemmArgs newargs = args; newargs._Msize = args._nbatches; newargs._nbatches = 1; newargs._cfg = nullptr; @@ -47,13 +47,15 @@ public: void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, const To *B, const int ldb, const int B_multi_stride, - Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override { - UNUSED(lda); - UNUSED(ldc); + Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride, + const Tr *bias, const int bias_multi_stride) override { /* A and C's batch stride becomes their new row stride. New batch stride is 0 as nbatches for subgemm is always 1. */ _subgemm->set_arrays(A, A_batch_stride, 0, A_multi_stride, B, ldb, B_multi_stride, - C, C_batch_stride, 0, C_multi_stride); + C, C_batch_stride, 0, C_multi_stride, + bias, bias_multi_stride); + UNUSED(lda); + UNUSED(ldc); } unsigned int get_window_size() const override { |