aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-04-13 13:44:10 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:37 +0000
commite7e96e09ff0d3e47797adf197aff2bc39671788c (patch)
treeb52ecdd7627bdf51b8b8da9b9553cb900460222f /src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
parent1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9 (diff)
downloadComputeLibrary-e7e96e09ff0d3e47797adf197aff2bc39671788c.tar.gz
COMPMID-1054 Update RSH's GEMM to add batch+multi support
Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp216
1 files changed, 150 insertions, 66 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 27e4e8d411..32c65cd3fb 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -33,9 +33,12 @@
#include "buffer_manager.hpp"
#include "mergeresults.hpp"
-#include "profiler.hpp"
#include "transform.hpp"
+#ifdef CYCLE_PROFILING
+#include "profiler.hpp"
+#endif
+
// Some macros used to decide how much working space to allocate.
// Round allocations up to the next cache line.
#define ALLOC_ROUND 64
@@ -60,6 +63,9 @@ class GemmInterleaved : public GemmCommon<To, Tr>
const unsigned int _Nsize;
const unsigned int _Ksize;
+ const unsigned int _nbatches;
+ const unsigned int _nmulti;
+
const bool _trA;
const bool _trB;
@@ -84,30 +90,31 @@ class GemmInterleaved : public GemmCommon<To, Tr>
class blockwalker
{
private:
- /* Loop parameters, we only block up N and K so don't worry about M. */
- const unsigned int _Nsize, _Ksize, _x_block, _k_block;
+ /* Size loops, etc. based on our parent's configuration */
+ const GemmInterleaved<strategy, To, Tr> &_parent;
- /* K and X parameters for current iteration. */
- unsigned int _k0 = 0, _x0 = 0;
+ /* K and X and multi parameters for current iteration. */
+ unsigned int _k0 = 0, _x0 = 0, _multi = 0;
unsigned int _index = 0;
bool _done = false;
bool _newkblock = true;
+ bool _newmulti = true;
public:
- blockwalker(const unsigned int K, const unsigned int k_block, const unsigned int N, const unsigned int x_block)
- : _Nsize(N), _Ksize(K), _x_block(x_block), _k_block(k_block)
+ blockwalker(const GemmInterleaved<strategy, To, Tr> &parent)
+ : _parent(parent)
{
}
unsigned int xmax()
{
- return std::min(_x0 + _x_block, _Nsize);
+ return std::min(_x0 + _parent._x_block, _parent._Nsize);
}
unsigned int kmax()
{
- return std::min(_k0 + _k_block, _Ksize);
+ return std::min(_k0 + _parent._k_block, _parent._Ksize);
}
/* Advance to the next block, return false at the end. */
@@ -119,15 +126,21 @@ class GemmInterleaved : public GemmCommon<To, Tr>
}
_newkblock = false;
- _x0 += _x_block;
- if(_x0 >= _Nsize)
+ _x0 += _parent._x_block;
+ if(_x0 >= _parent._Nsize)
{
_x0 = 0;
- _k0 += _k_block;
- if(_k0 >= _Ksize)
+ _k0 += _parent._k_block;
+ if(_k0 >= _parent._Ksize)
{
- _done = true;
- return false;
+ _k0 = 0;
+ _multi++;
+ if(_multi >= _parent._nmulti)
+ {
+ _done = true;
+ return false;
+ }
+ _newmulti = true;
}
_newkblock = true;
}
@@ -144,6 +157,10 @@ class GemmInterleaved : public GemmCommon<To, Tr>
{
return _x0;
}
+ unsigned int multi(void)
+ {
+ return _multi;
+ }
unsigned int index(void)
{
return _index;
@@ -161,7 +178,7 @@ class GemmInterleaved : public GemmCommon<To, Tr>
// A working size: One of these needed, regardless of thread count. Divided according to window.
size_t get_a_working_size() const
{
- return ROUND_UP(sizeof(Toi) * _k_block * _Mround);
+ return ROUND_UP(sizeof(Toi) * _k_block * _Mround * _nbatches);
}
// B working size: 0, 1 or 3 of these needed depending on pretransposed and threading settings.
@@ -181,15 +198,23 @@ class GemmInterleaved : public GemmCommon<To, Tr>
template <bool pretransposed>
void execute_internal(unsigned int start, unsigned int end, int threadid)
{
+#ifdef CYCLE_PROFILING
profiler prof;
+#endif
+
strategy strat(_ci);
- blockwalker current(_Ksize, _k_block, _Nsize, _x_block);
+ blockwalker current(*this);
blockwalker next = current;
+ /* Translate 'start' and 'end' into a position within the batches and rows. */
+ const unsigned int window_per_batch = _Mround / strategy::out_height;
+ unsigned int batch_0 = start / window_per_batch;
+ unsigned int batch_end = end / window_per_batch;
+
/* Compute the M values to operate on */
- unsigned int m_0 = start * strategy::out_height;
- unsigned int m_max = std::min(end * strategy::out_height, _Msize);
+ unsigned int m_0 = (start - (batch_0 * window_per_batch)) * strategy::out_height;
+ unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height;
/* Make sure we've been set up correctly. */
if(pretransposed)
@@ -205,7 +230,8 @@ class GemmInterleaved : public GemmCommon<To, Tr>
int8_t *working_space_bytes = reinterpret_cast<int8_t *>(_working_space);
// Private buffers. Treat working_space as an array of C buffers (one per thread) first, followed by the (window-divided) A buffer.
- Toi *const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()) + (m_0 * _k_block * sizeof(Toi)));
+ // Set a_panel to the base of the A buffers - compute offsets into it based on M/batches later.
+ Toi *const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()));
Tri *const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size()));
// Shared buffers - these come either from BufferManager or _B_transposed.
@@ -225,17 +251,31 @@ class GemmInterleaved : public GemmCommon<To, Tr>
{
if(current.newkblock())
{
- prof(PROFILE_PREPA, ((m_max - m_0) * (current.kmax() - current.k0()) * sizeof(Toi)), [&](void)
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height * (current.kmax() - current.k0()) * sizeof(Toi));
+#endif
+ for(unsigned int batch = batch_0; batch <= batch_end; batch++)
{
+ unsigned int first_m = (batch == batch_0) ? m_0 : 0;
+ unsigned int last_m = (batch == batch_end) ? m_max : _Msize;
+
+ if(first_m >= last_m)
+ continue;
if(_trA ^ strategy::A_transpose)
{
- Transform<strategy::A_interleave, strategy::A_block, true>(a_panel, this->_Aptr, this->_lda, m_0, m_max, current.k0(), current.kmax());
+ Transform<strategy::A_interleave, strategy::A_block, true>(
+ a_panel + ((batch * _Mround + first_m) * _k_block),
+ this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
+ this->_lda, first_m, last_m, current.k0(), current.kmax());
}
else
{
- Transform<strategy::A_interleave, strategy::A_block, false>(a_panel, this->_Aptr, this->_lda, m_0, m_max, current.k0(), current.kmax());
+ Transform<strategy::A_interleave, strategy::A_block, false>(
+ a_panel + ((batch * _Mround + first_m) * _k_block),
+ this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
+ this->_lda, first_m, last_m, current.k0(), current.kmax());
}
- });
+ }
// Figure out how many "K" the kernel will actually process.
kern_k = iceildiv(current.kmax() - current.k0(), strategy::k_unroll);
@@ -258,53 +298,84 @@ class GemmInterleaved : public GemmCommon<To, Tr>
{
_bm->try_populate(next.index(), [&](void *buffer)
{
- prof(PROFILE_PREPB, (next.xmax() - next.x0()) * (next.kmax() - next.k0()) * sizeof(Toi), [&](void)
- {
- Toi *b_panel = reinterpret_cast<Toi *>(buffer);
- if(_trB ^ strategy::B_transpose)
- {
- Transform<strategy::B_interleave, strategy::B_block, true>(b_panel, this->_Bptr, this->_ldb, next.x0(), next.xmax(), next.k0(), next.kmax());
- }
- else
- {
- Transform<strategy::B_interleave, strategy::B_block, false>(b_panel, this->_Bptr, this->_ldb, next.x0(), next.xmax(), next.k0(), next.kmax());
- }
- });
- });
- }
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_PREPB, (next.xmax() - next.x0()) * (next.kmax() - next.k0()) * sizeof(Toi));
+#endif
- /* Get the buffer for this iteration from the BufferManager. */
- b_panel = reinterpret_cast<Toi *>(_bm->get(current.index(), [&](void *bpv)
- {
- prof(PROFILE_PREPB, (current.xmax() - current.x0()) * (current.kmax() - current.k0()) * sizeof(Toi), [&](void)
- {
- Toi *b_panel = reinterpret_cast<Toi *>(bpv);
+ Toi *b_panel = reinterpret_cast<Toi *>(buffer);
if(_trB ^ strategy::B_transpose)
{
- Transform<strategy::B_interleave, strategy::B_block, true>(b_panel, this->_Bptr, this->_ldb, current.x0(), current.xmax(), current.k0(), current.kmax());
+ Transform<strategy::B_interleave, strategy::B_block, true>(
+ b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
+ next.x0(), next.xmax(), next.k0(), next.kmax());
}
else
{
- Transform<strategy::B_interleave, strategy::B_block, false>(b_panel, this->_Bptr, this->_ldb, current.x0(), current.xmax(), current.k0(), current.kmax());
+ Transform<strategy::B_interleave, strategy::B_block, false>(
+ b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
+ next.x0(), next.xmax(), next.k0(), next.kmax());
}
});
+ }
+ /* Get the buffer for this iteration from the BufferManager. */
+ b_panel = reinterpret_cast<Toi *>(_bm->get(current.index(), [&](void *bpv)
+ {
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_PREPB, (current.xmax() - current.x0()) * (current.kmax() - current.k0()) * sizeof(Toi));
+#endif
+
+ Toi *b_panel = reinterpret_cast<Toi *>(bpv);
+ if(_trB ^ strategy::B_transpose)
+ {
+ Transform<strategy::B_interleave, strategy::B_block, true>(
+ b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax());
+ }
+ else
+ {
+ Transform<strategy::B_interleave, strategy::B_block, false>(
+ b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax());
+ }
+
}));
}
/* Do the actual work. */
- for(unsigned int y = m_0; y < m_max; y += strategy::out_height)
+ for(unsigned int batch = batch_0; batch <= batch_end; batch++)
{
- unsigned int ymax = std::min(_Msize, y + strategy::out_height);
+ unsigned int first_m = (batch == batch_0) ? m_0 : 0;
+ unsigned int last_m = (batch == batch_end) ? m_max : _Msize;
- prof(PROFILE_KERNEL, (strategy::out_height * bblocks * strategy::out_width * kern_k), [&](void)
- {
- strat.kernel(a_panel + ((y - m_0) * kern_k), b_panel, c_panel, 1, bblocks, kern_k);
- });
- prof(PROFILE_MERGE, (strategy::out_height * bblocks * strategy::out_width * sizeof(Tr)), [&](void)
+ const Toi *a_ptr = a_panel + (batch * _Mround + first_m) * _k_block;
+
+ if(first_m >= last_m)
+ continue;
+
+ for(unsigned int y = first_m; y < last_m; y += strategy::out_height)
{
- MergeResults<strategy::out_width, strategy::out_height>(this->_Cptr, c_panel, this->_ldc, y, ymax,
- current.x0(), current.xmax(), _alpha, (current.k0() == 0 ? _beta : static_cast<Tr>(1)));
- });
+ unsigned int ymax = std::min(_Msize, y + strategy::out_height);
+
+ {
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height * bblocks * strategy::out_width * kern_k));
+#endif
+
+ strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k);
+
+ a_ptr += (strategy::out_height * kern_k);
+ }
+
+ {
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height * bblocks * strategy::out_width * sizeof(Tr)));
+#endif
+ MergeResults<strategy::out_width, strategy::out_height>(
+ this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride),
+ c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(),
+ _alpha, (current.k0() == 0 ? _beta : static_cast<Tr>(1)));
+ }
+ }
}
if(pretransposed)
@@ -324,9 +395,9 @@ public:
/* Constructor */
GemmInterleaved(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const bool trA, const bool trB, const Tr alpha, const Tr beta, const int maxthreads,
- const bool pretransposed)
- : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), _pretransposed(pretransposed)
+ const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
+ const Tr alpha, const Tr beta, const int maxthreads, const bool pretransposed)
+ : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), _pretransposed(pretransposed)
{
const unsigned int L1_size = ci->get_L1_cache_size();
const unsigned int L2_size = ci->get_L2_cache_size();
@@ -375,11 +446,15 @@ public:
// Interface implementation - Compulsory functions
- // Window size: Only the last thread should do a ragged block, so dole out work in units of out_height */
+ // Window size: Only the last thread should do a ragged block, so dole
+ // out work in units of out_height. Factor batches into the window, but
+ // not multi for now (as this would cause problems with the buffer
+ // manager).
+
unsigned int get_window_size() const override
{
// _Mround is a multiple of out_height by definition.
- return _Mround / strategy::out_height;
+ return (_Mround / strategy::out_height) * _nbatches;
}
// set_nthreads: pass on to buffer manager to avoid it waiting for non-existant threads.
@@ -471,7 +546,7 @@ public:
size_t get_B_pretransposed_array_size() const override
{
size_t total = 0;
- blockwalker current(_Ksize, _k_block, _Nsize, _x_block);
+ blockwalker current(*this);
do
{
@@ -493,9 +568,9 @@ public:
return total;
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb) override
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override
{
- blockwalker current(_Ksize, _k_block, _Nsize, _x_block);
+ blockwalker current(*this);
Toi *buffer = reinterpret_cast<Toi *>(in_buffer);
_B_transposed = buffer;
@@ -514,11 +589,15 @@ public:
if(_trB ^ strategy::B_transpose)
{
- Transform<strategy::B_interleave, strategy::B_block, true>(buffer, B, ldb, current.x0(), current.xmax(), current.k0(), current.kmax());
+ Transform<strategy::B_interleave, strategy::B_block, true>(
+ buffer, B + (current.multi() * B_multi_stride), ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax());
}
else
{
- Transform<strategy::B_interleave, strategy::B_block, false>(buffer, B, ldb, current.x0(), current.xmax(), current.k0(), current.kmax());
+ Transform<strategy::B_interleave, strategy::B_block, false>(
+ buffer, B + (current.multi() * B_multi_stride), ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax());
}
buffer += (x_size * k_size);
@@ -526,6 +605,11 @@ public:
while(current.advance());
}
+ void set_pretransposed_B_data(void *in_buffer) override
+ {
+ _B_transposed = reinterpret_cast<Toi *>(in_buffer);
+ }
+
~GemmInterleaved() override
{
delete _bm;