diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_native.hpp | 62 |
1 files changed, 35 insertions, 27 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp index 6fed645d82..baa1316745 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp @@ -62,6 +62,14 @@ class GemmNative : public GemmCommon<To, Tr> { unsigned int k_block=0; unsigned int n_block=0; + unsigned int window_per_batch() const { + return iceildiv(_Msize, strategy::out_height()); + } + + unsigned int window_per_multi() const { + return window_per_batch() * _nbatches; + } + public: GemmNative(GemmNative &) = delete; GemmNative & operator= (GemmNative &) = delete; @@ -73,9 +81,9 @@ public: n_block = N; } - // Window is number of out_height blocks + // Window is amount per multi multiplied by total number of multis. unsigned int get_window_size() const override { - return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis; + return window_per_multi() * _nmultis; } // Actually execute the GEMM. @@ -85,39 +93,39 @@ public: #endif strategy strat(_ci); - const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height); - const unsigned int window_per_multi = window_per_batch * _nbatches; - - const unsigned int first_multi = start / window_per_multi; - const unsigned int last_multi = end / window_per_multi; - - const unsigned int first_batch = (start - (first_multi * window_per_multi)) / window_per_batch; - const unsigned int last_batch = (end - (last_multi * window_per_multi)) / window_per_batch; - - const unsigned int first_row = ((start - (first_multi * window_per_multi)) % window_per_batch) * strategy::out_height; - const unsigned int last_row = ((end - (last_multi * window_per_multi)) % window_per_batch) * strategy::out_height; - static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same."); static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same."); - for (unsigned int multi=first_multi; multi<=last_multi; multi++) { - const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0; - const unsigned int batch_max = (multi == last_multi) ? last_batch : (_nbatches-1); + /* Compute starting point based on 'start' */ + unsigned int multi = start / window_per_multi(); + unsigned int multi_pos = start % window_per_multi(); + + unsigned int batch = multi_pos / window_per_batch(); + unsigned int batch_pos = multi_pos % window_per_batch(); - for (unsigned int batch=batch_0; batch <= batch_max; batch++) { - const unsigned int m_start = ((multi == first_multi) && (batch==first_batch)) ? first_row : 0; - const unsigned int m_end = ((multi == last_multi) && (batch==last_batch)) ? last_row : _Msize; + unsigned int y0 = batch_pos * strategy::out_height(); - for (unsigned int y0=m_start; y0<m_end; y0+=strategy::out_height) { - const unsigned int ymax = std::min(y0 + strategy::out_height, m_end); + for (unsigned int pos=start; pos<end; pos++) { + const unsigned int ymax = std::min(y0 + strategy::out_height(), _Msize); #ifdef CYCLE_PROFILING - auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize); + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize); #endif - strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda, - this->_Bptr + (multi * this->_B_multi_stride), this->_ldb, - this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc, - _beta, (ymax-y0), _Nsize, _Ksize); + strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda, + this->_Bptr + (multi * this->_B_multi_stride), this->_ldb, + this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc, + _beta, (ymax-y0), _Nsize, _Ksize); + + /* Advance to next item */ + y0 += strategy::out_height(); + + /* Check for batch/multi overflow */ + if (y0 >= _Msize) { + y0=0; + batch++; + if (batch == _nbatches) { + batch=0; + multi++; } } } |