From d93991e290618a685b67506c78090350e6aee43f Mon Sep 17 00:00:00 2001 From: David Mansell Date: Fri, 6 Jul 2018 14:52:52 +0100 Subject: COMPMID-1380: Pre-work for SVE support. This patch makes the needed infrastructure changes to allow SVE kernels to be added later on. Change-Id: Ide5bccac2f47278e93fff3d648231aee2d5f8c2e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139070 Reviewed-by: Anthony Barbier Tested-by: Jenkins --- .../NEON/kernels/arm_gemm/gemm_interleaved.hpp | 127 ++++++++------------- 1 file changed, 50 insertions(+), 77 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index c304edd1f9..c5a43e6519 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -158,7 +158,7 @@ class GemmInterleaved : public GemmCommon { // C working size: One needed per thread. size_t get_c_working_size() const { - return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height); + return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height()); } // Internal execute function. @@ -174,13 +174,13 @@ class GemmInterleaved : public GemmCommon { blockwalker next=current; /* Translate 'start' and 'end' into a position within the batches and rows. */ - const unsigned int window_per_batch = _Mround / strategy::out_height; + const unsigned int window_per_batch = _Mround / strategy::out_height(); unsigned int batch_0 = start / window_per_batch; unsigned int batch_end = end / window_per_batch; /* Compute the M values to operate on */ - unsigned int m_0 = (start - (batch_0 * window_per_batch)) * strategy::out_height; - unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height; + unsigned int m_0 = (start - (batch_0 * window_per_batch)) * strategy::out_height(); + unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height(); /* Make sure we've been set up correctly. */ if (pretransposed) { @@ -214,7 +214,7 @@ class GemmInterleaved : public GemmCommon { for (;!current.done();current.advance()) { if (current.newkblock()) { #ifdef CYCLE_PROFILING - auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height * (current.kmax()-current.k0()) * sizeof(Toi)); + auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height() * (current.kmax()-current.k0()) * sizeof(Toi)); #endif for (unsigned int batch = batch_0; batch <= batch_end; batch++) { unsigned int first_m = (batch == batch_0) ? m_0 : 0; @@ -223,25 +223,17 @@ class GemmInterleaved : public GemmCommon { if (first_m >= last_m) continue; - if (_trA ^ strategy::A_transpose) { - Transform( - a_panel + ((batch * _Mround + first_m) * _k_block), - this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride), - this->_lda, first_m, last_m, current.k0(), current.kmax()); - } else { - Transform( - a_panel + ((batch * _Mround + first_m) * _k_block), - this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride), - this->_lda, first_m, last_m, current.k0(), current.kmax()); - } + strat.transforms.PrepareA(a_panel + ((batch * _Mround + first_m) * _k_block), + this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride), + this->_lda, first_m, last_m, current.k0(), current.kmax(), _trA); } // Figure out how many "K" the kernel will actually process. - kern_k = iceildiv(current.kmax() - current.k0(), strategy::k_unroll); - kern_k *= strat.k_unroll; + kern_k = iceildiv(current.kmax() - current.k0(), strategy::k_unroll()); + kern_k *= strat.k_unroll(); } - int bblocks = iceildiv(current.xmax() - current.x0(), strategy::out_width); + int bblocks = iceildiv(current.xmax() - current.x0(), strategy::out_width()); if (!pretransposed) { /* Look ahead to the next block and populate it if necessary. @@ -259,15 +251,9 @@ class GemmInterleaved : public GemmCommon { #endif Toi *b_panel = reinterpret_cast(buffer); - if (_trB ^ strategy::B_transpose) { - Transform( - b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb, - next.x0(), next.xmax(), next.k0(), next.kmax()); - } else { - Transform( - b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb, - next.x0(), next.xmax(), next.k0(), next.kmax()); - } + + strat.transforms.PrepareB(b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb, + next.x0(), next.xmax(), next.k0(), next.kmax(), _trB); }); } @@ -278,15 +264,9 @@ class GemmInterleaved : public GemmCommon { #endif Toi *b_panel = reinterpret_cast(bpv); - if (_trB ^ strategy::B_transpose) { - Transform( - b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb, - current.x0(), current.xmax(), current.k0(), current.kmax()); - } else { - Transform( - b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb, - current.x0(), current.xmax(), current.k0(), current.kmax()); - } + + strat.transforms.PrepareB(b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb, + current.x0(), current.xmax(), current.k0(), current.kmax(), _trB); })); } @@ -300,33 +280,32 @@ class GemmInterleaved : public GemmCommon { if (first_m >= last_m) continue; - for (unsigned int y=first_m; y( - this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride), - c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(), - _alpha, (current.k0()==0 ? _beta : static_cast(1))); + strat.transforms.Merge(this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride), + c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(), + _alpha, (current.k0()==0 ? _beta : static_cast(1))); } } } if (pretransposed) { - b_panel += (bblocks * strat.out_width * kern_k); + b_panel += (bblocks * strat.out_width() * kern_k); } else { _bm->release(current.index()); } @@ -353,11 +332,11 @@ public: // k_block: Find out how much of the larger array can be loaded into half the cache. // This should account for associative caches. - _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width, strategy::out_height))); + _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height()))); // Needs to be (at least a single) multiple of the K unroll level. - _k_block /= strategy::k_unroll; - _k_block = std::max(_k_block, 1U) * strategy::k_unroll; + _k_block /= strategy::k_unroll(); + _k_block = std::max(_k_block, 1U) * strategy::k_unroll(); // Now tune to presented problem size; this is how many blocks we need. int num_k_blocks = iceildiv(K, _k_block); @@ -366,28 +345,28 @@ public: _k_block = iceildiv(K, num_k_blocks); // And round UP to the K unroll level required. - _k_block = iceildiv(_k_block, strategy::k_unroll); - _k_block *= strategy::k_unroll; + _k_block = iceildiv(_k_block, strategy::k_unroll()); + _k_block *= strategy::k_unroll(); // x_block: Work out how many rows (of length k_block) will fit in the L2 // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents. - _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width + strategy::out_height))) / + _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) / (sizeof(Toi) * _k_block); // Needs to be (at least a single) multiple of the kernel output width. - _x_block /= strategy::out_width; - _x_block = std::max(_x_block, 1U) * strategy::out_width; + _x_block /= strategy::out_width(); + _x_block = std::max(_x_block, 1U) * strategy::out_width(); // And tune to the presented problem size. int num_x_blocks = iceildiv(N, _x_block); _x_block = iceildiv(N, num_x_blocks); - _x_block = iceildiv(_x_block, strategy::out_width); - _x_block *= strategy::out_width; + _x_block = iceildiv(_x_block, strategy::out_width()); + _x_block *= strategy::out_width(); // Work out the rounded size of M - needed for some buffers. - _Mround = iceildiv(M, strategy::out_height); - _Mround *= strategy::out_height; + _Mround = iceildiv(M, strategy::out_height()); + _Mround *= strategy::out_height(); } // Interface implementation - Compulsory functions @@ -398,7 +377,7 @@ public: // manager). unsigned int get_window_size() const override { // _Mround is a multiple of out_height by definition. - return (_Mround / strategy::out_height) * _nbatches; + return (_Mround / strategy::out_height()) * _nbatches; } // set_nthreads: pass on to buffer manager to avoid it waiting for non-existant threads. @@ -483,11 +462,11 @@ public: size_t k_size = (current.kmax() - current.k0()); /* Round sizes up as needed. */ - x_size = iceildiv(x_size, strategy::out_width); - x_size *= strategy::out_width; + x_size = iceildiv(x_size, strategy::out_width()); + x_size *= strategy::out_width(); - k_size = iceildiv(k_size, strategy::k_unroll); - k_size *= strategy::k_unroll; + k_size = iceildiv(k_size, strategy::k_unroll()); + k_size *= strategy::k_unroll(); total += x_size * k_size * sizeof(Toi); } while (current.advance()); @@ -499,6 +478,7 @@ public: blockwalker current(*this); Toi *buffer = reinterpret_cast(in_buffer); _B_transposed = buffer; + strategy strat(_ci); do { /* Figure out the size of each block. */ @@ -506,21 +486,14 @@ public: size_t k_size = (current.kmax() - current.k0()); /* Round sizes up as needed. */ - x_size = iceildiv(x_size, strategy::out_width); - x_size *= strategy::out_width; + x_size = iceildiv(x_size, strategy::out_width()); + x_size *= strategy::out_width(); - k_size = iceildiv(k_size, strategy::k_unroll); - k_size *= strategy::k_unroll; + k_size = iceildiv(k_size, strategy::k_unroll()); + k_size *= strategy::k_unroll(); - if (_trB ^ strategy::B_transpose) { - Transform( - buffer, B + (current.multi() * B_multi_stride), ldb, - current.x0(), current.xmax(), current.k0(), current.kmax()); - } else { - Transform( - buffer, B + (current.multi() * B_multi_stride), ldb, - current.x0(), current.xmax(), current.k0(), current.kmax()); - } + strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb, + current.x0(), current.xmax(), current.k0(), current.kmax(), _trB); buffer += (x_size * k_size); } while (current.advance()); -- cgit v1.2.1