From 71ac9037abce1c6c4af42c485d5395dd6fd79a5a Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 14 Nov 2019 14:31:44 +0000 Subject: COMPMID-2923 Integrate arm_gemm per channel quantization Signed-off-by: Michalis Spyrou Change-Id: I8667e75843fdd6ac75bd8272a86a348b830da28d Reviewed-on: https://review.mlplatform.org/c/2548 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- .../kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp | 321 ++++++++++++++------- 1 file changed, 214 insertions(+), 107 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp') diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp index b48b674621..9f06a48ff5 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_s8s32_dot_16x4/generic.cpp @@ -35,7 +35,6 @@ namespace arm_gemm { void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_t *C, int ldc, int M, int N, int K, const int32_t *bias, Activation act, bool append) { UNUSED(bias); UNUSED(act); - const int K_stride = ((K + 3) / 4) * 4; const long loops_count = ((K + 16) / 32) - 1; K -= loops_count * 32; @@ -76,6 +75,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ switch(M-y) { case 1: __asm __volatile ( + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" @@ -90,8 +90,25 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ "ldr q13, [%[b_ptr0], #0x50]\n" "ldr q14, [%[b_ptr0], #0x60]\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" @@ -163,11 +180,11 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ "ldr q12, [%[b_ptr0], #-0x40]\n" "ldr q13, [%[b_ptr0], #-0x30]\n" "ldr q14, [%[b_ptr0], #-0x20]\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ldr q15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q4, [%[a_ptr0]]\n" ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" @@ -228,8 +245,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa4e9b1 // sdot v17.4s, v13.16b, v4.4b[3]\n" ".inst 0x4fa4e9d2 // sdot v18.4s, v14.16b, v4.4b[3]\n" ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q8, [%[b_ptr0]]\n" ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" @@ -255,9 +272,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa0e9b1 // sdot v17.4s, v13.16b, v0.4b[3]\n" ".inst 0x4fa0e9d2 // sdot v18.4s, v14.16b, v0.4b[3]\n" ".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -270,17 +287,17 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" ".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -289,7 +306,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f80e131 // sdot v17.4s, v9.16b, v0.4b[0]\n" ".inst 0x4f80e152 // sdot v18.4s, v10.16b, v0.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" @@ -304,30 +321,54 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ __asm __volatile ( "a_ptr1 .req X0\n" "c_ptr1 .req X1\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" - "ldr q8, [%[b_ptr0]]\n" + "ldr q1, [a_ptr1]\n" "movi v18.4s, #0\n" - "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q8, [%[b_ptr0]]\n" "movi v19.4s, #0\n" - "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" "movi v20.4s, #0\n" - "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" "movi v21.4s, #0\n" - "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" "movi v22.4s, #0\n" - "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" "movi v23.4s, #0\n" + "ldr q13, [%[b_ptr0], #0x50]\n" "ldr q14, [%[b_ptr0], #0x60]\n" - "add a_ptr1, %[a_ptr0], %[lda]\n" - "add c_ptr1, %[c_ptr0], %[ldc]\n" - "ldr q1, [a_ptr1]\n" "add %[a_ptr0], %[a_ptr0], #0x10\n" "add a_ptr1, a_ptr1, #0x10\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q20, [c_ptr1]\n" + "ldr q21, [c_ptr1, #0x10]\n" + "ldr q22, [c_ptr1, #0x20]\n" + "ldr q23, [c_ptr1, #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" + "ldr q1, [a_ptr1]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -435,12 +476,12 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ "ldr q14, [%[b_ptr0], #-0x20]\n" ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ldr q15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" "prfm PSTL1KEEP, [c_ptr1]\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q4, [%[a_ptr0]]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -535,8 +576,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa5e9d6 // sdot v22.4s, v14.16b, v5.4b[3]\n" ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" "ldr q8, [%[b_ptr0]]\n" @@ -578,9 +619,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa1e9d6 // sdot v22.4s, v14.16b, v1.4b[3]\n" ".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n" ".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -599,20 +640,20 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "ld1 {v1.b}[0], [a_ptr1], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "ld1 {v1.b}[1], [a_ptr1], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" "ld1 {v1.b}[2], [a_ptr1]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -625,7 +666,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f81e156 // sdot v22.4s, v10.16b, v1.4b[0]\n" ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" @@ -648,38 +689,68 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ "a_ptr2 .req X1\n" "c_ptr1 .req X2\n" "c_ptr2 .req X3\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" - "ldr q8, [%[b_ptr0]]\n" + "ldr q1, [a_ptr1]\n" "movi v18.4s, #0\n" - "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q2, [a_ptr2]\n" "movi v19.4s, #0\n" - "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q8, [%[b_ptr0]]\n" "movi v20.4s, #0\n" - "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" "movi v21.4s, #0\n" - "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" "movi v22.4s, #0\n" - "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" "movi v23.4s, #0\n" - "ldr q14, [%[b_ptr0], #0x60]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" "movi v24.4s, #0\n" - "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" "movi v25.4s, #0\n" - "ldr q1, [a_ptr1]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" "movi v26.4s, #0\n" - "add a_ptr2, a_ptr1, %[lda]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" "movi v27.4s, #0\n" - "ldr q2, [a_ptr2]\n" - "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr1, a_ptr1, #0x10\n" + "add a_ptr2, a_ptr2, #0x10\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q20, [c_ptr1]\n" + "ldr q21, [c_ptr1, #0x10]\n" + "ldr q22, [c_ptr1, #0x20]\n" + "ldr q23, [c_ptr1, #0x30]\n" + "ldr q24, [c_ptr2]\n" + "ldr q25, [c_ptr2, #0x10]\n" + "ldr q26, [c_ptr2, #0x20]\n" + "ldr q27, [c_ptr2, #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" "add %[a_ptr0], %[a_ptr0], #0x10\n" - "add c_ptr2, c_ptr1, %[ldc]\n" + "ldr q1, [a_ptr1]\n" "add a_ptr1, a_ptr1, #0x10\n" + "ldr q2, [a_ptr2]\n" "add a_ptr2, a_ptr2, #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -823,13 +894,13 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" ".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ldr q15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" "prfm PSTL1KEEP, [c_ptr1]\n" "prfm PSTL1KEEP, [c_ptr2]\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q4, [%[a_ptr0]]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -958,8 +1029,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa4e9f3 // sdot v19.4s, v15.16b, v4.4b[3]\n" ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" ".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" ".inst 0x4f82e118 // sdot v24.4s, v8.16b, v2.4b[0]\n" @@ -1017,9 +1088,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa0e9f3 // sdot v19.4s, v15.16b, v0.4b[3]\n" ".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n" ".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -1044,23 +1115,23 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "ld1 {v1.b}[0], [a_ptr1], #1\n" "ld1 {v2.b}[0], [a_ptr2], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "ld1 {v1.b}[1], [a_ptr1], #1\n" "ld1 {v2.b}[1], [a_ptr2], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" "ld1 {v1.b}[2], [a_ptr1]\n" "ld1 {v2.b}[2], [a_ptr2]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -1077,7 +1148,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f80e173 // sdot v19.4s, v11.16b, v0.4b[0]\n" ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" @@ -1109,46 +1180,82 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ "c_ptr1 .req X3\n" "c_ptr2 .req X4\n" "c_ptr3 .req X5\n" + "add a_ptr1, %[a_ptr0], %[lda]\n" + "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr2, a_ptr1, %[lda]\n" + "add c_ptr2, c_ptr1, %[ldc]\n" + "add a_ptr3, a_ptr2, %[lda]\n" + "add c_ptr3, c_ptr2, %[ldc]\n" + "cbnz %[append], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" - "ldr q8, [%[b_ptr0]]\n" + "ldr q1, [a_ptr1]\n" "movi v18.4s, #0\n" - "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q2, [a_ptr2]\n" "movi v19.4s, #0\n" - "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q3, [a_ptr3]\n" "movi v20.4s, #0\n" - "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q8, [%[b_ptr0]]\n" "movi v21.4s, #0\n" - "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" "movi v22.4s, #0\n" - "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" "movi v23.4s, #0\n" - "ldr q14, [%[b_ptr0], #0x60]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" "movi v24.4s, #0\n" - "add a_ptr1, %[a_ptr0], %[lda]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" "movi v25.4s, #0\n" - "ldr q1, [a_ptr1]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" "movi v26.4s, #0\n" - "add a_ptr2, a_ptr1, %[lda]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" "movi v27.4s, #0\n" - "ldr q2, [a_ptr2]\n" + "add %[a_ptr0], %[a_ptr0], #0x10\n" "movi v28.4s, #0\n" - "add a_ptr3, a_ptr2, %[lda]\n" + "add a_ptr1, a_ptr1, #0x10\n" "movi v29.4s, #0\n" - "ldr q3, [a_ptr3]\n" + "add a_ptr2, a_ptr2, #0x10\n" "movi v30.4s, #0\n" - "add c_ptr1, %[c_ptr0], %[ldc]\n" + "add a_ptr3, a_ptr3, #0x10\n" "movi v31.4s, #0\n" - "add c_ptr2, c_ptr1, %[ldc]\n" + "add %[b_ptr0], %[b_ptr0], #0x80\n" + "cbz %[loops], 2f\n" + "b 3f\n" + "1:\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #0x10]\n" + "ldr q18, [%[c_ptr0], #0x20]\n" + "ldr q19, [%[c_ptr0], #0x30]\n" + "ldr q20, [c_ptr1]\n" + "ldr q21, [c_ptr1, #0x10]\n" + "ldr q22, [c_ptr1, #0x20]\n" + "ldr q23, [c_ptr1, #0x30]\n" + "ldr q24, [c_ptr2]\n" + "ldr q25, [c_ptr2, #0x10]\n" + "ldr q26, [c_ptr2, #0x20]\n" + "ldr q27, [c_ptr2, #0x30]\n" + "ldr q28, [c_ptr3]\n" + "ldr q29, [c_ptr3, #0x10]\n" + "ldr q30, [c_ptr3, #0x20]\n" + "ldr q31, [c_ptr3, #0x30]\n" + "ldr q0, [%[a_ptr0]]\n" "add %[a_ptr0], %[a_ptr0], #0x10\n" - "add c_ptr3, c_ptr2, %[ldc]\n" + "ldr q1, [a_ptr1]\n" "add a_ptr1, a_ptr1, #0x10\n" + "ldr q2, [a_ptr2]\n" "add a_ptr2, a_ptr2, #0x10\n" + "ldr q3, [a_ptr3]\n" "add a_ptr3, a_ptr3, #0x10\n" + "ldr q8, [%[b_ptr0]]\n" + "ldr q9, [%[b_ptr0], #0x10]\n" + "ldr q10, [%[b_ptr0], #0x20]\n" + "ldr q11, [%[b_ptr0], #0x30]\n" + "ldr q12, [%[b_ptr0], #0x40]\n" + "ldr q13, [%[b_ptr0], #0x50]\n" + "ldr q14, [%[b_ptr0], #0x60]\n" "add %[b_ptr0], %[b_ptr0], #0x80\n" - "cbz %[loops], 1f\n" - "2:\n" + "cbz %[loops], 2f\n" + "3:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q15, [%[b_ptr0], #-0x10]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -1328,14 +1435,14 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" ".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n" ".inst 0x4fa7e9ff // sdot v31.4s, v15.16b, v7.4b[3]\n" - "b.ne 2b\n" - "1:\n" + "b.ne 3b\n" + "2:\n" "ldr q15, [%[b_ptr0], #-0x10]\n" "prfm PSTL1KEEP, [%[c_ptr0]]\n" "prfm PSTL1KEEP, [c_ptr1]\n" "prfm PSTL1KEEP, [c_ptr2]\n" "prfm PSTL1KEEP, [c_ptr3]\n" - "cbz %[regs], 3f\n" + "cbz %[regs], 4f\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" "ldr q4, [%[a_ptr0]]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" @@ -1498,8 +1605,8 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa5e9f7 // sdot v23.4s, v15.16b, v5.4b[3]\n" ".inst 0x4fa6e9fb // sdot v27.4s, v15.16b, v6.4b[3]\n" ".inst 0x4fa7e9ff // sdot v31.4s, v15.16b, v7.4b[3]\n" - "b 4f\n" - "3:\n" + "b 5f\n" + "4:\n" ".inst 0x4f80e110 // sdot v16.4s, v8.16b, v0.4b[0]\n" ".inst 0x4f81e114 // sdot v20.4s, v8.16b, v1.4b[0]\n" ".inst 0x4f82e118 // sdot v24.4s, v8.16b, v2.4b[0]\n" @@ -1573,9 +1680,9 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4fa1e9f7 // sdot v23.4s, v15.16b, v1.4b[3]\n" ".inst 0x4fa2e9fb // sdot v27.4s, v15.16b, v2.4b[3]\n" ".inst 0x4fa3e9ff // sdot v31.4s, v15.16b, v3.4b[3]\n" - "4:\n" - "cbz %[blocks], 5f\n" - "6:\n" + "5:\n" + "cbz %[blocks], 6f\n" + "7:\n" "ldr q8, [%[b_ptr0]]\n" "subs %[blocks], %[blocks], #0x1\n" "ldr q9, [%[b_ptr0], #0x10]\n" @@ -1606,26 +1713,26 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" ".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n" - "b.ne 6b\n" - "5:\n" - "cbz %[odds], 7f\n" + "b.ne 7b\n" + "6:\n" + "cbz %[odds], 8f\n" "ld1 {v0.b}[0], [%[a_ptr0]], #1\n" "ld1 {v1.b}[0], [a_ptr1], #1\n" "ld1 {v2.b}[0], [a_ptr2], #1\n" "ld1 {v3.b}[0], [a_ptr3], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[1], [%[a_ptr0]], #1\n" "ld1 {v1.b}[1], [a_ptr1], #1\n" "ld1 {v2.b}[1], [a_ptr2], #1\n" "ld1 {v3.b}[1], [a_ptr3], #1\n" "subs %[odds], %[odds], #0x1\n" - "b.eq 8f\n" + "b.eq 9f\n" "ld1 {v0.b}[2], [%[a_ptr0]]\n" "ld1 {v1.b}[2], [a_ptr1]\n" "ld1 {v2.b}[2], [a_ptr2]\n" "ld1 {v3.b}[2], [a_ptr3]\n" - "8:\n" + "9:\n" "ldr q8, [%[b_ptr0]]\n" "ldr q9, [%[b_ptr0], #0x10]\n" "ldr q10, [%[b_ptr0], #0x20]\n" @@ -1646,7 +1753,7 @@ void a64_hybrid_s8s32_dot_16x4(const int8_t *A, int lda, const int8_t *B, int32_ ".inst 0x4f81e177 // sdot v23.4s, v11.16b, v1.4b[0]\n" ".inst 0x4f82e17b // sdot v27.4s, v11.16b, v2.4b[0]\n" ".inst 0x4f83e17f // sdot v31.4s, v11.16b, v3.4b[0]\n" - "7:\n" + "8:\n" "str q16, [%[c_ptr0]]\n" "str q17, [%[c_ptr0], #0x10]\n" "str q18, [%[c_ptr0], #0x20]\n" -- cgit v1.2.1