diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index c5a43e6519..0e58a4d01f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -317,16 +317,15 @@ public: GemmInterleaved & operator= (GemmInterleaved &) = delete; /* Constructor */ - GemmInterleaved(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, - const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB, - const Tr alpha, const Tr beta, const int maxthreads, const bool pretransposed) : - _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), - _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), - _maxthreads(maxthreads), _nthreads(maxthreads), _pretransposed(pretransposed) { - const unsigned int L1_size = ci->get_L1_cache_size(); - const unsigned int L2_size = ci->get_L2_cache_size(); + GemmInterleaved(const GemmArgs<Tr> &args) + : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), + _nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB), + _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _pretransposed(args._pretransposed_hint) { + const unsigned int L1_size = _ci->get_L1_cache_size(); + const unsigned int L2_size = _ci->get_L2_cache_size(); - assert(maxthreads > 0); + assert(_maxthreads > 0); // Work out blocking parameters @@ -339,10 +338,10 @@ public: _k_block = std::max(_k_block, 1U) * strategy::k_unroll(); // Now tune to presented problem size; this is how many blocks we need. - int num_k_blocks = iceildiv(K, _k_block); + int num_k_blocks = iceildiv(_Ksize, _k_block); // So divide the space equally into that many blocks. - _k_block = iceildiv(K, num_k_blocks); + _k_block = iceildiv(_Ksize, num_k_blocks); // And round UP to the K unroll level required. _k_block = iceildiv(_k_block, strategy::k_unroll()); @@ -358,14 +357,14 @@ public: _x_block = std::max(_x_block, 1U) * strategy::out_width(); // And tune to the presented problem size. - int num_x_blocks = iceildiv(N, _x_block); - _x_block = iceildiv(N, num_x_blocks); + int num_x_blocks = iceildiv(_Nsize, _x_block); + _x_block = iceildiv(_Nsize, num_x_blocks); _x_block = iceildiv(_x_block, strategy::out_width()); _x_block *= strategy::out_width(); // Work out the rounded size of M - needed for some buffers. - _Mround = iceildiv(M, strategy::out_height()); + _Mround = iceildiv(_Msize, strategy::out_height()); _Mround *= strategy::out_height(); } |