diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-06-27 17:00:52 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-07-26 11:55:15 +0000 |
commit | cfa2bba98169cb5ab1945462514be1b6badf7d98 (patch) | |
tree | 1635e6e9463e9798c7195f0aa71b5df3f2650df1 /src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp | |
parent | f59b16f42ef68bde877b70816ffb953d64c8baa3 (diff) | |
download | ComputeLibrary-cfa2bba98169cb5ab1945462514be1b6badf7d98.tar.gz |
COMPMID-2178: Update GEMM assembly code.
Perform offset reduction and requantization within the assembly wrapper.
Change-Id: I5d5b3e1f6f9ef4c71805362c57f88ff199c027a3
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1541
Comments-Addressed: Pablo Marquez <pablo.tello@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp | 151 |
1 files changed, 148 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp index 64ef9d89a4..e61dbd82ea 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -54,6 +54,7 @@ namespace arm_gemm { void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float beta, int lda, int M, int N) { const float *a_ptr_base = Astart; float *y_ptr = Ystart; + const bool beta0 = (beta == 0.0f); register const float32x4_t vb asm("v1") = vdupq_n_f32(beta); @@ -375,6 +376,7 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl "fmla v25.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x170]\n" "fmla v26.4s, v2.4s, v0.4s\n" + "cbnz %w[beta0], 11f\n" "ldr q2, [%[y_ptr]]\n" "fmla v27.4s, v3.4s, v0.4s\n" "ldr q3, [%[y_ptr], #0x10]\n" @@ -449,13 +451,46 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl "str q26, [%[y_ptr], #0x120]\n" "fmla v31.4s, v7.4s, %[vb].4s\n" "str q27, [%[y_ptr], #0x130]\n" + "b 12f\n" + // beta 0 code - don't read. + "11:\n" + "str q8, [%[y_ptr], #0x00]\n" + "fmla v27.4s, v3.4s, v0.4s\n" + "str q9, [%[y_ptr], #0x10]\n" + "fmla v28.4s, v4.4s, v0.4s\n" + "str q10, [%[y_ptr], #0x20]\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "str q11, [%[y_ptr], #0x30]\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "str q12, [%[y_ptr], #0x40]\n" + "fmla v31.4s, v7.4s, v0.4s\n" + + "str q13, [%[y_ptr], #0x50]\n" + "str q14, [%[y_ptr], #0x60]\n" + "str q15, [%[y_ptr], #0x70]\n" + "str q16, [%[y_ptr], #0x80]\n" + "str q17, [%[y_ptr], #0x90]\n" + "str q18, [%[y_ptr], #0xa0]\n" + "str q19, [%[y_ptr], #0xb0]\n" + "str q20, [%[y_ptr], #0xc0]\n" + "str q21, [%[y_ptr], #0xd0]\n" + "str q22, [%[y_ptr], #0xe0]\n" + "str q23, [%[y_ptr], #0xf0]\n" + "str q24, [%[y_ptr], #0x100]\n" + "str q25, [%[y_ptr], #0x110]\n" + "str q26, [%[y_ptr], #0x120]\n" + "str q27, [%[y_ptr], #0x130]\n" + + "12:\n" "stp q28, q29, [%[y_ptr], #0x140]\n" "stp q30, q31, [%[y_ptr], #0x160]\n" "add %[y_ptr], %[y_ptr], #0x180\n" + + : [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr), [y_ptr] "+r" (y_ptr), [k] "+r" (k), [pf_ptr] "+r" (pf_ptr), [firstpf_ptr] "+r" (firstpf_ptr) - : [jump] "r" (jump), [vb] "w" (vb), [pf_limit] "r" (pf_limit) + : [jump] "r" (jump), [vb] "w" (vb), [pf_limit] "r" (pf_limit), [beta0] "r" (beta0) : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc" @@ -754,6 +789,8 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl // Now write out the outputs "10:\n" + "cbnz %w[beta0], 15f\n" + "cbz %w[numvecs], 12f\n" "mov %w[vecs], %w[numvecs]\n" @@ -908,13 +945,121 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl "ldr s7, [%[y_ptr]]\n" "fmla v5.2s, v7.2s, %[vb].2s\n" "str s5, [%[y_ptr]]\n" + "b 14f\n" + + "15:\n" + // beta0 code + "cbz %w[numvecs], 16f\n" + "mov %w[vecs], %w[numvecs]\n" + + // Vector 0 + "subs %w[vecs], %w[vecs], #1\n" + "str q8, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 1 + "subs %w[vecs], %w[vecs], #1\n" + "str q9, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 2 + "subs %w[vecs], %w[vecs], #1\n" + "str q10, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 3 + "subs %w[vecs], %w[vecs], #1\n" + "str q11, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 4 + "subs %w[vecs], %w[vecs], #1\n" + "str q12, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 5 + "subs %w[vecs], %w[vecs], #1\n" + "str q13, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 6 + "subs %w[vecs], %w[vecs], #1\n" + "str q14, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 7 + "subs %w[vecs], %w[vecs], #1\n" + "str q15, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 8 + "subs %w[vecs], %w[vecs], #1\n" + "str q16, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 9 + "subs %w[vecs], %w[vecs], #1\n" + "str q17, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 10 + "subs %w[vecs], %w[vecs], #1\n" + "str q18, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 11 + "subs %w[vecs], %w[vecs], #1\n" + "str q19, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 12 + "subs %w[vecs], %w[vecs], #1\n" + "str q20, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 13 + "subs %w[vecs], %w[vecs], #1\n" + "str q21, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 14 + "subs %w[vecs], %w[vecs], #1\n" + "str q22, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 15 + "subs %w[vecs], %w[vecs], #1\n" + "str q23, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 16 + "subs %w[vecs], %w[vecs], #1\n" + "str q24, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 17 + "subs %w[vecs], %w[vecs], #1\n" + "str q25, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 18 + "subs %w[vecs], %w[vecs], #1\n" + "str q26, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 19 + "subs %w[vecs], %w[vecs], #1\n" + "str q27, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 20 + "subs %w[vecs], %w[vecs], #1\n" + "str q28, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 21 + "subs %w[vecs], %w[vecs], #1\n" + "str q29, [%[y_ptr]], #0x10\n" + "beq 16f\n" + // Vector 22 + "subs %w[vecs], %w[vecs], #1\n" + "str q30, [%[y_ptr]], #0x10\n" + + // Odd 2 + "16:\n" + "cbz %[odd2_aptr], 17f\n" + "str d6, [%[y_ptr]], #0x8\n" + + // Odd 1 + "17:\n" + "cbz %[odd1_aptr], 14f\n" + "str s5, [%[y_ptr]]\n" "14:\n" : [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr), [y_ptr] "+r" (y_ptr), [k] "+r" (k), [pf_ptr] "+r" (pf_ptr), [firstpf_ptr] "+r" (firstpf_ptr), [odd1_aptr] "+r" (odd1_aptr), [odd2_aptr] "+r" (odd2_aptr), [dopf] "+r" (dopf), [vecs] "+r" (vecs) - : [jump] "r" (jump), [vb] "w" (vb), [pf_limit] "r" (pf_limit), [numvecs] "r" (numvecs) + : [jump] "r" (jump), [vb] "w" (vb), [pf_limit] "r" (pf_limit), [numvecs] "r" (numvecs), [beta0] "r" (beta0) : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc" |