aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-01-09 18:35:17 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-01-18 13:41:40 +0000
commit7cd26d4a1b14bc4bf7c61496803416ab3d84791f (patch)
tree12cc4a27d7ecebc69a43e96b1f46c7eb05437978 /src/core/NEON/kernels/arm_gemm/gemm_native.hpp
parent3ac2f3a1d9297220d1b0ce920dd13fdd4edcc187 (diff)
downloadComputeLibrary-7cd26d4a1b14bc4bf7c61496803416ab3d84791f.tar.gz
COMPMID-1867: Add NEON/SVE GEMM Hybrid kernels.
Change-Id: Ib40a9921e7f9a6a8be6c38872d6b3a0f24ed0cd3 Reviewed-on: https://review.mlplatform.org/515 Reviewed-by: Anthony Barbier <Anthony.barbier@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_native.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_native.hpp21
1 files changed, 13 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
index baa1316745..579533418d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -74,11 +74,11 @@ public:
GemmNative(GemmNative &) = delete;
GemmNative & operator= (GemmNative &) = delete;
- GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta) :
- _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci) {
+ GemmNative(const GemmArgs<Tr> &args)
+ : _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), _nbatches(args._nbatches), _nmultis(args._nmulti), _beta(args._beta), _ci(args._ci) {
/* For now don't do any blocking. TODO: figure out if we should. */
- k_block = K;
- n_block = N;
+ k_block = _Ksize;
+ n_block = _Nsize;
}
// Window is amount per multi multiplied by total number of multis.
@@ -105,8 +105,13 @@ public:
unsigned int y0 = batch_pos * strategy::out_height();
- for (unsigned int pos=start; pos<end; pos++) {
- const unsigned int ymax = std::min(y0 + strategy::out_height(), _Msize);
+ for (unsigned int l=end-start; l>0; ) {
+ // Do work from here to the end of the current batch/multi
+ const unsigned int ymax = std::min(y0 + (l * strategy::out_height()), _Msize);
+
+ // Work out how many units this is and subtract from loop counter.
+ l -= ((ymax - y0) + (strategy::out_height() - 1)) / strategy::out_height();
+
#ifdef CYCLE_PROFILING
auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize);
#endif
@@ -117,7 +122,7 @@ public:
_beta, (ymax-y0), _Nsize, _Ksize);
/* Advance to next item */
- y0 += strategy::out_height();
+ y0 = ymax;
/* Check for batch/multi overflow */
if (y0 >= _Msize) {