aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2018-03-20 16:46:55 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commit99ef8407cd5b27fdec6f8dfaf8b55f820b6dea71 (patch)
tree7d7448ebc71d20c15611076375eb0cbb22f83f5a /src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4
parent2d9de0a3fa6ad858e70040124f362799a962bb6a (diff)
downloadComputeLibrary-99ef8407cd5b27fdec6f8dfaf8b55f820b6dea71.tar.gz
COMPMID-881: Updated arm_gemm to the lastest
Change-Id: Iba2664f33320e79bd15ca9c1399e65e4cc165be6 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125265 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp147
1 files changed, 141 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
index 1b5787ce7c..8d4a38c36d 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
@@ -31,8 +31,9 @@ namespace arm_gemm
{
void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, float *C, int ldc, float beta, int M, int N, int K)
{
- int oddk = (K % 8) ? 1 : 0;
- int beta0 = (beta == 0.0f) ? 1 : 0;
+ const int oddk = ((K % 8) >= 4) ? 1 : 0;
+ const int beta0 = (beta == 0.0f) ? 1 : 0;
+ const int oddones = (K % 4);
/* For now, very naive with no blocking */
for(int y = 0; y < M; y += 4)
@@ -52,6 +53,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
float *c_ptr3 = c_ptr2 + ldc;
int loops = ((K + 4) / 8) - 1;
+ int odds = oddones;
size_t ldbb = ldb * sizeof(float);
@@ -434,14 +436,17 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"ldr b0aq, [%[b_ptr]]\n"
"fmla v17.4s, b1a.4s, a0.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #32\n"
"fmla v21.4s, b1a.4s, a1.s[1]\n"
- "subs %w[loops], %w[loops], #1\n"
+ "add %[a_ptr1], %[a_ptr1], #32\n"
"fmla v25.4s, b1a.4s, a2.s[1]\n"
+ "add %[a_ptr2], %[a_ptr2], #32\n"
"fmla v29.4s, b1a.4s, a3.s[1]\n"
"ldr b1aq, [%[b_ptr], #16]\n"
"fmla v18.4s, b2a.4s, a0.s[1]\n"
"fmla v22.4s, b2a.4s, a1.s[1]\n"
+ "add %[a_ptr3], %[a_ptr3], #32\n"
"fmla v26.4s, b2a.4s, a2.s[1]\n"
"fmla v30.4s, b2a.4s, a3.s[1]\n"
"ldr b2aq, [%[b_ptr], #32]\n"
@@ -488,7 +493,6 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"fmla v17.4s, b1a.4s, a0.s[3]\n"
"fmla v21.4s, b1a.4s, a1.s[3]\n"
- "ldr a3aq, [%[a_ptr3], #16]\n"
"fmla v25.4s, b1a.4s, a2.s[3]\n"
"fmla v29.4s, b1a.4s, a3.s[3]\n"
"ldr b1aq, [%[b_ptr], #16]\n"
@@ -560,6 +564,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
// Unroll 6
"fmla v16.4s, bb0.4s, a0a.s[2]\n"
"fmla v20.4s, bb0.4s, a1a.s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
"fmla v24.4s, bb0.4s, a2a.s[2]\n"
"fmla v28.4s, bb0.4s, a3a.s[2]\n"
@@ -583,6 +588,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"fmla v17.4s, b1a.4s, a0a.s[3]\n"
"fmla v18.4s, b2a.4s, a0a.s[3]\n"
"fmla v19.4s, b3a.4s, a0a.s[3]\n"
+ "cbnz %w[odds], 6f\n"
"fmla v20.4s, b0a.4s, a1a.s[3]\n"
"str q16, [%[c_ptr0]]\n"
@@ -615,12 +621,16 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
// Odd K case: Just do 4 more.
"2:\n"
"fmla v21.4s, bb1.4s, a1.s[0]\n"
+ "add %[a_ptr0], %[a_ptr0], #16\n"
"fmla v25.4s, bb1.4s, a2.s[0]\n"
+ "add %[a_ptr1], %[a_ptr1], #16\n"
"fmla v29.4s, bb1.4s, a3.s[0]\n"
"ldr b1q, [%[b_ptr], #16]\n"
"fmla v18.4s, bb2.4s, a0.s[0]\n"
+ "add %[a_ptr2], %[a_ptr2], #16\n"
"fmla v22.4s, bb2.4s, a1.s[0]\n"
+ "add %[a_ptr3], %[a_ptr3], #16\n"
"fmla v26.4s, bb2.4s, a2.s[0]\n"
"fmla v30.4s, bb2.4s, a3.s[0]\n"
"ldr b2q, [%[b_ptr], #32]\n"
@@ -641,7 +651,6 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"fmla v17.4s, b1a.4s, a0.s[1]\n"
"fmla v21.4s, b1a.4s, a1.s[1]\n"
- "subs %w[loops], %w[loops], #1\n"
"fmla v25.4s, b1a.4s, a2.s[1]\n"
"fmla v29.4s, b1a.4s, a3.s[1]\n"
"ldr b1aq, [%[b_ptr], #16]\n"
@@ -660,6 +669,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
// Unroll 2
"fmla v16.4s, bb0.4s, a0.s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
"fmla v20.4s, bb0.4s, a1.s[2]\n"
"fmla v24.4s, bb0.4s, a2.s[2]\n"
"fmla v28.4s, bb0.4s, a3.s[2]\n"
@@ -684,6 +694,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"fmla v17.4s, b1a.4s, a0.s[3]\n"
"fmla v18.4s, b2a.4s, a0.s[3]\n"
"fmla v19.4s, b3a.4s, a0.s[3]\n"
+ "cbnz %w[odds], 7f\n"
"fmla v20.4s, b0a.4s, a1.s[3]\n"
"str q16, [%[c_ptr0]]\n"
@@ -711,6 +722,130 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"str q26, [%[c_ptr2], #32]\n"
"fmla v31.4s, b3a.4s, a3.s[3]\n"
"str q27, [%[c_ptr2], #48]\n"
+ "b 3f\n"
+
+ // "Odd ones" - lead in from even
+ "6:\n"
+ "fmla v20.4s, b0a.4s, a1a.s[3]\n"
+ "fmla v21.4s, b1a.4s, a1a.s[3]\n"
+ "ldr b0q, [%[b_ptr]]\n"
+ "fmla v22.4s, b2a.4s, a1a.s[3]\n"
+ "subs %w[odds], %w[odds], #1\n"
+ "fmla v23.4s, b3a.4s, a1a.s[3]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
+
+ "fmla v24.4s, b0a.4s, a2a.s[3]\n"
+ "fmla v25.4s, b1a.4s, a2a.s[3]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v26.4s, b2a.4s, a2a.s[3]\n"
+ "fmla v27.4s, b3a.4s, a2a.s[3]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
+
+ "fmla v28.4s, b0a.4s, a3a.s[3]\n"
+ "ld1r {a0.4s}, [%[a_ptr0]], #4\n"
+ "fmla v29.4s, b1a.4s, a3a.s[3]\n"
+ "fmla v30.4s, b2a.4s, a3a.s[3]\n"
+ "ld1r {a1.4s}, [%[a_ptr1]], #4\n"
+ "fmla v31.4s, b3a.4s, a3a.s[3]\n"
+
+ "fmla v16.4s, bb0.4s, a0.4s\n"
+ "beq 9f\n"
+ "b 8f\n"
+
+ // "Odd ones" - lead in from odd
+ "7:\n"
+ "fmla v20.4s, b0a.4s, a1.s[3]\n"
+ "subs %w[odds], %w[odds], #1\n"
+ "fmla v21.4s, b1a.4s, a1.s[3]\n"
+ "ldr b0q, [%[b_ptr]]\n"
+ "fmla v22.4s, b2a.4s, a1.s[3]\n"
+ "fmla v23.4s, b3a.4s, a1.s[3]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
+
+ "fmla v24.4s, b0a.4s, a2.s[3]\n"
+ "fmla v25.4s, b1a.4s, a2.s[3]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v26.4s, b2a.4s, a2.s[3]\n"
+ "fmla v27.4s, b3a.4s, a2.s[3]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
+
+ "fmla v28.4s, b0a.4s, a3.s[3]\n"
+ "ld1r {a0.4s}, [%[a_ptr0]], #4\n"
+ "fmla v29.4s, b1a.4s, a3.s[3]\n"
+ "fmla v30.4s, b2a.4s, a3.s[3]\n"
+ "ld1r {a1.4s}, [%[a_ptr1]], #4\n"
+ "fmla v31.4s, b3a.4s, a3.s[3]\n"
+
+ "fmla v16.4s, bb0.4s, a0.4s\n"
+ "beq 9f\n"
+
+ // "Odd ones" - loop
+ "8:\n"
+ "fmla v17.4s, bb1.4s, a0.4s\n"
+ "ld1r {a2.4s}, [%[a_ptr2]], #4\n"
+ "fmla v18.4s, bb2.4s, a0.4s\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v19.4s, bb3.4s, a0.4s\n"
+ "ld1r {a3.4s}, [%[a_ptr3]], #4\n"
+
+ "fmla v20.4s, bb0.4s, a1.4s\n"
+ "subs %w[odds], %w[odds], #1\n"
+ "fmla v21.4s, bb1.4s, a1.4s\n"
+ "ld1r {a0.4s}, [%[a_ptr0]], #4\n"
+ "fmla v22.4s, bb2.4s, a1.4s\n"
+ "fmla v23.4s, bb3.4s, a1.4s\n"
+ "ld1r {a1.4s}, [%[a_ptr1]], #4\n"
+
+ "fmla v24.4s, bb0.4s, a2.4s\n"
+ "fmla v28.4s, bb0.4s, a3.4s\n"
+ "ldr b0q, [%[b_ptr]]\n"
+ "fmla v25.4s, bb1.4s, a2.4s\n"
+ "fmla v29.4s, bb1.4s, a3.4s\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
+
+ "fmla v26.4s, bb2.4s, a2.4s\n"
+ "fmla v30.4s, bb2.4s, a3.4s\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
+ "fmla v27.4s, bb3.4s, a2.4s\n"
+ "fmla v31.4s, bb3.4s, a3.4s\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v16.4s, bb0.4s, a0.4s\n"
+ "bne 8b\n"
+
+ // "Odd ones" - detached final iteration
+ "9:\n"
+ "fmla v17.4s, bb1.4s, a0.4s\n"
+ "ld1r {a2.4s}, [%[a_ptr2]], #4\n"
+ "fmla v18.4s, bb2.4s, a0.4s\n"
+ "fmla v19.4s, bb3.4s, a0.4s\n"
+ "ld1r {a3.4s}, [%[a_ptr3]], #4\n"
+
+ "fmla v20.4s, bb0.4s, a1.4s\n"
+ "str q16, [%[c_ptr0]]\n"
+ "fmla v21.4s, bb1.4s, a1.4s\n"
+ "str q17, [%[c_ptr0], #16]\n"
+ "fmla v22.4s, bb2.4s, a1.4s\n"
+ "str q18, [%[c_ptr0], #32]\n"
+ "fmla v23.4s, bb3.4s, a1.4s\n"
+ "str q19, [%[c_ptr0], #48]\n"
+
+ "fmla v24.4s, bb0.4s, a2.4s\n"
+ "str q20, [%[c_ptr1]]\n"
+ "fmla v25.4s, bb1.4s, a2.4s\n"
+ "str q21, [%[c_ptr1], #16]\n"
+ "fmla v26.4s, bb2.4s, a2.4s\n"
+ "str q22, [%[c_ptr1], #32]\n"
+ "fmla v27.4s, bb3.4s, a2.4s\n"
+ "str q23, [%[c_ptr1], #48]\n"
+
+ "fmla v28.4s, bb0.4s, a3.4s\n"
+ "str q24, [%[c_ptr2]]\n"
+ "fmla v29.4s, bb1.4s, a3.4s\n"
+ "str q25, [%[c_ptr2], #16]\n"
+ "fmla v30.4s, bb2.4s, a3.4s\n"
+ "str q26, [%[c_ptr2], #32]\n"
+ "fmla v31.4s, bb3.4s, a3.4s\n"
+ "str q27, [%[c_ptr2], #48]\n"
"3:\n"
"str q28, [%[c_ptr3]]\n"
@@ -719,7 +854,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"str q31, [%[c_ptr3], #48]\n"
: [a_ptr0] "+r"(a_ptr0), [a_ptr1] "+r"(a_ptr1), [a_ptr2] "+r"(a_ptr2), [a_ptr3] "+r"(a_ptr3),
- [b_ptr] "+r"(b_ptr), [loops] "+r"(loops)
+ [b_ptr] "+r"(b_ptr), [loops] "+r"(loops), [odds] "+r"(odds)
: [ldb] "r"(ldbb), [oddk] "r"(oddk), [beta0] "r"(beta0), [betaptr] "r"(&beta),
[c_ptr0] "r"(c_ptr0), [c_ptr1] "r"(c_ptr1), [c_ptr2] "r"(c_ptr2), [c_ptr3] "r"(c_ptr3)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",