From 1d480652b820317fc97ccbc3cb517e3b9e8be197 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 23 Jan 2019 11:24:50 +0000 Subject: COMPMID-1867: Add u8 and s8 hybrid assembly kernels. Change-Id: Ifeb005f9d18d19feff11949474cce84d9e03749c Reviewed-on: https://review.mlplatform.org/565 Reviewed-by: Michalis Spyrou Tested-by: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_native.hpp | 91 ++++++++++++-------------- 1 file changed, 43 insertions(+), 48 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp index 579533418d..98516b1ca6 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2019 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,8 +27,7 @@ #include "arm_gemm.hpp" -#include "mergeresults.hpp" -#include "transform.hpp" +#include "ndrange.hpp" #ifdef CYCLE_PROFILING #include "profiler.hpp" @@ -55,19 +54,25 @@ class GemmNative : public GemmCommon { const unsigned int _nbatches; const unsigned int _nmultis; - Tr _beta; + const Tr _beta; const CPUInfo * const _ci; - unsigned int k_block=0; - unsigned int n_block=0; + const unsigned int _k_block; + const unsigned int _n_block; - unsigned int window_per_batch() const { - return iceildiv(_Msize, strategy::out_height()); + const NDRange<4> _window_range; + + static unsigned int compute_k_block(const GemmArgs &args) { + return args._Ksize; } - unsigned int window_per_multi() const { - return window_per_batch() * _nbatches; + static unsigned int compute_n_block(const GemmArgs &args) { + if ((args._cfg != nullptr) && args._cfg->outer_block_size > 0) { + return args._cfg->outer_block_size; + } else { + return args._Nsize; + } } public: @@ -75,15 +80,20 @@ public: GemmNative & operator= (GemmNative &) = delete; GemmNative(const GemmArgs &args) - : _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), _nbatches(args._nbatches), _nmultis(args._nmulti), _beta(args._beta), _ci(args._ci) { - /* For now don't do any blocking. TODO: figure out if we should. */ - k_block = _Ksize; - n_block = _Nsize; - } + : _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), + _nbatches(args._nbatches), _nmultis(args._nmulti), + _beta(args._beta), _ci(args._ci), + _k_block(compute_k_block(args)), _n_block(compute_n_block(args)), + _window_range(iceildiv(_Msize, strategy::out_height()), _nbatches, iceildiv(_Nsize, _n_block), _nmultis) { } // Window is amount per multi multiplied by total number of multis. unsigned int get_window_size() const override { - return window_per_multi() * _nmultis; + return _window_range.total_size(); + } + + // Native GEMMs can always be dynamically scheduled (whether requested or not) + bool supports_dynamic_scheduling() const override { + return true; } // Actually execute the GEMM. @@ -96,45 +106,30 @@ public: static_assert(std::is_same::value, "gemm_native: Operand types must be the same."); static_assert(std::is_same::value, "gemm_native: Result types must be the same."); - /* Compute starting point based on 'start' */ - unsigned int multi = start / window_per_multi(); - unsigned int multi_pos = start % window_per_multi(); + auto p = _window_range.iterator(start, end); - unsigned int batch = multi_pos / window_per_batch(); - unsigned int batch_pos = multi_pos % window_per_batch(); - - unsigned int y0 = batch_pos * strategy::out_height(); - - for (unsigned int l=end-start; l>0; ) { - // Do work from here to the end of the current batch/multi - const unsigned int ymax = std::min(y0 + (l * strategy::out_height()), _Msize); + if (p.done()) { + return; + } - // Work out how many units this is and subtract from loop counter. - l -= ((ymax - y0) + (strategy::out_height() - 1)) / strategy::out_height(); + do { + unsigned int y0 = p.dim(0) * strategy::out_height(); + unsigned int ymax = std::min(p.dim0_max() * strategy::out_height(), _Msize); + unsigned int batch = p.dim(1); + unsigned int n0 = p.dim(2) * _n_block; + unsigned int nmax = std::min(n0 + _n_block, _Nsize); + unsigned int multi = p.dim(3); #ifdef CYCLE_PROFILING - auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize); + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * (nmax - n0) * _Ksize); #endif strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda, - this->_Bptr + (multi * this->_B_multi_stride), this->_ldb, - this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc, - _beta, (ymax-y0), _Nsize, _Ksize); - - /* Advance to next item */ - y0 = ymax; - - /* Check for batch/multi overflow */ - if (y0 >= _Msize) { - y0=0; - batch++; - if (batch == _nbatches) { - batch=0; - multi++; - } - } - } + this->_Bptr + (multi * this->_B_multi_stride) + n0, this->_ldb, + this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc) + n0, this->_ldc, + _beta, (ymax-y0), (nmax - n0), _Ksize); + } while (p.next_dim1()); } }; -} // namespace arm_gemm +} // namespace arm_gemm \ No newline at end of file -- cgit v1.2.1