aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-14 19:03:09 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-10-23 12:08:12 +0000
commit48b3ef89de5f21a0169d8416e3d54081f82c7bf8 (patch)
treef857d733ccf446c704823dc7ac796a96eb55095e /src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
parent1dce3101ef8d77c8cf0af7dfd4af6595a0136b91 (diff)
downloadComputeLibrary-48b3ef89de5f21a0169d8416e3d54081f82c7bf8.tar.gz
COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels
Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/2141 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp35
1 files changed, 27 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
index 436f55dee2..c3abb04db7 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,6 +28,7 @@
#include <algorithm>
#include "arm_gemm.hpp"
+#include "bias_adder.hpp"
#include "ndrange.hpp"
#include "utils.hpp"
@@ -58,7 +59,7 @@ class GemmHybrid : public GemmCommon<To, Tr> {
const bool _trB;
- const Tr _beta;
+ const Activation _act;
/* Blocking info */
const unsigned int _k_block;
@@ -70,7 +71,12 @@ class GemmHybrid : public GemmCommon<To, Tr> {
const NDRange<4> _window_range;
- static unsigned int compute_k_block(const GemmArgs<Tr> &args) {
+ static unsigned int compute_k_block(const GemmArgs &args) {
+ // Some kernels don't support append mode - these can't do K blocking at all.
+ if (!strategy::supports_append()) {
+ return args._Ksize;
+ }
+
if (args._cfg && args._cfg->inner_block_size) {
return args._cfg->inner_block_size;
}
@@ -97,7 +103,7 @@ class GemmHybrid : public GemmCommon<To, Tr> {
return k_block;
}
- static unsigned int compute_n_block(const GemmArgs<Tr> &args) {
+ static unsigned int compute_n_block(const GemmArgs &args) {
if (args._cfg && args._cfg->outer_block_size) {
return args._cfg->outer_block_size;
}
@@ -127,9 +133,10 @@ public:
GemmHybrid & operator= (GemmHybrid &) = delete;
/* Constructor */
- GemmHybrid(const GemmArgs<Tr> &args)
+ GemmHybrid(const GemmArgs &args)
: _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
- _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB), _beta(args._beta),
+ _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB),
+ _act(args._act),
_k_block(compute_k_block(args)), _n_block(compute_n_block(args)),
_Mround(roundup(args._Msize, strategy::out_height())),
_window_range(iceildiv(args._Msize, strategy::out_height()), _nbatches, iceildiv(_Nsize, _n_block), _nmulti) { }
@@ -164,6 +171,9 @@ public:
unsigned int kmax = std::min(k0 + _k_block, _Ksize);
unsigned int kern_k = roundup(kmax-k0, strategy::k_unroll());
+ const bool first_pass = (k0 == 0);
+ const bool last_pass = (kmax == _Ksize);
+
auto p = _window_range.iterator(start, end);
if (p.done()) {
@@ -190,8 +200,17 @@ public:
strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (m_start * this->_lda) + k0, this->_lda,
b_panel,
this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc,
- (k0 == 0) ? _beta : static_cast<Tr>(1),
- (m_end - m_start), (nmax - n0), kmax-k0);
+ (m_end - m_start), (nmax - n0), kmax-k0,
+ (strategy::supports_bias() && first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
+ last_pass ? _act : Activation(), !first_pass);
+
+ // Add bias externally if needed
+ if (!strategy::supports_bias() && this->_bias && first_pass) {
+ bias_adder(this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc,
+ this->_bias + (multi * this->_bias_multi_stride) + n0,
+ (m_end - m_start), (nmax - n0));
+ }
+
} while (p.next_dim1());
}
}