diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index faff9acd2e..efd984561d 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -68,8 +68,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { const bool _trA; const bool _trB; - const Tr _alpha; - const Tr _beta; + const Activation _act; const int _maxthreads; int _nthreads; @@ -297,9 +296,14 @@ class GemmInterleaved : public GemmCommon<To, Tr> { #ifdef CYCLE_PROFILING auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr))); #endif + /* Only activate on last pass, only add bias on first pass, ask for accumulation on any non-first pass */ + const bool first_pass = current.k0()==0; + const bool last_pass = current.kmax()==_Ksize; + strat.transforms.Merge(this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride), c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(), - _alpha, (current.k0()==0 ? _beta : static_cast<Tr>(1))); + ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr), + (last_pass ? _act : Activation()), !first_pass); } } } @@ -317,10 +321,10 @@ public: GemmInterleaved & operator= (GemmInterleaved &) = delete; /* Constructor */ - GemmInterleaved(const GemmArgs<Tr> &args) + GemmInterleaved(const GemmArgs &args) : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), _nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB), - _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), _pretransposed(args._pretransposed_hint) { const unsigned int L1_size = _ci->get_L1_cache_size(); const unsigned int L2_size = _ci->get_L2_cache_size(); |