aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp19
1 files changed, 15 insertions, 4 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
index 55b1f9bbe6..49681ec404 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
@@ -53,7 +53,7 @@ class GemvNativeTransposed : public GemmCommon<To, Tr> {
const unsigned int _nmultis;
- const Tr _beta;
+ const Activation _act;
const CPUInfo * const _ci;
@@ -64,8 +64,8 @@ public:
GemvNativeTransposed(GemvNativeTransposed &) = delete;
GemvNativeTransposed & operator= (GemvNativeTransposed &) = delete;
- GemvNativeTransposed(const GemmArgs<Tr> &args)
- : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _beta(args._beta), _ci(args._ci) {
+ GemvNativeTransposed(const GemmArgs &args)
+ : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _act(args._act), _ci(args._ci) {
/* For now don't do any blocking. TODO: figure out if we should. */
m_block = _Ksize;
n_block = _Nsize;
@@ -111,7 +111,18 @@ public:
strat.kernel(this->_Bptr + (multi * this->_B_multi_stride) + (m0 * this->_ldb) + n0,
this->_Aptr + (multi * this->_A_multi_stride) + m0,
this->_Cptr + (multi * this->_C_multi_stride) + n0,
- _beta, this->_ldb, (mmax-m0), (nmax-n0));
+ static_cast<Tr>(0), this->_ldb, (mmax-m0), (nmax-n0));
+
+ // Handle activation separately for now
+ if (this->_bias) {
+ activator<true>(this->_Cptr + (multi * this->_C_multi_stride) + n0, 0,
+ this->_bias + (multi * this->_bias_multi_stride) + n0,
+ _act, 1, (nmax-n0));
+ } else {
+ activator<false>(this->_Cptr + (multi * this->_C_multi_stride) + n0, 0,
+ static_cast<const Tr *>(nullptr),
+ _act, 1, (nmax-n0));
+ }
}
}
}