From 1d480652b820317fc97ccbc3cb517e3b9e8be197 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 23 Jan 2019 11:24:50 +0000 Subject: COMPMID-1867: Add u8 and s8 hybrid assembly kernels. Change-Id: Ifeb005f9d18d19feff11949474cce84d9e03749c Reviewed-on: https://review.mlplatform.org/565 Reviewed-by: Michalis Spyrou Tested-by: Arm Jenkins --- .../NEON/kernels/arm_gemm/gemv_pretransposed.hpp | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp index 842339ef23..f7beb0a34c 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp @@ -70,7 +70,7 @@ public: GemvPretransposed(const GemmArgs &args) : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _trB(args._trB), _beta(args._beta), _ci(args._ci), - _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave) { + _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave()) * strategy::A_interleave()) { /* For now don't do any blocking. TODO: figure out if we should. */ if (args._cfg && args._cfg->inner_block_size) { m_block = args._cfg->inner_block_size; @@ -87,7 +87,7 @@ public: // Window is number of out_width blocks, times number of multis. unsigned int get_window_size() const override { - return iceildiv(_Nsize, strategy::out_width) * _nmultis; + return iceildiv(_Nsize, strategy::out_width()) * _nmultis; } // Actually execute the GEMV. @@ -98,13 +98,13 @@ public: strategy strat(_ci); /* Break the window values down into multis of interest... */ - const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width); + const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width()); const unsigned int multi_0 = start / window_per_multi; const unsigned int multi_end = end / window_per_multi; /* ... and figure out where we start and end in the first and last multi. */ - const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width; - const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width; + const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width(); + const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width(); static_assert(std::is_same::value, "GemvPretransposed: Result types must be the same."); @@ -124,8 +124,8 @@ public: auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax-m0) * (nmax-n)); #endif /* This assumes that the underlying call was a GEMM with M=1; for the N=1 case we would have to pick up this->_Bptr below instead */ - strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave), - (_Ksize * strategy::A_interleave), + strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave()), + (_Ksize * strategy::A_interleave()), this->_Aptr + (multi * this->_A_multi_stride) + m0, this->_Cptr + (multi * this->_C_multi_stride) + n, _beta, (mmax-m0), (nmax-n)); @@ -148,6 +148,7 @@ public: return _buffer_per_multi * _nmultis * sizeof(To); } + using GemmCommon::pretranspose_B_array; void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { Toi *A_buffer = reinterpret_cast(buffer); @@ -155,10 +156,10 @@ public: /* Reverse sense here as we are dealing with B rather than A. So if * strategy::A_transpose is false and _trB is false, we still * transpose. */ - if (_trB ^ strategy::A_transpose) { - Transform(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); + if (_trB ^ strategy::A_transpose()) { + Transform(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); } else { - Transform(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); + Transform(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); } } -- cgit v1.2.1