aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp21
1 files changed, 15 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index e53ddb26c1..842339ef23 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -68,12 +68,21 @@ public:
GemvPretransposed(GemvPretransposed &) = delete;
GemvPretransposed & operator= (GemvPretransposed &) = delete;
- GemvPretransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const bool trB, const Tr beta) :
- _Nsize(N), _Ksize(K), _nmultis(nmultis), _trB(trB), _beta(beta), _ci(ci),
- _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave) {
+ GemvPretransposed(const GemmArgs<Tr> &args)
+ : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _trB(args._trB), _beta(args._beta), _ci(args._ci),
+ _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave) {
/* For now don't do any blocking. TODO: figure out if we should. */
- m_block = K;
- n_block = N;
+ if (args._cfg && args._cfg->inner_block_size) {
+ m_block = args._cfg->inner_block_size;
+ } else {
+ m_block = _Ksize;
+ }
+
+ if (args._cfg && args._cfg->outer_block_size) {
+ n_block = args._cfg->outer_block_size;
+ } else {
+ n_block = _Nsize;
+ }
}
// Window is number of out_width blocks, times number of multis.