From 48b3ef89de5f21a0169d8416e3d54081f82c7bf8 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 14 Oct 2019 19:03:09 +0100 Subject: COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/2141 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez --- src/core/NEON/kernels/arm_gemm/gemm_native.hpp | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp index ba9163b29b..fe6ebef045 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp @@ -54,7 +54,7 @@ class GemmNative : public GemmCommon { const unsigned int _nbatches; const unsigned int _nmultis; - const Tr _beta; + const Activation _act; const CPUInfo * const _ci; @@ -63,11 +63,11 @@ class GemmNative : public GemmCommon { const NDRange<4> _window_range; - static unsigned int compute_k_block(const GemmArgs &args) { + static unsigned int compute_k_block(const GemmArgs &args) { return args._Ksize; } - static unsigned int compute_n_block(const GemmArgs &args) { + static unsigned int compute_n_block(const GemmArgs &args) { if ((args._cfg != nullptr) && args._cfg->outer_block_size > 0) { return args._cfg->outer_block_size; } else { @@ -79,10 +79,10 @@ public: GemmNative(GemmNative &) = delete; GemmNative & operator= (GemmNative &) = delete; - GemmNative(const GemmArgs &args) + GemmNative(const GemmArgs &args) : _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), _nbatches(args._nbatches), _nmultis(args._nmulti), - _beta(args._beta), _ci(args._ci), + _act(args._act), _ci(args._ci), _k_block(compute_k_block(args)), _n_block(compute_n_block(args)), _window_range(iceildiv(_Msize, strategy::out_height()), _nbatches, iceildiv(_Nsize, _n_block), _nmultis) { } @@ -127,7 +127,16 @@ public: strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda, this->_Bptr + (multi * this->_B_multi_stride) + n0, this->_ldb, this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc) + n0, this->_ldc, - _beta, (ymax-y0), (nmax - n0), _Ksize); + (ymax-y0), (nmax-n0), _Ksize, + (strategy::supports_bias() && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, + _act, false); + + // Add bias externally if needed + if (!strategy::supports_bias() && this->_bias) { + bias_adder(this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc) + n0, this->_ldc, + this->_bias + (multi * this->_bias_multi_stride) + n0, + (ymax - y0), (nmax - n0)); + } } while (p.next_dim1()); } }; -- cgit v1.2.1