diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp | 913 |
1 files changed, 913 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp new file mode 100644 index 0000000000..3309baff3a --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp @@ -0,0 +1,913 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <cstddef> + +#include <arm_neon.h> + +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +// Kernel implementation - transposed GEMV +// +// The kernel will process "M" rows of A (= steps of dot product) and "N" +// columns (= dot products total) +// +// General plan is to do as many columns simultaneously as possible - a +// reasonable limit is half the NEON regfile = 64 total accumulators. +// +// It's possible that messing around with sub-blocking M and N can yield +// higher performance, but that's left to the outer loop. In this kernel we +// process all of M at the same time. + +// How far ahead to prefetch for the first and subsequent prefetches. +// These values work for A72 on JunoR2... + +#define FIRST_PFD 9 +#define PFD 6 + +namespace arm_gemm +{ +void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float alpha, int lda, int M, int N) +{ + const float *a_ptr_base = Astart; + float *y_ptr = Ystart; + + register const float32x4_t va asm("v1") = vdupq_n_f32(alpha); + + int firstpfd = FIRST_PFD; + if(firstpfd > M) + { + firstpfd = (M - 1); + } + + int pfd = PFD; + if(pfd > M) + { + pfd = (M - 1); + } + + ptrdiff_t jump = lda * sizeof(int); + + for(; N >= 96; N -= 96) + { + int k = M - 1; + + const float *a_ptr = a_ptr_base; + const float *x_ptr = Xstart; + const float *pf_ptr = a_ptr; + const float *firstpf_ptr = a_ptr; + const float *pf_limit = a_ptr + (M * lda); + + for(int i = 0; i < firstpfd; i++) + { + prefetch_1x(firstpf_ptr); + firstpf_ptr += lda; + } + + for(int i = 0; i < pfd; i++) + { + prefetch_5x(pf_ptr + 16); + pf_ptr += lda; + } + + a_ptr_base += 96; + + __asm __volatile( + "movi v8.4s,#0x0\n" + "ldr w0, [%[x_ptr]]\n" + "movi v9.4s,#0x0\n" + "ldr q2, [%[a_ptr], #0]\n" + "movi v10.4s,#0x0\n" + "ldr q3, [%[a_ptr], #0x10]\n" + "movi v11.4s,#0x0\n" + "ldr q4, [%[a_ptr], #0x20]\n" + "movi v12.4s,#0x0\n" + "ldr q5, [%[a_ptr], #0x30]\n" + "movi v13.4s,#0x0\n" + "ldr q6, [%[a_ptr], #0x40]\n" + "movi v14.4s,#0x0\n" + "ldr q7, [%[a_ptr], #0x50]\n" + "movi v15.4s,#0x0\n" ASM_PREFETCH("[%[firstpf_ptr]]") + "movi v16.4s, #0x0\n" + "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #64]") + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #128]") + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #192]") + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #256]") + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #320]") + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "add %[pf_ptr], %[pf_ptr], %[jump]\n" + "movi v28.4s, #0x0\n" + "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + + // Skip everything if there are no iterations of the main loop to do. + "cbz %w[k], 10f\n" + + // Loop with all prefetches. Exit this loop when firstpf_ptr + // hits pf_limit. + "1:\n" + "dup v0.4s, w0\n" + "ldr w0, [%[x_ptr], #4]\n" + "add %[x_ptr], %[x_ptr], #0x4\n" + "fmla v8.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x60]\n" + "fmla v9.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x70]\n" ASM_PREFETCH("[%[firstpf_ptr]]") + "fmla v10.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x80]\n" + "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n" + "fmla v11.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x90]\n" + "sub %w[k], %w[k], #1\n" ASM_PREFETCH("[%[x_ptr], #128]") + "fmla v12.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0xa0]\n" + "fmla v13.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0xb0]\n" ASM_PREFETCH("[%[pf_ptr], #0x40]") + "fmla v14.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0xc0]\n" + "fmla v15.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0xd0]\n" + "fmla v16.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0xe0]\n" + "fmla v17.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0xf0]\n" ASM_PREFETCH("[%[pf_ptr], #0x80]") + "fmla v18.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x100]\n" + "fmla v19.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x110]\n" + "fmla v20.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x120]\n" + "fmla v21.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x130]\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]") + "fmla v22.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x140]\n" + "fmla v23.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x150]\n" + "fmla v24.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x160]\n" + "fmla v25.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x170]\n" ASM_PREFETCH("[%[pf_ptr], #0x100]") + "add %[a_ptr], %[a_ptr], %[jump]\n" + "fmla v26.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x00]\n" + "fmla v27.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x10]\n" + "fmla v28.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x20]\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x30]\n" ASM_PREFETCH("[%[pf_ptr], #0x140]") + "fmla v30.4s, v6.4s, v0.4s\n" + "add %[pf_ptr], %[pf_ptr], %[jump]\n" + "ldr q6, [%[a_ptr], #0x40]\n" + "fmla v31.4s, v7.4s, v0.4s\n" + "cmp %[firstpf_ptr], %[pf_limit]\n" + "ldr q7, [%[a_ptr], #0x50]\n" + "blt 1b\n" + + // Check that there are still "main" prefetches to do. + "cmp %[pf_ptr], %[pf_limit]\n" + "bge 9f\n" + + // Just the main prefetches, exit this loop when pf_ptr hits pf_limit. + "8:\n" + "dup v0.4s, w0\n" + "ldr w0, [%[x_ptr], #4]\n" + "add %[x_ptr], %[x_ptr], #0x4\n" + "fmla v8.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x60]\n" + "fmla v9.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x70]\n" + "fmla v10.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x80]\n" + "fmla v11.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x90]\n" + "sub %w[k], %w[k], #1\n" ASM_PREFETCH("[%[x_ptr], #128]") + "fmla v12.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0xa0]\n" + "fmla v13.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0xb0]\n" ASM_PREFETCH("[%[pf_ptr], #0x40]") + "fmla v14.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0xc0]\n" + "fmla v15.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0xd0]\n" + "fmla v16.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0xe0]\n" + "fmla v17.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0xf0]\n" ASM_PREFETCH("[%[pf_ptr], #0x80]") + "fmla v18.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x100]\n" + "fmla v19.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x110]\n" + "fmla v20.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x120]\n" + "fmla v21.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x130]\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]") + "fmla v22.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x140]\n" + "fmla v23.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x150]\n" + "fmla v24.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x160]\n" + "fmla v25.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x170]\n" ASM_PREFETCH("[%[pf_ptr], #0x100]") + "add %[a_ptr], %[a_ptr], %[jump]\n" + "fmla v26.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x00]\n" + "fmla v27.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x10]\n" + "fmla v28.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x20]\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x30]\n" ASM_PREFETCH("[%[pf_ptr], #0x140]") + "fmla v30.4s, v6.4s, v0.4s\n" + "add %[pf_ptr], %[pf_ptr], %[jump]\n" + "ldr q6, [%[a_ptr], #0x40]\n" + "fmla v31.4s, v7.4s, v0.4s\n" + "cmp %[pf_ptr], %[pf_limit]\n" + "ldr q7, [%[a_ptr], #0x50]\n" + "blt 8b\n" + + // Check that there is still work to do. + "9:\n" + "cmp %w[k], #0\n" + "beq 10f\n" + + // Loop without prefetches, exit when k hits 0. + "2:\n" + "dup v0.4s, w0\n" + "ldr w0, [%[x_ptr], #4]\n" + "add %[x_ptr], %[x_ptr], #0x4\n" + "fmla v8.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x60]\n" + "fmla v9.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x70]\n" + "fmla v10.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x80]\n" + "fmla v11.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x90]\n" + "subs %w[k], %w[k], #1\n" + "fmla v12.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0xa0]\n" + "fmla v13.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0xb0]\n" + "fmla v14.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0xc0]\n" + "fmla v15.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0xd0]\n" + "fmla v16.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0xe0]\n" + "fmla v17.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0xf0]\n" + "fmla v18.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x100]\n" + "fmla v19.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x110]\n" + "fmla v20.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x120]\n" + "fmla v21.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x130]\n" + "fmla v22.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x140]\n" + "fmla v23.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x150]\n" + "fmla v24.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x160]\n" + "fmla v25.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x170]\n" + "add %[a_ptr], %[a_ptr], %[jump]\n" + "fmla v26.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x00]\n" + "fmla v27.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x10]\n" + "fmla v28.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x20]\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x30]\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x40]\n" + "fmla v31.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x50]\n" + "bne 2b\n" + + "10:\n" + + // Final iteration + "dup v0.4s, w0\n" + "fmla v8.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x60]\n" + "fmla v9.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x70]\n" + "fmla v10.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x80]\n" + "fmla v11.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x90]\n" + "fmla v12.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0xa0]\n" + "fmla v13.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0xb0]\n" + "fmla v14.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0xc0]\n" + "fmla v15.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0xd0]\n" + "fmla v16.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0xe0]\n" + "fmla v17.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0xf0]\n" + "fmla v18.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x100]\n" + "fmla v19.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x110]\n" + "fmla v20.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x120]\n" + "fmla v21.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x130]\n" + "fmla v22.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x140]\n" + "fmla v23.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x150]\n" + "fmla v24.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x160]\n" + "fmla v25.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x170]\n" + "fmla v26.4s, v2.4s, v0.4s\n" + "ldr q2, [%[y_ptr]]\n" + "fmla v27.4s, v3.4s, v0.4s\n" + "ldr q3, [%[y_ptr], #0x10]\n" + "fmla v28.4s, v4.4s, v0.4s\n" + "ldr q4, [%[y_ptr], #0x20]\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "ldr q5, [%[y_ptr], #0x30]\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "ldr q6, [%[y_ptr], #0x40]\n" + "fmla v31.4s, v7.4s, v0.4s\n" + "ldr q7, [%[y_ptr], #0x50]\n" + + "fmla v2.4s, v8.4s, %[va].4s\n" + "ldr q8, [%[y_ptr], #0x60]\n" + "fmla v3.4s, v9.4s, %[va].4s\n" + "ldr q9, [%[y_ptr], #0x70]\n" + "fmla v4.4s, v10.4s, %[va].4s\n" + "ldr q10, [%[y_ptr], #0x80]\n" + "fmla v5.4s, v11.4s, %[va].4s\n" + "ldr q11, [%[y_ptr], #0x90]\n" + "fmla v6.4s, v12.4s, %[va].4s\n" + "ldr q12, [%[y_ptr], #0xa0]\n" + "str q2, [%[y_ptr], #0x00]\n" + "fmla v7.4s, v13.4s, %[va].4s\n" + "ldr q13, [%[y_ptr], #0xb0]\n" + "str q3, [%[y_ptr], #0x10]\n" + "fmla v8.4s, v14.4s, %[va].4s\n" + "ldr q14, [%[y_ptr], #0xc0]\n" + "str q4, [%[y_ptr], #0x20]\n" + "fmla v9.4s, v15.4s, %[va].4s\n" + "ldr q15, [%[y_ptr], #0xd0]\n" + "str q5, [%[y_ptr], #0x30]\n" + "fmla v10.4s, v16.4s, %[va].4s\n" + "ldr q16, [%[y_ptr], #0xe0]\n" + "str q6, [%[y_ptr], #0x40]\n" + "fmla v11.4s, v17.4s, %[va].4s\n" + "ldr q17, [%[y_ptr], #0xf0]\n" + "str q7, [%[y_ptr], #0x50]\n" + "fmla v12.4s, v18.4s, %[va].4s\n" + "ldr q18, [%[y_ptr], #0x100]\n" + "str q8, [%[y_ptr], #0x60]\n" + "fmla v13.4s, v19.4s, %[va].4s\n" + "ldr q19, [%[y_ptr], #0x110]\n" + "str q9, [%[y_ptr], #0x70]\n" + "fmla v14.4s, v20.4s, %[va].4s\n" + "ldr q20, [%[y_ptr], #0x120]\n" + "str q10, [%[y_ptr], #0x80]\n" + "fmla v15.4s, v21.4s, %[va].4s\n" + "ldr q21, [%[y_ptr], #0x130]\n" + "str q11, [%[y_ptr], #0x90]\n" + "fmla v16.4s, v22.4s, %[va].4s\n" + "ldr q22, [%[y_ptr], #0x140]\n" + "str q12, [%[y_ptr], #0xa0]\n" + "fmla v17.4s, v23.4s, %[va].4s\n" + "ldr q23, [%[y_ptr], #0x150]\n" + "str q13, [%[y_ptr], #0xb0]\n" + "fmla v18.4s, v24.4s, %[va].4s\n" + "ldr q24, [%[y_ptr], #0x160]\n" + "str q14, [%[y_ptr], #0xc0]\n" + "fmla v19.4s, v25.4s, %[va].4s\n" + "ldr q25, [%[y_ptr], #0x170]\n" + "str q15, [%[y_ptr], #0xd0]\n" + "fmla v20.4s, v26.4s, %[va].4s\n" + "str q16, [%[y_ptr], #0xe0]\n" + "fmla v21.4s, v27.4s, %[va].4s\n" + "str q17, [%[y_ptr], #0xf0]\n" + "fmla v22.4s, v28.4s, %[va].4s\n" + "str q18, [%[y_ptr], #0x100]\n" + "fmla v23.4s, v29.4s, %[va].4s\n" + "str q19, [%[y_ptr], #0x110]\n" + "fmla v24.4s, v30.4s, %[va].4s\n" + "str q20, [%[y_ptr], #0x120]\n" + "fmla v25.4s, v31.4s, %[va].4s\n" + "str q21, [%[y_ptr], #0x130]\n" + + "stp q22, q23, [%[y_ptr], #0x140]\n" + "stp q24, q25, [%[y_ptr], #0x160]\n" + "add %[y_ptr], %[y_ptr], #0x180\n" + + : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k), [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr) + : [jump] "r"(jump), [va] "w"(va), [pf_limit] "r"(pf_limit) + : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "cc"); + } + + if(N > 0) + { + // Handle N tail - up to 95 stragglers. + // This is 0-23 vectors, plus optionally an 64-bit vector and/or a + // single value for the remainder. + + // Independent pointers into the matrix for the odd 2 and odd 1. + // Double up as flag to indicate whether they are needed. + const float *odd2_aptr = NULL; + const float *odd1_aptr = NULL; + + // Figure out how much work we need to do. + int numvecs = N / 4; + int rem = N % 4; + int k = M; + + // Set up pointers for the odd 2/1 if needed. + if(rem >= 2) + { + odd2_aptr = a_ptr_base + (numvecs * 4); + } + + if(rem & 1) + { + odd1_aptr = a_ptr_base + (numvecs * 4) + (odd2_aptr == NULL ? 0 : 2); + } + + const float *a_ptr = a_ptr_base; + const float *firstpf_ptr = a_ptr_base; + const float *pf_ptr = a_ptr_base; + const float *pf_limit = a_ptr + (M * lda); + + const float *x_ptr = Xstart; + int vecs = 0; // Working variable to count how many vectors to work on. + int dopf = 1; // Track whether we are doing prefetches. + + // Figure out how many cache lines we need to prefetch each time. + int numpfs = (N + 15) / 16; + + // Do initial prefetches + for(int i = 0; i < firstpfd + 1; i++) + { + prefetch_1x(firstpf_ptr); + firstpf_ptr += lda; + } + + // Do "main" prefetches - adapt number to the number we actually need. + if(numpfs > 1) + { + for(int i = 0; i < pfd + 1; i++) + { + switch(numpfs) + { + case 2: + prefetch_1x(pf_ptr + 16); + break; + + case 3: + prefetch_2x(pf_ptr + 16); + break; + + case 4: + prefetch_3x(pf_ptr + 16); + break; + + case 5: + prefetch_4x(pf_ptr + 16); + break; + + case 6: + prefetch_5x(pf_ptr + 16); + break; + + default: + UNREACHABLE("Impossible."); + } + pf_ptr += lda; + } + } + else + { + // Just disable additional prefetches + dopf = 0; + } + + // Do the real work + __asm __volatile( + // Initialize all the vectors - not worth skipping this if only + // some are needed. + "movi v8.4s,#0x0\n" + "ldr w0, [%[x_ptr]]\n" + "movi v9.4s,#0x0\n" + "movi v10.4s,#0x0\n" + "movi v11.4s,#0x0\n" + "movi v12.4s,#0x0\n" + "movi v13.4s,#0x0\n" + "movi v14.4s,#0x0\n" + "movi v15.4s,#0x0\n" + "movi v16.4s, #0x0\n" + "movi v17.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v6.2s, #0x0\n" + "movi v5.2s, #0x0\n" + + "1:\n" ASM_PREFETCH("[%[firstpf_ptr]]\n") + "11:\n" + "dup v0.4s, w0\n" + "ldr w0, [%[x_ptr], #4]\n" + "add %[x_ptr], %[x_ptr], #4\n" + + "cbz %w[numvecs], 2f\n" + "mov %w[vecs], %w[numvecs]\n" + + // Vector 0 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x00]\n" + "fmla v8.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 1 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x10]\n" + "fmla v9.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 2 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x20]\n" + "fmla v10.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 3 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x30]\n" + "fmla v11.4s, v7.4s, v0.4s\n" + // Prefetch + "cbz %w[dopf], 3f\n" ASM_PREFETCH("[%[pf_ptr], #0x40]") + "3:\n" + "beq 2f\n" + + // Vector 4 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x40]\n" + "fmla v12.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 5 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x50]\n" + "fmla v13.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 6 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x60]\n" + "fmla v14.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 7 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x70]\n" + "fmla v15.4s, v7.4s, v0.4s\n" + // Prefetch + "cbz %w[dopf], 4f\n" ASM_PREFETCH("[%[pf_ptr], #0x80]") + "4:\n" + "beq 2f\n" + + // Vector 8 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x80]\n" + "fmla v16.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 9 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x90]\n" + "fmla v17.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 10 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xa0]\n" + "fmla v18.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 11 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xb0]\n" + "fmla v19.4s, v7.4s, v0.4s\n" + // Prefetch + "cbz %w[dopf], 5f\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]") + "5:\n" + "beq 2f\n" + + // Vector 12 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xc0]\n" + "fmla v20.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 13 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xd0]\n" + "fmla v21.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 14 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xe0]\n" + "fmla v22.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 15 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xf0]\n" + "fmla v23.4s, v7.4s, v0.4s\n" + // Prefetch + "cbz %w[dopf], 6f\n" ASM_PREFETCH("[%[pf_ptr], #0x100]") + "6:\n" + "beq 2f\n" + + // Vector 16 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x100]\n" + "fmla v24.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 17 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x110]\n" + "fmla v25.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 18 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x120]\n" + "fmla v26.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 19 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x130]\n" + "fmla v27.4s, v7.4s, v0.4s\n" + // Prefetch + "cbz %w[dopf], 7f\n" ASM_PREFETCH("[%[pf_ptr], #0x140]") + "7:\n" + "beq 2f\n" + + // Vector 20 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x140]\n" + "fmla v28.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 21 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x150]\n" + "fmla v29.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 22 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x160]\n" + "fmla v30.4s, v7.4s, v0.4s\n" + + "2:\n" + "add %[a_ptr], %[a_ptr], %[jump]\n" + + // Do the odd 2-vector, if needed + "cbz %[odd2_aptr], 8f\n" + "ldr d7, [%[odd2_aptr]]\n" + "fmla v6.2s, v7.2s, v0.2s\n" + "add %[odd2_aptr], %[odd2_aptr], %[jump]\n" + + "8:\n" + // Do the odd 1-vector, if needed + "cbz %[odd1_aptr], 9f\n" + "ldr s7, [%[odd1_aptr]]\n" + "fmla v5.2s, v7.2s, v0.2s\n" + "add %[odd1_aptr], %[odd1_aptr], %[jump]\n" + + // Get out if needed. + "9:\n" + "subs %w[k], %w[k], #1\n" + "beq 10f\n" + + // Update the "main" prefetch pointer, if it strays beyond the limit turn off "dopf" + "add %[pf_ptr], %[pf_ptr], %[jump]\n" + "cmp %[pf_ptr], %[pf_limit]\n" + "csel %w[dopf], %w[dopf], WZR, LT\n" + + // Update the "leading" prefetch pointer, don't do the first + // instruction of the loop if it's over the limit. + "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n" + "cmp %[firstpf_ptr], %[pf_limit]\n" + "blt 1b\n" + "b 11b\n" + + // Now write out the outputs + "10:\n" + "cbz %w[numvecs], 12f\n" + "mov %w[vecs], %w[numvecs]\n" + + // Vector 0 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v8.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 1 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v9.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 2 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v10.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 3 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v11.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 4 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v12.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 5 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v13.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 6 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v14.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 7 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v15.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 8 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v16.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 9 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v17.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 10 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v18.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 11 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v19.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 12 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v20.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 13 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v21.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 14 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v22.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 15 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v23.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 16 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v24.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 17 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v25.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 18 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v26.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 19 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v27.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 20 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v28.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 21 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v29.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 22 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v30.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + + // Odd 2 + "12:\n" + "cbz %[odd2_aptr], 13f\n" + "ldr d7, [%[y_ptr]]\n" + "fmla v7.2s, v6.2s, %[va].2s\n" + "str d7, [%[y_ptr]], #0x8\n" + + // Odd 1 + "13:\n" + "cbz %[odd1_aptr], 14f\n" + "ldr s7, [%[y_ptr]]\n" + "fmla v7.2s, v5.2s, %[va].2s\n" + "str s7, [%[y_ptr]]\n" + + "14:\n" + : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k), + [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr), + [odd1_aptr] "+r"(odd1_aptr), [odd2_aptr] "+r"(odd2_aptr), + [dopf] "+r"(dopf), [vecs] "+r"(vecs) + : [jump] "r"(jump), [va] "w"(va), [pf_limit] "r"(pf_limit), [numvecs] "r"(numvecs) + : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "cc"); + } +} + +} // namespace arm_gemm + +#endif // __aarch64__ |