diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp index 842339ef23..f7beb0a34c 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp @@ -70,7 +70,7 @@ public: GemvPretransposed(const GemmArgs<Tr> &args) : _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _trB(args._trB), _beta(args._beta), _ci(args._ci), - _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave) { + _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave()) * strategy::A_interleave()) { /* For now don't do any blocking. TODO: figure out if we should. */ if (args._cfg && args._cfg->inner_block_size) { m_block = args._cfg->inner_block_size; @@ -87,7 +87,7 @@ public: // Window is number of out_width blocks, times number of multis. unsigned int get_window_size() const override { - return iceildiv(_Nsize, strategy::out_width) * _nmultis; + return iceildiv(_Nsize, strategy::out_width()) * _nmultis; } // Actually execute the GEMV. @@ -98,13 +98,13 @@ public: strategy strat(_ci); /* Break the window values down into multis of interest... */ - const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width); + const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width()); const unsigned int multi_0 = start / window_per_multi; const unsigned int multi_end = end / window_per_multi; /* ... and figure out where we start and end in the first and last multi. */ - const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width; - const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width; + const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width(); + const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width(); static_assert(std::is_same<Tr, Tri>::value, "GemvPretransposed: Result types must be the same."); @@ -124,8 +124,8 @@ public: auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax-m0) * (nmax-n)); #endif /* This assumes that the underlying call was a GEMM with M=1; for the N=1 case we would have to pick up this->_Bptr below instead */ - strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave), - (_Ksize * strategy::A_interleave), + strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave()), + (_Ksize * strategy::A_interleave()), this->_Aptr + (multi * this->_A_multi_stride) + m0, this->_Cptr + (multi * this->_C_multi_stride) + n, _beta, (mmax-m0), (nmax-n)); @@ -148,6 +148,7 @@ public: return _buffer_per_multi * _nmultis * sizeof(To); } + using GemmCommon<To, Tr>::pretranspose_B_array; void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override { Toi *A_buffer = reinterpret_cast<Toi *>(buffer); @@ -155,10 +156,10 @@ public: /* Reverse sense here as we are dealing with B rather than A. So if * strategy::A_transpose is false and _trB is false, we still * transpose. */ - if (_trB ^ strategy::A_transpose) { - Transform<strategy::A_interleave, strategy::A_block, false>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); + if (_trB ^ strategy::A_transpose()) { + Transform<strategy::A_interleave(), strategy::A_block(), false>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); } else { - Transform<strategy::A_interleave, strategy::A_block, true>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); + Transform<strategy::A_interleave(), strategy::A_block(), true>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize); } } |