diff options
author | David Mansell <David.Mansell@arm.com> | 2018-07-06 17:53:35 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:10 +0000 |
commit | e39334c15c7fd141bb8173d5017ea5ca157fca2c (patch) | |
tree | fffa2f7b136525037c4d99586bc194374e5bd3dc /src/core/NEON/kernels/arm_gemm/gemm_native.hpp | |
parent | e8bd2c729546e59aa0adc241976ea91fc6f25b52 (diff) | |
download | ComputeLibrary-e39334c15c7fd141bb8173d5017ea5ca157fca2c.tar.gz |
COMPMID-1271: New system for GEMM heuristics
This patch implements a system for separating the "validity" from
"preferred" aspect of the current heuristics in gemm_*.cpp.
Now, each gemm_*.cpp defines a list of candidate implementations,
each of which supplies an is_valid() function (to check for
validity), an is_preferred() function (the "heuristic" part), and an
instantiate() function which actually produces the GemmCommon object
pointer.
The actual gemm() function is now templated and uses this list to
select an implementation. This patch also implements a mechanism to
identify the preferred implementation, and override it via the
GemmConfig structure.
Change-Id: Id49ab7af8bf2e3e9fd951a9698883ade234d40e1
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139120
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_native.hpp | 62 |
1 files changed, 35 insertions, 27 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp index 6fed645d82..baa1316745 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp @@ -62,6 +62,14 @@ class GemmNative : public GemmCommon<To, Tr> { unsigned int k_block=0; unsigned int n_block=0; + unsigned int window_per_batch() const { + return iceildiv(_Msize, strategy::out_height()); + } + + unsigned int window_per_multi() const { + return window_per_batch() * _nbatches; + } + public: GemmNative(GemmNative &) = delete; GemmNative & operator= (GemmNative &) = delete; @@ -73,9 +81,9 @@ public: n_block = N; } - // Window is number of out_height blocks + // Window is amount per multi multiplied by total number of multis. unsigned int get_window_size() const override { - return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis; + return window_per_multi() * _nmultis; } // Actually execute the GEMM. @@ -85,39 +93,39 @@ public: #endif strategy strat(_ci); - const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height); - const unsigned int window_per_multi = window_per_batch * _nbatches; - - const unsigned int first_multi = start / window_per_multi; - const unsigned int last_multi = end / window_per_multi; - - const unsigned int first_batch = (start - (first_multi * window_per_multi)) / window_per_batch; - const unsigned int last_batch = (end - (last_multi * window_per_multi)) / window_per_batch; - - const unsigned int first_row = ((start - (first_multi * window_per_multi)) % window_per_batch) * strategy::out_height; - const unsigned int last_row = ((end - (last_multi * window_per_multi)) % window_per_batch) * strategy::out_height; - static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same."); static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same."); - for (unsigned int multi=first_multi; multi<=last_multi; multi++) { - const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0; - const unsigned int batch_max = (multi == last_multi) ? last_batch : (_nbatches-1); + /* Compute starting point based on 'start' */ + unsigned int multi = start / window_per_multi(); + unsigned int multi_pos = start % window_per_multi(); + + unsigned int batch = multi_pos / window_per_batch(); + unsigned int batch_pos = multi_pos % window_per_batch(); - for (unsigned int batch=batch_0; batch <= batch_max; batch++) { - const unsigned int m_start = ((multi == first_multi) && (batch==first_batch)) ? first_row : 0; - const unsigned int m_end = ((multi == last_multi) && (batch==last_batch)) ? last_row : _Msize; + unsigned int y0 = batch_pos * strategy::out_height(); - for (unsigned int y0=m_start; y0<m_end; y0+=strategy::out_height) { - const unsigned int ymax = std::min(y0 + strategy::out_height, m_end); + for (unsigned int pos=start; pos<end; pos++) { + const unsigned int ymax = std::min(y0 + strategy::out_height(), _Msize); #ifdef CYCLE_PROFILING - auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize); + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize); #endif - strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda, - this->_Bptr + (multi * this->_B_multi_stride), this->_ldb, - this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc, - _beta, (ymax-y0), _Nsize, _Ksize); + strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda, + this->_Bptr + (multi * this->_B_multi_stride), this->_ldb, + this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc, + _beta, (ymax-y0), _Nsize, _Ksize); + + /* Advance to next item */ + y0 += strategy::out_height(); + + /* Check for batch/multi overflow */ + if (y0 >= _Msize) { + y0=0; + batch++; + if (batch == _nbatches) { + batch=0; + multi++; } } } |