From 48b3ef89de5f21a0169d8416e3d54081f82c7bf8 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 14 Oct 2019 19:03:09 +0100 Subject: COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/2141 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez --- src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index faff9acd2e..efd984561d 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -68,8 +68,7 @@ class GemmInterleaved : public GemmCommon { const bool _trA; const bool _trB; - const Tr _alpha; - const Tr _beta; + const Activation _act; const int _maxthreads; int _nthreads; @@ -297,9 +296,14 @@ class GemmInterleaved : public GemmCommon { #ifdef CYCLE_PROFILING auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr))); #endif + /* Only activate on last pass, only add bias on first pass, ask for accumulation on any non-first pass */ + const bool first_pass = current.k0()==0; + const bool last_pass = current.kmax()==_Ksize; + strat.transforms.Merge(this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride), c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(), - _alpha, (current.k0()==0 ? _beta : static_cast(1))); + ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr), + (last_pass ? _act : Activation()), !first_pass); } } } @@ -317,10 +321,10 @@ public: GemmInterleaved & operator= (GemmInterleaved &) = delete; /* Constructor */ - GemmInterleaved(const GemmArgs &args) + GemmInterleaved(const GemmArgs &args) : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), _nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB), - _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), _pretransposed(args._pretransposed_hint) { const unsigned int L1_size = _ci->get_L1_cache_size(); const unsigned int L2_size = _ci->get_L2_cache_size(); -- cgit v1.2.1