aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4/a55.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4/a55.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4/a55.cpp17
1 files changed, 14 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4/a55.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4/a55.cpp
index 5bce632bc4..1b828ee503 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4/a55.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4/a55.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -61,12 +61,23 @@ void a64_hybrid_fp32_mla_16x4_a55(const float *A, int lda, const float *B, float
break;
}
- for (int y=0; y<M; y+=4) {
+ int rows_to_compute;
+
+ for (int y=0; y<M; y+=rows_to_compute) {
const float * const a_ptr0_base = A + (y * lda);
const unsigned long ldab = lda * sizeof(float);
float *c_ptr0 = C + (y * ldc);
+ rows_to_compute = M-y;
+ if (rows_to_compute > 4) {
+ if (rows_to_compute % 4) {
+ rows_to_compute = 4 - 1;
+ } else {
+ rows_to_compute = 4;
+ }
+ }
+
for (int x0=0; x0<N; x0+=16ul) {
const long width = std::min((unsigned long)N-x0, 16ul);
long loops = loops_count;
@@ -90,7 +101,7 @@ void a64_hybrid_fp32_mla_16x4_a55(const float *A, int lda, const float *B, float
}
const float *biasptr = bias ? bias+x0 : nullbias;
- switch(M-y) {
+ switch(rows_to_compute) {
case 1:
__asm __volatile (
"temploadreg0 .req X0\n"