From c0b6f76561580414f08633a804fc548ccad65659 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 2 Nov 2020 01:37:17 +0000 Subject: COMPMID-3776: Indirect GEMM Signed-off-by: Georgios Pinitas Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343 Tested-by: Arm Jenkins Reviewed-by: Sang-Hoon Park Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- .../NEON/kernels/arm_gemm/gemv_pretransposed.hpp | 86 ++++++++-------------- 1 file changed, 31 insertions(+), 55 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp index 47909cdaeb..9de44fcb73 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp @@ -46,46 +46,39 @@ class GemvPretransposed : public GemmCommon { typedef typename strategy::operand_type Toi; typedef typename strategy::result_type Tri; - const unsigned int _Nsize; - const unsigned int _Ksize; - - const unsigned int _nmultis; - - const Activation _act; - - const CPUInfo * const _ci; + const GemmArgs _args; const unsigned int _buffer_per_multi; - unsigned int m_block=0; + unsigned int k_block=0; unsigned int n_block=0; - const Toi *_A_pretransposed = nullptr; + const Toi *_B_pretransposed = nullptr; public: GemvPretransposed(GemvPretransposed &) = delete; GemvPretransposed & operator= (GemvPretransposed &) = delete; GemvPretransposed(const GemmArgs &args) - : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _act(args._act), _ci(args._ci), - _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave()) * strategy::A_interleave()) { + : _args(args), + _buffer_per_multi(args._Ksize * roundup(args._Nsize, strategy::out_width())) { /* For now don't do any blocking. TODO: figure out if we should. */ - if (args._cfg && args._cfg->inner_block_size) { - m_block = args._cfg->inner_block_size; + if (strategy::supports_accumulate() && args._cfg && args._cfg->inner_block_size) { + k_block = args._cfg->inner_block_size; } else { - m_block = _Ksize; + k_block = args._Ksize; } if (args._cfg && args._cfg->outer_block_size) { n_block = args._cfg->outer_block_size; } else { - n_block = _Nsize; + n_block = args._Nsize; } } // Window is number of out_width blocks, times number of multis. ndrange_t get_window_size() const override { - return { iceildiv(_Nsize, strategy::out_width()) * _nmultis }; + return { iceildiv(_args._Nsize, strategy::out_width()) * _args._nmulti }; } // Actually execute the GEMV. @@ -93,13 +86,13 @@ public: #ifdef CYCLE_PROFILING profiler prof; #endif - strategy strat(_ci); + strategy strat(_args._ci); const auto start = work_range.get_position(0); const auto end = work_range.get_position_end(0); /* Break the window values down into multis of interest... */ - const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width()); + const unsigned int window_per_multi = iceildiv(_args._Nsize, strategy::out_width()); const unsigned int multi_0 = start / window_per_multi; const unsigned int multi_end = end / window_per_multi; @@ -111,36 +104,25 @@ public: for (unsigned int multi=multi_0; multi<=multi_end; multi++) { const unsigned int n_start = (multi==multi_0) ? n_0 : 0; - const unsigned int n_end = (multi==multi_end) ? n_max : _Nsize; + const unsigned int n_end = (multi==multi_end) ? n_max : _args._Nsize; if (n_end <= n_start) continue; - for (unsigned int m0=0; m0<_Ksize; m0+=m_block) { - unsigned int mmax = std::min(m0 + m_block, _Ksize); + for (unsigned int k0=0; k0<_args._Ksize; k0+=k_block) { + unsigned int kmax = std::min(k0 + k_block, _args._Ksize); for (unsigned int n=n_start; n_Bptr below instead */ - strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave()), - (_Ksize * strategy::A_interleave()), - this->_Aptr + (multi * this->_A_multi_stride) + m0, + strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + k0, + _B_pretransposed + (multi * _buffer_per_multi) + (n * roundup(_args._Ksize, strategy::k_unroll())) + (k0 * strategy::out_width()), this->_Cptr + (multi * this->_C_multi_stride) + n, - static_cast(0), (mmax-m0), (nmax-n)); - - // Handle activation separately for now - if (this->_bias) { - activator(this->_Cptr + (multi * this->_C_multi_stride) + n, 0, - this->_bias + (multi * this->_bias_multi_stride) + n, - _act, 1, (nmax-n)); - } else { - activator(this->_Cptr + (multi * this->_C_multi_stride) + n, 0, - static_cast(nullptr), - _act, 1, (nmax-n)); - } + (nmax - n), (kmax-k0), + this->_bias ? this->_bias + (multi * this->_bias_multi_stride) + n : nullptr, + _args._act, (k0 != 0)); } } } @@ -152,33 +134,27 @@ public: } bool B_pretranspose_required() const override { - /* Transpose is required if _A_pretransposed is still nullptr */ - return (_A_pretransposed == nullptr); + /* Transpose is required if _B_pretransposed is still nullptr */ + return (_B_pretransposed == nullptr); } size_t get_B_pretransposed_array_size() const override { - return _buffer_per_multi * _nmultis * sizeof(To); + return _buffer_per_multi * _args._nmulti * sizeof(To); } void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { - Toi *A_buffer = reinterpret_cast(buffer); - - for (unsigned int multi=0; multi<_nmultis; multi++) { - /* Reverse sense here as we are dealing with B rather than A. So if - * strategy::A_transpose is false and _trB is false, we still - * transpose. */ - if (strategy::A_transpose()) { - Transform(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); - } else { - Transform(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); - } + Toi *B_buffer = reinterpret_cast(buffer); + strategy strat(_args._ci); + + for (unsigned int multi=0; multi<_args._nmulti; multi++) { + strat.transforms.PrepareB(B_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _args._Nsize, 0, _args._Ksize); } - _A_pretransposed = A_buffer; + _B_pretransposed = B_buffer; } void set_pretransposed_B_data(void *buffer) override { - _A_pretransposed = reinterpret_cast(buffer); + _B_pretransposed = reinterpret_cast(buffer); } }; -- cgit v1.2.1