/* * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #ifdef __aarch64__ #include #include #include "../../asmlib.hpp" #include "../../utils.hpp" // Kernel implementation - transposed GEMV // // The kernel will process "M" rows of A (= steps of dot product) and "N" // columns (= dot products total) // // General plan is to do as many columns simultaneously as possible - a // reasonable limit is half the NEON regfile = 64 total accumulators. // // It's possible that messing around with sub-blocking M and N can yield // higher performance, but that's left to the outer loop. In this kernel we // process all of M at the same time. // How far ahead to prefetch for the first and subsequent prefetches. // These values work for A72 on JunoR2... #define FIRST_PFD 9 #define PFD 6 namespace arm_gemm { void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float beta, int lda, int M, int N) { const float *a_ptr_base = Astart; float *y_ptr = Ystart; register const float32x4_t vb asm("v1") = vdupq_n_f32(beta); int firstpfd = FIRST_PFD; if(firstpfd > M) { firstpfd = (M - 1); } int pfd = PFD; if(pfd > M) { pfd = (M - 1); } ptrdiff_t jump = lda * sizeof(int); for(; N >= 96; N -= 96) { int k = M - 1; const float *a_ptr = a_ptr_base; const float *x_ptr = Xstart; const float *pf_ptr = a_ptr; const float *firstpf_ptr = a_ptr; const float *pf_limit = a_ptr + (M * lda); for(int i = 0; i < firstpfd; i++) { prefetch_1x(firstpf_ptr); firstpf_ptr += lda; } for(int i = 0; i < pfd; i++) { prefetch_5x(pf_ptr + 16); pf_ptr += lda; } a_ptr_base += 96; __asm __volatile( "movi v8.4s,#0x0\n" "ldr w0, [%[x_ptr]]\n" "movi v9.4s,#0x0\n" "ldr q2, [%[a_ptr], #0]\n" "movi v10.4s,#0x0\n" "ldr q3, [%[a_ptr], #0x10]\n" "movi v11.4s,#0x0\n" "ldr q4, [%[a_ptr], #0x20]\n" "movi v12.4s,#0x0\n" "ldr q5, [%[a_ptr], #0x30]\n" "movi v13.4s,#0x0\n" "ldr q6, [%[a_ptr], #0x40]\n" "movi v14.4s,#0x0\n" "ldr q7, [%[a_ptr], #0x50]\n" "movi v15.4s,#0x0\n" ASM_PREFETCH("[%[firstpf_ptr]]") "movi v16.4s, #0x0\n" "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #64]") "movi v18.4s, #0x0\n" "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #128]") "movi v20.4s, #0x0\n" "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #192]") "movi v22.4s, #0x0\n" "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #256]") "movi v24.4s, #0x0\n" "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #320]") "movi v26.4s, #0x0\n" "movi v27.4s, #0x0\n" "add %[pf_ptr], %[pf_ptr], %[jump]\n" "movi v28.4s, #0x0\n" "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n" "movi v29.4s, #0x0\n" "movi v30.4s, #0x0\n" "movi v31.4s, #0x0\n" // Skip everything if there are no iterations of the main loop to do. "cbz %w[k], 10f\n" // Loop with all prefetches. Exit this loop when firstpf_ptr // hits pf_limit. "1:\n" "dup v0.4s, w0\n" "ldr w0, [%[x_ptr], #4]\n" "add %[x_ptr], %[x_ptr], #0x4\n" "fmla v8.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x60]\n" "fmla v9.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x70]\n" ASM_PREFETCH("[%[firstpf_ptr]]") "fmla v10.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x80]\n" "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n" "fmla v11.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x90]\n" "sub %w[k], %w[k], #1\n" ASM_PREFETCH("[%[x_ptr], #128]") "fmla v12.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0xa0]\n" "fmla v13.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0xb0]\n" ASM_PREFETCH("[%[pf_ptr], #0x40]") "fmla v14.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0xc0]\n" "fmla v15.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0xd0]\n" "fmla v16.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0xe0]\n" "fmla v17.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0xf0]\n" ASM_PREFETCH("[%[pf_ptr], #0x80]") "fmla v18.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0x100]\n" "fmla v19.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x110]\n" "fmla v20.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x120]\n" "fmla v21.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x130]\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]") "fmla v22.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x140]\n" "fmla v23.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x150]\n" "fmla v24.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0x160]\n" "fmla v25.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x170]\n" ASM_PREFETCH("[%[pf_ptr], #0x100]") "add %[a_ptr], %[a_ptr], %[jump]\n" "fmla v26.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x00]\n" "fmla v27.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x10]\n" "fmla v28.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x20]\n" "fmla v29.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x30]\n" ASM_PREFETCH("[%[pf_ptr], #0x140]") "fmla v30.4s, v6.4s, v0.4s\n" "add %[pf_ptr], %[pf_ptr], %[jump]\n" "ldr q6, [%[a_ptr], #0x40]\n" "fmla v31.4s, v7.4s, v0.4s\n" "cmp %[firstpf_ptr], %[pf_limit]\n" "ldr q7, [%[a_ptr], #0x50]\n" "blt 1b\n" // Check that there are still "main" prefetches to do. "cmp %[pf_ptr], %[pf_limit]\n" "bge 9f\n" // Just the main prefetches, exit this loop when pf_ptr hits pf_limit. "8:\n" "dup v0.4s, w0\n" "ldr w0, [%[x_ptr], #4]\n" "add %[x_ptr], %[x_ptr], #0x4\n" "fmla v8.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x60]\n" "fmla v9.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x70]\n" "fmla v10.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x80]\n" "fmla v11.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x90]\n" "sub %w[k], %w[k], #1\n" ASM_PREFETCH("[%[x_ptr], #128]") "fmla v12.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0xa0]\n" "fmla v13.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0xb0]\n" ASM_PREFETCH("[%[pf_ptr], #0x40]") "fmla v14.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0xc0]\n" "fmla v15.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0xd0]\n" "fmla v16.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0xe0]\n" "fmla v17.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0xf0]\n" ASM_PREFETCH("[%[pf_ptr], #0x80]") "fmla v18.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0x100]\n" "fmla v19.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x110]\n" "fmla v20.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x120]\n" "fmla v21.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x130]\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]") "fmla v22.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x140]\n" "fmla v23.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x150]\n" "fmla v24.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0x160]\n" "fmla v25.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x170]\n" ASM_PREFETCH("[%[pf_ptr], #0x100]") "add %[a_ptr], %[a_ptr], %[jump]\n" "fmla v26.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x00]\n" "fmla v27.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x10]\n" "fmla v28.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x20]\n" "fmla v29.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x30]\n" ASM_PREFETCH("[%[pf_ptr], #0x140]") "fmla v30.4s, v6.4s, v0.4s\n" "add %[pf_ptr], %[pf_ptr], %[jump]\n" "ldr q6, [%[a_ptr], #0x40]\n" "fmla v31.4s, v7.4s, v0.4s\n" "cmp %[pf_ptr], %[pf_limit]\n" "ldr q7, [%[a_ptr], #0x50]\n" "blt 8b\n" // Check that there is still work to do. "9:\n" "cmp %w[k], #0\n" "beq 10f\n" // Loop without prefetches, exit when k hits 0. "2:\n" "dup v0.4s, w0\n" "ldr w0, [%[x_ptr], #4]\n" "add %[x_ptr], %[x_ptr], #0x4\n" "fmla v8.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x60]\n" "fmla v9.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x70]\n" "fmla v10.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x80]\n" "fmla v11.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x90]\n" "subs %w[k], %w[k], #1\n" "fmla v12.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0xa0]\n" "fmla v13.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0xb0]\n" "fmla v14.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0xc0]\n" "fmla v15.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0xd0]\n" "fmla v16.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0xe0]\n" "fmla v17.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0xf0]\n" "fmla v18.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0x100]\n" "fmla v19.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x110]\n" "fmla v20.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x120]\n" "fmla v21.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x130]\n" "fmla v22.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x140]\n" "fmla v23.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x150]\n" "fmla v24.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0x160]\n" "fmla v25.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x170]\n" "add %[a_ptr], %[a_ptr], %[jump]\n" "fmla v26.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x00]\n" "fmla v27.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x10]\n" "fmla v28.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x20]\n" "fmla v29.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x30]\n" "fmla v30.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0x40]\n" "fmla v31.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x50]\n" "bne 2b\n" "10:\n" // Final iteration "dup v0.4s, w0\n" "fmla v8.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x60]\n" "fmla v9.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x70]\n" "fmla v10.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x80]\n" "fmla v11.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x90]\n" "fmla v12.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0xa0]\n" "fmla v13.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0xb0]\n" "fmla v14.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0xc0]\n" "fmla v15.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0xd0]\n" "fmla v16.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0xe0]\n" "fmla v17.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0xf0]\n" "fmla v18.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0x100]\n" "fmla v19.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x110]\n" "fmla v20.4s, v2.4s, v0.4s\n" "ldr q2, [%[a_ptr], #0x120]\n" "fmla v21.4s, v3.4s, v0.4s\n" "ldr q3, [%[a_ptr], #0x130]\n" "fmla v22.4s, v4.4s, v0.4s\n" "ldr q4, [%[a_ptr], #0x140]\n" "fmla v23.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0x150]\n" "fmla v24.4s, v6.4s, v0.4s\n" "ldr q6, [%[a_ptr], #0x160]\n" "fmla v25.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x170]\n" "fmla v26.4s, v2.4s, v0.4s\n" "ldr q2, [%[y_ptr]]\n" "fmla v27.4s, v3.4s, v0.4s\n" "ldr q3, [%[y_ptr], #0x10]\n" "fmla v28.4s, v4.4s, v0.4s\n" "ldr q4, [%[y_ptr], #0x20]\n" "fmla v29.4s, v5.4s, v0.4s\n" "ldr q5, [%[y_ptr], #0x30]\n" "fmla v30.4s, v6.4s, v0.4s\n" "ldr q6, [%[y_ptr], #0x40]\n" "fmla v31.4s, v7.4s, v0.4s\n" "ldr q7, [%[y_ptr], #0x50]\n" "fmla v8.4s, v2.4s, %[vb].4s\n" "ldr q2, [%[y_ptr], #0x60]\n" "fmla v9.4s, v3.4s, %[vb].4s\n" "ldr q3, [%[y_ptr], #0x70]\n" "fmla v10.4s, v4.4s, %[vb].4s\n" "ldr q4, [%[y_ptr], #0x80]\n" "fmla v11.4s, v5.4s, %[vb].4s\n" "ldr q5, [%[y_ptr], #0x90]\n" "fmla v12.4s, v6.4s, %[vb].4s\n" "ldr q6, [%[y_ptr], #0xa0]\n" "str q8, [%[y_ptr], #0x00]\n" "fmla v13.4s, v7.4s, %[vb].4s\n" "ldr q7, [%[y_ptr], #0xb0]\n" "str q9, [%[y_ptr], #0x10]\n" "fmla v14.4s, v2.4s, %[vb].4s\n" "ldr q2, [%[y_ptr], #0xc0]\n" "str q10, [%[y_ptr], #0x20]\n" "fmla v15.4s, v3.4s, %[vb].4s\n" "ldr q3, [%[y_ptr], #0xd0]\n" "str q11, [%[y_ptr], #0x30]\n" "fmla v16.4s, v4.4s, %[vb].4s\n" "ldr q4, [%[y_ptr], #0xe0]\n" "str q12, [%[y_ptr], #0x40]\n" "fmla v17.4s, v5.4s, %[vb].4s\n" "ldr q5, [%[y_ptr], #0xf0]\n" "str q13, [%[y_ptr], #0x50]\n" "fmla v18.4s, v6.4s, %[vb].4s\n" "ldr q6, [%[y_ptr], #0x100]\n" "str q14, [%[y_ptr], #0x60]\n" "fmla v19.4s, v7.4s, %[vb].4s\n" "ldr q7, [%[y_ptr], #0x110]\n" "str q15, [%[y_ptr], #0x70]\n" "fmla v20.4s, v2.4s, %[vb].4s\n" "ldr q2, [%[y_ptr], #0x120]\n" "str q16, [%[y_ptr], #0x80]\n" "fmla v21.4s, v3.4s, %[vb].4s\n" "ldr q3, [%[y_ptr], #0x130]\n" "str q17, [%[y_ptr], #0x90]\n" "fmla v22.4s, v4.4s, %[vb].4s\n" "ldr q4, [%[y_ptr], #0x140]\n" "str q18, [%[y_ptr], #0xa0]\n" "fmla v23.4s, v5.4s, %[vb].4s\n" "ldr q5, [%[y_ptr], #0x150]\n" "str q19, [%[y_ptr], #0xb0]\n" "fmla v24.4s, v6.4s, %[vb].4s\n" "ldr q6, [%[y_ptr], #0x160]\n" "str q20, [%[y_ptr], #0xc0]\n" "fmla v25.4s, v7.4s, %[vb].4s\n" "ldr q7, [%[y_ptr], #0x170]\n" "str q21, [%[y_ptr], #0xd0]\n" "fmla v26.4s, v2.4s, %[vb].4s\n" "str q22, [%[y_ptr], #0xe0]\n" "fmla v27.4s, v3.4s, %[vb].4s\n" "str q23, [%[y_ptr], #0xf0]\n" "fmla v28.4s, v4.4s, %[vb].4s\n" "str q24, [%[y_ptr], #0x100]\n" "fmla v29.4s, v5.4s, %[vb].4s\n" "str q25, [%[y_ptr], #0x110]\n" "fmla v30.4s, v6.4s, %[vb].4s\n" "str q26, [%[y_ptr], #0x120]\n" "fmla v31.4s, v7.4s, %[vb].4s\n" "str q27, [%[y_ptr], #0x130]\n" "stp q28, q29, [%[y_ptr], #0x140]\n" "stp q30, q31, [%[y_ptr], #0x160]\n" "add %[y_ptr], %[y_ptr], #0x180\n" : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k), [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr) : [jump] "r"(jump), [vb] "w"(vb), [pf_limit] "r"(pf_limit) : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); } if(N > 0) { // Handle N tail - up to 95 stragglers. // This is 0-23 vectors, plus optionally an 64-bit vector and/or a // single value for the remainder. // Independent pointers into the matrix for the odd 2 and odd 1. // Double up as flag to indicate whether they are needed. const float *odd2_aptr = NULL; const float *odd1_aptr = NULL; // Figure out how much work we need to do. int numvecs = N / 4; int rem = N % 4; int k = M; // Set up pointers for the odd 2/1 if needed. if(rem >= 2) { odd2_aptr = a_ptr_base + (numvecs * 4); } if(rem & 1) { odd1_aptr = a_ptr_base + (numvecs * 4) + (odd2_aptr == NULL ? 0 : 2); } const float *a_ptr = a_ptr_base; const float *firstpf_ptr = a_ptr_base; const float *pf_ptr = a_ptr_base; const float *pf_limit = a_ptr + (M * lda); const float *x_ptr = Xstart; int vecs = 0; // Working variable to count how many vectors to work on. int dopf = 1; // Track whether we are doing prefetches. // Figure out how many cache lines we need to prefetch each time. int numpfs = (N + 15) / 16; // Do initial prefetches for(int i = 0; i < firstpfd + 1; i++) { prefetch_1x(firstpf_ptr); firstpf_ptr += lda; } // Do "main" prefetches - adapt number to the number we actually need. if(numpfs > 1) { for(int i = 0; i < pfd + 1; i++) { switch(numpfs) { case 2: prefetch_1x(pf_ptr + 16); break; case 3: prefetch_2x(pf_ptr + 16); break; case 4: prefetch_3x(pf_ptr + 16); break; case 5: prefetch_4x(pf_ptr + 16); break; case 6: prefetch_5x(pf_ptr + 16); break; default: UNREACHABLE("Impossible."); } pf_ptr += lda; } } else { // Just disable additional prefetches dopf = 0; } // Do the real work __asm __volatile( // Initialize all the vectors - not worth skipping this if only // some are needed. "movi v8.4s,#0x0\n" "ldr w0, [%[x_ptr]]\n" "movi v9.4s,#0x0\n" "movi v10.4s,#0x0\n" "movi v11.4s,#0x0\n" "movi v12.4s,#0x0\n" "movi v13.4s,#0x0\n" "movi v14.4s,#0x0\n" "movi v15.4s,#0x0\n" "movi v16.4s, #0x0\n" "movi v17.4s, #0x0\n" "movi v18.4s, #0x0\n" "movi v19.4s, #0x0\n" "movi v20.4s, #0x0\n" "movi v21.4s, #0x0\n" "movi v22.4s, #0x0\n" "movi v23.4s, #0x0\n" "movi v24.4s, #0x0\n" "movi v25.4s, #0x0\n" "movi v26.4s, #0x0\n" "movi v27.4s, #0x0\n" "movi v28.4s, #0x0\n" "movi v29.4s, #0x0\n" "movi v30.4s, #0x0\n" "movi v6.2s, #0x0\n" "movi v5.2s, #0x0\n" "1:\n" ASM_PREFETCH("[%[firstpf_ptr]]\n") "11:\n" "dup v0.4s, w0\n" "ldr w0, [%[x_ptr], #4]\n" "add %[x_ptr], %[x_ptr], #4\n" "cbz %w[numvecs], 2f\n" "mov %w[vecs], %w[numvecs]\n" // Vector 0 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x00]\n" "fmla v8.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 1 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x10]\n" "fmla v9.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 2 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x20]\n" "fmla v10.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 3 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x30]\n" "fmla v11.4s, v7.4s, v0.4s\n" // Prefetch "cbz %w[dopf], 3f\n" ASM_PREFETCH("[%[pf_ptr], #0x40]") "3:\n" "beq 2f\n" // Vector 4 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x40]\n" "fmla v12.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 5 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x50]\n" "fmla v13.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 6 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x60]\n" "fmla v14.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 7 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x70]\n" "fmla v15.4s, v7.4s, v0.4s\n" // Prefetch "cbz %w[dopf], 4f\n" ASM_PREFETCH("[%[pf_ptr], #0x80]") "4:\n" "beq 2f\n" // Vector 8 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x80]\n" "fmla v16.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 9 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x90]\n" "fmla v17.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 10 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0xa0]\n" "fmla v18.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 11 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0xb0]\n" "fmla v19.4s, v7.4s, v0.4s\n" // Prefetch "cbz %w[dopf], 5f\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]") "5:\n" "beq 2f\n" // Vector 12 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0xc0]\n" "fmla v20.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 13 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0xd0]\n" "fmla v21.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 14 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0xe0]\n" "fmla v22.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 15 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0xf0]\n" "fmla v23.4s, v7.4s, v0.4s\n" // Prefetch "cbz %w[dopf], 6f\n" ASM_PREFETCH("[%[pf_ptr], #0x100]") "6:\n" "beq 2f\n" // Vector 16 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x100]\n" "fmla v24.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 17 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x110]\n" "fmla v25.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 18 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x120]\n" "fmla v26.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 19 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x130]\n" "fmla v27.4s, v7.4s, v0.4s\n" // Prefetch "cbz %w[dopf], 7f\n" ASM_PREFETCH("[%[pf_ptr], #0x140]") "7:\n" "beq 2f\n" // Vector 20 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x140]\n" "fmla v28.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 21 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x150]\n" "fmla v29.4s, v7.4s, v0.4s\n" "beq 2f\n" // Vector 22 "subs %w[vecs], %w[vecs], #1\n" "ldr q7,[%[a_ptr], #0x160]\n" "fmla v30.4s, v7.4s, v0.4s\n" "2:\n" "add %[a_ptr], %[a_ptr], %[jump]\n" // Do the odd 2-vector, if needed "cbz %[odd2_aptr], 8f\n" "ldr d7, [%[odd2_aptr]]\n" "fmla v6.2s, v7.2s, v0.2s\n" "add %[odd2_aptr], %[odd2_aptr], %[jump]\n" "8:\n" // Do the odd 1-vector, if needed "cbz %[odd1_aptr], 9f\n" "ldr s7, [%[odd1_aptr]]\n" "fmla v5.2s, v7.2s, v0.2s\n" "add %[odd1_aptr], %[odd1_aptr], %[jump]\n" // Get out if needed. "9:\n" "subs %w[k], %w[k], #1\n" "beq 10f\n" // Update the "main" prefetch pointer, if it strays beyond the limit turn off "dopf" "add %[pf_ptr], %[pf_ptr], %[jump]\n" "cmp %[pf_ptr], %[pf_limit]\n" "csel %w[dopf], %w[dopf], WZR, LT\n" // Update the "leading" prefetch pointer, don't do the first // instruction of the loop if it's over the limit. "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n" "cmp %[firstpf_ptr], %[pf_limit]\n" "blt 1b\n" "b 11b\n" // Now write out the outputs "10:\n" "cbz %w[numvecs], 12f\n" "mov %w[vecs], %w[numvecs]\n" // Vector 0 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v8.4s, v7.4s, %[vb].4s\n" "str q8, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 1 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v9.4s, v7.4s, %[vb].4s\n" "str q9, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 2 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v10.4s, v7.4s, %[vb].4s\n" "str q10, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 3 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v11.4s, v7.4s, %[vb].4s\n" "str q11, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 4 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v12.4s, v7.4s, %[vb].4s\n" "str q12, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 5 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v13.4s, v7.4s, %[vb].4s\n" "str q13, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 6 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v14.4s, v7.4s, %[vb].4s\n" "str q14, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 7 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v15.4s, v7.4s, %[vb].4s\n" "str q15, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 8 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v16.4s, v7.4s, %[vb].4s\n" "str q16, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 9 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v17.4s, v7.4s, %[vb].4s\n" "str q17, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 10 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v18.4s, v7.4s, %[vb].4s\n" "str q18, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 11 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v19.4s, v7.4s, %[vb].4s\n" "str q19, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 12 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v20.4s, v7.4s, %[vb].4s\n" "str q20, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 13 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v21.4s, v7.4s, %[vb].4s\n" "str q21, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 14 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v22.4s, v7.4s, %[vb].4s\n" "str q22, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 15 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v23.4s, v7.4s, %[vb].4s\n" "str q23, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 16 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v24.4s, v7.4s, %[vb].4s\n" "str q24, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 17 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v25.4s, v7.4s, %[vb].4s\n" "str q25, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 18 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v26.4s, v7.4s, %[vb].4s\n" "str q26, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 19 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v27.4s, v7.4s, %[vb].4s\n" "str q27, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 20 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v28.4s, v7.4s, %[vb].4s\n" "str q28, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 21 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v29.4s, v7.4s, %[vb].4s\n" "str q29, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 22 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" "fmla v30.4s, v7.4s, %[vb].4s\n" "str q30, [%[y_ptr]], #0x10\n" // Odd 2 "12:\n" "cbz %[odd2_aptr], 13f\n" "ldr d7, [%[y_ptr]]\n" "fmla v6.2s, v7.2s, %[vb].2s\n" "str d6, [%[y_ptr]], #0x8\n" // Odd 1 "13:\n" "cbz %[odd1_aptr], 14f\n" "ldr s7, [%[y_ptr]]\n" "fmla v5.2s, v7.2s, %[vb].2s\n" "str s5, [%[y_ptr]]\n" "14:\n" : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k), [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr), [odd1_aptr] "+r"(odd1_aptr), [odd2_aptr] "+r"(odd2_aptr), [dopf] "+r"(dopf), [vecs] "+r"(vecs) : [jump] "r"(jump), [vb] "w"(vb), [pf_limit] "r"(pf_limit), [numvecs] "r"(numvecs) : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); } } } // namespace arm_gemm #endif // __aarch64__