aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp84
1 files changed, 57 insertions, 27 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 3b829491ca..c4dceef922 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -23,15 +23,14 @@
*/
#pragma once
-#include <stdio.h>
-#include <assert.h>
-
#include <algorithm>
+#include <cassert>
#include "arm_gemm.hpp"
#include "utils.hpp"
#include "mergeresults.hpp"
+#include "performance_parameters.hpp"
#include "transform.hpp"
#ifdef CYCLE_PROFILING
@@ -149,6 +148,33 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height());
}
+ static unsigned int get_k_block_size(const GemmArgs &args) {
+ if (args._cfg && args._cfg->inner_block_size) {
+ return args._cfg->inner_block_size;
+ }
+
+ const unsigned int L1_size = args._ci->get_L1_cache_size();
+ unsigned int k_block;
+
+ // k_block: Find out how much of the larger array can be loaded into half the cache.
+ // This should account for associative caches.
+ k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height())));
+
+ // Needs to be (at least a single) multiple of the K unroll level.
+ k_block /= strategy::k_unroll();
+ k_block = std::max(k_block, 1U) * strategy::k_unroll();
+
+ // Now tune to presented problem size; this is how many blocks we need.
+ unsigned int num_k_blocks = iceildiv(args._Ksize, k_block);
+
+ // So divide the space equally into that many blocks.
+ k_block = iceildiv(args._Ksize, num_k_blocks);
+
+ // And round UP to the K unroll level required.
+ k_block = roundup(k_block, strategy::k_unroll());
+
+ return k_block;
+ }
public:
GemmInterleaved(GemmInterleaved &) = delete;
@@ -158,35 +184,14 @@ public:
GemmInterleaved(const GemmArgs &args)
: _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
_nbatches(args._nbatches), _nmulti(args._nmulti),
- _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads) {
- const unsigned int L1_size = _ci->get_L1_cache_size();
+ _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _k_block(get_k_block_size(args)) {
const unsigned int L2_size = _ci->get_L2_cache_size();
assert(_maxthreads > 0);
// Work out blocking parameters, or override from provided GemmConfig
- if (args._cfg && args._cfg->inner_block_size) {
- _k_block = args._cfg->inner_block_size;
- } else {
- // k_block: Find out how much of the larger array can be loaded into half the cache.
- // This should account for associative caches.
- _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height())));
-
- // Needs to be (at least a single) multiple of the K unroll level.
- _k_block /= strategy::k_unroll();
- _k_block = std::max(_k_block, 1U) * strategy::k_unroll();
-
- // Now tune to presented problem size; this is how many blocks we need.
- unsigned int num_k_blocks = iceildiv(_Ksize, _k_block);
-
- // So divide the space equally into that many blocks.
- _k_block = iceildiv(_Ksize, num_k_blocks);
-
- // And round UP to the K unroll level required.
- _k_block = iceildiv(_k_block, strategy::k_unroll());
- _k_block *= strategy::k_unroll();
- }
-
+ // TODO: Move outer block into a static function too.
if (args._cfg && args._cfg->outer_block_size) {
_x_block = args._cfg->outer_block_size;
} else {
@@ -422,6 +427,31 @@ public:
void set_pretransposed_B_data(void *in_buffer) override {
_B_transposed = reinterpret_cast<Toi *>(in_buffer);
}
+
+ // Estimate cycles for given problem given provided parameters
+ static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters &params) {
+ unsigned int k_blocks = iceildiv(args._Ksize, get_k_block_size(args));
+
+ uint64_t total_macs = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * roundup(args._Ksize, strategy::k_unroll());
+ uint64_t prepare_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Ksize, strategy::k_unroll()) * sizeof(Toi);
+ uint64_t merge_bytes = static_cast<uint16_t>(args._nbatches) * args._nmulti * k_blocks * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * sizeof(Tr);
+
+ float mac_cycles = static_cast<float>(total_macs) / params.kernel_macs_cycle;
+ float prepare_cycles = static_cast<float>(prepare_bytes) / params.prepare_bytes_cycle;
+ float merge_cycles = static_cast<float>(merge_bytes) / params.merge_bytes_cycle;
+
+ float total_cycles = mac_cycles + prepare_cycles + merge_cycles;
+
+ // We can't thread over multis or width, which makes this a poor
+ // choice in many threaded cases. Penalize that here.
+ float parallelism_available = static_cast<float>(iceildiv(args._Msize, strategy::out_height()) * args._nbatches) * 0.9f;
+
+ if (parallelism_available < args._maxthreads) {
+ total_cycles *= (static_cast<float>(args._maxthreads) / parallelism_available);
+ }
+
+ return static_cast<uint64_t>(total_cycles);
+ }
};
} // namespace arm_gemm