aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2018-07-06 17:53:35 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:10 +0000
commite39334c15c7fd141bb8173d5017ea5ca157fca2c (patch)
treefffa2f7b136525037c4d99586bc194374e5bd3dc /src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
parente8bd2c729546e59aa0adc241976ea91fc6f25b52 (diff)
downloadComputeLibrary-e39334c15c7fd141bb8173d5017ea5ca157fca2c.tar.gz
COMPMID-1271: New system for GEMM heuristics
This patch implements a system for separating the "validity" from "preferred" aspect of the current heuristics in gemm_*.cpp. Now, each gemm_*.cpp defines a list of candidate implementations, each of which supplies an is_valid() function (to check for validity), an is_preferred() function (the "heuristic" part), and an instantiate() function which actually produces the GemmCommon object pointer. The actual gemm() function is now templated and uses this list to select an implementation. This patch also implements a mechanism to identify the preferred implementation, and override it via the GemmConfig structure. Change-Id: Id49ab7af8bf2e3e9fd951a9698883ade234d40e1 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139120 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp27
1 files changed, 13 insertions, 14 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index c5a43e6519..0e58a4d01f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -317,16 +317,15 @@ public:
GemmInterleaved & operator= (GemmInterleaved &) = delete;
/* Constructor */
- GemmInterleaved(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
- const Tr alpha, const Tr beta, const int maxthreads, const bool pretransposed) :
- _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti),
- _trA(trA), _trB(trB), _alpha(alpha), _beta(beta),
- _maxthreads(maxthreads), _nthreads(maxthreads), _pretransposed(pretransposed) {
- const unsigned int L1_size = ci->get_L1_cache_size();
- const unsigned int L2_size = ci->get_L2_cache_size();
+ GemmInterleaved(const GemmArgs<Tr> &args)
+ : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
+ _nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB),
+ _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _pretransposed(args._pretransposed_hint) {
+ const unsigned int L1_size = _ci->get_L1_cache_size();
+ const unsigned int L2_size = _ci->get_L2_cache_size();
- assert(maxthreads > 0);
+ assert(_maxthreads > 0);
// Work out blocking parameters
@@ -339,10 +338,10 @@ public:
_k_block = std::max(_k_block, 1U) * strategy::k_unroll();
// Now tune to presented problem size; this is how many blocks we need.
- int num_k_blocks = iceildiv(K, _k_block);
+ int num_k_blocks = iceildiv(_Ksize, _k_block);
// So divide the space equally into that many blocks.
- _k_block = iceildiv(K, num_k_blocks);
+ _k_block = iceildiv(_Ksize, num_k_blocks);
// And round UP to the K unroll level required.
_k_block = iceildiv(_k_block, strategy::k_unroll());
@@ -358,14 +357,14 @@ public:
_x_block = std::max(_x_block, 1U) * strategy::out_width();
// And tune to the presented problem size.
- int num_x_blocks = iceildiv(N, _x_block);
- _x_block = iceildiv(N, num_x_blocks);
+ int num_x_blocks = iceildiv(_Nsize, _x_block);
+ _x_block = iceildiv(_Nsize, num_x_blocks);
_x_block = iceildiv(_x_block, strategy::out_width());
_x_block *= strategy::out_width();
// Work out the rounded size of M - needed for some buffers.
- _Mround = iceildiv(M, strategy::out_height());
+ _Mround = iceildiv(_Msize, strategy::out_height());
_Mround *= strategy::out_height();
}