From e7e96e09ff0d3e47797adf197aff2bc39671788c Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 13 Apr 2018 13:44:10 +0100 Subject: COMPMID-1054 Update RSH's GEMM to add batch+multi support Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876 Tested-by: Jenkins Reviewed-by: Pablo Tello Reviewed-by: Anthony Barbier --- src/core/NEON/kernels/arm_gemm/gemm_native.hpp | 54 ++++++++++++++++++++------ 1 file changed, 43 insertions(+), 11 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp index b0192793b9..beecb76f20 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp @@ -28,9 +28,12 @@ #include "arm_gemm.hpp" #include "mergeresults.hpp" -#include "profiler.hpp" #include "transform.hpp" +#ifdef CYCLE_PROFILING +#include "profiler.hpp" +#endif + namespace arm_gemm { // Implementation of the GemmCommon abstract class. @@ -50,6 +53,9 @@ class GemmNative : public GemmCommon const unsigned int _Nsize; const unsigned int _Ksize; + const unsigned int _nbatches; + const unsigned int _nmultis; + Tr _beta; const CPUInfo *const _ci; @@ -61,8 +67,8 @@ public: GemmNative(GemmNative &) = delete; GemmNative &operator=(GemmNative &) = delete; - GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const Tr beta) - : _Msize(M), _Nsize(N), _Ksize(K), _beta(beta), _ci(ci) + GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta) + : _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci) { /* For now don't do any blocking. TODO: figure out if we should. */ k_block = K; @@ -72,29 +78,55 @@ public: // Window is number of out_height blocks unsigned int get_window_size() const override { - return iceildiv(_Msize, strategy::out_height); + return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis; } // Actually execute the GEMM. void execute(unsigned int start, unsigned int end, int) override { +#ifdef CYCLE_PROFILING profiler prof; +#endif strategy strat(_ci); - unsigned int M_start = start * strategy::out_height; - unsigned int M_end = std::min(end * strategy::out_height, _Msize); + const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height); + const unsigned int window_per_multi = window_per_batch * _nbatches; + + const unsigned int first_multi = start / window_per_multi; + const unsigned int last_multi = end / window_per_multi; + + const unsigned int first_batch = (start - (first_multi * window_per_multi)) / window_per_batch; + const unsigned int last_batch = (end - (last_multi * window_per_multi)) / window_per_batch; + + const unsigned int first_row = ((start - (first_multi * window_per_multi)) % window_per_batch) * strategy::out_height; + const unsigned int last_row = ((end - (last_multi * window_per_multi)) % window_per_batch) * strategy::out_height; static_assert(std::is_same::value, "gemm_native: Operand types must be the same."); static_assert(std::is_same::value, "gemm_native: Result types must be the same."); - for(unsigned int y0 = M_start; y0 < M_end; y0 += strategy::out_height) + for(unsigned int multi = first_multi; multi <= last_multi; multi++) { - unsigned int ymax = std::min(y0 + strategy::out_height, M_end); + const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0; + const unsigned int batch_max = (multi == last_multi) ? last_batch : _nbatches; - prof(PROFILE_KERNEL, (ymax - y0) * _Nsize * _Ksize, [&](void) + for(unsigned int batch = batch_0; batch < batch_max; batch++) { - strat.kernel(this->_Aptr + (y0 * this->_lda), this->_lda, this->_Bptr, this->_ldb, this->_Cptr + (y0 * this->_ldc), this->_ldc, _beta, (ymax - y0), _Nsize, _Ksize); - }); + const unsigned int m_start = ((multi == first_multi) && (batch == first_batch)) ? first_row : 0; + const unsigned int m_end = ((multi == last_multi) && (batch == last_batch)) ? last_row : _Msize; + + for(unsigned int y0 = m_start; y0 < m_end; y0 += strategy::out_height) + { + const unsigned int ymax = std::min(y0 + strategy::out_height, m_end); +#ifdef CYCLE_PROFILING + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax - y0) * _Nsize * _Ksize); +#endif + + strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda, + this->_Bptr + (multi * this->_B_multi_stride), this->_ldb, + this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc, + _beta, (ymax - y0), _Nsize, _Ksize); + } + } } } }; -- cgit v1.2.1