aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/sve_native_bf16fp32_dot_4VLx4/generic.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/sve_native_bf16fp32_dot_4VLx4/generic.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sve_native_bf16fp32_dot_4VLx4/generic.cpp249
1 files changed, 132 insertions, 117 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_native_bf16fp32_dot_4VLx4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_native_bf16fp32_dot_4VLx4/generic.cpp
index ce1971b2c5..d3bd89b8c5 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/sve_native_bf16fp32_dot_4VLx4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_native_bf16fp32_dot_4VLx4/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -62,12 +62,23 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
break;
}
- for (int y=0; y<M; y+=4) {
+ int rows_to_compute;
+
+ for (int y=0; y<M; y+=rows_to_compute) {
const bfloat16 * const a_ptr0_base = A + (y * lda);
const unsigned long ldab = lda * sizeof(bfloat16);
float *c_ptr0 = C + (y * ldc);
+ rows_to_compute = M-y;
+ if (rows_to_compute > 4) {
+ if (rows_to_compute % 4) {
+ rows_to_compute = 4 - 1;
+ } else {
+ rows_to_compute = 4;
+ }
+ }
+
for (int x0=0; x0<N; x0+=(4 * get_vector_length<float>())) {
const long width = std::min((unsigned long)N-x0, (4 * get_vector_length<float>()));
long loops = loops_count;
@@ -82,7 +93,7 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
const unsigned long ldcb = ldc * sizeof(float);
const float *biasptr = bias ? bias+x0 : nullbias;
- switch(M-y) {
+ switch(rows_to_compute) {
case 1:
__asm __volatile (
"whilelt p6.h, %[temp], %[leftovers]\n"
@@ -235,46 +246,46 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"b.ne 2b\n"
"1:\n"
"zip1 z12.h, z13.h, z14.h\n"
- "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"zip2 z13.h, z13.h, z14.h\n"
"cbz %[regs], 3f\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
+ "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64604173 // bfdot z19.s, z11.h, z0.h[0]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ ".inst 0x64684190 // bfdot z16.s, z12.h, z0.h[1]\n"
"ld1h z9.h, p4/z, [%[b_ptr0]]\n"
"zip1 z14.h, z15.h, z8.h\n"
"ld1h z11.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"zip2 z15.h, z15.h, z8.h\n"
"add %[b_ptr1], %[b_ptr1], %[ldb]\n"
- ".inst 0x64684190 // bfdot z16.s, z12.h, z0.h[1]\n"
- "ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x646841b1 // bfdot z17.s, z13.h, z0.h[1]\n"
- "ld1h z12.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
+ "ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x646841d2 // bfdot z18.s, z14.h, z0.h[1]\n"
+ "ld1h z12.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
+ ".inst 0x646841f3 // bfdot z19.s, z15.h, z0.h[1]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip1 z8.h, z9.h, z10.h\n"
"ld1h z13.h, p4/z, [%[b_ptr0]]\n"
"zip2 z9.h, z9.h, z10.h\n"
- "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"zip1 z10.h, z11.h, z12.h\n"
- "ld1h z14.h, p4/z, [%[b_ptr1]]\n"
+ "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
"zip2 z11.h, z11.h, z12.h\n"
- ".inst 0x646841f3 // bfdot z19.s, z15.h, z0.h[1]\n"
- "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
+ "ld1h z14.h, p4/z, [%[b_ptr1]]\n"
".inst 0x64704110 // bfdot z16.s, z8.h, z0.h[2]\n"
"ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
- "zip1 z12.h, z13.h, z14.h\n"
+ ".inst 0x64704131 // bfdot z17.s, z9.h, z0.h[2]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "zip1 z12.h, z13.h, z14.h\n"
+ "ld1h z9.h, p4/z, [%[b_ptr0]]\n"
"zip2 z13.h, z13.h, z14.h\n"
"add %[b_ptr1], %[b_ptr1], %[ldb]\n"
"zip1 z14.h, z15.h, z8.h\n"
"zip2 z15.h, z15.h, z8.h\n"
- ".inst 0x64704131 // bfdot z17.s, z9.h, z0.h[2]\n"
- "ld1h z9.h, p4/z, [%[b_ptr0]]\n"
".inst 0x64704152 // bfdot z18.s, z10.h, z0.h[2]\n"
"ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x64704173 // bfdot z19.s, z11.h, z0.h[2]\n"
@@ -452,42 +463,43 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"b 7f\n"
"3:\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "ld1rqh z4.h, p6/z, [%[a_ptr0]]\n"
+ "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
+ "ld1rqh z4.h, p6/z, [%[a_ptr0]]\n"
+ ".inst 0x64604173 // bfdot z19.s, z11.h, z0.h[0]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip1 z14.h, z15.h, z8.h\n"
"ld1h z9.h, p4/z, [%[b_ptr0]]\n"
"zip2 z15.h, z15.h, z8.h\n"
- "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
- ".inst 0x64604173 // bfdot z19.s, z11.h, z0.h[0]\n"
- "ld1h z10.h, p4/z, [%[b_ptr1]]\n"
- ".inst 0x64684190 // bfdot z16.s, z12.h, z0.h[1]\n"
"ld1h z11.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
+ ".inst 0x64684190 // bfdot z16.s, z12.h, z0.h[1]\n"
+ "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
".inst 0x646841b1 // bfdot z17.s, z13.h, z0.h[1]\n"
- "ld1h z12.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
+ "ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x646841d2 // bfdot z18.s, z14.h, z0.h[1]\n"
+ "ld1h z12.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
+ ".inst 0x646841f3 // bfdot z19.s, z15.h, z0.h[1]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip1 z8.h, z9.h, z10.h\n"
"ld1h z13.h, p4/z, [%[b_ptr0]]\n"
"zip2 z9.h, z9.h, z10.h\n"
- "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"zip1 z10.h, z11.h, z12.h\n"
- "ld1h z14.h, p4/z, [%[b_ptr1]]\n"
+ "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
"zip2 z11.h, z11.h, z12.h\n"
- "addvl %[a_ptr0], %[a_ptr0], #1\n"
- ".inst 0x646841f3 // bfdot z19.s, z15.h, z0.h[1]\n"
- "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
- "zip1 z12.h, z13.h, z14.h\n"
- "zip2 z13.h, z13.h, z14.h\n"
+ "ld1h z14.h, p4/z, [%[b_ptr1]]\n"
".inst 0x64704110 // bfdot z16.s, z8.h, z0.h[2]\n"
"ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64704131 // bfdot z17.s, z9.h, z0.h[2]\n"
- ".inst 0x64704152 // bfdot z18.s, z10.h, z0.h[2]\n"
- ".inst 0x64704173 // bfdot z19.s, z11.h, z0.h[2]\n"
+ "addvl %[a_ptr0], %[a_ptr0], #1\n"
+ "zip1 z12.h, z13.h, z14.h\n"
+ "zip2 z13.h, z13.h, z14.h\n"
"zip1 z14.h, z15.h, z8.h\n"
"zip2 z15.h, z15.h, z8.h\n"
+ ".inst 0x64704152 // bfdot z18.s, z10.h, z0.h[2]\n"
+ ".inst 0x64704173 // bfdot z19.s, z11.h, z0.h[2]\n"
".inst 0x64784190 // bfdot z16.s, z12.h, z0.h[3]\n"
".inst 0x647841b1 // bfdot z17.s, z13.h, z0.h[3]\n"
".inst 0x647841d2 // bfdot z18.s, z14.h, z0.h[3]\n"
@@ -666,37 +678,37 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"zip2 z11.h, z11.h, z12.h\n"
"ld1h z13.h, p4/z, [%[b_ptr0]]\n"
"ld1h z14.h, p4/z, [%[b_ptr1]]\n"
- "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"mov z23.d, z19.d\n"
"cbz %[loops], 1f\n"
"2:\n"
"zip1 z12.h, z13.h, z14.h\n"
- "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"zip2 z13.h, z13.h, z14.h\n"
- "ld1rqh z5.h, p7/z, [a_ptr1]\n"
+ "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "subs %[loops], %[loops], #0x1\n"
+ "ld1rqh z5.h, p7/z, [a_ptr1]\n"
".inst 0x64614114 // bfdot z20.s, z8.h, z1.h[0]\n"
"ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "subs %[loops], %[loops], #0x1\n"
".inst 0x64614135 // bfdot z21.s, z9.h, z1.h[0]\n"
- "ld1h z9.h, p4/z, [%[b_ptr0]]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip1 z14.h, z15.h, z8.h\n"
- "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
+ "ld1h z9.h, p4/z, [%[b_ptr0]]\n"
"zip2 z15.h, z15.h, z8.h\n"
- "add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
- "add a_ptr1, a_ptr1, #0x20\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
".inst 0x64614156 // bfdot z22.s, z10.h, z1.h[0]\n"
"ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x64604173 // bfdot z19.s, z11.h, z0.h[0]\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
".inst 0x64614177 // bfdot z23.s, z11.h, z1.h[0]\n"
"ld1h z11.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
- ".inst 0x64684190 // bfdot z16.s, z12.h, z0.h[1]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip1 z8.h, z9.h, z10.h\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip2 z9.h, z9.h, z10.h\n"
+ ".inst 0x64684190 // bfdot z16.s, z12.h, z0.h[1]\n"
".inst 0x64694194 // bfdot z20.s, z12.h, z1.h[1]\n"
"ld1h z12.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x646841b1 // bfdot z17.s, z13.h, z0.h[1]\n"
@@ -820,26 +832,26 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"ld1h z14.h, p4/z, [%[b_ptr1]]\n"
".inst 0x647c41f3 // bfdot z19.s, z15.h, z4.h[3]\n"
".inst 0x647d41f7 // bfdot z23.s, z15.h, z5.h[3]\n"
- "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"b.ne 2b\n"
"1:\n"
"zip1 z12.h, z13.h, z14.h\n"
"zip2 z13.h, z13.h, z14.h\n"
"cbz %[regs], 3f\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64614114 // bfdot z20.s, z8.h, z1.h[0]\n"
- "ld1rqh z5.h, p7/z, [a_ptr1]\n"
+ "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
+ "ld1rqh z5.h, p7/z, [a_ptr1]\n"
".inst 0x64614135 // bfdot z21.s, z9.h, z1.h[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ ".inst 0x64614156 // bfdot z22.s, z10.h, z1.h[0]\n"
"ld1h z9.h, p4/z, [%[b_ptr0]]\n"
"zip1 z14.h, z15.h, z8.h\n"
"add %[b_ptr1], %[b_ptr1], %[ldb]\n"
"zip2 z15.h, z15.h, z8.h\n"
- ".inst 0x64614156 // bfdot z22.s, z10.h, z1.h[0]\n"
"ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x64604173 // bfdot z19.s, z11.h, z0.h[0]\n"
".inst 0x64614177 // bfdot z23.s, z11.h, z1.h[0]\n"
@@ -1103,28 +1115,29 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"b 7f\n"
"3:\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "ld1rqh z4.h, p6/z, [%[a_ptr0]]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64614114 // bfdot z20.s, z8.h, z1.h[0]\n"
"ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "ld1rqh z5.h, p6/z, [a_ptr1]\n"
+ "ld1rqh z4.h, p6/z, [%[a_ptr0]]\n"
".inst 0x64614135 // bfdot z21.s, z9.h, z1.h[0]\n"
+ "ld1rqh z5.h, p6/z, [a_ptr1]\n"
+ ".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip1 z14.h, z15.h, z8.h\n"
"ld1h z9.h, p4/z, [%[b_ptr0]]\n"
"zip2 z15.h, z15.h, z8.h\n"
"add %[b_ptr1], %[b_ptr1], %[ldb]\n"
- ".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
- "addvl %[a_ptr0], %[a_ptr0], #1\n"
".inst 0x64614156 // bfdot z22.s, z10.h, z1.h[0]\n"
"ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x64604173 // bfdot z19.s, z11.h, z0.h[0]\n"
- "addvl a_ptr1, a_ptr1, #1\n"
+ "addvl %[a_ptr0], %[a_ptr0], #1\n"
".inst 0x64614177 // bfdot z23.s, z11.h, z1.h[0]\n"
"ld1h z11.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"zip1 z8.h, z9.h, z10.h\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip2 z9.h, z9.h, z10.h\n"
+ "addvl a_ptr1, a_ptr1, #1\n"
".inst 0x64684190 // bfdot z16.s, z12.h, z0.h[1]\n"
".inst 0x64694194 // bfdot z20.s, z12.h, z1.h[1]\n"
"ld1h z12.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
@@ -1386,34 +1399,34 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"mov z27.d, z19.d\n"
"ld1h z13.h, p4/z, [%[b_ptr0]]\n"
- "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"add %[b_ptr1], %[b_ptr1], %[ldb]\n"
"ld1h z14.h, p4/z, [%[b_ptr1]]\n"
"cbz %[loops], 1f\n"
"2:\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"zip1 z12.h, z13.h, z14.h\n"
- "ld1rqh z5.h, p7/z, [a_ptr1]\n"
+ "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
"zip2 z13.h, z13.h, z14.h\n"
- "ld1rqh z6.h, p7/z, [a_ptr2]\n"
+ "ld1rqh z5.h, p7/z, [a_ptr1]\n"
".inst 0x64614114 // bfdot z20.s, z8.h, z1.h[0]\n"
- "subs %[loops], %[loops], #0x1\n"
+ "ld1rqh z6.h, p7/z, [a_ptr2]\n"
".inst 0x64624118 // bfdot z24.s, z8.h, z2.h[0]\n"
"ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "subs %[loops], %[loops], #0x1\n"
".inst 0x64614135 // bfdot z21.s, z9.h, z1.h[0]\n"
- "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip1 z14.h, z15.h, z8.h\n"
- "add %[a_ptr0], %[a_ptr0], #0x20\n"
+ "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
"zip2 z15.h, z15.h, z8.h\n"
- "add a_ptr1, a_ptr1, #0x20\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
".inst 0x64624139 // bfdot z25.s, z9.h, z2.h[0]\n"
"ld1h z9.h, p4/z, [%[b_ptr0]]\n"
".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
- "add a_ptr2, a_ptr2, #0x20\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
".inst 0x64614156 // bfdot z22.s, z10.h, z1.h[0]\n"
+ "add a_ptr2, a_ptr2, #0x20\n"
".inst 0x6462415a // bfdot z26.s, z10.h, z2.h[0]\n"
"ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x64604173 // bfdot z19.s, z11.h, z0.h[0]\n"
@@ -1576,28 +1589,28 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
".inst 0x647c41f3 // bfdot z19.s, z15.h, z4.h[3]\n"
".inst 0x647d41f7 // bfdot z23.s, z15.h, z5.h[3]\n"
".inst 0x647e41fb // bfdot z27.s, z15.h, z6.h[3]\n"
- "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"b.ne 2b\n"
"1:\n"
"zip1 z12.h, z13.h, z14.h\n"
"zip2 z13.h, z13.h, z14.h\n"
"cbz %[regs], 3f\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64614114 // bfdot z20.s, z8.h, z1.h[0]\n"
- "ld1rqh z5.h, p7/z, [a_ptr1]\n"
+ "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
".inst 0x64624118 // bfdot z24.s, z8.h, z2.h[0]\n"
- "ld1rqh z6.h, p7/z, [a_ptr2]\n"
+ "ld1rqh z5.h, p7/z, [a_ptr1]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
+ "ld1rqh z6.h, p7/z, [a_ptr2]\n"
".inst 0x64614135 // bfdot z21.s, z9.h, z1.h[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64624139 // bfdot z25.s, z9.h, z2.h[0]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ ".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
"ld1h z9.h, p4/z, [%[b_ptr0]]\n"
"zip1 z14.h, z15.h, z8.h\n"
"add %[b_ptr1], %[b_ptr1], %[ldb]\n"
"zip2 z15.h, z15.h, z8.h\n"
- ".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
".inst 0x64614156 // bfdot z22.s, z10.h, z1.h[0]\n"
".inst 0x6462415a // bfdot z26.s, z10.h, z2.h[0]\n"
"ld1h z10.h, p4/z, [%[b_ptr1]]\n"
@@ -1922,35 +1935,36 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"b 7f\n"
"3:\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "ld1rqh z4.h, p6/z, [%[a_ptr0]]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64614114 // bfdot z20.s, z8.h, z1.h[0]\n"
- "ld1rqh z5.h, p6/z, [a_ptr1]\n"
+ "ld1rqh z4.h, p6/z, [%[a_ptr0]]\n"
".inst 0x64624118 // bfdot z24.s, z8.h, z2.h[0]\n"
"ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "ld1rqh z6.h, p6/z, [a_ptr2]\n"
+ "ld1rqh z5.h, p6/z, [a_ptr1]\n"
".inst 0x64614135 // bfdot z21.s, z9.h, z1.h[0]\n"
+ "ld1rqh z6.h, p6/z, [a_ptr2]\n"
+ ".inst 0x64624139 // bfdot z25.s, z9.h, z2.h[0]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip1 z14.h, z15.h, z8.h\n"
- "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
- "zip2 z15.h, z15.h, z8.h\n"
- "addvl %[a_ptr0], %[a_ptr0], #1\n"
- ".inst 0x64624139 // bfdot z25.s, z9.h, z2.h[0]\n"
"ld1h z9.h, p4/z, [%[b_ptr0]]\n"
+ "zip2 z15.h, z15.h, z8.h\n"
+ "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
- "addvl a_ptr1, a_ptr1, #1\n"
+ "addvl %[a_ptr0], %[a_ptr0], #1\n"
".inst 0x64614156 // bfdot z22.s, z10.h, z1.h[0]\n"
- "addvl a_ptr2, a_ptr2, #1\n"
+ "addvl a_ptr1, a_ptr1, #1\n"
".inst 0x6462415a // bfdot z26.s, z10.h, z2.h[0]\n"
"ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x64604173 // bfdot z19.s, z11.h, z0.h[0]\n"
+ "addvl a_ptr2, a_ptr2, #1\n"
".inst 0x64614177 // bfdot z23.s, z11.h, z1.h[0]\n"
- ".inst 0x6462417b // bfdot z27.s, z11.h, z2.h[0]\n"
- "ld1h z11.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"zip1 z8.h, z9.h, z10.h\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip2 z9.h, z9.h, z10.h\n"
+ ".inst 0x6462417b // bfdot z27.s, z11.h, z2.h[0]\n"
+ "ld1h z11.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64684190 // bfdot z16.s, z12.h, z0.h[1]\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
".inst 0x64694194 // bfdot z20.s, z12.h, z1.h[1]\n"
".inst 0x646a4198 // bfdot z24.s, z12.h, z2.h[1]\n"
"ld1h z12.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
@@ -2276,7 +2290,6 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"mov z31.d, z19.d\n"
"ld1h z13.h, p4/z, [%[b_ptr0]]\n"
- "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"add %[b_ptr1], %[b_ptr1], %[ldb]\n"
"ld1h z14.h, p4/z, [%[b_ptr1]]\n"
"zip1 z12.h, z13.h, z14.h\n"
@@ -2284,38 +2297,39 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"cbz %[loops], 1f\n"
"2:\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64614114 // bfdot z20.s, z8.h, z1.h[0]\n"
- "ld1rqh z5.h, p7/z, [a_ptr1]\n"
+ "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
".inst 0x64624118 // bfdot z24.s, z8.h, z2.h[0]\n"
- "ld1rqh z6.h, p7/z, [a_ptr2]\n"
+ "ld1rqh z5.h, p7/z, [a_ptr1]\n"
".inst 0x6463411c // bfdot z28.s, z8.h, z3.h[0]\n"
- "ld1rqh z7.h, p7/z, [a_ptr3]\n"
+ "ld1rqh z6.h, p7/z, [a_ptr2]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
+ "ld1rqh z7.h, p7/z, [a_ptr3]\n"
".inst 0x64614135 // bfdot z21.s, z9.h, z1.h[0]\n"
- "subs %[loops], %[loops], #0x1\n"
+ "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64624139 // bfdot z25.s, z9.h, z2.h[0]\n"
+ "subs %[loops], %[loops], #0x1\n"
+ ".inst 0x6463413d // bfdot z29.s, z9.h, z3.h[0]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip1 z14.h, z15.h, z8.h\n"
- "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
- "zip2 z15.h, z15.h, z8.h\n"
- "add %[a_ptr0], %[a_ptr0], #0x20\n"
- ".inst 0x6463413d // bfdot z29.s, z9.h, z3.h[0]\n"
"ld1h z9.h, p4/z, [%[b_ptr0]]\n"
+ "zip2 z15.h, z15.h, z8.h\n"
+ "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
- "add a_ptr1, a_ptr1, #0x20\n"
+ "add %[a_ptr0], %[a_ptr0], #0x20\n"
".inst 0x64614156 // bfdot z22.s, z10.h, z1.h[0]\n"
- "add a_ptr2, a_ptr2, #0x20\n"
+ "add a_ptr1, a_ptr1, #0x20\n"
".inst 0x6462415a // bfdot z26.s, z10.h, z2.h[0]\n"
- "add a_ptr3, a_ptr3, #0x20\n"
+ "add a_ptr2, a_ptr2, #0x20\n"
".inst 0x6463415e // bfdot z30.s, z10.h, z3.h[0]\n"
"ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x64604173 // bfdot z19.s, z11.h, z0.h[0]\n"
+ "add a_ptr3, a_ptr3, #0x20\n"
".inst 0x64614177 // bfdot z23.s, z11.h, z1.h[0]\n"
- ".inst 0x6462417b // bfdot z27.s, z11.h, z2.h[0]\n"
"zip1 z8.h, z9.h, z10.h\n"
"zip2 z9.h, z9.h, z10.h\n"
+ ".inst 0x6462417b // bfdot z27.s, z11.h, z2.h[0]\n"
".inst 0x6463417f // bfdot z31.s, z11.h, z3.h[0]\n"
"ld1h z11.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64684190 // bfdot z16.s, z12.h, z0.h[1]\n"
@@ -2503,28 +2517,28 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"zip1 z12.h, z13.h, z14.h\n"
"zip2 z13.h, z13.h, z14.h\n"
".inst 0x647f41ff // bfdot z31.s, z15.h, z7.h[3]\n"
- "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
"b.ne 2b\n"
"1:\n"
"cbz %[regs], 3f\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64614114 // bfdot z20.s, z8.h, z1.h[0]\n"
- "ld1rqh z5.h, p7/z, [a_ptr1]\n"
+ "ld1rqh z4.h, p7/z, [%[a_ptr0]]\n"
".inst 0x64624118 // bfdot z24.s, z8.h, z2.h[0]\n"
- "ld1rqh z6.h, p7/z, [a_ptr2]\n"
+ "ld1rqh z5.h, p7/z, [a_ptr1]\n"
".inst 0x6463411c // bfdot z28.s, z8.h, z3.h[0]\n"
- "ld1rqh z7.h, p7/z, [a_ptr3]\n"
+ "ld1rqh z6.h, p7/z, [a_ptr2]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
+ "ld1rqh z7.h, p7/z, [a_ptr3]\n"
".inst 0x64614135 // bfdot z21.s, z9.h, z1.h[0]\n"
- "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
+ "ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64624139 // bfdot z25.s, z9.h, z2.h[0]\n"
- "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
- "zip1 z14.h, z15.h, z8.h\n"
- "zip2 z15.h, z15.h, z8.h\n"
+ "add %[b_ptr0], %[b_ptr0], %[ldb]\n"
".inst 0x6463413d // bfdot z29.s, z9.h, z3.h[0]\n"
"ld1h z9.h, p4/z, [%[b_ptr0]]\n"
+ "zip1 z14.h, z15.h, z8.h\n"
+ "add %[b_ptr1], %[b_ptr1], %[ldb]\n"
+ "zip2 z15.h, z15.h, z8.h\n"
".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
".inst 0x64614156 // bfdot z22.s, z10.h, z1.h[0]\n"
".inst 0x6462415a // bfdot z26.s, z10.h, z2.h[0]\n"
@@ -2910,30 +2924,31 @@ void sve_native_bf16fp32_dot_4VLx4(const bfloat16 *A, int lda, const bfloat16 *B
"b 7f\n"
"3:\n"
".inst 0x64604110 // bfdot z16.s, z8.h, z0.h[0]\n"
- "ld1rqh z4.h, p6/z, [%[a_ptr0]]\n"
+ "ld1h z15.h, p5/z, [%[b_ptr0], #1, MUL VL]\n"
".inst 0x64614114 // bfdot z20.s, z8.h, z1.h[0]\n"
- "ld1rqh z5.h, p6/z, [a_ptr1]\n"
+ "ld1rqh z4.h, p6/z, [%[a_ptr0]]\n"
".inst 0x64624118 // bfdot z24.s, z8.h, z2.h[0]\n"
- "ld1rqh z6.h, p6/z, [a_ptr2]\n"
+ "ld1rqh z5.h, p6/z, [a_ptr1]\n"
".inst 0x6463411c // bfdot z28.s, z8.h, z3.h[0]\n"
"ld1h z8.h, p5/z, [%[b_ptr1], #1, MUL VL]\n"
".inst 0x64604131 // bfdot z17.s, z9.h, z0.h[0]\n"
- "ld1rqh z7.h, p6/z, [a_ptr3]\n"
+ "ld1rqh z6.h, p6/z, [a_ptr2]\n"
".inst 0x64614135 // bfdot z21.s, z9.h, z1.h[0]\n"
+ "ld1rqh z7.h, p6/z, [a_ptr3]\n"
+ ".inst 0x64624139 // bfdot z25.s, z9.h, z2.h[0]\n"
"add %[b_ptr0], %[b_ptr0], %[ldb]\n"
"zip1 z14.h, z15.h, z8.h\n"
"add %[b_ptr1], %[b_ptr1], %[ldb]\n"
"zip2 z15.h, z15.h, z8.h\n"
"addvl %[a_ptr0], %[a_ptr0], #1\n"
- ".inst 0x64624139 // bfdot z25.s, z9.h, z2.h[0]\n"
- "addvl a_ptr1, a_ptr1, #1\n"
".inst 0x6463413d // bfdot z29.s, z9.h, z3.h[0]\n"
"ld1h z9.h, p4/z, [%[b_ptr0]]\n"
".inst 0x64604152 // bfdot z18.s, z10.h, z0.h[0]\n"
- "addvl a_ptr2, a_ptr2, #1\n"
+ "addvl a_ptr1, a_ptr1, #1\n"
".inst 0x64614156 // bfdot z22.s, z10.h, z1.h[0]\n"
- "addvl a_ptr3, a_ptr3, #1\n"
+ "addvl a_ptr2, a_ptr2, #1\n"
".inst 0x6462415a // bfdot z26.s, z10.h, z2.h[0]\n"
+ "addvl a_ptr3, a_ptr3, #1\n"
".inst 0x6463415e // bfdot z30.s, z10.h, z3.h[0]\n"
"ld1h z10.h, p4/z, [%[b_ptr1]]\n"
".inst 0x64604173 // bfdot z19.s, z11.h, z0.h[0]\n"