aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_batched.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_batched.hpp9
1 files changed, 5 insertions, 4 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
index d91b44b9a8..d65971e47d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
@@ -36,11 +36,12 @@ private:
UniqueGemmCommon<To, Tr> _subgemm = nullptr;
public:
- GemvBatched(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
- const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint) {
+ GemvBatched(const GemmArgs<Tr> &args) {
/* Just create a subgemm with batches->M */
- _subgemm = gemm<To,Tr>(ci, nbatches, N, K, 1, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint);
+ GemmArgs<Tr> newargs = args;
+ newargs._Msize = args._nbatches;
+ newargs._nbatches = 1;
+ _subgemm = gemm<To,Tr>(newargs, nullptr);
}
void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,