aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp21
1 files changed, 16 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index 92064180a2..26fdfba8ff 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -26,7 +26,7 @@
#include <stdio.h>
#include "arm_gemm.hpp"
-
+#include "bias_adder.hpp"
#include "mergeresults.hpp"
#include "transform.hpp"
@@ -53,7 +53,7 @@ class GemvPretransposed : public GemmCommon<To, Tr> {
const bool _trB;
- const Tr _beta;
+ const Activation _act;
const CPUInfo * const _ci;
@@ -68,8 +68,8 @@ public:
GemvPretransposed(GemvPretransposed &) = delete;
GemvPretransposed & operator= (GemvPretransposed &) = delete;
- GemvPretransposed(const GemmArgs<Tr> &args)
- : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _trB(args._trB), _beta(args._beta), _ci(args._ci),
+ GemvPretransposed(const GemmArgs &args)
+ : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _trB(args._trB), _act(args._act), _ci(args._ci),
_buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave()) * strategy::A_interleave()) {
/* For now don't do any blocking. TODO: figure out if we should. */
if (args._cfg && args._cfg->inner_block_size) {
@@ -128,7 +128,18 @@ public:
(_Ksize * strategy::A_interleave()),
this->_Aptr + (multi * this->_A_multi_stride) + m0,
this->_Cptr + (multi * this->_C_multi_stride) + n,
- _beta, (mmax-m0), (nmax-n));
+ static_cast<Tr>(0), (mmax-m0), (nmax-n));
+
+ // Handle activation separately for now
+ if (this->_bias) {
+ activator<true>(this->_Cptr + (multi * this->_C_multi_stride) + n, 0,
+ this->_bias + (multi * this->_bias_multi_stride) + n,
+ _act, 1, (nmax-n));
+ } else {
+ activator<false>(this->_Cptr + (multi * this->_C_multi_stride) + n, 0,
+ static_cast<const Tr *>(nullptr),
+ _act, 1, (nmax-n));
+ }
}
}
}