aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/sve_native_fp32_mla_4VLx4/generic.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/sve_native_fp32_mla_4VLx4/generic.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sve_native_fp32_mla_4VLx4/generic.cpp199
1 files changed, 107 insertions, 92 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_native_fp32_mla_4VLx4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_native_fp32_mla_4VLx4/generic.cpp
index 3fc0e5fa36..b05906e199 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/sve_native_fp32_mla_4VLx4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_native_fp32_mla_4VLx4/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -60,12 +60,23 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
break;
}
- for (int y=0; y<M; y+=4) {
+ int rows_to_compute;
+
+ for (int y=0; y<M; y+=rows_to_compute) {
const float * const a_ptr0_base = A + (y * lda);
const unsigned long ldab = lda * sizeof(float);
float *c_ptr0 = C + (y * ldc);
+ rows_to_compute = M-y;
+ if (rows_to_compute > 4) {
+ if (rows_to_compute % 4) {
+ rows_to_compute = 4 - 1;
+ } else {
+ rows_to_compute = 4;
+ }
+ }
+
for (int x0=0; x0<N; x0+=(4 * get_vector_length<float>())) {
const long width = std::min((unsigned long)N-x0, (4 * get_vector_length<float>()));
long loops = loops_count;
@@ -78,7 +89,7 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
const unsigned long ldcb = ldc * sizeof(float);
const float *biasptr = bias ? bias+x0 : nullbias;
- switch(M-y) {
+ switch(rows_to_compute) {
case 1:
__asm __volatile (
"whilelt p6.s, %[temp], %[leftovers]\n"
@@ -184,52 +195,51 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"ld1w z12.s, p0/z, [%[b_ptr0]]\n"
"ld1w z13.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"ld1w z14.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
- "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"cbz %[regs], 3f\n"
"fmla z16.s, z8.s, z0.s[0]\n"
- "ld1rqw z4.s, p7/z, [%[a_ptr0]]\n"
+ "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z17.s, z9.s, z0.s[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1rqw z4.s, p7/z, [%[a_ptr0]]\n"
"fmla z18.s, z10.s, z0.s[0]\n"
- "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z19.s, z11.s, z0.s[0]\n"
- "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
+ "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
"fmla z16.s, z12.s, z0.s[1]\n"
- "ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
+ "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z17.s, z13.s, z0.s[1]\n"
- "ld1w z11.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
+ "ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
"fmla z18.s, z14.s, z0.s[1]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1w z11.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z19.s, z15.s, z0.s[1]\n"
- "ld1w z12.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z16.s, z8.s, z0.s[2]\n"
- "ld1w z13.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
+ "ld1w z12.s, p0/z, [%[b_ptr0]]\n"
"fmla z17.s, z9.s, z0.s[2]\n"
- "ld1w z14.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
+ "ld1w z13.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z18.s, z10.s, z0.s[2]\n"
- "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
+ "ld1w z14.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
"fmla z19.s, z11.s, z0.s[2]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z16.s, z12.s, z0.s[3]\n"
- "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z17.s, z13.s, z0.s[3]\n"
- "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
+ "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
"fmla z18.s, z14.s, z0.s[3]\n"
- "ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
+ "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z19.s, z15.s, z0.s[3]\n"
+ "ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
"ld1w z11.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
- "fmla z16.s, z8.s, z4.s[0]\n"
- "ld1rqw z0.s, p6/z, [%[a_ptr0], #0x10]\n"
- "fmla z17.s, z9.s, z4.s[0]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
- "fmla z18.s, z10.s, z4.s[0]\n"
+ "fmla z16.s, z8.s, z4.s[0]\n"
"ld1w z12.s, p0/z, [%[b_ptr0]]\n"
- "fmla z19.s, z11.s, z4.s[0]\n"
+ "fmla z17.s, z9.s, z4.s[0]\n"
"ld1w z13.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
+ "fmla z18.s, z10.s, z4.s[0]\n"
"ld1w z14.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
- "addvl %[a_ptr0], %[a_ptr0], #2\n"
- "fmla z16.s, z12.s, z4.s[1]\n"
+ "fmla z19.s, z11.s, z4.s[0]\n"
"ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
+ "fmla z16.s, z12.s, z4.s[1]\n"
+ "ld1rqw z0.s, p6/z, [%[a_ptr0], #0x10]\n"
"fmla z17.s, z13.s, z4.s[1]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z18.s, z14.s, z4.s[1]\n"
@@ -237,15 +247,16 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"fmla z19.s, z15.s, z4.s[1]\n"
"ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
+ "addvl %[a_ptr0], %[a_ptr0], #2\n"
+ "fmla z16.s, z8.s, z4.s[2]\n"
"ld1w z11.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
+ "fmla z17.s, z9.s, z4.s[2]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
- "fmla z16.s, z8.s, z4.s[2]\n"
+ "fmla z18.s, z10.s, z4.s[2]\n"
"ld1w z12.s, p0/z, [%[b_ptr0]]\n"
- "fmla z17.s, z9.s, z4.s[2]\n"
+ "fmla z19.s, z11.s, z4.s[2]\n"
"ld1w z13.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
- "fmla z18.s, z10.s, z4.s[2]\n"
"ld1w z14.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
- "fmla z19.s, z11.s, z4.s[2]\n"
"ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z16.s, z12.s, z4.s[3]\n"
"fmla z17.s, z13.s, z4.s[3]\n"
@@ -286,30 +297,31 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"b 4f\n"
"3:\n"
"fmla z16.s, z8.s, z0.s[0]\n"
- "ld1rqw z4.s, p6/z, [%[a_ptr0]]\n"
+ "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z17.s, z9.s, z0.s[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1rqw z4.s, p6/z, [%[a_ptr0]]\n"
"fmla z18.s, z10.s, z0.s[0]\n"
- "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z19.s, z11.s, z0.s[0]\n"
- "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
+ "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
"fmla z16.s, z12.s, z0.s[1]\n"
- "ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
+ "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z17.s, z13.s, z0.s[1]\n"
- "ld1w z11.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
+ "ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
"fmla z18.s, z14.s, z0.s[1]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1w z11.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z19.s, z15.s, z0.s[1]\n"
- "ld1w z12.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z16.s, z8.s, z0.s[2]\n"
- "ld1w z13.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
+ "ld1w z12.s, p0/z, [%[b_ptr0]]\n"
"fmla z17.s, z9.s, z0.s[2]\n"
- "ld1w z14.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
+ "ld1w z13.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z18.s, z10.s, z0.s[2]\n"
- "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
+ "ld1w z14.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
"fmla z19.s, z11.s, z0.s[2]\n"
- "addvl %[a_ptr0], %[a_ptr0], #1\n"
+ "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z16.s, z12.s, z0.s[3]\n"
+ "addvl %[a_ptr0], %[a_ptr0], #1\n"
"fmla z17.s, z13.s, z0.s[3]\n"
"fmla z18.s, z14.s, z0.s[3]\n"
"fmla z19.s, z15.s, z0.s[3]\n"
@@ -516,21 +528,21 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"fmla z23.s, z15.s, z5.s[3]\n"
"b.ne 2b\n"
"1:\n"
- "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"cbz %[regs], 3f\n"
"fmla z16.s, z8.s, z0.s[0]\n"
- "ld1rqw z4.s, p7/z, [%[a_ptr0]]\n"
+ "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z20.s, z8.s, z1.s[0]\n"
- "ld1rqw z5.s, p7/z, [a_ptr1]\n"
+ "ld1rqw z4.s, p7/z, [%[a_ptr0]]\n"
"fmla z17.s, z9.s, z0.s[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1rqw z5.s, p7/z, [a_ptr1]\n"
"fmla z21.s, z9.s, z1.s[0]\n"
- "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z18.s, z10.s, z0.s[0]\n"
- "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
+ "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
"fmla z22.s, z10.s, z1.s[0]\n"
- "ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
+ "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z19.s, z11.s, z0.s[0]\n"
+ "ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
"fmla z23.s, z11.s, z1.s[0]\n"
"ld1w z11.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z16.s, z12.s, z0.s[1]\n"
@@ -665,19 +677,19 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"b 4f\n"
"3:\n"
"fmla z16.s, z8.s, z0.s[0]\n"
- "ld1rqw z4.s, p6/z, [%[a_ptr0]]\n"
+ "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z20.s, z8.s, z1.s[0]\n"
- "ld1rqw z5.s, p6/z, [a_ptr1]\n"
+ "ld1rqw z4.s, p6/z, [%[a_ptr0]]\n"
"fmla z17.s, z9.s, z0.s[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1rqw z5.s, p6/z, [a_ptr1]\n"
"fmla z21.s, z9.s, z1.s[0]\n"
- "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z18.s, z10.s, z0.s[0]\n"
- "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
+ "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
"fmla z22.s, z10.s, z1.s[0]\n"
- "ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
+ "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z19.s, z11.s, z0.s[0]\n"
- "addvl %[a_ptr0], %[a_ptr0], #1\n"
+ "ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
"fmla z23.s, z11.s, z1.s[0]\n"
"ld1w z11.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z16.s, z12.s, z0.s[1]\n"
@@ -685,10 +697,11 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"fmla z20.s, z12.s, z1.s[1]\n"
"ld1w z12.s, p0/z, [%[b_ptr0]]\n"
"fmla z17.s, z13.s, z0.s[1]\n"
- "addvl a_ptr1, a_ptr1, #1\n"
+ "addvl %[a_ptr0], %[a_ptr0], #1\n"
"fmla z21.s, z13.s, z1.s[1]\n"
"ld1w z13.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z18.s, z14.s, z0.s[1]\n"
+ "addvl a_ptr1, a_ptr1, #1\n"
"fmla z22.s, z14.s, z1.s[1]\n"
"ld1w z14.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
"fmla z19.s, z15.s, z0.s[1]\n"
@@ -861,9 +874,9 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"fmla z27.s, z11.s, z2.s[0]\n"
"ld1w z11.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z16.s, z12.s, z0.s[1]\n"
- "add a_ptr2, a_ptr2, #0x20\n"
- "fmla z20.s, z12.s, z1.s[1]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "fmla z20.s, z12.s, z1.s[1]\n"
+ "add a_ptr2, a_ptr2, #0x20\n"
"fmla z24.s, z12.s, z2.s[1]\n"
"ld1w z12.s, p0/z, [%[b_ptr0]]\n"
"fmla z17.s, z13.s, z0.s[1]\n"
@@ -984,21 +997,21 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"fmla z27.s, z15.s, z6.s[3]\n"
"b.ne 2b\n"
"1:\n"
- "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"cbz %[regs], 3f\n"
"fmla z16.s, z8.s, z0.s[0]\n"
- "ld1rqw z4.s, p7/z, [%[a_ptr0]]\n"
+ "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z20.s, z8.s, z1.s[0]\n"
- "ld1rqw z5.s, p7/z, [a_ptr1]\n"
+ "ld1rqw z4.s, p7/z, [%[a_ptr0]]\n"
"fmla z24.s, z8.s, z2.s[0]\n"
- "ld1rqw z6.s, p7/z, [a_ptr2]\n"
+ "ld1rqw z5.s, p7/z, [a_ptr1]\n"
"fmla z17.s, z9.s, z0.s[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1rqw z6.s, p7/z, [a_ptr2]\n"
"fmla z21.s, z9.s, z1.s[0]\n"
- "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z25.s, z9.s, z2.s[0]\n"
- "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
+ "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
"fmla z18.s, z10.s, z0.s[0]\n"
+ "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z22.s, z10.s, z1.s[0]\n"
"fmla z26.s, z10.s, z2.s[0]\n"
"ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
@@ -1180,26 +1193,27 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"b 4f\n"
"3:\n"
"fmla z16.s, z8.s, z0.s[0]\n"
- "ld1rqw z4.s, p6/z, [%[a_ptr0]]\n"
+ "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z20.s, z8.s, z1.s[0]\n"
- "ld1rqw z5.s, p6/z, [a_ptr1]\n"
+ "ld1rqw z4.s, p6/z, [%[a_ptr0]]\n"
"fmla z24.s, z8.s, z2.s[0]\n"
- "ld1rqw z6.s, p6/z, [a_ptr2]\n"
+ "ld1rqw z5.s, p6/z, [a_ptr1]\n"
"fmla z17.s, z9.s, z0.s[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1rqw z6.s, p6/z, [a_ptr2]\n"
"fmla z21.s, z9.s, z1.s[0]\n"
- "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z25.s, z9.s, z2.s[0]\n"
- "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
+ "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
"fmla z18.s, z10.s, z0.s[0]\n"
- "addvl %[a_ptr0], %[a_ptr0], #1\n"
+ "ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z22.s, z10.s, z1.s[0]\n"
- "addvl a_ptr1, a_ptr1, #1\n"
+ "addvl %[a_ptr0], %[a_ptr0], #1\n"
"fmla z26.s, z10.s, z2.s[0]\n"
"ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
"fmla z19.s, z11.s, z0.s[0]\n"
- "addvl a_ptr2, a_ptr2, #1\n"
+ "addvl a_ptr1, a_ptr1, #1\n"
"fmla z23.s, z11.s, z1.s[0]\n"
+ "addvl a_ptr2, a_ptr2, #1\n"
"fmla z27.s, z11.s, z2.s[0]\n"
"ld1w z11.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z16.s, z12.s, z0.s[1]\n"
@@ -1589,21 +1603,21 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"fmla z31.s, z15.s, z7.s[3]\n"
"b.ne 2b\n"
"1:\n"
- "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"cbz %[regs], 3f\n"
"fmla z16.s, z8.s, z0.s[0]\n"
- "ld1rqw z4.s, p7/z, [%[a_ptr0]]\n"
+ "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z20.s, z8.s, z1.s[0]\n"
- "ld1rqw z5.s, p7/z, [a_ptr1]\n"
+ "ld1rqw z4.s, p7/z, [%[a_ptr0]]\n"
"fmla z24.s, z8.s, z2.s[0]\n"
- "ld1rqw z6.s, p7/z, [a_ptr2]\n"
+ "ld1rqw z5.s, p7/z, [a_ptr1]\n"
"fmla z28.s, z8.s, z3.s[0]\n"
- "ld1rqw z7.s, p7/z, [a_ptr3]\n"
+ "ld1rqw z6.s, p7/z, [a_ptr2]\n"
"fmla z17.s, z9.s, z0.s[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1rqw z7.s, p7/z, [a_ptr3]\n"
"fmla z21.s, z9.s, z1.s[0]\n"
- "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z25.s, z9.s, z2.s[0]\n"
+ "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
"fmla z29.s, z9.s, z3.s[0]\n"
"ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z18.s, z10.s, z0.s[0]\n"
@@ -1832,30 +1846,31 @@ void sve_native_fp32_mla_4VLx4(const float *A, int lda, const float *B, int ldb,
"b 4f\n"
"3:\n"
"fmla z16.s, z8.s, z0.s[0]\n"
- "ld1rqw z4.s, p6/z, [%[a_ptr0]]\n"
+ "ld1w z15.s, p3/z, [%[b_ptr0], #3, MUL VL]\n"
"fmla z20.s, z8.s, z1.s[0]\n"
- "ld1rqw z5.s, p6/z, [a_ptr1]\n"
+ "ld1rqw z4.s, p6/z, [%[a_ptr0]]\n"
"fmla z24.s, z8.s, z2.s[0]\n"
- "ld1rqw z6.s, p6/z, [a_ptr2]\n"
+ "ld1rqw z5.s, p6/z, [a_ptr1]\n"
"fmla z28.s, z8.s, z3.s[0]\n"
- "ld1rqw z7.s, p6/z, [a_ptr3]\n"
+ "ld1rqw z6.s, p6/z, [a_ptr2]\n"
"fmla z17.s, z9.s, z0.s[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1rqw z7.s, p6/z, [a_ptr3]\n"
"fmla z21.s, z9.s, z1.s[0]\n"
- "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"fmla z25.s, z9.s, z2.s[0]\n"
- "addvl %[a_ptr0], %[a_ptr0], #1\n"
+ "ld1w z8.s, p0/z, [%[b_ptr0]]\n"
"fmla z29.s, z9.s, z3.s[0]\n"
"ld1w z9.s, p1/z, [%[b_ptr0], #1, MUL VL]\n"
"fmla z18.s, z10.s, z0.s[0]\n"
- "addvl a_ptr1, a_ptr1, #1\n"
+ "addvl %[a_ptr0], %[a_ptr0], #1\n"
"fmla z22.s, z10.s, z1.s[0]\n"
- "addvl a_ptr2, a_ptr2, #1\n"
+ "addvl a_ptr1, a_ptr1, #1\n"
"fmla z26.s, z10.s, z2.s[0]\n"
- "addvl a_ptr3, a_ptr3, #1\n"
+ "addvl a_ptr2, a_ptr2, #1\n"
"fmla z30.s, z10.s, z3.s[0]\n"
"ld1w z10.s, p2/z, [%[b_ptr0], #2, MUL VL]\n"
"fmla z19.s, z11.s, z0.s[0]\n"
+ "addvl a_ptr3, a_ptr3, #1\n"
"fmla z23.s, z11.s, z1.s[0]\n"
"fmla z27.s, z11.s, z2.s[0]\n"
"fmla z31.s, z11.s, z3.s[0]\n"