aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp14
1 files changed, 9 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index faff9acd2e..efd984561d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -68,8 +68,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
const bool _trA;
const bool _trB;
- const Tr _alpha;
- const Tr _beta;
+ const Activation _act;
const int _maxthreads;
int _nthreads;
@@ -297,9 +296,14 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
#ifdef CYCLE_PROFILING
auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr)));
#endif
+ /* Only activate on last pass, only add bias on first pass, ask for accumulation on any non-first pass */
+ const bool first_pass = current.k0()==0;
+ const bool last_pass = current.kmax()==_Ksize;
+
strat.transforms.Merge(this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride),
c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(),
- _alpha, (current.k0()==0 ? _beta : static_cast<Tr>(1)));
+ ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
+ (last_pass ? _act : Activation()), !first_pass);
}
}
}
@@ -317,10 +321,10 @@ public:
GemmInterleaved & operator= (GemmInterleaved &) = delete;
/* Constructor */
- GemmInterleaved(const GemmArgs<Tr> &args)
+ GemmInterleaved(const GemmArgs &args)
: _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
_nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB),
- _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
_pretransposed(args._pretransposed_hint) {
const unsigned int L1_size = _ci->get_L1_cache_size();
const unsigned int L2_size = _ci->get_L2_cache_size();