aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_batched.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_batched.hpp6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
index ad504f2664..aa03fb6aa1 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
@@ -31,9 +31,9 @@ namespace arm_gemm {
* efficiently as a GEMM (with M'=nbatches and nbatches'=1). This wrapper
* implements this. */
template<typename To, typename Tr>
-class GemvBatched : public GemmCommon<To, Tr> {
+class GemvBatched : public GemmCommon<To, To, Tr> {
private:
- UniqueGemmCommon<To, Tr> _subgemm = nullptr;
+ UniqueGemmCommon<To, To, Tr> _subgemm = nullptr;
public:
GemvBatched(const GemmArgs &args) {
@@ -42,7 +42,7 @@ public:
newargs._Msize = args._nbatches;
newargs._nbatches = 1;
newargs._cfg = nullptr;
- _subgemm = gemm<To,Tr>(newargs);
+ _subgemm = gemm<To,To,Tr>(newargs);
}
void set_arrays(const To *A, const int, const int A_batch_stride, const int A_multi_stride,