diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp index 82e0625b68..a744376393 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp @@ -108,7 +108,7 @@ class GemmHybrid : public GemmCommon<To, Tr> { // n_block: Work out how many rows (of length k_block) will fit in the L2 // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents. unsigned int n_block = (((L2_size * 9) / 10) - (k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) / - (sizeof(Toi) * k_block); + (sizeof(Toi) * k_block); // Needs to be (at least a single) multiple of the kernel output width. n_block /= strategy::out_width(); @@ -128,11 +128,11 @@ public: /* Constructor */ GemmHybrid(const GemmArgs<Tr> &args) - : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), - _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB), _beta(args._beta), - _k_block(compute_k_block(args)), _n_block(compute_n_block(args)), - _Mround(roundup(args._Msize, strategy::out_height())), - _window_range(iceildiv(args._Msize, strategy::out_height()), _nbatches, iceildiv(_Nsize, _n_block), _nmulti) { } + : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), + _nbatches(args._nbatches), _nmulti(args._nmulti), _trB(args._trB), _beta(args._beta), + _k_block(compute_k_block(args)), _n_block(compute_n_block(args)), + _Mround(roundup(args._Msize, strategy::out_height())), + _window_range(iceildiv(args._Msize, strategy::out_height()), _nbatches, iceildiv(_Nsize, _n_block), _nmulti) { } // Interface implementation - Compulsory functions unsigned int get_window_size() const override { @@ -190,7 +190,7 @@ public: b_panel, this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc, (k0 == 0) ? _beta : static_cast<Tr>(1), - (m_end - m_start), (nmax - n0), kern_k); + (m_end - m_start), (nmax - n0), kmax-k0); } while (p.next_dim1()); } } |