diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_native.hpp | 52 |
1 files changed, 23 insertions, 29 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp index 695236bdc4..6fed645d82 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp @@ -34,8 +34,8 @@ #include "profiler.hpp" #endif -namespace arm_gemm -{ +namespace arm_gemm { + // Implementation of the GemmCommon abstract class. // // This is implementation is for native GEMM with no transposition. @@ -43,11 +43,10 @@ namespace arm_gemm // By default the source data is used in-place, but if type conversion is // needed we need to allocate working space (CURRENTLY NOT IMPLEMENTED). -template <typename strategy, typename To, typename Tr> -class GemmNative : public GemmCommon<To, Tr> -{ +template<typename strategy, typename To, typename Tr> +class GemmNative : public GemmCommon<To, Tr> { typedef typename strategy::operand_type Toi; - typedef typename strategy::result_type Tri; + typedef typename strategy::result_type Tri; const unsigned int _Msize; const unsigned int _Nsize; @@ -58,36 +57,34 @@ class GemmNative : public GemmCommon<To, Tr> Tr _beta; - const CPUInfo *const _ci; + const CPUInfo * const _ci; - unsigned int k_block = 0; - unsigned int n_block = 0; + unsigned int k_block=0; + unsigned int n_block=0; public: GemmNative(GemmNative &) = delete; - GemmNative &operator=(GemmNative &) = delete; + GemmNative & operator= (GemmNative &) = delete; - GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta) - : _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci) - { + GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta) : + _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci) { /* For now don't do any blocking. TODO: figure out if we should. */ k_block = K; n_block = N; } // Window is number of out_height blocks - unsigned int get_window_size() const override - { + unsigned int get_window_size() const override { return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis; } // Actually execute the GEMM. - void execute(unsigned int start, unsigned int end, int) override - { + void execute(unsigned int start, unsigned int end, int) override { #ifdef CYCLE_PROFILING profiler prof; #endif - strategy strat(_ci); + strategy strat(_ci); + const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height); const unsigned int window_per_multi = window_per_batch * _nbatches; @@ -103,27 +100,24 @@ public: static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same."); static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same."); - for(unsigned int multi = first_multi; multi <= last_multi; multi++) - { + for (unsigned int multi=first_multi; multi<=last_multi; multi++) { const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0; - const unsigned int batch_max = (multi == last_multi) ? last_batch : _nbatches - 1; + const unsigned int batch_max = (multi == last_multi) ? last_batch : (_nbatches-1); - for(unsigned int batch = batch_0; batch <= batch_max; batch++) - { - const unsigned int m_start = ((multi == first_multi) && (batch == first_batch)) ? first_row : 0; - const unsigned int m_end = ((multi == last_multi) && (batch == last_batch)) ? last_row : _Msize; + for (unsigned int batch=batch_0; batch <= batch_max; batch++) { + const unsigned int m_start = ((multi == first_multi) && (batch==first_batch)) ? first_row : 0; + const unsigned int m_end = ((multi == last_multi) && (batch==last_batch)) ? last_row : _Msize; - for(unsigned int y0 = m_start; y0 < m_end; y0 += strategy::out_height) - { + for (unsigned int y0=m_start; y0<m_end; y0+=strategy::out_height) { const unsigned int ymax = std::min(y0 + strategy::out_height, m_end); #ifdef CYCLE_PROFILING - auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax - y0) * _Nsize * _Ksize); + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize); #endif strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda, this->_Bptr + (multi * this->_B_multi_stride), this->_ldb, this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc, - _beta, (ymax - y0), _Nsize, _Ksize); + _beta, (ymax-y0), _Nsize, _Ksize); } } } |