From 48b3ef89de5f21a0169d8416e3d54081f82c7bf8 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 14 Oct 2019 19:03:09 +0100 Subject: COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/2141 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez --- src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp | 35 ++++++++++++++++++++------ 1 file changed, 27 insertions(+), 8 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp index 436f55dee2..c3abb04db7 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,7 @@ #include #include "arm_gemm.hpp" +#include "bias_adder.hpp" #include "ndrange.hpp" #include "utils.hpp" @@ -58,7 +59,7 @@ class GemmHybrid : public GemmCommon { const bool _trB; - const Tr _beta; + const Activation _act; /* Blocking info */ const unsigned int _k_block; @@ -70,7 +71,12 @@ class GemmHybrid : public GemmCommon { const NDRange<4> _window_range; - static unsigned int compute_k_block(const GemmArgs &args) { + static unsigned int compute_k_block(const GemmArgs &args) { + // Some kernels don't support append mode - these can't do K blocking at all. + if (!strategy::supports_append()) { + return args._Ksize; + } + if (args._cfg && args._cfg->inner_block_size) { return args._cfg->inner_block_size; } @@ -97,7 +103,7 @@ class GemmHybrid : public GemmCommon { return k_block; } - static unsigned int compute_n_block(const GemmArgs &args) { + static unsigned int compute_n_block(const GemmArgs &args) { if (args._cfg && args._cfg->outer_block_size) { return args._cfg->outer_block_size; } @@ -127,9 +133,10 @@ public: GemmHybrid & operator= (GemmHybrid &) = delete; /* Constructor */ - GemmHybrid(const GemmArgs &args) + GemmHybrid(const GemmArgs &args) : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), - _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB), _beta(args._beta), + _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB), + _act(args._act), _k_block(compute_k_block(args)), _n_block(compute_n_block(args)), _Mround(roundup(args._Msize, strategy::out_height())), _window_range(iceildiv(args._Msize, strategy::out_height()), _nbatches, iceildiv(_Nsize, _n_block), _nmulti) { } @@ -164,6 +171,9 @@ public: unsigned int kmax = std::min(k0 + _k_block, _Ksize); unsigned int kern_k = roundup(kmax-k0, strategy::k_unroll()); + const bool first_pass = (k0 == 0); + const bool last_pass = (kmax == _Ksize); + auto p = _window_range.iterator(start, end); if (p.done()) { @@ -190,8 +200,17 @@ public: strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (m_start * this->_lda) + k0, this->_lda, b_panel, this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc, - (k0 == 0) ? _beta : static_cast(1), - (m_end - m_start), (nmax - n0), kmax-k0); + (m_end - m_start), (nmax - n0), kmax-k0, + (strategy::supports_bias() && first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, + last_pass ? _act : Activation(), !first_pass); + + // Add bias externally if needed + if (!strategy::supports_bias() && this->_bias && first_pass) { + bias_adder(this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc, + this->_bias + (multi * this->_bias_multi_stride) + n0, + (m_end - m_start), (nmax - n0)); + } + } while (p.next_dim1()); } } -- cgit v1.2.1