diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp | 71 |
1 files changed, 39 insertions, 32 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index 0e58a4d01f..436438f351 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -318,50 +318,57 @@ public: /* Constructor */ GemmInterleaved(const GemmArgs<Tr> &args) - : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), - _nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB), - _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), - _pretransposed(args._pretransposed_hint) { + : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), + _nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB), + _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _pretransposed(args._pretransposed_hint) { const unsigned int L1_size = _ci->get_L1_cache_size(); const unsigned int L2_size = _ci->get_L2_cache_size(); assert(_maxthreads > 0); - // Work out blocking parameters - - // k_block: Find out how much of the larger array can be loaded into half the cache. - // This should account for associative caches. - _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height()))); + // Work out blocking parameters, or override from provided GemmConfig + if (args._cfg && args._cfg->inner_block_size) { + _k_block = args._cfg->inner_block_size; + } else { + // k_block: Find out how much of the larger array can be loaded into half the cache. + // This should account for associative caches. + _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height()))); - // Needs to be (at least a single) multiple of the K unroll level. - _k_block /= strategy::k_unroll(); - _k_block = std::max(_k_block, 1U) * strategy::k_unroll(); + // Needs to be (at least a single) multiple of the K unroll level. + _k_block /= strategy::k_unroll(); + _k_block = std::max(_k_block, 1U) * strategy::k_unroll(); - // Now tune to presented problem size; this is how many blocks we need. - int num_k_blocks = iceildiv(_Ksize, _k_block); + // Now tune to presented problem size; this is how many blocks we need. + int num_k_blocks = iceildiv(_Ksize, _k_block); - // So divide the space equally into that many blocks. - _k_block = iceildiv(_Ksize, num_k_blocks); + // So divide the space equally into that many blocks. + _k_block = iceildiv(_Ksize, num_k_blocks); - // And round UP to the K unroll level required. - _k_block = iceildiv(_k_block, strategy::k_unroll()); - _k_block *= strategy::k_unroll(); + // And round UP to the K unroll level required. + _k_block = iceildiv(_k_block, strategy::k_unroll()); + _k_block *= strategy::k_unroll(); + } - // x_block: Work out how many rows (of length k_block) will fit in the L2 - // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents. - _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) / - (sizeof(Toi) * _k_block); + if (args._cfg && args._cfg->outer_block_size) { + _x_block = args._cfg->outer_block_size; + } else { + // x_block: Work out how many rows (of length k_block) will fit in the L2 + // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents. + _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) / + (sizeof(Toi) * _k_block); - // Needs to be (at least a single) multiple of the kernel output width. - _x_block /= strategy::out_width(); - _x_block = std::max(_x_block, 1U) * strategy::out_width(); + // Needs to be (at least a single) multiple of the kernel output width. + _x_block /= strategy::out_width(); + _x_block = std::max(_x_block, 1U) * strategy::out_width(); - // And tune to the presented problem size. - int num_x_blocks = iceildiv(_Nsize, _x_block); - _x_block = iceildiv(_Nsize, num_x_blocks); + // And tune to the presented problem size. + int num_x_blocks = iceildiv(_Nsize, _x_block); + _x_block = iceildiv(_Nsize, num_x_blocks); - _x_block = iceildiv(_x_block, strategy::out_width()); - _x_block *= strategy::out_width(); + _x_block = iceildiv(_x_block, strategy::out_width()); + _x_block *= strategy::out_width(); + } // Work out the rounded size of M - needed for some buffers. _Mround = iceildiv(_Msize, strategy::out_height()); |