diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_batched.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemv_batched.hpp | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp index d91b44b9a8..d65971e47d 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp @@ -36,11 +36,12 @@ private: UniqueGemmCommon<To, Tr> _subgemm = nullptr; public: - GemvBatched(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, - const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, - const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint) { + GemvBatched(const GemmArgs<Tr> &args) { /* Just create a subgemm with batches->M */ - _subgemm = gemm<To,Tr>(ci, nbatches, N, K, 1, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint); + GemmArgs<Tr> newargs = args; + newargs._Msize = args._nbatches; + newargs._nbatches = 1; + _subgemm = gemm<To,Tr>(newargs, nullptr); } void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, |