/* * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #ifdef __aarch64__ #include #include #include "../../asmlib.hpp" #include "../../utils.hpp" // Kernel implementation - transposed GEMV // // The kernel will process "M" rows of A (= steps of dot product) and "N" // columns (= dot products total) // // General plan is to do as many columns simultaneously as possible - a // reasonable limit is half the NEON regfile = 64 total accumulators. // // It's possible that messing around with sub-blocking M and N can yield // higher performance, but that's left to the outer loop. In this kernel we // process all of M at the same time. // How far ahead to prefetch for the first and subsequent prefetches. // These values work for A72 on JunoR2... #define FIRST_PFD 9 #define PFD 6 namespace arm_gemm { void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float beta, int lda, int M, int N) { const float *a_ptr_base = Astart; float *y_ptr = Ystart; const bool beta0 = (beta == 0.0f); register const float32x4_t vb asm("v1") = vdupq_n_f32(beta); int firstpfd=FIRST_PFD; if (firstpfd > M) { firstpfd = (M-1); } int pfd = PFD; if (pfd > M) { pfd = (M-1); } ptrdiff_t jump = lda * sizeof(int); for (;N>=96;N-=96) { int k = M-1; const float *a_ptr = a_ptr_base; const float *x_ptr = Xstart; const float *pf_ptr = a_ptr; const float *firstpf_ptr = a_ptr; const float *pf_limit = a_ptr + (M * lda); for (int i=0; i0) { // Handle N tail - up to 95 stragglers. // This is 0-23 vectors, plus optionally an 64-bit vector and/or a // single value for the remainder. // Independent pointers into the matrix for the odd 2 and odd 1. // Double up as flag to indicate whether they are needed. const float *odd2_aptr=NULL; const float *odd1_aptr=NULL; // Figure out how much work we need to do. int numvecs = N/4; int rem = N%4; int k=M; // Set up pointers for the odd 2/1 if needed. if (rem >= 2) { odd2_aptr = a_ptr_base + (numvecs * 4); } if (rem & 1) { odd1_aptr = a_ptr_base + (numvecs * 4) + (odd2_aptr==NULL ? 0 : 2); } const float *a_ptr = a_ptr_base; const float *firstpf_ptr = a_ptr_base; const float *pf_ptr = a_ptr_base; const float *pf_limit = a_ptr + (M * lda); const float *x_ptr = Xstart; int vecs=0; // Working variable to count how many vectors to work on. int dopf=1; // Track whether we are doing prefetches. // Figure out how many cache lines we need to prefetch each time. int numpfs = (N + 15) / 16; // Do initial prefetches for (int i=0; i 1) { for (int i=0; i