aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-14 19:03:09 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-23 12:08:12 +0000
commit48b3ef89de5f21a0169d8416e3d54081f82c7bf8 (patch)
treef857d733ccf446c704823dc7ac796a96eb55095e /src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
parent1dce3101ef8d77c8cf0af7dfd4af6595a0136b91 (diff)
downloadComputeLibrary-48b3ef89de5f21a0169d8416e3d54081f82c7bf8.tar.gz
COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels
Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/2141 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp14
1 files changed, 9 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index faff9acd2e..efd984561d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -68,8 +68,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
const bool _trA;
const bool _trB;
- const Tr _alpha;
- const Tr _beta;
+ const Activation _act;
const int _maxthreads;
int _nthreads;
@@ -297,9 +296,14 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
#ifdef CYCLE_PROFILING
auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr)));
#endif
+ /* Only activate on last pass, only add bias on first pass, ask for accumulation on any non-first pass */
+ const bool first_pass = current.k0()==0;
+ const bool last_pass = current.kmax()==_Ksize;
+
strat.transforms.Merge(this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride),
c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(),
- _alpha, (current.k0()==0 ? _beta : static_cast<Tr>(1)));
+ ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
+ (last_pass ? _act : Activation()), !first_pass);
}
}
}
@@ -317,10 +321,10 @@ public:
GemmInterleaved & operator= (GemmInterleaved &) = delete;
/* Constructor */
- GemmInterleaved(const GemmArgs<Tr> &args)
+ GemmInterleaved(const GemmArgs &args)
: _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
_nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB),
- _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
_pretransposed(args._pretransposed_hint) {
const unsigned int L1_size = _ci->get_L1_cache_size();
const unsigned int L2_size = _ci->get_L2_cache_size();