aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp151
1 files changed, 148 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp
index 64ef9d89a4..e61dbd82ea 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,6 +54,7 @@ namespace arm_gemm {
void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float beta, int lda, int M, int N) {
const float *a_ptr_base = Astart;
float *y_ptr = Ystart;
+ const bool beta0 = (beta == 0.0f);
register const float32x4_t vb asm("v1") = vdupq_n_f32(beta);
@@ -375,6 +376,7 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl
"fmla v25.4s, v7.4s, v0.4s\n"
"ldr q7, [%[a_ptr], #0x170]\n"
"fmla v26.4s, v2.4s, v0.4s\n"
+ "cbnz %w[beta0], 11f\n"
"ldr q2, [%[y_ptr]]\n"
"fmla v27.4s, v3.4s, v0.4s\n"
"ldr q3, [%[y_ptr], #0x10]\n"
@@ -449,13 +451,46 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl
"str q26, [%[y_ptr], #0x120]\n"
"fmla v31.4s, v7.4s, %[vb].4s\n"
"str q27, [%[y_ptr], #0x130]\n"
+ "b 12f\n"
+ // beta 0 code - don't read.
+ "11:\n"
+ "str q8, [%[y_ptr], #0x00]\n"
+ "fmla v27.4s, v3.4s, v0.4s\n"
+ "str q9, [%[y_ptr], #0x10]\n"
+ "fmla v28.4s, v4.4s, v0.4s\n"
+ "str q10, [%[y_ptr], #0x20]\n"
+ "fmla v29.4s, v5.4s, v0.4s\n"
+ "str q11, [%[y_ptr], #0x30]\n"
+ "fmla v30.4s, v6.4s, v0.4s\n"
+ "str q12, [%[y_ptr], #0x40]\n"
+ "fmla v31.4s, v7.4s, v0.4s\n"
+
+ "str q13, [%[y_ptr], #0x50]\n"
+ "str q14, [%[y_ptr], #0x60]\n"
+ "str q15, [%[y_ptr], #0x70]\n"
+ "str q16, [%[y_ptr], #0x80]\n"
+ "str q17, [%[y_ptr], #0x90]\n"
+ "str q18, [%[y_ptr], #0xa0]\n"
+ "str q19, [%[y_ptr], #0xb0]\n"
+ "str q20, [%[y_ptr], #0xc0]\n"
+ "str q21, [%[y_ptr], #0xd0]\n"
+ "str q22, [%[y_ptr], #0xe0]\n"
+ "str q23, [%[y_ptr], #0xf0]\n"
+ "str q24, [%[y_ptr], #0x100]\n"
+ "str q25, [%[y_ptr], #0x110]\n"
+ "str q26, [%[y_ptr], #0x120]\n"
+ "str q27, [%[y_ptr], #0x130]\n"
+
+ "12:\n"
"stp q28, q29, [%[y_ptr], #0x140]\n"
"stp q30, q31, [%[y_ptr], #0x160]\n"
"add %[y_ptr], %[y_ptr], #0x180\n"
+
+
: [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr), [y_ptr] "+r" (y_ptr), [k] "+r" (k), [pf_ptr] "+r" (pf_ptr), [firstpf_ptr] "+r" (firstpf_ptr)
- : [jump] "r" (jump), [vb] "w" (vb), [pf_limit] "r" (pf_limit)
+ : [jump] "r" (jump), [vb] "w" (vb), [pf_limit] "r" (pf_limit), [beta0] "r" (beta0)
: "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
"v27", "v28", "v29", "v30", "v31", "cc"
@@ -754,6 +789,8 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl
// Now write out the outputs
"10:\n"
+ "cbnz %w[beta0], 15f\n"
+
"cbz %w[numvecs], 12f\n"
"mov %w[vecs], %w[numvecs]\n"
@@ -908,13 +945,121 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl
"ldr s7, [%[y_ptr]]\n"
"fmla v5.2s, v7.2s, %[vb].2s\n"
"str s5, [%[y_ptr]]\n"
+ "b 14f\n"
+
+ "15:\n"
+ // beta0 code
+ "cbz %w[numvecs], 16f\n"
+ "mov %w[vecs], %w[numvecs]\n"
+
+ // Vector 0
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q8, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 1
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q9, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 2
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q10, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 3
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q11, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 4
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q12, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 5
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q13, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 6
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q14, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 7
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q15, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 8
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q16, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 9
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q17, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 10
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q18, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 11
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q19, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 12
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q20, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 13
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q21, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 14
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q22, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 15
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q23, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 16
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q24, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 17
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q25, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 18
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q26, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 19
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q27, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 20
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q28, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 21
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q29, [%[y_ptr]], #0x10\n"
+ "beq 16f\n"
+ // Vector 22
+ "subs %w[vecs], %w[vecs], #1\n"
+ "str q30, [%[y_ptr]], #0x10\n"
+
+ // Odd 2
+ "16:\n"
+ "cbz %[odd2_aptr], 17f\n"
+ "str d6, [%[y_ptr]], #0x8\n"
+
+ // Odd 1
+ "17:\n"
+ "cbz %[odd1_aptr], 14f\n"
+ "str s5, [%[y_ptr]]\n"
"14:\n"
: [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr), [y_ptr] "+r" (y_ptr), [k] "+r" (k),
[pf_ptr] "+r" (pf_ptr), [firstpf_ptr] "+r" (firstpf_ptr),
[odd1_aptr] "+r" (odd1_aptr), [odd2_aptr] "+r" (odd2_aptr),
[dopf] "+r" (dopf), [vecs] "+r" (vecs)
- : [jump] "r" (jump), [vb] "w" (vb), [pf_limit] "r" (pf_limit), [numvecs] "r" (numvecs)
+ : [jump] "r" (jump), [vb] "w" (vb), [pf_limit] "r" (pf_limit), [numvecs] "r" (numvecs), [beta0] "r" (beta0)
: "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
"v27", "v28", "v29", "v30", "v31", "cc"