aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_batched.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_batched.hpp14
1 files changed, 8 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
index 4453ee8243..be2f5614be 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
@@ -36,9 +36,9 @@ private:
UniqueGemmCommon<To, Tr> _subgemm = nullptr;
public:
- GemvBatched(const GemmArgs<Tr> &args) {
+ GemvBatched(const GemmArgs &args) {
/* Just create a subgemm with batches->M */
- GemmArgs<Tr> newargs = args;
+ GemmArgs newargs = args;
newargs._Msize = args._nbatches;
newargs._nbatches = 1;
newargs._cfg = nullptr;
@@ -47,13 +47,15 @@ public:
void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
const To *B, const int ldb, const int B_multi_stride,
- Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override {
- UNUSED(lda);
- UNUSED(ldc);
+ Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
+ const Tr *bias, const int bias_multi_stride) override {
/* A and C's batch stride becomes their new row stride. New batch stride is 0 as nbatches for subgemm is always 1. */
_subgemm->set_arrays(A, A_batch_stride, 0, A_multi_stride,
B, ldb, B_multi_stride,
- C, C_batch_stride, 0, C_multi_stride);
+ C, C_batch_stride, 0, C_multi_stride,
+ bias, bias_multi_stride);
+ UNUSED(lda);
+ UNUSED(ldc);
}
unsigned int get_window_size() const override {