aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp321
1 files changed, 214 insertions, 107 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp
index b48b674621..9f06a48ff5 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp
@@ -35,7 +35,6 @@ namespace arm_gemm {
void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_t *C, int ldc, int M, int N, int K, const int32_t *bias, Activation act, bool append) {
UNUSED(bias);
UNUSED(act);
-
const int K_stride = ((K + 3) / 4) * 4;
const long loops_count = ((K + 16) / 32) - 1;
K -= loops_count * 32;
@@ -76,6 +75,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
switch(M-y) {
case 1:
__asm __volatile (
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
@@ -90,8 +90,25 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
"ldr q13, [%[b_ptr0], #0x50]\n"
"ldr q14, [%[b_ptr0], #0x60]\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
@@ -163,11 +180,11 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
"ldr q12, [%[b_ptr0], #-0x40]\n"
"ldr q13, [%[b_ptr0], #-0x30]\n"
"ldr q14, [%[b_ptr0], #-0x20]\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q4, [%[a_ptr0]]\n"
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
@@ -228,8 +245,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa4e9b1 // sdot v17.4s, v13.16b, v4.4b[3]\n"
".inst 0x4fa4e9d2 // sdot v18.4s, v14.16b, v4.4b[3]\n"
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q8, [%[b_ptr0]]\n"
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
@@ -255,9 +272,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa0e9b1 // sdot v17.4s, v13.16b, v0.4b[3]\n"
".inst 0x4fa0e9d2 // sdot v18.4s, v14.16b, v0.4b[3]\n"
".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -270,17 +287,17 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -289,7 +306,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n"
".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
@@ -304,30 +321,54 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
__asm __volatile (
"a_ptr1 .req X0\n"
"c_ptr1 .req X1\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
- "ldr q8, [%[b_ptr0]]\n"
+ "ldr q1, [a_ptr1]\n"
"movi v18.4s, #0\n"
- "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q8, [%[b_ptr0]]\n"
"movi v19.4s, #0\n"
- "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
"movi v20.4s, #0\n"
- "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
"movi v21.4s, #0\n"
- "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
"movi v22.4s, #0\n"
- "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
"movi v23.4s, #0\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
"ldr q14, [%[b_ptr0], #0x60]\n"
- "add a_ptr1, %[a_ptr0], %[lda]\n"
- "add c_ptr1, %[c_ptr0], %[ldc]\n"
- "ldr q1, [a_ptr1]\n"
"add %[a_ptr0], %[a_ptr0], #0x10\n"
"add a_ptr1, a_ptr1, #0x10\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q20, [c_ptr1]\n"
+ "ldr q21, [c_ptr1, #0x10]\n"
+ "ldr q22, [c_ptr1, #0x20]\n"
+ "ldr q23, [c_ptr1, #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
+ "ldr q1, [a_ptr1]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -435,12 +476,12 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
"ldr q14, [%[b_ptr0], #-0x20]\n"
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
"prfm PSTL1KEEP, [c_ptr1]\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q4, [%[a_ptr0]]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -535,8 +576,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa5e9d6 // sdot v22.4s, v14.16b, v5.4b[3]\n"
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
"ldr q8, [%[b_ptr0]]\n"
@@ -578,9 +619,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa1e9d6 // sdot v22.4s, v14.16b, v1.4b[3]\n"
".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n"
".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -599,20 +640,20 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[0], [a_ptr1], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[1], [a_ptr1], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
"ld1 {v1.b}[2], [a_ptr1]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -625,7 +666,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n"
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
@@ -648,38 +689,68 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
"a_ptr2 .req X1\n"
"c_ptr1 .req X2\n"
"c_ptr2 .req X3\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
- "ldr q8, [%[b_ptr0]]\n"
+ "ldr q1, [a_ptr1]\n"
"movi v18.4s, #0\n"
- "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q2, [a_ptr2]\n"
"movi v19.4s, #0\n"
- "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q8, [%[b_ptr0]]\n"
"movi v20.4s, #0\n"
- "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
"movi v21.4s, #0\n"
- "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
"movi v22.4s, #0\n"
- "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
"movi v23.4s, #0\n"
- "ldr q14, [%[b_ptr0], #0x60]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
"movi v24.4s, #0\n"
- "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
"movi v25.4s, #0\n"
- "ldr q1, [a_ptr1]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
"movi v26.4s, #0\n"
- "add a_ptr2, a_ptr1, %[lda]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
"movi v27.4s, #0\n"
- "ldr q2, [a_ptr2]\n"
- "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q20, [c_ptr1]\n"
+ "ldr q21, [c_ptr1, #0x10]\n"
+ "ldr q22, [c_ptr1, #0x20]\n"
+ "ldr q23, [c_ptr1, #0x30]\n"
+ "ldr q24, [c_ptr2]\n"
+ "ldr q25, [c_ptr2, #0x10]\n"
+ "ldr q26, [c_ptr2, #0x20]\n"
+ "ldr q27, [c_ptr2, #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
"add %[a_ptr0], %[a_ptr0], #0x10\n"
- "add c_ptr2, c_ptr1, %[ldc]\n"
+ "ldr q1, [a_ptr1]\n"
"add a_ptr1, a_ptr1, #0x10\n"
+ "ldr q2, [a_ptr2]\n"
"add a_ptr2, a_ptr2, #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -823,13 +894,13 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
"prfm PSTL1KEEP, [c_ptr1]\n"
"prfm PSTL1KEEP, [c_ptr2]\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q4, [%[a_ptr0]]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -958,8 +1029,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n"
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
".inst 0x4f82e118 // sdot v24.4s, v8.16b, v2.4b[0]\n"
@@ -1017,9 +1088,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n"
".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n"
".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -1044,23 +1115,23 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[0], [a_ptr1], #1\n"
"ld1 {v2.b}[0], [a_ptr2], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[1], [a_ptr1], #1\n"
"ld1 {v2.b}[1], [a_ptr2], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
"ld1 {v1.b}[2], [a_ptr1]\n"
"ld1 {v2.b}[2], [a_ptr2]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -1077,7 +1148,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n"
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"
@@ -1109,46 +1180,82 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
"c_ptr1 .req X3\n"
"c_ptr2 .req X4\n"
"c_ptr3 .req X5\n"
+ "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr2, a_ptr1, %[lda]\n"
+ "add c_ptr2, c_ptr1, %[ldc]\n"
+ "add a_ptr3, a_ptr2, %[lda]\n"
+ "add c_ptr3, c_ptr2, %[ldc]\n"
+ "cbnz %[append], 1f\n"
"movi v16.4s, #0\n"
"ldr q0, [%[a_ptr0]]\n"
"movi v17.4s, #0\n"
- "ldr q8, [%[b_ptr0]]\n"
+ "ldr q1, [a_ptr1]\n"
"movi v18.4s, #0\n"
- "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q2, [a_ptr2]\n"
"movi v19.4s, #0\n"
- "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q3, [a_ptr3]\n"
"movi v20.4s, #0\n"
- "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q8, [%[b_ptr0]]\n"
"movi v21.4s, #0\n"
- "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
"movi v22.4s, #0\n"
- "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
"movi v23.4s, #0\n"
- "ldr q14, [%[b_ptr0], #0x60]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
"movi v24.4s, #0\n"
- "add a_ptr1, %[a_ptr0], %[lda]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
"movi v25.4s, #0\n"
- "ldr q1, [a_ptr1]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
"movi v26.4s, #0\n"
- "add a_ptr2, a_ptr1, %[lda]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
"movi v27.4s, #0\n"
- "ldr q2, [a_ptr2]\n"
+ "add %[a_ptr0], %[a_ptr0], #0x10\n"
"movi v28.4s, #0\n"
- "add a_ptr3, a_ptr2, %[lda]\n"
+ "add a_ptr1, a_ptr1, #0x10\n"
"movi v29.4s, #0\n"
- "ldr q3, [a_ptr3]\n"
+ "add a_ptr2, a_ptr2, #0x10\n"
"movi v30.4s, #0\n"
- "add c_ptr1, %[c_ptr0], %[ldc]\n"
+ "add a_ptr3, a_ptr3, #0x10\n"
"movi v31.4s, #0\n"
- "add c_ptr2, c_ptr1, %[ldc]\n"
+ "add %[b_ptr0], %[b_ptr0], #0x80\n"
+ "cbz %[loops], 2f\n"
+ "b 3f\n"
+ "1:\n"
+ "ldr q16, [%[c_ptr0]]\n"
+ "ldr q17, [%[c_ptr0], #0x10]\n"
+ "ldr q18, [%[c_ptr0], #0x20]\n"
+ "ldr q19, [%[c_ptr0], #0x30]\n"
+ "ldr q20, [c_ptr1]\n"
+ "ldr q21, [c_ptr1, #0x10]\n"
+ "ldr q22, [c_ptr1, #0x20]\n"
+ "ldr q23, [c_ptr1, #0x30]\n"
+ "ldr q24, [c_ptr2]\n"
+ "ldr q25, [c_ptr2, #0x10]\n"
+ "ldr q26, [c_ptr2, #0x20]\n"
+ "ldr q27, [c_ptr2, #0x30]\n"
+ "ldr q28, [c_ptr3]\n"
+ "ldr q29, [c_ptr3, #0x10]\n"
+ "ldr q30, [c_ptr3, #0x20]\n"
+ "ldr q31, [c_ptr3, #0x30]\n"
+ "ldr q0, [%[a_ptr0]]\n"
"add %[a_ptr0], %[a_ptr0], #0x10\n"
- "add c_ptr3, c_ptr2, %[ldc]\n"
+ "ldr q1, [a_ptr1]\n"
"add a_ptr1, a_ptr1, #0x10\n"
+ "ldr q2, [a_ptr2]\n"
"add a_ptr2, a_ptr2, #0x10\n"
+ "ldr q3, [a_ptr3]\n"
"add a_ptr3, a_ptr3, #0x10\n"
+ "ldr q8, [%[b_ptr0]]\n"
+ "ldr q9, [%[b_ptr0], #0x10]\n"
+ "ldr q10, [%[b_ptr0], #0x20]\n"
+ "ldr q11, [%[b_ptr0], #0x30]\n"
+ "ldr q12, [%[b_ptr0], #0x40]\n"
+ "ldr q13, [%[b_ptr0], #0x50]\n"
+ "ldr q14, [%[b_ptr0], #0x60]\n"
"add %[b_ptr0], %[b_ptr0], #0x80\n"
- "cbz %[loops], 1f\n"
- "2:\n"
+ "cbz %[loops], 2f\n"
+ "3:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -1328,14 +1435,14 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n"
".inst 0x4fa7e9ff // sdot v31.4s, v15.16b, v7.4b[3]\n"
- "b.ne 2b\n"
- "1:\n"
+ "b.ne 3b\n"
+ "2:\n"
"ldr q15, [%[b_ptr0], #-0x10]\n"
"prfm PSTL1KEEP, [%[c_ptr0]]\n"
"prfm PSTL1KEEP, [c_ptr1]\n"
"prfm PSTL1KEEP, [c_ptr2]\n"
"prfm PSTL1KEEP, [c_ptr3]\n"
- "cbz %[regs], 3f\n"
+ "cbz %[regs], 4f\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
"ldr q4, [%[a_ptr0]]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
@@ -1498,8 +1605,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n"
".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n"
".inst 0x4fa7e9ff // sdot v31.4s, v15.16b, v7.4b[3]\n"
- "b 4f\n"
- "3:\n"
+ "b 5f\n"
+ "4:\n"
".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n"
".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n"
".inst 0x4f82e118 // sdot v24.4s, v8.16b, v2.4b[0]\n"
@@ -1573,9 +1680,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n"
".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n"
".inst 0x4fa3e9ff // sdot v31.4s, v15.16b, v3.4b[3]\n"
- "4:\n"
- "cbz %[blocks], 5f\n"
- "6:\n"
+ "5:\n"
+ "cbz %[blocks], 6f\n"
+ "7:\n"
"ldr q8, [%[b_ptr0]]\n"
"subs %[blocks], %[blocks], #0x1\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
@@ -1606,26 +1713,26 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n"
- "b.ne 6b\n"
- "5:\n"
- "cbz %[odds], 7f\n"
+ "b.ne 7b\n"
+ "6:\n"
+ "cbz %[odds], 8f\n"
"ld1 {v0.b}[0], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[0], [a_ptr1], #1\n"
"ld1 {v2.b}[0], [a_ptr2], #1\n"
"ld1 {v3.b}[0], [a_ptr3], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[1], [%[a_ptr0]], #1\n"
"ld1 {v1.b}[1], [a_ptr1], #1\n"
"ld1 {v2.b}[1], [a_ptr2], #1\n"
"ld1 {v3.b}[1], [a_ptr3], #1\n"
"subs %[odds], %[odds], #0x1\n"
- "b.eq 8f\n"
+ "b.eq 9f\n"
"ld1 {v0.b}[2], [%[a_ptr0]]\n"
"ld1 {v1.b}[2], [a_ptr1]\n"
"ld1 {v2.b}[2], [a_ptr2]\n"
"ld1 {v3.b}[2], [a_ptr3]\n"
- "8:\n"
+ "9:\n"
"ldr q8, [%[b_ptr0]]\n"
"ldr q9, [%[b_ptr0], #0x10]\n"
"ldr q10, [%[b_ptr0], #0x20]\n"
@@ -1646,7 +1753,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_
".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n"
".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n"
".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n"
- "7:\n"
+ "8:\n"
"str q16, [%[c_ptr0]]\n"
"str q17, [%[c_ptr0], #0x10]\n"
"str q18, [%[c_ptr0], #0x20]\n"