aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp281
1 files changed, 92 insertions, 189 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index f572f7940b..3b829491ca 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -31,7 +31,6 @@
#include "arm_gemm.hpp"
#include "utils.hpp"
-#include "buffer_manager.hpp"
#include "mergeresults.hpp"
#include "transform.hpp"
@@ -65,14 +64,10 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
const unsigned int _nbatches;
const unsigned int _nmulti;
- const bool _trA;
- const bool _trB;
-
const Activation _act;
const int _maxthreads;
int _nthreads;
- const bool _pretransposed;
/* Blocking info */
unsigned int _k_block=0;
@@ -81,7 +76,6 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
/* Working space, pretransposed buffer, buffer manager */
const Toi *_B_transposed=nullptr;
- BufferManager *_bm=nullptr;
void *_working_space=nullptr;
/* We will need to walk through the blocks of B in a few contexts, so
@@ -150,27 +144,100 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
return ROUND_UP(sizeof(Toi) * _k_block * _Mround * _nbatches);
}
- // B working size: 0, 1 or 3 of these needed depending on pretransposed and threading settings.
- size_t get_b_working_size() const {
- return ROUND_UP(sizeof(Toi) * _x_block * _k_block);
- }
-
// C working size: One needed per thread.
size_t get_c_working_size() const {
return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height());
}
- // Internal execute function.
- // This supports both the "pretransposed" and "standard" interfaces via the template parameter.
- template<bool pretransposed>
- void execute_internal(unsigned int start, unsigned int end, int threadid) {
+
+public:
+ GemmInterleaved(GemmInterleaved &) = delete;
+ GemmInterleaved & operator= (GemmInterleaved &) = delete;
+
+ /* Constructor */
+ GemmInterleaved(const GemmArgs &args)
+ : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
+ _nbatches(args._nbatches), _nmulti(args._nmulti),
+ _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads) {
+ const unsigned int L1_size = _ci->get_L1_cache_size();
+ const unsigned int L2_size = _ci->get_L2_cache_size();
+
+ assert(_maxthreads > 0);
+
+ // Work out blocking parameters, or override from provided GemmConfig
+ if (args._cfg && args._cfg->inner_block_size) {
+ _k_block = args._cfg->inner_block_size;
+ } else {
+ // k_block: Find out how much of the larger array can be loaded into half the cache.
+ // This should account for associative caches.
+ _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height())));
+
+ // Needs to be (at least a single) multiple of the K unroll level.
+ _k_block /= strategy::k_unroll();
+ _k_block = std::max(_k_block, 1U) * strategy::k_unroll();
+
+ // Now tune to presented problem size; this is how many blocks we need.
+ unsigned int num_k_blocks = iceildiv(_Ksize, _k_block);
+
+ // So divide the space equally into that many blocks.
+ _k_block = iceildiv(_Ksize, num_k_blocks);
+
+ // And round UP to the K unroll level required.
+ _k_block = iceildiv(_k_block, strategy::k_unroll());
+ _k_block *= strategy::k_unroll();
+ }
+
+ if (args._cfg && args._cfg->outer_block_size) {
+ _x_block = args._cfg->outer_block_size;
+ } else {
+ // x_block: Work out how many rows (of length k_block) will fit in the L2
+ // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents.
+ _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) /
+ (sizeof(Toi) * _k_block);
+
+ // Needs to be (at least a single) multiple of the kernel output width.
+ _x_block /= strategy::out_width();
+ _x_block = std::max(_x_block, 1U) * strategy::out_width();
+
+ // And tune to the presented problem size.
+ unsigned int num_x_blocks = iceildiv(_Nsize, _x_block);
+ _x_block = iceildiv(_Nsize, num_x_blocks);
+
+ _x_block = iceildiv(_x_block, strategy::out_width());
+ _x_block *= strategy::out_width();
+ }
+
+ // Work out the rounded size of M - needed for some buffers.
+ _Mround = iceildiv(_Msize, strategy::out_height());
+ _Mround *= strategy::out_height();
+ }
+
+ // Interface implementation - Compulsory functions
+
+ // Window size: Only the last thread should do a ragged block, so dole
+ // out work in units of out_height. Factor batches into the window, but
+ // not multi for now (as this would cause problems with the buffer
+ // manager).
+ ndrange_t get_window_size() const override {
+ // _Mround is a multiple of out_height by definition.
+ return { (_Mround / strategy::out_height()) * _nbatches };
+ }
+
+ // set_nthreads: pass on to buffer manager to avoid it waiting for non-existant threads.
+ void set_nthreads(int nthreads) override {
+ _nthreads = std::min(nthreads, _maxthreads);
+ }
+
+ // Execute
+ void execute(const ndcoord_t &work_range, const ndcoord_t &, int threadid) override {
+ const auto start = work_range.get_position(0);
+ const auto end = work_range.get_position_end(0);
#ifdef CYCLE_PROFILING
profiler prof;
#endif
strategy strat(_ci);
blockwalker current(*this);
- blockwalker next=current;
/* Translate 'start' and 'end' into a position within the batches and rows. */
const unsigned int window_per_batch = _Mround / strategy::out_height();
@@ -182,12 +249,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height();
/* Make sure we've been set up correctly. */
- if (pretransposed) {
- assert(_B_transposed);
- } else {
- assert(_bm);
- }
-
+ assert(_B_transposed);
assert(_working_space);
int8_t *working_space_bytes = reinterpret_cast<int8_t *>(_working_space);
@@ -198,12 +260,8 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
Toi * const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()));
Tri * const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size()));
- // Shared buffers - these come either from BufferManager or _B_transposed.
const Toi *b_panel;
-
- if (pretransposed) {
- b_panel = _B_transposed;
- }
+ b_panel = _B_transposed;
//printf("Starting GEMM loop, x_block=%d, k_block=%d\n", _x_block, _k_block);
@@ -224,7 +282,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
strat.transforms.PrepareA(a_panel + ((batch * _Mround + first_m) * _k_block),
this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
- this->_lda, first_m, last_m, current.k0(), current.kmax(), _trA);
+ this->_lda, first_m, last_m, current.k0(), current.kmax());
}
// Figure out how many "K" the kernel will actually process.
@@ -234,41 +292,6 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
int bblocks = iceildiv(current.xmax() - current.x0(), strategy::out_width());
- if (!pretransposed) {
- /* Look ahead to the next block and populate it if necessary.
- * This avoids the populate operation becoming a bottleneck, and
- * helps keep the threads synchronized (the first thread to get
- * here will populate while the rest will advance).
- *
- * If we are running single threaded, bm->try_populate() will do
- * nothing.
- */
- if (next.advance()) {
- _bm->try_populate(next.index(), [&](void *buffer) {
-#ifdef CYCLE_PROFILING
- auto p=prof.ScopedProfiler(PROFILE_PREPB, (next.xmax()-next.x0()) * (next.kmax()-next.k0()) * sizeof(Toi));
-#endif
-
- Toi *b_panel = reinterpret_cast<Toi *>(buffer);
-
- strat.transforms.PrepareB(b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
- next.x0(), next.xmax(), next.k0(), next.kmax(), _trB);
- });
- }
-
- /* Get the buffer for this iteration from the BufferManager. */
- b_panel = reinterpret_cast<Toi *>(_bm->get(current.index(), [&](void *bpv) {
-#ifdef CYCLE_PROFILING
- auto p=prof.ScopedProfiler(PROFILE_PREPB, (current.xmax()-current.x0()) * (current.kmax()-current.k0()) * sizeof(Toi));
-#endif
-
- Toi *b_panel = reinterpret_cast<Toi *>(bpv);
-
- strat.transforms.PrepareB(b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
- current.x0(), current.xmax(), current.k0(), current.kmax(), _trB);
- }));
- }
-
/* Do the actual work. */
for (unsigned int batch = batch_0; batch <= batch_end; batch++) {
unsigned int first_m = (batch == batch_0) ? m_0 : 0;
@@ -308,105 +331,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
}
}
- if (pretransposed) {
- b_panel += (bblocks * strat.out_width() * kern_k);
- } else {
- _bm->release(current.index());
- }
- }
- }
-
-public:
- GemmInterleaved(GemmInterleaved &) = delete;
- GemmInterleaved & operator= (GemmInterleaved &) = delete;
-
- /* Constructor */
- GemmInterleaved(const GemmArgs &args)
- : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
- _nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB),
- _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
- _pretransposed(args._pretransposed_hint) {
- const unsigned int L1_size = _ci->get_L1_cache_size();
- const unsigned int L2_size = _ci->get_L2_cache_size();
-
- assert(_maxthreads > 0);
-
- // Work out blocking parameters, or override from provided GemmConfig
- if (args._cfg && args._cfg->inner_block_size) {
- _k_block = args._cfg->inner_block_size;
- } else {
- // k_block: Find out how much of the larger array can be loaded into half the cache.
- // This should account for associative caches.
- _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height())));
-
- // Needs to be (at least a single) multiple of the K unroll level.
- _k_block /= strategy::k_unroll();
- _k_block = std::max(_k_block, 1U) * strategy::k_unroll();
-
- // Now tune to presented problem size; this is how many blocks we need.
- unsigned int num_k_blocks = iceildiv(_Ksize, _k_block);
-
- // So divide the space equally into that many blocks.
- _k_block = iceildiv(_Ksize, num_k_blocks);
-
- // And round UP to the K unroll level required.
- _k_block = iceildiv(_k_block, strategy::k_unroll());
- _k_block *= strategy::k_unroll();
- }
-
- if (args._cfg && args._cfg->outer_block_size) {
- _x_block = args._cfg->outer_block_size;
- } else {
- // x_block: Work out how many rows (of length k_block) will fit in the L2
- // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents.
- _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) /
- (sizeof(Toi) * _k_block);
-
- // Needs to be (at least a single) multiple of the kernel output width.
- _x_block /= strategy::out_width();
- _x_block = std::max(_x_block, 1U) * strategy::out_width();
-
- // And tune to the presented problem size.
- unsigned int num_x_blocks = iceildiv(_Nsize, _x_block);
- _x_block = iceildiv(_Nsize, num_x_blocks);
-
- _x_block = iceildiv(_x_block, strategy::out_width());
- _x_block *= strategy::out_width();
- }
-
- // Work out the rounded size of M - needed for some buffers.
- _Mround = iceildiv(_Msize, strategy::out_height());
- _Mround *= strategy::out_height();
- }
-
- // Interface implementation - Compulsory functions
-
- // Window size: Only the last thread should do a ragged block, so dole
- // out work in units of out_height. Factor batches into the window, but
- // not multi for now (as this would cause problems with the buffer
- // manager).
- ndrange_t get_window_size() const override {
- // _Mround is a multiple of out_height by definition.
- return { (_Mround / strategy::out_height()) * _nbatches };
- }
-
- // set_nthreads: pass on to buffer manager to avoid it waiting for non-existant threads.
- void set_nthreads(int nthreads) override {
- _nthreads = std::min(nthreads, _maxthreads);
- if (_bm) {
- _bm->set_nthreads(_nthreads);
- }
- }
-
- // Execute
- void execute(const ndcoord_t &work_range, const ndcoord_t &, int threadid) override {
- const auto start = work_range.get_position(0);
- const auto end = work_range.get_position_end(0);
-
- if (_pretransposed) {
- execute_internal<true>(start, end, threadid);
- } else {
- execute_internal<false>(start, end, threadid);
+ b_panel += (bblocks * strat.out_width() * kern_k);
}
}
@@ -415,12 +340,6 @@ public:
// In all cases, we need one A buffer plus a C buffer per thread.
size_t size = get_a_working_size() + (get_c_working_size() * _maxthreads);
- // For pretransposed case, there is no working space needed for B.
- // Otherwise, we need a BufferManager.
- if (!_pretransposed) {
- size += BufferManager::get_storage_requirement(_maxthreads, get_b_working_size());
- }
-
size += 64; // Add on a cache line extra for alignment.
return size;
@@ -439,29 +358,17 @@ public:
working_space_bytes += diff;
- if (_pretransposed) {
- // Pretransposed case: just set internal pointer to parameter value.
- _working_space = reinterpret_cast<void *>(working_space_bytes);
- } else {
- // Otherwise, use the first part of the working space for the buffer manager.
- // It's legal to call this again so don't leak a buffer manager if it already existed.
- delete _bm;
-
- _bm = new BufferManager(_nthreads, get_b_working_size(), reinterpret_cast<void *>(working_space_bytes));
-
- working_space_bytes += BufferManager::get_storage_requirement(_maxthreads, get_b_working_size());
-
- _working_space = reinterpret_cast<void *>(working_space_bytes);
- }
+ // Pretransposed case: just set internal pointer to parameter value.
+ _working_space = reinterpret_cast<void *>(working_space_bytes);
}
// Interface implementation - pretransposed
bool B_is_pretransposed() const override {
- return _pretransposed;
+ return true;
}
bool B_pretranspose_required() const override {
- return _pretransposed && (_B_transposed==nullptr);
+ return (_B_transposed==nullptr);
}
// TODO: this could almost certainly be considerably simpler.
@@ -506,7 +413,7 @@ public:
k_size *= strategy::k_unroll();
strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
- current.x0(), current.xmax(), current.k0(), current.kmax(), _trB);
+ current.x0(), current.xmax(), current.k0(), current.kmax());
buffer += (x_size * k_size);
} while (current.advance());
@@ -515,10 +422,6 @@ public:
void set_pretransposed_B_data(void *in_buffer) override {
_B_transposed = reinterpret_cast<Toi *>(in_buffer);
}
-
- ~GemmInterleaved() override {
- delete _bm;
- }
};
} // namespace arm_gemm