aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_fp16_mla_4VLx4/generic.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_fp16_mla_4VLx4/generic.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_fp16_mla_4VLx4/generic.cpp17
1 files changed, 14 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_fp16_mla_4VLx4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_fp16_mla_4VLx4/generic.cpp
index 2998f33d87..3aef916ad2 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_fp16_mla_4VLx4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_fp16_mla_4VLx4/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -61,12 +61,23 @@ void sve_hybrid_fp16_mla_4VLx4(const __fp16 *A, int lda, const __fp16 *B, __fp16
break;
}
- for (int y=0; y<M; y+=4) {
+ int rows_to_compute;
+
+ for (int y=0; y<M; y+=rows_to_compute) {
const __fp16 * const a_ptr0_base = A + (y * lda);
const unsigned long ldab = lda * sizeof(__fp16);
__fp16 *c_ptr0 = C + (y * ldc);
+ rows_to_compute = M-y;
+ if (rows_to_compute > 4) {
+ if (rows_to_compute % 4) {
+ rows_to_compute = 4 - 1;
+ } else {
+ rows_to_compute = 4;
+ }
+ }
+
for (int x0=0; x0<N; x0+=(4 * get_vector_length<__fp16>())) {
const long width = std::min((unsigned long)N-x0, (4 * get_vector_length<__fp16>()));
long loops = loops_count;
@@ -78,7 +89,7 @@ void sve_hybrid_fp16_mla_4VLx4(const __fp16 *A, int lda, const __fp16 *B, __fp16
const unsigned long ldcb = ldc * sizeof(__fp16);
const __fp16 *biasptr = bias ? bias+x0 : nullbias;
- switch(M-y) {
+ switch(rows_to_compute) {
case 1:
__asm __volatile (
"whilelt p6.h, %[temp], %[leftovers]\n"