diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp | 50 |
1 files changed, 35 insertions, 15 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp index 29c71f2511..e5cc79eaed 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp @@ -28,9 +28,12 @@ #include "arm_gemm.hpp" #include "mergeresults.hpp" -#include "profiler.hpp" #include "transform.hpp" +#ifdef CYCLE_PROFILING +#include "profiler.hpp" +#endif + namespace arm_gemm { // Implementation of the GemmCommon abstract class. @@ -48,6 +51,7 @@ class GemvNativeTransposed : public GemmCommon<To, Tr> const unsigned int _Nsize; const unsigned int _Ksize; + const unsigned int _nmultis; const Tr _beta; @@ -60,45 +64,61 @@ public: GemvNativeTransposed(GemvNativeTransposed &) = delete; GemvNativeTransposed &operator=(GemvNativeTransposed &) = delete; - GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const Tr beta) - : _Nsize(N), _Ksize(K), _beta(beta), _ci(ci) + GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const Tr beta) + : _Nsize(N), _Ksize(K), _nmultis(nmultis), _beta(beta), _ci(ci) { /* For now don't do any blocking. TODO: figure out if we should. */ m_block = K; n_block = N; } - // Window is number of out_width blocks. + // Window is number of out_width blocks times number of multis. unsigned int get_window_size() const override { - return iceildiv(_Nsize, strategy::out_width); + return iceildiv(_Nsize, strategy::out_width) * _nmultis; } // Actually execute the GEMV. void execute(unsigned int start, unsigned int end, int) override { +#ifdef CYCLE_PROFILING profiler prof; +#endif + strategy strat(_ci); - unsigned int N_start = start * strategy::out_width; - unsigned int N_end = std::min(end * strategy::out_width, _Nsize); + const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width); + const unsigned int multi_0 = start / window_per_multi; + const unsigned int multi_end = end / window_per_multi; + + const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width; + const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width; static_assert(std::is_same<To, Toi>::value, "gemv_transposed: Operand types must be the same."); static_assert(std::is_same<Tr, Tri>::value, "gemv_transposed: Result types must be the same."); - for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block) + for(unsigned int multi = multi_0; multi <= multi_end; multi++) { - unsigned int mmax = std::min(m0 + m_block, _Ksize); + const unsigned int n_start = (multi == multi_0) ? n_0 : 0; + const unsigned int n_end = (multi == multi_end) ? n_max : _Nsize; - for(unsigned int n0 = N_start; n0 < N_end; n0 += n_block) - { - unsigned int nmax = std::min(n0 + n_block, N_end); + if(n_end <= n_start) + continue; - prof(PROFILE_KERNEL, ((mmax - m0) * (nmax - n0)), [&](void) + for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block) + { + unsigned int mmax = std::min(m0 + m_block, _Ksize); + for(unsigned int n0 = n_start; n0 < n_end; n0 += n_block) { - strat.kernel(this->_Bptr + (m0 * this->_ldb) + n0, this->_Aptr + m0, this->_Cptr + n0, + unsigned int nmax = std::min(n0 + n_block, n_end); +#ifdef CYCLE_PROFILING + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax - m0) * (nmax - n0)); +#endif + strat.kernel(this->_Bptr + (multi * this->_B_multi_stride) + (m0 * this->_ldb) + n0, + this->_Aptr + (multi * this->_A_multi_stride) + m0, + this->_Cptr + (multi * this->_C_multi_stride) + n0, _beta, this->_ldb, (mmax - m0), (nmax - n0)); - }); + } } } } |