aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp21
1 files changed, 11 insertions, 10 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index 842339ef23..f7beb0a34c 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -70,7 +70,7 @@ public:
GemvPretransposed(const GemmArgs<Tr> &args)
: _Nsize(args._Nsize), _Ksize(args._Ksize), _nmultis(args._nmulti), _trB(args._trB), _beta(args._beta), _ci(args._ci),
- _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave) {
+ _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave()) * strategy::A_interleave()) {
/* For now don't do any blocking. TODO: figure out if we should. */
if (args._cfg && args._cfg->inner_block_size) {
m_block = args._cfg->inner_block_size;
@@ -87,7 +87,7 @@ public:
// Window is number of out_width blocks, times number of multis.
unsigned int get_window_size() const override {
- return iceildiv(_Nsize, strategy::out_width) * _nmultis;
+ return iceildiv(_Nsize, strategy::out_width()) * _nmultis;
}
// Actually execute the GEMV.
@@ -98,13 +98,13 @@ public:
strategy strat(_ci);
/* Break the window values down into multis of interest... */
- const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width);
+ const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width());
const unsigned int multi_0 = start / window_per_multi;
const unsigned int multi_end = end / window_per_multi;
/* ... and figure out where we start and end in the first and last multi. */
- const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width;
- const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width;
+ const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width();
+ const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width();
static_assert(std::is_same<Tr, Tri>::value, "GemvPretransposed: Result types must be the same.");
@@ -124,8 +124,8 @@ public:
auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax-m0) * (nmax-n));
#endif
/* This assumes that the underlying call was a GEMM with M=1; for the N=1 case we would have to pick up this->_Bptr below instead */
- strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave),
- (_Ksize * strategy::A_interleave),
+ strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave()),
+ (_Ksize * strategy::A_interleave()),
this->_Aptr + (multi * this->_A_multi_stride) + m0,
this->_Cptr + (multi * this->_C_multi_stride) + n,
_beta, (mmax-m0), (nmax-n));
@@ -148,6 +148,7 @@ public:
return _buffer_per_multi * _nmultis * sizeof(To);
}
+ using GemmCommon<To, Tr>::pretranspose_B_array;
void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override {
Toi *A_buffer = reinterpret_cast<Toi *>(buffer);
@@ -155,10 +156,10 @@ public:
/* Reverse sense here as we are dealing with B rather than A. So if
* strategy::A_transpose is false and _trB is false, we still
* transpose. */
- if (_trB ^ strategy::A_transpose) {
- Transform<strategy::A_interleave, strategy::A_block, false>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize);
+ if (_trB ^ strategy::A_transpose()) {
+ Transform<strategy::A_interleave(), strategy::A_block(), false>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize);
} else {
- Transform<strategy::A_interleave, strategy::A_block, true>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize);
+ Transform<strategy::A_interleave(), strategy::A_block(), true>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize);
}
}