aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2018-05-17 18:51:26 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:52:54 +0000
commitce8f60510210efc0cf1c921fac75efc49bc70edc (patch)
treec9f6fb303593198d783639cce25e09ed160e2d0b /src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4
parent2d008a476f5f09f63574990a93e8bf606ae5629e (diff)
downloadComputeLibrary-ce8f60510210efc0cf1c921fac75efc49bc70edc.tar.gz
COMPMID-1177: Improved native GEMM.
Improve the native GEMM so it can cope with any value for M. Also change the selection code so that the native GEMM is selected if M is small and nmulti is large - Winograd needs GEMMs like this and they don't thread properly with the blocked GEMM. (also rename gemm_batched.hpp back to gemv_batched.hpp) Change-Id: I736c33373ada562cbc0c00540520a58103faa9d5 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/131739 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp185
1 files changed, 103 insertions, 82 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
index 8d4a38c36d..2b846c7f10 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp
@@ -23,7 +23,9 @@
*/
#ifdef __aarch64__
+#include <algorithm>
#include <cstddef>
+#include <cstring>
#include <arm_neon.h>
@@ -35,22 +37,35 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
const int beta0 = (beta == 0.0f) ? 1 : 0;
const int oddones = (K % 4);
+ float dummy_buffer[16];
+
+ std::memset(dummy_buffer, 0, sizeof(dummy_buffer));
+
/* For now, very naive with no blocking */
- for(int y = 0; y < M; y += 4)
- {
- for(int x0 = 0; x0 < N; x0 += 16)
- {
- const float *a_ptr0 = A + (y * lda);
- const float *a_ptr1 = a_ptr0 + lda;
- const float *a_ptr2 = a_ptr1 + lda;
- const float *a_ptr3 = a_ptr2 + lda;
+ for (int y=0; y<M; y+=4) {
+ const int activerows = std::min(M-y, 4);
- const float *b_ptr = B + x0;
+ const float * const a_ptr0_base = A + (y * lda);
+ const float * const a_ptr1_base = (activerows > 1) ? (a_ptr0_base + lda) : dummy_buffer;
+ const float * const a_ptr2_base = (activerows > 2) ? (a_ptr1_base + lda) : dummy_buffer;
+ const float * const a_ptr3_base = (activerows > 3) ? (a_ptr2_base + lda) : dummy_buffer;
+
+ const unsigned long a_incr1 = (activerows > 1) ? 32 : 0;
+ const unsigned long a_incr2 = (activerows > 2) ? 32 : 0;
+ const unsigned long a_incr3 = (activerows > 3) ? 32 : 0;
- float *c_ptr0 = C + (y * ldc) + x0;
- float *c_ptr1 = c_ptr0 + ldc;
- float *c_ptr2 = c_ptr1 + ldc;
- float *c_ptr3 = c_ptr2 + ldc;
+ float *c_ptr0 = C + (y * ldc);
+ float *c_ptr1 = (activerows > 1) ? c_ptr0 + ldc : dummy_buffer;
+ float *c_ptr2 = (activerows > 1) ? c_ptr1 + ldc : dummy_buffer;
+ float *c_ptr3 = (activerows > 1) ? c_ptr2 + ldc : dummy_buffer;
+
+ for (int x0=0; x0<N; x0+=16) {
+ const float *a_ptr0 = a_ptr0_base;
+ const float *a_ptr1 = a_ptr1_base;
+ const float *a_ptr2 = a_ptr2_base;
+ const float *a_ptr3 = a_ptr3_base;
+
+ const float *b_ptr = B + x0;
int loops = ((K + 4) / 8) - 1;
int odds = oddones;
@@ -228,34 +243,34 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 2
- "fmla v16.4s, bb0.4s, a0.s[2]\n"
- "fmla v20.4s, bb0.4s, a1.s[2]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v24.4s, bb0.4s, a2.s[2]\n"
- "fmla v28.4s, bb0.4s, a3.s[2]\n"
- "ldr b0q, [%[b_ptr]]\n"
-
- "fmla v17.4s, bb1.4s, a0.s[2]\n"
- "add %[a_ptr0], %[a_ptr0], #32\n"
- "fmla v21.4s, bb1.4s, a1.s[2]\n"
- "add %[a_ptr1], %[a_ptr1], #32\n"
- "fmla v25.4s, bb1.4s, a2.s[2]\n"
- "add %[a_ptr2], %[a_ptr2], #32\n"
- "fmla v29.4s, bb1.4s, a3.s[2]\n"
- "ldr b1q, [%[b_ptr], #16]\n"
-
- "fmla v18.4s, bb2.4s, a0.s[2]\n"
- "add %[a_ptr3], %[a_ptr3], #32\n"
- "fmla v22.4s, bb2.4s, a1.s[2]\n"
- "fmla v26.4s, bb2.4s, a2.s[2]\n"
- "fmla v30.4s, bb2.4s, a3.s[2]\n"
- "ldr b2q, [%[b_ptr], #32]\n"
-
- "fmla v19.4s, bb3.4s, a0.s[2]\n"
- "fmla v23.4s, bb3.4s, a1.s[2]\n"
- "fmla v27.4s, bb3.4s, a2.s[2]\n"
- "fmla v31.4s, bb3.4s, a3.s[2]\n"
- "ldr b3q, [%[b_ptr], #48]\n"
+ "fmla v16.4s, bb0.4s, a0.s[2]\n"
+ "fmla v20.4s, bb0.4s, a1.s[2]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v24.4s, bb0.4s, a2.s[2]\n"
+ "fmla v28.4s, bb0.4s, a3.s[2]\n"
+ "ldr b0q, [%[b_ptr]]\n"
+
+ "fmla v17.4s, bb1.4s, a0.s[2]\n"
+ "add %[a_ptr0], %[a_ptr0], #32\n"
+ "fmla v21.4s, bb1.4s, a1.s[2]\n"
+ "add %[a_ptr1], %[a_ptr1], %[a_incr1]\n"
+ "fmla v25.4s, bb1.4s, a2.s[2]\n"
+ "add %[a_ptr2], %[a_ptr2], %[a_incr2]\n"
+ "fmla v29.4s, bb1.4s, a3.s[2]\n"
+ "ldr b1q, [%[b_ptr], #16]\n"
+
+ "fmla v18.4s, bb2.4s, a0.s[2]\n"
+ "add %[a_ptr3], %[a_ptr3], %[a_incr3]\n"
+ "fmla v22.4s, bb2.4s, a1.s[2]\n"
+ "fmla v26.4s, bb2.4s, a2.s[2]\n"
+ "fmla v30.4s, bb2.4s, a3.s[2]\n"
+ "ldr b2q, [%[b_ptr], #32]\n"
+
+ "fmla v19.4s, bb3.4s, a0.s[2]\n"
+ "fmla v23.4s, bb3.4s, a1.s[2]\n"
+ "fmla v27.4s, bb3.4s, a2.s[2]\n"
+ "fmla v31.4s, bb3.4s, a3.s[2]\n"
+ "ldr b3q, [%[b_ptr], #48]\n"
// Unroll 3
"fmla v16.4s, b0a.4s, a0.s[3]\n"
@@ -427,35 +442,35 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"ldr b3q, [%[b_ptr], #48]\n"
// Unroll 1
- "fmla v16.4s, b0a.4s, a0.s[1]\n"
- "add %[b_ptr], %[b_ptr], %[ldb]\n"
- "fmla v20.4s, b0a.4s, a1.s[1]\n"
- "ldr a3aq, [%[a_ptr3], #16]\n"
- "fmla v24.4s, b0a.4s, a2.s[1]\n"
- "fmla v28.4s, b0a.4s, a3.s[1]\n"
- "ldr b0aq, [%[b_ptr]]\n"
-
- "fmla v17.4s, b1a.4s, a0.s[1]\n"
- "add %[a_ptr0], %[a_ptr0], #32\n"
- "fmla v21.4s, b1a.4s, a1.s[1]\n"
- "add %[a_ptr1], %[a_ptr1], #32\n"
- "fmla v25.4s, b1a.4s, a2.s[1]\n"
- "add %[a_ptr2], %[a_ptr2], #32\n"
- "fmla v29.4s, b1a.4s, a3.s[1]\n"
- "ldr b1aq, [%[b_ptr], #16]\n"
-
- "fmla v18.4s, b2a.4s, a0.s[1]\n"
- "fmla v22.4s, b2a.4s, a1.s[1]\n"
- "add %[a_ptr3], %[a_ptr3], #32\n"
- "fmla v26.4s, b2a.4s, a2.s[1]\n"
- "fmla v30.4s, b2a.4s, a3.s[1]\n"
- "ldr b2aq, [%[b_ptr], #32]\n"
-
- "fmla v19.4s, b3a.4s, a0.s[1]\n"
- "fmla v23.4s, b3a.4s, a1.s[1]\n"
- "fmla v27.4s, b3a.4s, a2.s[1]\n"
- "fmla v31.4s, b3a.4s, a3.s[1]\n"
- "ldr b3aq, [%[b_ptr], #48]\n"
+ "fmla v16.4s, b0a.4s, a0.s[1]\n"
+ "add %[b_ptr], %[b_ptr], %[ldb]\n"
+ "fmla v20.4s, b0a.4s, a1.s[1]\n"
+ "ldr a3aq, [%[a_ptr3], #16]\n"
+ "fmla v24.4s, b0a.4s, a2.s[1]\n"
+ "fmla v28.4s, b0a.4s, a3.s[1]\n"
+ "ldr b0aq, [%[b_ptr]]\n"
+
+ "fmla v17.4s, b1a.4s, a0.s[1]\n"
+ "add %[a_ptr0], %[a_ptr0], #32\n"
+ "fmla v21.4s, b1a.4s, a1.s[1]\n"
+ "add %[a_ptr1], %[a_ptr1], %[a_incr1]\n"
+ "fmla v25.4s, b1a.4s, a2.s[1]\n"
+ "add %[a_ptr2], %[a_ptr2], %[a_incr2]\n"
+ "fmla v29.4s, b1a.4s, a3.s[1]\n"
+ "ldr b1aq, [%[b_ptr], #16]\n"
+
+ "fmla v18.4s, b2a.4s, a0.s[1]\n"
+ "fmla v22.4s, b2a.4s, a1.s[1]\n"
+ "add %[a_ptr3], %[a_ptr3], %[a_incr3]\n"
+ "fmla v26.4s, b2a.4s, a2.s[1]\n"
+ "fmla v30.4s, b2a.4s, a3.s[1]\n"
+ "ldr b2aq, [%[b_ptr], #32]\n"
+
+ "fmla v19.4s, b3a.4s, a0.s[1]\n"
+ "fmla v23.4s, b3a.4s, a1.s[1]\n"
+ "fmla v27.4s, b3a.4s, a2.s[1]\n"
+ "fmla v31.4s, b3a.4s, a3.s[1]\n"
+ "ldr b3aq, [%[b_ptr], #48]\n"
// Unroll 2
"fmla v16.4s, bb0.4s, a0.s[2]\n"
@@ -848,18 +863,24 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo
"str q27, [%[c_ptr2], #48]\n"
"3:\n"
- "str q28, [%[c_ptr3]]\n"
- "str q29, [%[c_ptr3], #16]\n"
- "str q30, [%[c_ptr3], #32]\n"
- "str q31, [%[c_ptr3], #48]\n"
-
- : [a_ptr0] "+r"(a_ptr0), [a_ptr1] "+r"(a_ptr1), [a_ptr2] "+r"(a_ptr2), [a_ptr3] "+r"(a_ptr3),
- [b_ptr] "+r"(b_ptr), [loops] "+r"(loops), [odds] "+r"(odds)
- : [ldb] "r"(ldbb), [oddk] "r"(oddk), [beta0] "r"(beta0), [betaptr] "r"(&beta),
- [c_ptr0] "r"(c_ptr0), [c_ptr1] "r"(c_ptr1), [c_ptr2] "r"(c_ptr2), [c_ptr3] "r"(c_ptr3)
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
- "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
- "cc", "memory");
+ "str q28, [%[c_ptr3]]\n"
+ "add %[c_ptr0], %[c_ptr0], #64\n"
+ "str q29, [%[c_ptr3], #16]\n"
+ "add %[c_ptr1], %[c_ptr1], %[a_incr1], LSL #1\n"
+ "str q30, [%[c_ptr3], #32]\n"
+ "add %[c_ptr2], %[c_ptr2], %[a_incr2], LSL #1\n"
+ "str q31, [%[c_ptr3], #48]\n"
+ "add %[c_ptr3], %[c_ptr3], %[a_incr3], LSL #1\n"
+
+ : [a_ptr0] "+r" (a_ptr0), [a_ptr1] "+r" (a_ptr1), [a_ptr2] "+r" (a_ptr2), [a_ptr3] "+r" (a_ptr3),
+ [b_ptr] "+r" (b_ptr), [loops] "+r" (loops), [odds] "+r" (odds),
+ [c_ptr0] "+r" (c_ptr0), [c_ptr1] "+r" (c_ptr1), [c_ptr2] "+r" (c_ptr2), [c_ptr3] "+r" (c_ptr3)
+ : [ldb] "r" (ldbb), [oddk] "r" (oddk), [beta0] "r" (beta0), [betaptr] "r" (&beta),
+ [a_incr1] "r" (a_incr1), [a_incr2] "r" (a_incr2), [a_incr3] "r" (a_incr3)
+ : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
+ "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
+ "cc", "memory"
+ );
}
}
}