diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp index 55b1f9bbe6..49681ec404 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp @@ -53,7 +53,7 @@ class GemvNativeTransposed : public GemmCommon<To, Tr> { const unsigned int _nmultis; - const Tr _beta; + const Activation _act; const CPUInfo * const _ci; @@ -64,8 +64,8 @@ public: GemvNativeTransposed(GemvNativeTransposed &) = delete; GemvNativeTransposed & operator= (GemvNativeTransposed &) = delete; - GemvNativeTransposed(const GemmArgs<Tr> &args) - : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _beta(args._beta), _ci(args._ci) { + GemvNativeTransposed(const GemmArgs &args) + : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _act(args._act), _ci(args._ci) { /* For now don't do any blocking. TODO: figure out if we should. */ m_block = _Ksize; n_block = _Nsize; @@ -111,7 +111,18 @@ public: strat.kernel(this->_Bptr + (multi * this->_B_multi_stride) + (m0 * this->_ldb) + n0, this->_Aptr + (multi * this->_A_multi_stride) + m0, this->_Cptr + (multi * this->_C_multi_stride) + n0, - _beta, this->_ldb, (mmax-m0), (nmax-n0)); + static_cast<Tr>(0), this->_ldb, (mmax-m0), (nmax-n0)); + + // Handle activation separately for now + if (this->_bias) { + activator<true>(this->_Cptr + (multi * this->_C_multi_stride) + n0, 0, + this->_bias + (multi * this->_bias_multi_stride) + n0, + _act, 1, (nmax-n0)); + } else { + activator<false>(this->_Cptr + (multi * this->_C_multi_stride) + n0, 0, + static_cast<const Tr *>(nullptr), + _act, 1, (nmax-n0)); + } } } } |