diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_native.hpp | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp index baa1316745..579533418d 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -74,11 +74,11 @@ public: GemmNative(GemmNative &) = delete; GemmNative & operator= (GemmNative &) = delete; - GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta) : - _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci) { + GemmNative(const GemmArgs<Tr> &args) + : _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), _nbatches(args._nbatches), _nmultis(args._nmulti), _beta(args._beta), _ci(args._ci) { /* For now don't do any blocking. TODO: figure out if we should. */ - k_block = K; - n_block = N; + k_block = _Ksize; + n_block = _Nsize; } // Window is amount per multi multiplied by total number of multis. @@ -105,8 +105,13 @@ public: unsigned int y0 = batch_pos * strategy::out_height(); - for (unsigned int pos=start; pos<end; pos++) { - const unsigned int ymax = std::min(y0 + strategy::out_height(), _Msize); + for (unsigned int l=end-start; l>0; ) { + // Do work from here to the end of the current batch/multi + const unsigned int ymax = std::min(y0 + (l * strategy::out_height()), _Msize); + + // Work out how many units this is and subtract from loop counter. + l -= ((ymax - y0) + (strategy::out_height() - 1)) / strategy::out_height(); + #ifdef CYCLE_PROFILING auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize); #endif @@ -117,7 +122,7 @@ public: _beta, (ymax-y0), _Nsize, _Ksize); /* Advance to next item */ - y0 += strategy::out_height(); + y0 = ymax; /* Check for batch/multi overflow */ if (y0 >= _Msize) { |