aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_native.hpp62
1 files changed, 35 insertions, 27 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
index 6fed645d82..baa1316745 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
@@ -62,6 +62,14 @@ class GemmNative : public GemmCommon<To, Tr> {
unsigned int k_block=0;
unsigned int n_block=0;
+ unsigned int window_per_batch() const {
+ return iceildiv(_Msize, strategy::out_height());
+ }
+
+ unsigned int window_per_multi() const {
+ return window_per_batch() * _nbatches;
+ }
+
public:
GemmNative(GemmNative &) = delete;
GemmNative & operator= (GemmNative &) = delete;
@@ -73,9 +81,9 @@ public:
n_block = N;
}
- // Window is number of out_height blocks
+ // Window is amount per multi multiplied by total number of multis.
unsigned int get_window_size() const override {
- return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis;
+ return window_per_multi() * _nmultis;
}
// Actually execute the GEMM.
@@ -85,39 +93,39 @@ public:
#endif
strategy strat(_ci);
- const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height);
- const unsigned int window_per_multi = window_per_batch * _nbatches;
-
- const unsigned int first_multi = start / window_per_multi;
- const unsigned int last_multi = end / window_per_multi;
-
- const unsigned int first_batch = (start - (first_multi * window_per_multi)) / window_per_batch;
- const unsigned int last_batch = (end - (last_multi * window_per_multi)) / window_per_batch;
-
- const unsigned int first_row = ((start - (first_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
- const unsigned int last_row = ((end - (last_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
-
static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same.");
static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
- for (unsigned int multi=first_multi; multi<=last_multi; multi++) {
- const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0;
- const unsigned int batch_max = (multi == last_multi) ? last_batch : (_nbatches-1);
+ /* Compute starting point based on 'start' */
+ unsigned int multi = start / window_per_multi();
+ unsigned int multi_pos = start % window_per_multi();
+
+ unsigned int batch = multi_pos / window_per_batch();
+ unsigned int batch_pos = multi_pos % window_per_batch();
- for (unsigned int batch=batch_0; batch <= batch_max; batch++) {
- const unsigned int m_start = ((multi == first_multi) && (batch==first_batch)) ? first_row : 0;
- const unsigned int m_end = ((multi == last_multi) && (batch==last_batch)) ? last_row : _Msize;
+ unsigned int y0 = batch_pos * strategy::out_height();
- for (unsigned int y0=m_start; y0<m_end; y0+=strategy::out_height) {
- const unsigned int ymax = std::min(y0 + strategy::out_height, m_end);
+ for (unsigned int pos=start; pos<end; pos++) {
+ const unsigned int ymax = std::min(y0 + strategy::out_height(), _Msize);
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize);
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize);
#endif
- strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
- this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
- this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
- _beta, (ymax-y0), _Nsize, _Ksize);
+ strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
+ this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
+ this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
+ _beta, (ymax-y0), _Nsize, _Ksize);
+
+ /* Advance to next item */
+ y0 += strategy::out_height();
+
+ /* Check for batch/multi overflow */
+ if (y0 >= _Msize) {
+ y0=0;
+ batch++;
+ if (batch == _nbatches) {
+ batch=0;
+ multi++;
}
}
}