diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp | 185 |
1 files changed, 103 insertions, 82 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp index 8d4a38c36d..2b846c7f10 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp @@ -23,7 +23,9 @@ */ #ifdef __aarch64__ +#include <algorithm> #include <cstddef> +#include <cstring> #include <arm_neon.h> @@ -35,22 +37,35 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo const int beta0 = (beta == 0.0f) ? 1 : 0; const int oddones = (K % 4); + float dummy_buffer[16]; + + std::memset(dummy_buffer, 0, sizeof(dummy_buffer)); + /* For now, very naive with no blocking */ - for(int y = 0; y < M; y += 4) - { - for(int x0 = 0; x0 < N; x0 += 16) - { - const float *a_ptr0 = A + (y * lda); - const float *a_ptr1 = a_ptr0 + lda; - const float *a_ptr2 = a_ptr1 + lda; - const float *a_ptr3 = a_ptr2 + lda; + for (int y=0; y<M; y+=4) { + const int activerows = std::min(M-y, 4); - const float *b_ptr = B + x0; + const float * const a_ptr0_base = A + (y * lda); + const float * const a_ptr1_base = (activerows > 1) ? (a_ptr0_base + lda) : dummy_buffer; + const float * const a_ptr2_base = (activerows > 2) ? (a_ptr1_base + lda) : dummy_buffer; + const float * const a_ptr3_base = (activerows > 3) ? (a_ptr2_base + lda) : dummy_buffer; + + const unsigned long a_incr1 = (activerows > 1) ? 32 : 0; + const unsigned long a_incr2 = (activerows > 2) ? 32 : 0; + const unsigned long a_incr3 = (activerows > 3) ? 32 : 0; - float *c_ptr0 = C + (y * ldc) + x0; - float *c_ptr1 = c_ptr0 + ldc; - float *c_ptr2 = c_ptr1 + ldc; - float *c_ptr3 = c_ptr2 + ldc; + float *c_ptr0 = C + (y * ldc); + float *c_ptr1 = (activerows > 1) ? c_ptr0 + ldc : dummy_buffer; + float *c_ptr2 = (activerows > 1) ? c_ptr1 + ldc : dummy_buffer; + float *c_ptr3 = (activerows > 1) ? c_ptr2 + ldc : dummy_buffer; + + for (int x0=0; x0<N; x0+=16) { + const float *a_ptr0 = a_ptr0_base; + const float *a_ptr1 = a_ptr1_base; + const float *a_ptr2 = a_ptr2_base; + const float *a_ptr3 = a_ptr3_base; + + const float *b_ptr = B + x0; int loops = ((K + 4) / 8) - 1; int odds = oddones; @@ -228,34 +243,34 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "ldr b3aq, [%[b_ptr], #48]\n" // Unroll 2 - "fmla v16.4s, bb0.4s, a0.s[2]\n" - "fmla v20.4s, bb0.4s, a1.s[2]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v24.4s, bb0.4s, a2.s[2]\n" - "fmla v28.4s, bb0.4s, a3.s[2]\n" - "ldr b0q, [%[b_ptr]]\n" - - "fmla v17.4s, bb1.4s, a0.s[2]\n" - "add %[a_ptr0], %[a_ptr0], #32\n" - "fmla v21.4s, bb1.4s, a1.s[2]\n" - "add %[a_ptr1], %[a_ptr1], #32\n" - "fmla v25.4s, bb1.4s, a2.s[2]\n" - "add %[a_ptr2], %[a_ptr2], #32\n" - "fmla v29.4s, bb1.4s, a3.s[2]\n" - "ldr b1q, [%[b_ptr], #16]\n" - - "fmla v18.4s, bb2.4s, a0.s[2]\n" - "add %[a_ptr3], %[a_ptr3], #32\n" - "fmla v22.4s, bb2.4s, a1.s[2]\n" - "fmla v26.4s, bb2.4s, a2.s[2]\n" - "fmla v30.4s, bb2.4s, a3.s[2]\n" - "ldr b2q, [%[b_ptr], #32]\n" - - "fmla v19.4s, bb3.4s, a0.s[2]\n" - "fmla v23.4s, bb3.4s, a1.s[2]\n" - "fmla v27.4s, bb3.4s, a2.s[2]\n" - "fmla v31.4s, bb3.4s, a3.s[2]\n" - "ldr b3q, [%[b_ptr], #48]\n" + "fmla v16.4s, bb0.4s, a0.s[2]\n" + "fmla v20.4s, bb0.4s, a1.s[2]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2.s[2]\n" + "fmla v28.4s, bb0.4s, a3.s[2]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0.s[2]\n" + "add %[a_ptr0], %[a_ptr0], #32\n" + "fmla v21.4s, bb1.4s, a1.s[2]\n" + "add %[a_ptr1], %[a_ptr1], %[a_incr1]\n" + "fmla v25.4s, bb1.4s, a2.s[2]\n" + "add %[a_ptr2], %[a_ptr2], %[a_incr2]\n" + "fmla v29.4s, bb1.4s, a3.s[2]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0.s[2]\n" + "add %[a_ptr3], %[a_ptr3], %[a_incr3]\n" + "fmla v22.4s, bb2.4s, a1.s[2]\n" + "fmla v26.4s, bb2.4s, a2.s[2]\n" + "fmla v30.4s, bb2.4s, a3.s[2]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0.s[2]\n" + "fmla v23.4s, bb3.4s, a1.s[2]\n" + "fmla v27.4s, bb3.4s, a2.s[2]\n" + "fmla v31.4s, bb3.4s, a3.s[2]\n" + "ldr b3q, [%[b_ptr], #48]\n" // Unroll 3 "fmla v16.4s, b0a.4s, a0.s[3]\n" @@ -427,35 +442,35 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "ldr b3q, [%[b_ptr], #48]\n" // Unroll 1 - "fmla v16.4s, b0a.4s, a0.s[1]\n" - "add %[b_ptr], %[b_ptr], %[ldb]\n" - "fmla v20.4s, b0a.4s, a1.s[1]\n" - "ldr a3aq, [%[a_ptr3], #16]\n" - "fmla v24.4s, b0a.4s, a2.s[1]\n" - "fmla v28.4s, b0a.4s, a3.s[1]\n" - "ldr b0aq, [%[b_ptr]]\n" - - "fmla v17.4s, b1a.4s, a0.s[1]\n" - "add %[a_ptr0], %[a_ptr0], #32\n" - "fmla v21.4s, b1a.4s, a1.s[1]\n" - "add %[a_ptr1], %[a_ptr1], #32\n" - "fmla v25.4s, b1a.4s, a2.s[1]\n" - "add %[a_ptr2], %[a_ptr2], #32\n" - "fmla v29.4s, b1a.4s, a3.s[1]\n" - "ldr b1aq, [%[b_ptr], #16]\n" - - "fmla v18.4s, b2a.4s, a0.s[1]\n" - "fmla v22.4s, b2a.4s, a1.s[1]\n" - "add %[a_ptr3], %[a_ptr3], #32\n" - "fmla v26.4s, b2a.4s, a2.s[1]\n" - "fmla v30.4s, b2a.4s, a3.s[1]\n" - "ldr b2aq, [%[b_ptr], #32]\n" - - "fmla v19.4s, b3a.4s, a0.s[1]\n" - "fmla v23.4s, b3a.4s, a1.s[1]\n" - "fmla v27.4s, b3a.4s, a2.s[1]\n" - "fmla v31.4s, b3a.4s, a3.s[1]\n" - "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v16.4s, b0a.4s, a0.s[1]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v20.4s, b0a.4s, a1.s[1]\n" + "ldr a3aq, [%[a_ptr3], #16]\n" + "fmla v24.4s, b0a.4s, a2.s[1]\n" + "fmla v28.4s, b0a.4s, a3.s[1]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #32\n" + "fmla v21.4s, b1a.4s, a1.s[1]\n" + "add %[a_ptr1], %[a_ptr1], %[a_incr1]\n" + "fmla v25.4s, b1a.4s, a2.s[1]\n" + "add %[a_ptr2], %[a_ptr2], %[a_incr2]\n" + "fmla v29.4s, b1a.4s, a3.s[1]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0.s[1]\n" + "fmla v22.4s, b2a.4s, a1.s[1]\n" + "add %[a_ptr3], %[a_ptr3], %[a_incr3]\n" + "fmla v26.4s, b2a.4s, a2.s[1]\n" + "fmla v30.4s, b2a.4s, a3.s[1]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0.s[1]\n" + "fmla v23.4s, b3a.4s, a1.s[1]\n" + "fmla v27.4s, b3a.4s, a2.s[1]\n" + "fmla v31.4s, b3a.4s, a3.s[1]\n" + "ldr b3aq, [%[b_ptr], #48]\n" // Unroll 2 "fmla v16.4s, bb0.4s, a0.s[2]\n" @@ -848,18 +863,24 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "str q27, [%[c_ptr2], #48]\n" "3:\n" - "str q28, [%[c_ptr3]]\n" - "str q29, [%[c_ptr3], #16]\n" - "str q30, [%[c_ptr3], #32]\n" - "str q31, [%[c_ptr3], #48]\n" - - : [a_ptr0] "+r"(a_ptr0), [a_ptr1] "+r"(a_ptr1), [a_ptr2] "+r"(a_ptr2), [a_ptr3] "+r"(a_ptr3), - [b_ptr] "+r"(b_ptr), [loops] "+r"(loops), [odds] "+r"(odds) - : [ldb] "r"(ldbb), [oddk] "r"(oddk), [beta0] "r"(beta0), [betaptr] "r"(&beta), - [c_ptr0] "r"(c_ptr0), [c_ptr1] "r"(c_ptr1), [c_ptr2] "r"(c_ptr2), [c_ptr3] "r"(c_ptr3) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", - "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", - "cc", "memory"); + "str q28, [%[c_ptr3]]\n" + "add %[c_ptr0], %[c_ptr0], #64\n" + "str q29, [%[c_ptr3], #16]\n" + "add %[c_ptr1], %[c_ptr1], %[a_incr1], LSL #1\n" + "str q30, [%[c_ptr3], #32]\n" + "add %[c_ptr2], %[c_ptr2], %[a_incr2], LSL #1\n" + "str q31, [%[c_ptr3], #48]\n" + "add %[c_ptr3], %[c_ptr3], %[a_incr3], LSL #1\n" + + : [a_ptr0] "+r" (a_ptr0), [a_ptr1] "+r" (a_ptr1), [a_ptr2] "+r" (a_ptr2), [a_ptr3] "+r" (a_ptr3), + [b_ptr] "+r" (b_ptr), [loops] "+r" (loops), [odds] "+r" (odds), + [c_ptr0] "+r" (c_ptr0), [c_ptr1] "+r" (c_ptr1), [c_ptr2] "+r" (c_ptr2), [c_ptr3] "+r" (c_ptr3) + : [ldb] "r" (ldbb), [oddk] "r" (oddk), [beta0] "r" (beta0), [betaptr] "r" (&beta), + [a_incr1] "r" (a_incr1), [a_incr2] "r" (a_incr2), [a_incr3] "r" (a_incr3) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "cc", "memory" + ); } } } |