aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-04-13 13:44:10 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:37 +0000
commite7e96e09ff0d3e47797adf197aff2bc39671788c (patch)
treeb52ecdd7627bdf51b8b8da9b9553cb900460222f /src/core/NEON/kernels/arm_gemm/gemm_native.hpp
parent1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9 (diff)
downloadComputeLibrary-e7e96e09ff0d3e47797adf197aff2bc39671788c.tar.gz
COMPMID-1054 Update RSH's GEMM to add batch+multi support
Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_native.hpp54
1 files changed, 43 insertions, 11 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
index b0192793b9..beecb76f20 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
@@ -28,9 +28,12 @@
#include "arm_gemm.hpp"
#include "mergeresults.hpp"
-#include "profiler.hpp"
#include "transform.hpp"
+#ifdef CYCLE_PROFILING
+#include "profiler.hpp"
+#endif
+
namespace arm_gemm
{
// Implementation of the GemmCommon abstract class.
@@ -50,6 +53,9 @@ class GemmNative : public GemmCommon<To, Tr>
const unsigned int _Nsize;
const unsigned int _Ksize;
+ const unsigned int _nbatches;
+ const unsigned int _nmultis;
+
Tr _beta;
const CPUInfo *const _ci;
@@ -61,8 +67,8 @@ public:
GemmNative(GemmNative &) = delete;
GemmNative &operator=(GemmNative &) = delete;
- GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const Tr beta)
- : _Msize(M), _Nsize(N), _Ksize(K), _beta(beta), _ci(ci)
+ GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta)
+ : _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci)
{
/* For now don't do any blocking. TODO: figure out if we should. */
k_block = K;
@@ -72,29 +78,55 @@ public:
// Window is number of out_height blocks
unsigned int get_window_size() const override
{
- return iceildiv(_Msize, strategy::out_height);
+ return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis;
}
// Actually execute the GEMM.
void execute(unsigned int start, unsigned int end, int) override
{
+#ifdef CYCLE_PROFILING
profiler prof;
+#endif
strategy strat(_ci);
- unsigned int M_start = start * strategy::out_height;
- unsigned int M_end = std::min(end * strategy::out_height, _Msize);
+ const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height);
+ const unsigned int window_per_multi = window_per_batch * _nbatches;
+
+ const unsigned int first_multi = start / window_per_multi;
+ const unsigned int last_multi = end / window_per_multi;
+
+ const unsigned int first_batch = (start - (first_multi * window_per_multi)) / window_per_batch;
+ const unsigned int last_batch = (end - (last_multi * window_per_multi)) / window_per_batch;
+
+ const unsigned int first_row = ((start - (first_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
+ const unsigned int last_row = ((end - (last_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same.");
static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
- for(unsigned int y0 = M_start; y0 < M_end; y0 += strategy::out_height)
+ for(unsigned int multi = first_multi; multi <= last_multi; multi++)
{
- unsigned int ymax = std::min(y0 + strategy::out_height, M_end);
+ const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0;
+ const unsigned int batch_max = (multi == last_multi) ? last_batch : _nbatches;
- prof(PROFILE_KERNEL, (ymax - y0) * _Nsize * _Ksize, [&](void)
+ for(unsigned int batch = batch_0; batch < batch_max; batch++)
{
- strat.kernel(this->_Aptr + (y0 * this->_lda), this->_lda, this->_Bptr, this->_ldb, this->_Cptr + (y0 * this->_ldc), this->_ldc, _beta, (ymax - y0), _Nsize, _Ksize);
- });
+ const unsigned int m_start = ((multi == first_multi) && (batch == first_batch)) ? first_row : 0;
+ const unsigned int m_end = ((multi == last_multi) && (batch == last_batch)) ? last_row : _Msize;
+
+ for(unsigned int y0 = m_start; y0 < m_end; y0 += strategy::out_height)
+ {
+ const unsigned int ymax = std::min(y0 + strategy::out_height, m_end);
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax - y0) * _Nsize * _Ksize);
+#endif
+
+ strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
+ this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
+ this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
+ _beta, (ymax - y0), _Nsize, _Ksize);
+ }
+ }
}
}
};