aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2018-02-23 13:43:50 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commiteb82fd2aa786715c3b6a941dc6d6deac4ce8e2a0 (patch)
tree42cca378eed97c07348f28e1ec708d9c7ed531ce /src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans
parent8df6c452820719d201ee79596cde8445c2071db5 (diff)
downloadComputeLibrary-eb82fd2aa786715c3b6a941dc6d6deac4ce8e2a0.tar.gz
COMPMID-881: RSH new arm_gemm interface.
Change-Id: I1e2a1a77097d8017c274af3f97eba6964f80f5fa Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122592 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp913
1 files changed, 913 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp
new file mode 100644
index 0000000000..3309baff3a
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp
@@ -0,0 +1,913 @@
+/*
+ * Copyright (c) 2017-2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __aarch64__
+
+#include <cstddef>
+
+#include <arm_neon.h>
+
+#include "../../asmlib.hpp"
+#include "../../utils.hpp"
+
+// Kernel implementation - transposed GEMV
+//
+// The kernel will process "M" rows of A (= steps of dot product) and "N"
+// columns (= dot products total)
+//
+// General plan is to do as many columns simultaneously as possible - a
+// reasonable limit is half the NEON regfile = 64 total accumulators.
+//
+// It's possible that messing around with sub-blocking M and N can yield
+// higher performance, but that's left to the outer loop. In this kernel we
+// process all of M at the same time.
+
+// How far ahead to prefetch for the first and subsequent prefetches.
+// These values work for A72 on JunoR2...
+
+#define FIRST_PFD 9
+#define PFD 6
+
+namespace arm_gemm
+{
+void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float alpha, int lda, int M, int N)
+{
+ const float *a_ptr_base = Astart;
+ float *y_ptr = Ystart;
+
+ register const float32x4_t va asm("v1") = vdupq_n_f32(alpha);
+
+ int firstpfd = FIRST_PFD;
+ if(firstpfd > M)
+ {
+ firstpfd = (M - 1);
+ }
+
+ int pfd = PFD;
+ if(pfd > M)
+ {
+ pfd = (M - 1);
+ }
+
+ ptrdiff_t jump = lda * sizeof(int);
+
+ for(; N >= 96; N -= 96)
+ {
+ int k = M - 1;
+
+ const float *a_ptr = a_ptr_base;
+ const float *x_ptr = Xstart;
+ const float *pf_ptr = a_ptr;
+ const float *firstpf_ptr = a_ptr;
+ const float *pf_limit = a_ptr + (M * lda);
+
+ for(int i = 0; i < firstpfd; i++)
+ {
+ prefetch_1x(firstpf_ptr);
+ firstpf_ptr += lda;
+ }
+
+ for(int i = 0; i < pfd; i++)
+ {
+ prefetch_5x(pf_ptr + 16);
+ pf_ptr += lda;
+ }
+
+ a_ptr_base += 96;
+
+ __asm __volatile(
+ "movi v8.4s,#0x0\n"
+ "ldr w0, [%[x_ptr]]\n"
+ "movi v9.4s,#0x0\n"
+ "ldr q2, [%[a_ptr], #0]\n"
+ "movi v10.4s,#0x0\n"
+ "ldr q3, [%[a_ptr], #0x10]\n"
+ "movi v11.4s,#0x0\n"
+ "ldr q4, [%[a_ptr], #0x20]\n"
+ "movi v12.4s,#0x0\n"
+ "ldr q5, [%[a_ptr], #0x30]\n"
+ "movi v13.4s,#0x0\n"
+ "ldr q6, [%[a_ptr], #0x40]\n"
+ "movi v14.4s,#0x0\n"
+ "ldr q7, [%[a_ptr], #0x50]\n"
+ "movi v15.4s,#0x0\n" ASM_PREFETCH("[%[firstpf_ptr]]")
+ "movi v16.4s, #0x0\n"
+ "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #64]")
+ "movi v18.4s, #0x0\n"
+ "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #128]")
+ "movi v20.4s, #0x0\n"
+ "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #192]")
+ "movi v22.4s, #0x0\n"
+ "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #256]")
+ "movi v24.4s, #0x0\n"
+ "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #320]")
+ "movi v26.4s, #0x0\n"
+ "movi v27.4s, #0x0\n"
+ "add %[pf_ptr], %[pf_ptr], %[jump]\n"
+ "movi v28.4s, #0x0\n"
+ "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n"
+ "movi v29.4s, #0x0\n"
+ "movi v30.4s, #0x0\n"
+ "movi v31.4s, #0x0\n"
+
+ // Skip everything if there are no iterations of the main loop to do.
+ "cbz %w[k], 10f\n"
+
+ // Loop with all prefetches. Exit this loop when firstpf_ptr
+ // hits pf_limit.
+ "1:\n"
+ "dup v0.4s, w0\n"
+ "ldr w0, [%[x_ptr], #4]\n"
+ "add %[x_ptr], %[x_ptr], #0x4\n"
+ "fmla v8.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x60]\n"
+ "fmla v9.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x70]\n" ASM_PREFETCH("[%[firstpf_ptr]]")
+ "fmla v10.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x80]\n"
+ "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n"
+ "fmla v11.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x90]\n"
+ "sub %w[k], %w[k], #1\n" ASM_PREFETCH("[%[x_ptr], #128]")
+ "fmla v12.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0xa0]\n"
+ "fmla v13.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0xb0]\n" ASM_PREFETCH("[%[pf_ptr], #0x40]")
+ "fmla v14.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0xc0]\n"
+ "fmla v15.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0xd0]\n"
+ "fmla v16.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0xe0]\n"
+ "fmla v17.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0xf0]\n" ASM_PREFETCH("[%[pf_ptr], #0x80]")
+ "fmla v18.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x100]\n"
+ "fmla v19.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x110]\n"
+ "fmla v20.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x120]\n"
+ "fmla v21.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x130]\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]")
+ "fmla v22.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x140]\n"
+ "fmla v23.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x150]\n"
+ "fmla v24.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x160]\n"
+ "fmla v25.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x170]\n" ASM_PREFETCH("[%[pf_ptr], #0x100]")
+ "add %[a_ptr], %[a_ptr], %[jump]\n"
+ "fmla v26.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x00]\n"
+ "fmla v27.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x10]\n"
+ "fmla v28.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x20]\n"
+ "fmla v29.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x30]\n" ASM_PREFETCH("[%[pf_ptr], #0x140]")
+ "fmla v30.4s, v6.4s, v0.4s\n"
+ "add %[pf_ptr], %[pf_ptr], %[jump]\n"
+ "ldr q6, [%[a_ptr], #0x40]\n"
+ "fmla v31.4s, v7.4s, v0.4s\n"
+ "cmp %[firstpf_ptr], %[pf_limit]\n"
+ "ldr q7, [%[a_ptr], #0x50]\n"
+ "blt 1b\n"
+
+ // Check that there are still "main" prefetches to do.
+ "cmp %[pf_ptr], %[pf_limit]\n"
+ "bge 9f\n"
+
+ // Just the main prefetches, exit this loop when pf_ptr hits pf_limit.
+ "8:\n"
+ "dup v0.4s, w0\n"
+ "ldr w0, [%[x_ptr], #4]\n"
+ "add %[x_ptr], %[x_ptr], #0x4\n"
+ "fmla v8.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x60]\n"
+ "fmla v9.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x70]\n"
+ "fmla v10.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x80]\n"
+ "fmla v11.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x90]\n"
+ "sub %w[k], %w[k], #1\n" ASM_PREFETCH("[%[x_ptr], #128]")
+ "fmla v12.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0xa0]\n"
+ "fmla v13.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0xb0]\n" ASM_PREFETCH("[%[pf_ptr], #0x40]")
+ "fmla v14.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0xc0]\n"
+ "fmla v15.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0xd0]\n"
+ "fmla v16.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0xe0]\n"
+ "fmla v17.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0xf0]\n" ASM_PREFETCH("[%[pf_ptr], #0x80]")
+ "fmla v18.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x100]\n"
+ "fmla v19.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x110]\n"
+ "fmla v20.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x120]\n"
+ "fmla v21.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x130]\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]")
+ "fmla v22.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x140]\n"
+ "fmla v23.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x150]\n"
+ "fmla v24.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x160]\n"
+ "fmla v25.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x170]\n" ASM_PREFETCH("[%[pf_ptr], #0x100]")
+ "add %[a_ptr], %[a_ptr], %[jump]\n"
+ "fmla v26.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x00]\n"
+ "fmla v27.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x10]\n"
+ "fmla v28.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x20]\n"
+ "fmla v29.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x30]\n" ASM_PREFETCH("[%[pf_ptr], #0x140]")
+ "fmla v30.4s, v6.4s, v0.4s\n"
+ "add %[pf_ptr], %[pf_ptr], %[jump]\n"
+ "ldr q6, [%[a_ptr], #0x40]\n"
+ "fmla v31.4s, v7.4s, v0.4s\n"
+ "cmp %[pf_ptr], %[pf_limit]\n"
+ "ldr q7, [%[a_ptr], #0x50]\n"
+ "blt 8b\n"
+
+ // Check that there is still work to do.
+ "9:\n"
+ "cmp %w[k], #0\n"
+ "beq 10f\n"
+
+ // Loop without prefetches, exit when k hits 0.
+ "2:\n"
+ "dup v0.4s, w0\n"
+ "ldr w0, [%[x_ptr], #4]\n"
+ "add %[x_ptr], %[x_ptr], #0x4\n"
+ "fmla v8.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x60]\n"
+ "fmla v9.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x70]\n"
+ "fmla v10.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x80]\n"
+ "fmla v11.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x90]\n"
+ "subs %w[k], %w[k], #1\n"
+ "fmla v12.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0xa0]\n"
+ "fmla v13.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0xb0]\n"
+ "fmla v14.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0xc0]\n"
+ "fmla v15.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0xd0]\n"
+ "fmla v16.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0xe0]\n"
+ "fmla v17.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0xf0]\n"
+ "fmla v18.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x100]\n"
+ "fmla v19.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x110]\n"
+ "fmla v20.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x120]\n"
+ "fmla v21.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x130]\n"
+ "fmla v22.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x140]\n"
+ "fmla v23.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x150]\n"
+ "fmla v24.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x160]\n"
+ "fmla v25.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x170]\n"
+ "add %[a_ptr], %[a_ptr], %[jump]\n"
+ "fmla v26.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x00]\n"
+ "fmla v27.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x10]\n"
+ "fmla v28.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x20]\n"
+ "fmla v29.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x30]\n"
+ "fmla v30.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x40]\n"
+ "fmla v31.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x50]\n"
+ "bne 2b\n"
+
+ "10:\n"
+
+ // Final iteration
+ "dup v0.4s, w0\n"
+ "fmla v8.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x60]\n"
+ "fmla v9.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x70]\n"
+ "fmla v10.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x80]\n"
+ "fmla v11.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x90]\n"
+ "fmla v12.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0xa0]\n"
+ "fmla v13.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0xb0]\n"
+ "fmla v14.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0xc0]\n"
+ "fmla v15.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0xd0]\n"
+ "fmla v16.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0xe0]\n"
+ "fmla v17.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0xf0]\n"
+ "fmla v18.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x100]\n"
+ "fmla v19.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x110]\n"
+ "fmla v20.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[a_ptr], #0x120]\n"
+ "fmla v21.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[a_ptr], #0x130]\n"
+ "fmla v22.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[a_ptr], #0x140]\n"
+ "fmla v23.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[a_ptr], #0x150]\n"
+ "fmla v24.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[a_ptr], #0x160]\n"
+ "fmla v25.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[a_ptr], #0x170]\n"
+ "fmla v26.4s, v2.4s, v0.4s\n"
+ "ldr q2, [%[y_ptr]]\n"
+ "fmla v27.4s, v3.4s, v0.4s\n"
+ "ldr q3, [%[y_ptr], #0x10]\n"
+ "fmla v28.4s, v4.4s, v0.4s\n"
+ "ldr q4, [%[y_ptr], #0x20]\n"
+ "fmla v29.4s, v5.4s, v0.4s\n"
+ "ldr q5, [%[y_ptr], #0x30]\n"
+ "fmla v30.4s, v6.4s, v0.4s\n"
+ "ldr q6, [%[y_ptr], #0x40]\n"
+ "fmla v31.4s, v7.4s, v0.4s\n"
+ "ldr q7, [%[y_ptr], #0x50]\n"
+
+ "fmla v2.4s, v8.4s, %[va].4s\n"
+ "ldr q8, [%[y_ptr], #0x60]\n"
+ "fmla v3.4s, v9.4s, %[va].4s\n"
+ "ldr q9, [%[y_ptr], #0x70]\n"
+ "fmla v4.4s, v10.4s, %[va].4s\n"
+ "ldr q10, [%[y_ptr], #0x80]\n"
+ "fmla v5.4s, v11.4s, %[va].4s\n"
+ "ldr q11, [%[y_ptr], #0x90]\n"
+ "fmla v6.4s, v12.4s, %[va].4s\n"
+ "ldr q12, [%[y_ptr], #0xa0]\n"
+ "str q2, [%[y_ptr], #0x00]\n"
+ "fmla v7.4s, v13.4s, %[va].4s\n"
+ "ldr q13, [%[y_ptr], #0xb0]\n"
+ "str q3, [%[y_ptr], #0x10]\n"
+ "fmla v8.4s, v14.4s, %[va].4s\n"
+ "ldr q14, [%[y_ptr], #0xc0]\n"
+ "str q4, [%[y_ptr], #0x20]\n"
+ "fmla v9.4s, v15.4s, %[va].4s\n"
+ "ldr q15, [%[y_ptr], #0xd0]\n"
+ "str q5, [%[y_ptr], #0x30]\n"
+ "fmla v10.4s, v16.4s, %[va].4s\n"
+ "ldr q16, [%[y_ptr], #0xe0]\n"
+ "str q6, [%[y_ptr], #0x40]\n"
+ "fmla v11.4s, v17.4s, %[va].4s\n"
+ "ldr q17, [%[y_ptr], #0xf0]\n"
+ "str q7, [%[y_ptr], #0x50]\n"
+ "fmla v12.4s, v18.4s, %[va].4s\n"
+ "ldr q18, [%[y_ptr], #0x100]\n"
+ "str q8, [%[y_ptr], #0x60]\n"
+ "fmla v13.4s, v19.4s, %[va].4s\n"
+ "ldr q19, [%[y_ptr], #0x110]\n"
+ "str q9, [%[y_ptr], #0x70]\n"
+ "fmla v14.4s, v20.4s, %[va].4s\n"
+ "ldr q20, [%[y_ptr], #0x120]\n"
+ "str q10, [%[y_ptr], #0x80]\n"
+ "fmla v15.4s, v21.4s, %[va].4s\n"
+ "ldr q21, [%[y_ptr], #0x130]\n"
+ "str q11, [%[y_ptr], #0x90]\n"
+ "fmla v16.4s, v22.4s, %[va].4s\n"
+ "ldr q22, [%[y_ptr], #0x140]\n"
+ "str q12, [%[y_ptr], #0xa0]\n"
+ "fmla v17.4s, v23.4s, %[va].4s\n"
+ "ldr q23, [%[y_ptr], #0x150]\n"
+ "str q13, [%[y_ptr], #0xb0]\n"
+ "fmla v18.4s, v24.4s, %[va].4s\n"
+ "ldr q24, [%[y_ptr], #0x160]\n"
+ "str q14, [%[y_ptr], #0xc0]\n"
+ "fmla v19.4s, v25.4s, %[va].4s\n"
+ "ldr q25, [%[y_ptr], #0x170]\n"
+ "str q15, [%[y_ptr], #0xd0]\n"
+ "fmla v20.4s, v26.4s, %[va].4s\n"
+ "str q16, [%[y_ptr], #0xe0]\n"
+ "fmla v21.4s, v27.4s, %[va].4s\n"
+ "str q17, [%[y_ptr], #0xf0]\n"
+ "fmla v22.4s, v28.4s, %[va].4s\n"
+ "str q18, [%[y_ptr], #0x100]\n"
+ "fmla v23.4s, v29.4s, %[va].4s\n"
+ "str q19, [%[y_ptr], #0x110]\n"
+ "fmla v24.4s, v30.4s, %[va].4s\n"
+ "str q20, [%[y_ptr], #0x120]\n"
+ "fmla v25.4s, v31.4s, %[va].4s\n"
+ "str q21, [%[y_ptr], #0x130]\n"
+
+ "stp q22, q23, [%[y_ptr], #0x140]\n"
+ "stp q24, q25, [%[y_ptr], #0x160]\n"
+ "add %[y_ptr], %[y_ptr], #0x180\n"
+
+ : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k), [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr)
+ : [jump] "r"(jump), [va] "w"(va), [pf_limit] "r"(pf_limit)
+ : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
+ "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31", "cc");
+ }
+
+ if(N > 0)
+ {
+ // Handle N tail - up to 95 stragglers.
+ // This is 0-23 vectors, plus optionally an 64-bit vector and/or a
+ // single value for the remainder.
+
+ // Independent pointers into the matrix for the odd 2 and odd 1.
+ // Double up as flag to indicate whether they are needed.
+ const float *odd2_aptr = NULL;
+ const float *odd1_aptr = NULL;
+
+ // Figure out how much work we need to do.
+ int numvecs = N / 4;
+ int rem = N % 4;
+ int k = M;
+
+ // Set up pointers for the odd 2/1 if needed.
+ if(rem >= 2)
+ {
+ odd2_aptr = a_ptr_base + (numvecs * 4);
+ }
+
+ if(rem & 1)
+ {
+ odd1_aptr = a_ptr_base + (numvecs * 4) + (odd2_aptr == NULL ? 0 : 2);
+ }
+
+ const float *a_ptr = a_ptr_base;
+ const float *firstpf_ptr = a_ptr_base;
+ const float *pf_ptr = a_ptr_base;
+ const float *pf_limit = a_ptr + (M * lda);
+
+ const float *x_ptr = Xstart;
+ int vecs = 0; // Working variable to count how many vectors to work on.
+ int dopf = 1; // Track whether we are doing prefetches.
+
+ // Figure out how many cache lines we need to prefetch each time.
+ int numpfs = (N + 15) / 16;
+
+ // Do initial prefetches
+ for(int i = 0; i < firstpfd + 1; i++)
+ {
+ prefetch_1x(firstpf_ptr);
+ firstpf_ptr += lda;
+ }
+
+ // Do "main" prefetches - adapt number to the number we actually need.
+ if(numpfs > 1)
+ {
+ for(int i = 0; i < pfd + 1; i++)
+ {
+ switch(numpfs)
+ {
+ case 2:
+ prefetch_1x(pf_ptr + 16);
+ break;
+
+ case 3:
+ prefetch_2x(pf_ptr + 16);
+ break;
+
+ case 4:
+ prefetch_3x(pf_ptr + 16);
+ break;
+
+ case 5:
+ prefetch_4x(pf_ptr + 16);
+ break;
+
+ case 6:
+ prefetch_5x(pf_ptr + 16);
+ break;
+
+ default:
+ UNREACHABLE("Impossible.");
+ }
+ pf_ptr += lda;
+ }
+ }
+ else
+ {
+ // Just disable additional prefetches
+ dopf = 0;
+ }
+
+ // Do the real work
+ __asm __volatile(
+ // Initialize all the vectors - not worth skipping this if only
+ // some are needed.
+ "movi v8.4s,#0x0\n"
+ "ldr w0, [%[x_ptr]]\n"
+ "movi v9.4s,#0x0\n"
+ "movi v10.4s,#0x0\n"
+ "movi v11.4s,#0x0\n"
+ "movi v12.4s,#0x0\n"
+ "movi v13.4s,#0x0\n"
+ "movi v14.4s,#0x0\n"
+ "movi v15.4s,#0x0\n"
+ "movi v16.4s, #0x0\n"
+ "movi v17.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ "movi v19.4s, #0x0\n"
+ "movi v20.4s, #0x0\n"
+ "movi v21.4s, #0x0\n"
+ "movi v22.4s, #0x0\n"
+ "movi v23.4s, #0x0\n"
+ "movi v24.4s, #0x0\n"
+ "movi v25.4s, #0x0\n"
+ "movi v26.4s, #0x0\n"
+ "movi v27.4s, #0x0\n"
+ "movi v28.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ "movi v30.4s, #0x0\n"
+ "movi v6.2s, #0x0\n"
+ "movi v5.2s, #0x0\n"
+
+ "1:\n" ASM_PREFETCH("[%[firstpf_ptr]]\n")
+ "11:\n"
+ "dup v0.4s, w0\n"
+ "ldr w0, [%[x_ptr], #4]\n"
+ "add %[x_ptr], %[x_ptr], #4\n"
+
+ "cbz %w[numvecs], 2f\n"
+ "mov %w[vecs], %w[numvecs]\n"
+
+ // Vector 0
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x00]\n"
+ "fmla v8.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 1
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x10]\n"
+ "fmla v9.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 2
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x20]\n"
+ "fmla v10.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 3
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x30]\n"
+ "fmla v11.4s, v7.4s, v0.4s\n"
+ // Prefetch
+ "cbz %w[dopf], 3f\n" ASM_PREFETCH("[%[pf_ptr], #0x40]")
+ "3:\n"
+ "beq 2f\n"
+
+ // Vector 4
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x40]\n"
+ "fmla v12.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 5
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x50]\n"
+ "fmla v13.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 6
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x60]\n"
+ "fmla v14.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 7
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x70]\n"
+ "fmla v15.4s, v7.4s, v0.4s\n"
+ // Prefetch
+ "cbz %w[dopf], 4f\n" ASM_PREFETCH("[%[pf_ptr], #0x80]")
+ "4:\n"
+ "beq 2f\n"
+
+ // Vector 8
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x80]\n"
+ "fmla v16.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 9
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x90]\n"
+ "fmla v17.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 10
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xa0]\n"
+ "fmla v18.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 11
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xb0]\n"
+ "fmla v19.4s, v7.4s, v0.4s\n"
+ // Prefetch
+ "cbz %w[dopf], 5f\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]")
+ "5:\n"
+ "beq 2f\n"
+
+ // Vector 12
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xc0]\n"
+ "fmla v20.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 13
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xd0]\n"
+ "fmla v21.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 14
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xe0]\n"
+ "fmla v22.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 15
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0xf0]\n"
+ "fmla v23.4s, v7.4s, v0.4s\n"
+ // Prefetch
+ "cbz %w[dopf], 6f\n" ASM_PREFETCH("[%[pf_ptr], #0x100]")
+ "6:\n"
+ "beq 2f\n"
+
+ // Vector 16
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x100]\n"
+ "fmla v24.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 17
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x110]\n"
+ "fmla v25.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 18
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x120]\n"
+ "fmla v26.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 19
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x130]\n"
+ "fmla v27.4s, v7.4s, v0.4s\n"
+ // Prefetch
+ "cbz %w[dopf], 7f\n" ASM_PREFETCH("[%[pf_ptr], #0x140]")
+ "7:\n"
+ "beq 2f\n"
+
+ // Vector 20
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x140]\n"
+ "fmla v28.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 21
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x150]\n"
+ "fmla v29.4s, v7.4s, v0.4s\n"
+ "beq 2f\n"
+ // Vector 22
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7,[%[a_ptr], #0x160]\n"
+ "fmla v30.4s, v7.4s, v0.4s\n"
+
+ "2:\n"
+ "add %[a_ptr], %[a_ptr], %[jump]\n"
+
+ // Do the odd 2-vector, if needed
+ "cbz %[odd2_aptr], 8f\n"
+ "ldr d7, [%[odd2_aptr]]\n"
+ "fmla v6.2s, v7.2s, v0.2s\n"
+ "add %[odd2_aptr], %[odd2_aptr], %[jump]\n"
+
+ "8:\n"
+ // Do the odd 1-vector, if needed
+ "cbz %[odd1_aptr], 9f\n"
+ "ldr s7, [%[odd1_aptr]]\n"
+ "fmla v5.2s, v7.2s, v0.2s\n"
+ "add %[odd1_aptr], %[odd1_aptr], %[jump]\n"
+
+ // Get out if needed.
+ "9:\n"
+ "subs %w[k], %w[k], #1\n"
+ "beq 10f\n"
+
+ // Update the "main" prefetch pointer, if it strays beyond the limit turn off "dopf"
+ "add %[pf_ptr], %[pf_ptr], %[jump]\n"
+ "cmp %[pf_ptr], %[pf_limit]\n"
+ "csel %w[dopf], %w[dopf], WZR, LT\n"
+
+ // Update the "leading" prefetch pointer, don't do the first
+ // instruction of the loop if it's over the limit.
+ "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n"
+ "cmp %[firstpf_ptr], %[pf_limit]\n"
+ "blt 1b\n"
+ "b 11b\n"
+
+ // Now write out the outputs
+ "10:\n"
+ "cbz %w[numvecs], 12f\n"
+ "mov %w[vecs], %w[numvecs]\n"
+
+ // Vector 0
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v8.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 1
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v9.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 2
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v10.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 3
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v11.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 4
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v12.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 5
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v13.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 6
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v14.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 7
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v15.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 8
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v16.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 9
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v17.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 10
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v18.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 11
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v19.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 12
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v20.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 13
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v21.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 14
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v22.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 15
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v23.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 16
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v24.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 17
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v25.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 18
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v26.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 19
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v27.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 20
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v28.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 21
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v29.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+ "beq 12f\n"
+ // Vector 22
+ "subs %w[vecs], %w[vecs], #1\n"
+ "ldr q7, [%[y_ptr]]\n"
+ "fmla v7.4s, v30.4s, %[va].4s\n"
+ "str q7, [%[y_ptr]], #0x10\n"
+
+ // Odd 2
+ "12:\n"
+ "cbz %[odd2_aptr], 13f\n"
+ "ldr d7, [%[y_ptr]]\n"
+ "fmla v7.2s, v6.2s, %[va].2s\n"
+ "str d7, [%[y_ptr]], #0x8\n"
+
+ // Odd 1
+ "13:\n"
+ "cbz %[odd1_aptr], 14f\n"
+ "ldr s7, [%[y_ptr]]\n"
+ "fmla v7.2s, v5.2s, %[va].2s\n"
+ "str s7, [%[y_ptr]]\n"
+
+ "14:\n"
+ : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k),
+ [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr),
+ [odd1_aptr] "+r"(odd1_aptr), [odd2_aptr] "+r"(odd2_aptr),
+ [dopf] "+r"(dopf), [vecs] "+r"(vecs)
+ : [jump] "r"(jump), [va] "w"(va), [pf_limit] "r"(pf_limit), [numvecs] "r"(numvecs)
+ : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
+ "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31", "cc");
+ }
+}
+
+} // namespace arm_gemm
+
+#endif // __aarch64__