aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2018-07-06 17:53:35 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:10 +0000
commite39334c15c7fd141bb8173d5017ea5ca157fca2c (patch)
treefffa2f7b136525037c4d99586bc194374e5bd3dc /src/core/NEON/kernels/arm_gemm/gemm_native.hpp
parente8bd2c729546e59aa0adc241976ea91fc6f25b52 (diff)
downloadComputeLibrary-e39334c15c7fd141bb8173d5017ea5ca157fca2c.tar.gz
COMPMID-1271: New system for GEMM heuristics
This patch implements a system for separating the "validity" from "preferred" aspect of the current heuristics in gemm_*.cpp. Now, each gemm_*.cpp defines a list of candidate implementations, each of which supplies an is_valid() function (to check for validity), an is_preferred() function (the "heuristic" part), and an instantiate() function which actually produces the GemmCommon object pointer. The actual gemm() function is now templated and uses this list to select an implementation. This patch also implements a mechanism to identify the preferred implementation, and override it via the GemmConfig structure. Change-Id: Id49ab7af8bf2e3e9fd951a9698883ade234d40e1 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139120 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_native.hpp62
1 files changed, 35 insertions, 27 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
index 6fed645d82..baa1316745 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
@@ -62,6 +62,14 @@ class GemmNative : public GemmCommon<To, Tr> {
unsigned int k_block=0;
unsigned int n_block=0;
+ unsigned int window_per_batch() const {
+ return iceildiv(_Msize, strategy::out_height());
+ }
+
+ unsigned int window_per_multi() const {
+ return window_per_batch() * _nbatches;
+ }
+
public:
GemmNative(GemmNative &) = delete;
GemmNative & operator= (GemmNative &) = delete;
@@ -73,9 +81,9 @@ public:
n_block = N;
}
- // Window is number of out_height blocks
+ // Window is amount per multi multiplied by total number of multis.
unsigned int get_window_size() const override {
- return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis;
+ return window_per_multi() * _nmultis;
}
// Actually execute the GEMM.
@@ -85,39 +93,39 @@ public:
#endif
strategy strat(_ci);
- const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height);
- const unsigned int window_per_multi = window_per_batch * _nbatches;
-
- const unsigned int first_multi = start / window_per_multi;
- const unsigned int last_multi = end / window_per_multi;
-
- const unsigned int first_batch = (start - (first_multi * window_per_multi)) / window_per_batch;
- const unsigned int last_batch = (end - (last_multi * window_per_multi)) / window_per_batch;
-
- const unsigned int first_row = ((start - (first_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
- const unsigned int last_row = ((end - (last_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
-
static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same.");
static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
- for (unsigned int multi=first_multi; multi<=last_multi; multi++) {
- const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0;
- const unsigned int batch_max = (multi == last_multi) ? last_batch : (_nbatches-1);
+ /* Compute starting point based on 'start' */
+ unsigned int multi = start / window_per_multi();
+ unsigned int multi_pos = start % window_per_multi();
+
+ unsigned int batch = multi_pos / window_per_batch();
+ unsigned int batch_pos = multi_pos % window_per_batch();
- for (unsigned int batch=batch_0; batch <= batch_max; batch++) {
- const unsigned int m_start = ((multi == first_multi) && (batch==first_batch)) ? first_row : 0;
- const unsigned int m_end = ((multi == last_multi) && (batch==last_batch)) ? last_row : _Msize;
+ unsigned int y0 = batch_pos * strategy::out_height();
- for (unsigned int y0=m_start; y0<m_end; y0+=strategy::out_height) {
- const unsigned int ymax = std::min(y0 + strategy::out_height, m_end);
+ for (unsigned int pos=start; pos<end; pos++) {
+ const unsigned int ymax = std::min(y0 + strategy::out_height(), _Msize);
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize);
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize);
#endif
- strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
- this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
- this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
- _beta, (ymax-y0), _Nsize, _Ksize);
+ strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
+ this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
+ this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
+ _beta, (ymax-y0), _Nsize, _Ksize);
+
+ /* Advance to next item */
+ y0 += strategy::out_height();
+
+ /* Check for batch/multi overflow */
+ if (y0 >= _Msize) {
+ y0=0;
+ batch++;
+ if (batch == _nbatches) {
+ batch=0;
+ multi++;
}
}
}