diff options
author | Pablo Tello <pablo.tello@arm.com> | 2018-02-23 13:43:50 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:49:16 +0000 |
commit | eb82fd2aa786715c3b6a941dc6d6deac4ce8e2a0 (patch) | |
tree | 42cca378eed97c07348f28e1ec708d9c7ed531ce /src/core/NEON/kernels/arm_gemm/kernels | |
parent | 8df6c452820719d201ee79596cde8445c2071db5 (diff) | |
download | ComputeLibrary-eb82fd2aa786715c3b6a941dc6d6deac4ce8e2a0.tar.gz |
COMPMID-881: RSH new arm_gemm interface.
Change-Id: I1e2a1a77097d8017c274af3f97eba6964f80f5fa
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122592
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels')
34 files changed, 9436 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp new file mode 100644 index 0000000000..de11dc582c --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __arm__ + +namespace arm_gemm +{ +// Actual kernel implementations +void a32_sgemm_8x6(const float *, const float *, float *, int, int, int); +void a32_sgemm_8x6_a53(const float *, const float *, float *, int, int, int); +void a32_sgemm_8x6_a55r1(const float *, const float *, float *, int, int, int); + +// 8x6 SGEMM "strategy" class. +// +// This describes the characteristics of a family of kernels, in terms of +// the required interleave properties and the output block size. +// +// All kernels in the family must share these characteristics. The actual +// kernel to be used can be chosen at runtime, based on the CPU_type +// structure. +class sgemm_8x6 +{ +public: + typedef float operand_type; + typedef float result_type; + + typedef void (*kern_type)(const float *, const float *, float *, int, int, int); + + /* Describes the data layout for A input */ + static const int A_interleave = 6; + static const int A_block = 1; + static const int A_transpose = 0; + + /* Same for B input */ + static const int B_interleave = 8; + static const int B_block = 1; + static const int B_transpose = 1; + + /* Kernel blocking parameters */ + static const int out_width = 8; + static const int out_height = 6; + static const int k_unroll = 1; + + kern_type kernel = a32_sgemm_8x6; + + sgemm_8x6(const CPUInfo *ci) + { + switch(ci->get_cpu_model()) + { + case CPUModel::A53: + kernel = a32_sgemm_8x6_a53; + break; + + case CPUModel::A55r1: + kernel = a32_sgemm_8x6_a55r1; + break; + + default: + kernel = a32_sgemm_8x6; + break; + } + } +}; + +} // namespace arm_gemm +#endif // __arm__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp new file mode 100644 index 0000000000..428498f79e --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp @@ -0,0 +1,400 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __arm__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +// Kernel implementation. +// +// Assume that "Apanel" points to a chunk of A blocks (each size 6xK) in read-order. +// Assume that "Bpanel" points to a chunk of B blocks (each size 8xK) in read-order. +// Assume that "Cpanel" points to a chunk of C output blocks (each size +// 8x6), the chunks being arranged in a row major fashion. +// +// Note that the intent of this is that either ablocks or bblocks will be 1 +// - this construction allows the output loop to proceed in either order. + +namespace arm_gemm +{ +void a32_sgemm_8x6_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) +{ + const float *a_ptr = Apanel; + float *c_ptr = Cpanel; + + for(int yb = 0; yb < ablocks; yb++) + { + const float *a_ptr0 = a_ptr; + const float *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + int tails = (K & 3); + if(tails == 0) + { + tails = 4; + } + int k = ((K + 3) / 4) - 1; + + __asm __volatile( + "vmov.i32 q4, #0\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]\n" + "vmov.i32 q5, #0\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]\n" + "vmov.i32 q6, #0\n" + "ldr r0, [%[a_ptr], #0x10]\n" + "vmov.i32 q7, #0\n" + "ldr r1, [%[a_ptr], #0x14]\n" + "vmov.i32 q8, #0\n" ASM_PREFETCH("[%[a_ptr], #0x40]") "vmov.i32 q9, #0\n" ASM_PREFETCH("[%[b_ptr], #0x40]") "vmov.i32 q10, #0\n" ASM_PREFETCH("[%[a_ptr], #0x80]") "vmov.i32 q11, #0\n" + ASM_PREFETCH("[%[b_ptr], #0x80]") + "vmov.i32 q12, #0\n" + "vmov.i32 q13, #0\n" ASM_PREFETCH("[%[a_ptr], #0xC0]") "vmov.i32 q14, #0\n" ASM_PREFETCH("[%[b_ptr], #0XC0]") + "vmov.i32 q15, #0\n" + "cmp %[k], #0\n" + "beq 6f\n" + + "1:\n" + // Unroll 0 + "vldr d6, [%[b_ptr], #0x10]\n" + "vmov d2, r0, r1\n" + "vmla.f32 q4, q2, d0[0]\n" + "ldr r0, [%[b_ptr], #0x18]\n" + "vmla.f32 q5, q2, d0[1]\n" + "ldr r1, [%[b_ptr], #0x1C]\n" + "vmla.f32 q6, q2, d1[0]\n" + + "vldr d3, [%[a_ptr], #0x18]\n" + "vmov d7, r0, r1\n" + "vmla.f32 q7, q2, d1[1]\n" ASM_PREFETCH("[%[a_ptr], #0x100]") + "vmla.f32 q8, q2, d2[0]\n" + "vmla.f32 q9, q2, d2[1]\n" + + "vldr d4, [%[b_ptr], #0x20]\n" + "vmla.f32 q10, q3, d0[0]\n" + "ldr r0, [%[b_ptr], #0x28]\n" + "vmla.f32 q11, q3, d0[1]\n" + "ldr r1, [%[b_ptr], #0x2C]\n" + "vmla.f32 q12, q3, d1[0]\n" + + "vldr d0, [%[a_ptr], #0x20]\n" + "vmov d5, r0, r1\n" + "vmla.f32 q13, q3, d1[1]\n" + "ldr r0, [%[a_ptr], #0x28]\n" + "vmla.f32 q14, q3, d2[0]\n" + "ldr r1, [%[a_ptr], #0x2C]\n" + "vmla.f32 q15, q3, d2[1]\n" + + // Unroll 1 + "vldr d6, [%[b_ptr], #0x30]\n" + "vmov d1, r0, r1\n" + "vmla.f32 q4, q2, d3[0]\n" + "ldr r0, [%[b_ptr], #0x38]\n" + "vmla.f32 q5, q2, d3[1]\n" + "ldr r1, [%[b_ptr], #0x3C]\n" + "vmla.f32 q6, q2, d0[0]\n" + + "vldr d2, [%[a_ptr], #0x30]\n" + "vmov d7, r0, r1\n" + "vmla.f32 q7, q2, d0[1]\n" ASM_PREFETCH("[%[b_ptr], #0x100]") + "vmla.f32 q8, q2, d1[0]\n" + "vmla.f32 q9, q2, d1[1]\n" + + "vldr d4, [%[b_ptr], #0x40]\n" + "vmla.f32 q10, q3, d3[0]\n" + "ldr r0, [%[b_ptr], #0x48]\n" + "vmla.f32 q11, q3, d3[1]\n" + "ldr r1, [%[b_ptr], #0x4C]\n" + "vmla.f32 q12, q3, d0[0]\n" + + "vldr d3, [%[a_ptr], #0x38]\n" + "vmov d5, r0, r1\n" + "vmla.f32 q13, q3, d0[1]\n" + "ldr r0, [%[a_ptr], #0x40]\n" + "vmla.f32 q14, q3, d1[0]\n" + "ldr r1, [%[a_ptr], #0x44]\n" + "vmla.f32 q15, q3, d1[1]\n" + + // Unroll 2 + "vldr d6, [%[b_ptr], #0x50]\n" + "vmov d0, r0, r1\n" + "vmla.f32 q4, q2, d2[0]\n" + "ldr r0, [%[b_ptr], #0x58]\n" + "vmla.f32 q5, q2, d2[1]\n" + "ldr r1, [%[b_ptr], #0x5C]\n" + "vmla.f32 q6, q2, d3[0]\n" + + "vldr d1, [%[a_ptr], #0x48]\n" + "vmov d7, r0, r1\n" + "vmla.f32 q7, q2, d3[1]\n" ASM_PREFETCH("[%[a_ptr], #0x140]") + "vmla.f32 q8, q2, d0[0]\n" + "vmla.f32 q9, q2, d0[1]\n" + + "vldr d4, [%[b_ptr], #0x60]\n" + "vmla.f32 q10, q3, d2[0]\n" + "ldr r0, [%[b_ptr], #0x68]\n" + "vmla.f32 q11, q3, d2[1]\n" + "ldr r1, [%[b_ptr], #0x6C]\n" + "vmla.f32 q12, q3, d3[0]\n" + + "vldr d2, [%[a_ptr], #0x50]\n" + "vmov d5, r0, r1\n" + "vmla.f32 q13, q3, d3[1]\n" + "ldr r0, [%[a_ptr], #0x58]\n" + "vmla.f32 q14, q3, d0[0]\n" + "ldr r1, [%[a_ptr], #0x5C]\n" + "vmla.f32 q15, q3, d0[1]\n" + "add %[a_ptr], %[a_ptr], #0x60\n" + + // Unroll 3 + "vldr d6, [%[b_ptr], #0x70]\n" + "vmov d3, r0, r1\n" + "vmla.f32 q4, q2, d1[0]\n" + "ldr r0, [%[b_ptr], #0x78]\n" + "vmla.f32 q5, q2, d1[1]\n" + "ldr r1, [%[b_ptr], #0x7C]\n" + "vmla.f32 q6, q2, d2[0]\n" + "add %[b_ptr], %[b_ptr], #0x80\n" + + "vldr d0, [%[a_ptr], #0x00]\n" + "vmov d7, r0, r1\n" + "vmla.f32 q7, q2, d2[1]\n" ASM_PREFETCH("[%[b_ptr], #0xC0]") + "vmla.f32 q8, q2, d3[0]\n" + "vmla.f32 q9, q2, d3[1]\n" + + "vldr d4, [%[b_ptr], #0x00]\n" + "vmla.f32 q10, q3, d1[0]\n" + "ldr r0, [%[b_ptr], #0x08]\n" + "vmla.f32 q11, q3, d1[1]\n" + "ldr r1, [%[b_ptr], #0x0C]\n" + "vmla.f32 q12, q3, d2[0]\n" + "subs %[k], %[k], #1\n" + + "vldr d1, [%[a_ptr], #0x08]\n" + "vmov d5, r0, r1\n" + "vmla.f32 q13, q3, d2[1]\n" + "ldr r0, [%[a_ptr], #0x10]\n" + "vmla.f32 q14, q3, d3[0]\n" + "ldr r1, [%[a_ptr], #0x14]\n" + "vmla.f32 q15, q3, d3[1]\n" + "bne 1b\n" + + // "Tails" shows how many multiply blocks are needed at the + // end, must be 1-4 inclusive. Bail out to alternative tail + // immediately if it's 1. + "6:\n" + "subs %[tails], %[tails], #1\n" + "beq 3f\n" + + // Detached final iteration - for now adapt the generic + // tails rather than reimplementing for A53. + + // Unroll 0 + "vmov d2, r0, r1\n" + "add %[a_ptr], %[a_ptr], #0x18\n" + "vmla.f32 q4, q2, d0[0]\n" + "vld1.32 {d3}, [%[a_ptr] :64]!\n" + "vmla.f32 q5, q2, d0[1]\n" + "add %[b_ptr], %[b_ptr], #0x10\n" + "vmla.f32 q6, q2, d1[0]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + "vmla.f32 q7, q2, d1[1]\n" + "vmla.f32 q8, q2, d2[0]\n" + "subs %[tails], %[tails], #1\n" + "vmla.f32 q9, q2, d2[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + + "vmla.f32 q10, q3, d0[0]\n" + "vmla.f32 q11, q3, d0[1]\n" + "vmla.f32 q12, q3, d1[0]\n" + "vmla.f32 q13, q3, d1[1]\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n" + "vmla.f32 q14, q3, d2[0]\n" + "vmla.f32 q15, q3, d2[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + "beq 4f\n" + + // Unroll 1 + "vmla.f32 q4, q2, d3[0]\n" + "vmla.f32 q5, q2, d3[1]\n" + "subs %[tails], %[tails], #1\n" + "vmla.f32 q6, q2, d0[0]\n" + "vmla.f32 q7, q2, d0[1]\n" + "vmla.f32 q8, q2, d1[0]\n" + "vmla.f32 q9, q2, d1[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + + "vmla.f32 q10, q3, d3[0]\n" + "vmla.f32 q11, q3, d3[1]\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n" + "vmla.f32 q12, q3, d0[0]\n" + "vmla.f32 q13, q3, d0[1]\n" + "vmla.f32 q14, q3, d1[0]\n" + "vmla.f32 q15, q3, d1[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + "beq 5f\n" + + // Unroll 2 + "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n" + "vmla.f32 q4, q2, d2[0]\n" + "vmla.f32 q5, q2, d2[1]\n" + "vmla.f32 q6, q2, d3[0]\n" + "vmla.f32 q7, q2, d3[1]\n" + "vmla.f32 q8, q2, d0[0]\n" + "vmla.f32 q9, q2, d0[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + + "vmla.f32 q10, q3, d2[0]\n" + "vmla.f32 q11, q3, d2[1]\n" + "vmla.f32 q12, q3, d3[0]\n" + "vmla.f32 q13, q3, d3[1]\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n" + "vmla.f32 q14, q3, d0[0]\n" + "vmla.f32 q15, q3, d0[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + + // Unroll 3 + "vmla.f32 q4, q2, d1[0]\n" + "vmla.f32 q10, q3, d1[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q5, q2, d1[1]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d1[1]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q6, q2, d2[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d2[0]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d2[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d2[1]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d3[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d3[0]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d3[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d3[1]\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + "b 2f\n" + + // tails==1 final tail + "3:\n" + "vmov d2, r0, r1\n" + "add %[b_ptr], %[b_ptr], #0x10\n" + "vmla.f32 q4, q2, d0[0]\n" + "add %[a_ptr], %[a_ptr], #0x18\n" + "vmla.f32 q5, q2, d0[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + "vmla.f32 q6, q2, d1[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q10, q3, d0[0]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d0[1]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d1[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d1[1]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d1[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d2[0]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d2[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d2[1]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d2[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + "b 2f\n" + + // tails==2 final tail + "4:\n" + "vmla.f32 q4, q2, d3[0]\n" + "vmla.f32 q10, q3, d3[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q5, q2, d3[1]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d3[1]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q6, q2, d0[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d0[0]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d0[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d0[1]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d1[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d1[0]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d1[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d1[1]\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + "b 2f\n" + + // tails==3 final tail + "5:\n" + "vmla.f32 q4, q2, d2[0]\n" + "vld1.32 {d0}, [%[a_ptr] :64]!\n" + "vmla.f32 q5, q2, d2[1]\n" + "vmla.f32 q6, q2, d3[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q10, q3, d2[0]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d2[1]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d3[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d3[1]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d3[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d0[0]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d0[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d0[1]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d0[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + + "2:\n" + "vst1.32 {d30-d31}, [%[c_ptr] :128]!\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), [k] "+r"(k), [tails] "+r"(tails) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0", "r1"); + } + } +} + +} // namespace arm_gemm + +#endif diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a55r1.cpp new file mode 100644 index 0000000000..4cfb72a455 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a55r1.cpp @@ -0,0 +1,398 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __arm__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +// Kernel implementation. +// +// Assume that "Apanel" points to a chunk of A blocks (each size 6xK) in read-order. +// Assume that "Bpanel" points to a chunk of B blocks (each size 8xK) in read-order. +// Assume that "Cpanel" points to a chunk of C output blocks (each size +// 8x6), the chunks being arranged in a row major fashion. +// +// Note that the intent of this is that either ablocks or bblocks will be 1 +// - this construction allows the output loop to proceed in either order. + +namespace arm_gemm +{ +void a32_sgemm_8x6_a55r1(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) +{ + const float *a_ptr = Apanel; + float *c_ptr = Cpanel; + + /* Work out starting values for "k" and "tails" in the inner loop. */ + int tails_initial = (K & 3); + if(tails_initial == 0) + { + tails_initial = 4; + } + + int k_initial = ((K + 3) / 4) - 1; + + for(int yb = 0; yb < ablocks; yb++) + { + const float *a_ptr0 = a_ptr; + const float *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + int tails = tails_initial; + int k = k_initial; + + a_ptr = a_ptr0; + + __asm __volatile( + "vldr d0, [%[a_ptr]]\n" + "vmov.i32 q4, #0\n" + "vldr d1, [%[a_ptr], #0x08]\n" + "vmov.i32 q5, #0\n" + "vldr d4, [%[b_ptr]]\n" + "vmov.i32 q6, #0\n" + "vldr d5, [%[b_ptr], #0x08]\n" + "vmov.i32 q7, #0\n" + "vldr d2, [%[a_ptr], #0x10]\n" + "vmov.i32 q8, #0\n" ASM_PREFETCH("[%[b_ptr], #0x40]") "vmov.i32 q9, #0\n" ASM_PREFETCH("[%[a_ptr], #0x40]") "vmov.i32 q10, #0\n" ASM_PREFETCH("[%[b_ptr], #0x80]") "vmov.i32 q11, #0\n" + ASM_PREFETCH("[%[a_ptr], #0x80]") "vmov.i32 q12, #0\n" ASM_PREFETCH("[%[b_ptr], #0XC0]") "vmov.i32 q13, #0\n" ASM_PREFETCH("[%[a_ptr], #0xC0]") "vmov.i32 q14, #0\n" + ASM_PREFETCH("[%[b_ptr], #0x100]") "vmov.i32 q15, #0\n" ASM_PREFETCH("[%[a_ptr], #0x100]") "cmp %[k], #0\n" ASM_PREFETCH("[%[b_ptr], #0x140]") "beq 6f\n" + ASM_PREFETCH("[%[b_ptr], #0x180]") + + "1:\n" + // Unroll 0 + "vmla.f32 q4, q2, d0[0]\n" + "vldr d6, [%[b_ptr], #0x10]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vldr d7, [%[b_ptr], #0x18]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vldr d3, [%[a_ptr], #0x18]\n" + "vmla.f32 q7, q2, d1[1]\n" ASM_PREFETCH("[%[a_ptr], #0x140]") + "vmla.f32 q8, q2, d2[0]\n" + "subs %[k], %[k], #1\n" + "vmla.f32 q9, q2, d2[1]\n" + "vldr d4, [%[b_ptr], #0x20]\n" + "vmla.f32 q10, q3, d0[0]\n" + "vldr d5, [%[b_ptr], #0x28]\n" + "vmla.f32 q11, q3, d0[1]\n" + "vldr d0, [%[a_ptr], #0x20]\n" + "vmla.f32 q12, q3, d1[0]\n" + + "vmla.f32 q13, q3, d1[1]\n" + "vldr d1, [%[a_ptr], #0x28]\n" + "vmla.f32 q14, q3, d2[0]\n" + + "vmla.f32 q15, q3, d2[1]\n" + "vldr d6, [%[b_ptr], #0x30]\n" + + // Unroll 1 + "vmla.f32 q4, q2, d3[0]\n" + "vldr d7, [%[b_ptr], #0x38]\n" + "vmla.f32 q5, q2, d3[1]\n" + "vldr d2, [%[a_ptr], #0x30]\n" + "vmla.f32 q6, q2, d0[0]\n" + + "vmla.f32 q7, q2, d0[1]\n" ASM_PREFETCH("[%[b_ptr], #0x1C0]") + "vmla.f32 q8, q2, d1[0]\n" + + "vmla.f32 q9, q2, d1[1]\n" + "vldr d4, [%[b_ptr], #0x40]\n" + "vmla.f32 q10, q3, d3[0]\n" + "vldr d5, [%[b_ptr], #0x48]\n" + "vmla.f32 q11, q3, d3[1]\n" + "vldr d3, [%[a_ptr], #0x38]\n" + "vmla.f32 q12, q3, d0[0]\n" + + "vmla.f32 q13, q3, d0[1]\n" + "vldr d0, [%[a_ptr], #0x40]\n" + "vmla.f32 q14, q3, d1[0]\n" + + "vmla.f32 q15, q3, d1[1]\n" + "vldr d6, [%[b_ptr], #0x50]\n" + + // Unroll 2 + "vmla.f32 q4, q2, d2[0]\n" + "vldr d7, [%[b_ptr], #0x58]\n" + "vmla.f32 q5, q2, d2[1]\n" + "vldr d1, [%[a_ptr], #0x48]\n" + "vmla.f32 q6, q2, d3[0]\n" + + "vmla.f32 q7, q2, d3[1]\n" ASM_PREFETCH("[%[a_ptr], #0x180]") + "vmla.f32 q8, q2, d0[0]\n" + + "vmla.f32 q9, q2, d0[1]\n" + "vldr d4, [%[b_ptr], #0x60]\n" + "vmla.f32 q10, q3, d2[0]\n" + "vldr d5, [%[b_ptr], #0x68]\n" + "vmla.f32 q11, q3, d2[1]\n" + "vldr d2, [%[a_ptr], #0x50]\n" + "vmla.f32 q12, q3, d3[0]\n" + + "vmla.f32 q13, q3, d3[1]\n" + "vldr d3, [%[a_ptr], #0x58]\n" + "vmla.f32 q14, q3, d0[0]\n" + "add %[a_ptr], %[a_ptr], #0x60\n" + "vmla.f32 q15, q3, d0[1]\n" + "vldr d6, [%[b_ptr], #0x70]\n" + + // Unroll 3 + "vmla.f32 q4, q2, d1[0]\n" + "vldr d7, [%[b_ptr], #0x78]\n" + "vmla.f32 q5, q2, d1[1]\n" + "add %[b_ptr], %[b_ptr], #0x80\n" + "vmla.f32 q6, q2, d2[0]\n" + "vldr d0, [%[a_ptr], #0x00]\n" + "vmla.f32 q7, q2, d2[1]\n" ASM_PREFETCH("[%[b_ptr], #0x180]") + "vmla.f32 q8, q2, d3[0]\n" + + "vmla.f32 q9, q2, d3[1]\n" + "vldr d4, [%[b_ptr], #0x00]\n" + "vmla.f32 q10, q3, d1[0]\n" + "vldr d5, [%[b_ptr], #0x08]\n" + "vmla.f32 q11, q3, d1[1]\n" + "vldr d1, [%[a_ptr], #0x08]\n" + "vmla.f32 q12, q3, d2[0]\n" + + "vmla.f32 q13, q3, d2[1]\n" + "vldr d2, [%[a_ptr], #0x10]\n" + "vmla.f32 q14, q3, d3[0]\n" + + "vmla.f32 q15, q3, d3[1]\n" + "bne 1b\n" + + // "Tails" shows how many multiply blocks are needed at the + // end, must be 1-4 inclusive. Bail out to alternative tail + // immediately if it's 1. + "6:\n" + "subs %[tails], %[tails], #1\n" + "beq 3f\n" + + // Detached final iteration + + // Unroll 0 + "vmla.f32 q4, q2, d0[0]\n" + "vldr d6, [%[b_ptr], #0x10]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vldr d7, [%[b_ptr], #0x18]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vldr d3, [%[a_ptr], #0x18]\n" + "vmla.f32 q7, q2, d1[1]\n" + "subs %[tails], %[tails], #1\n" + "vmla.f32 q8, q2, d2[0]\n" + "vmla.f32 q9, q2, d2[1]\n" + "vldr d4, [%[b_ptr], #0x20]\n" + + "vmla.f32 q10, q3, d0[0]\n" + "vldr d5, [%[b_ptr], #0x28]\n" + "vmla.f32 q11, q3, d0[1]\n" + "vldr d0, [%[a_ptr], #0x20]\n" + "vmla.f32 q12, q3, d1[0]\n" + "add %[b_ptr], %[b_ptr], #0x30\n" + "vmla.f32 q13, q3, d1[1]\n" + "vldr d1, [%[a_ptr], #0x28]\n" + "vmla.f32 q14, q3, d2[0]\n" + "vmla.f32 q15, q3, d2[1]\n" + "beq 4f\n" + + // Unroll 1 + "vmla.f32 q4, q2, d3[0]\n" + "vldr d6, [%[b_ptr], #0x30]\n" + "vmla.f32 q5, q2, d3[1]\n" + "vldr d7, [%[b_ptr], #0x38]\n" + "vmla.f32 q6, q2, d0[0]\n" + "vldr d2, [%[a_ptr], #0x30]\n" + "vmla.f32 q7, q2, d0[1]\n" + "subs %[tails], %[tails], #1\n" + "vmla.f32 q8, q2, d1[0]\n" + + "vmla.f32 q9, q2, d1[1]\n" + + "vmla.f32 q10, q3, d3[0]\n" + "vldr d4, [%[b_ptr], #0x40]\n" + "vmla.f32 q11, q3, d3[1]\n" + "vldr d5, [%[b_ptr], #0x48]\n" + "vmla.f32 q12, q3, d0[0]\n" + "vldr d3, [%[a_ptr], #0x38]\n" + "vmla.f32 q13, q3, d0[1]\n" + "vldr d0, [%[a_ptr], #0x40]\n" + "vmla.f32 q14, q3, d1[0]\n" + "vmla.f32 q15, q3, d1[1]\n" + "beq 5f\n" + + // Unroll 2 + "vmla.f32 q4, q2, d2[0]\n" + "vldr d6, [%[b_ptr], #0x50]\n" + "vmla.f32 q5, q2, d2[1]\n" + "vldr d7, [%[b_ptr], #0x58]\n" + "vmla.f32 q6, q2, d3[0]\n" + "vldr d1, [%[a_ptr], #0x48]\n" + "vmla.f32 q7, q2, d3[1]\n" + "vmla.f32 q8, q2, d0[0]\n" + "vmla.f32 q9, q2, d0[1]\n" + + "vmla.f32 q10, q3, d2[0]\n" + "vldr d4, [%[b_ptr], #0x60]\n" + "vmla.f32 q11, q3, d2[1]\n" + "vldr d5, [%[b_ptr], #0x68]\n" + "vmla.f32 q12, q3, d3[0]\n" + "vldr d2, [%[a_ptr], #0x50]\n" + "vmla.f32 q13, q3, d3[1]\n" + "vldr d3, [%[a_ptr], #0x58]\n" + "vmla.f32 q14, q3, d0[0]\n" + "vmla.f32 q15, q3, d0[1]\n" + + // Unroll 3 + "vmla.f32 q4, q2, d1[0]\n" + "vldr d6, [%[b_ptr], #0x70]\n" + "vmla.f32 q5, q2, d1[1]\n" + "vldr d7, [%[b_ptr], #0x78]\n" + "vmla.f32 q10, q3, d1[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d1[1]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q6, q2, d2[0]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d2[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d2[1]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d2[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d3[0]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d3[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d3[1]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d3[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "add %[a_ptr], %[a_ptr], #0x60\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + "add %[b_ptr], %[b_ptr], #0x80\n" + "b 2f\n" + + // tails==1 final tail + "3:\n" + "vmla.f32 q4, q2, d0[0]\n" + "vldr d6, [%[b_ptr], #0x10]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vldr d7, [%[b_ptr], #0x18]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q10, q3, d0[0]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d0[1]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d1[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d1[1]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d1[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d2[0]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d2[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d2[1]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d2[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "add %[a_ptr], %[a_ptr], #0x18\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + "add %[b_ptr], %[b_ptr], #0x20\n" + "b 2f\n" + + // tails==2 final tail + "4:\n" + "vmla.f32 q4, q2, d3[0]\n" + "vldr d6, [%[b_ptr], #0x30]\n" + "vmla.f32 q5, q2, d3[1]\n" + "vldr d7, [%[b_ptr], #0x38]\n" + "vmla.f32 q10, q3, d3[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d3[1]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q6, q2, d0[0]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d0[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d0[1]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d0[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d1[0]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d1[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d1[1]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d1[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "add %[b_ptr], %[b_ptr], #0x40\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + "add %[a_ptr], %[a_ptr], #0x30\n" + "b 2f\n" + + // tails==3 final tail + "5:\n" + "vmla.f32 q4, q2, d2[0]\n" + "vldr d6, [%[b_ptr], #0x50]\n" + "vmla.f32 q5, q2, d2[1]\n" + "vldr d7, [%[b_ptr], #0x58]\n" + "vmla.f32 q6, q2, d3[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q10, q3, d2[0]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d2[1]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d3[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d3[1]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d3[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d0[0]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d0[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d0[1]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d0[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "add %[a_ptr], %[a_ptr], #0x48\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + "add %[b_ptr], %[b_ptr], #0x60\n" + + "2:\n" + "vst1.32 {d30-d31}, [%[c_ptr] :128]!\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), [k] "+r"(k), [tails] "+r"(tails) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0", "r1"); + } + } +} + +} // namespace arm_gemm + +#endif diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/generic.cpp new file mode 100644 index 0000000000..d7d0484610 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/generic.cpp @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __arm__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +// Kernel implementation. +// +// Assume that "Apanel" points to a chunk of A blocks (each size 6xK) in read-order. +// Assume that "Bpanel" points to a chunk of B blocks (each size 8xK) in read-order. +// Assume that "Cpanel" points to a chunk of C output blocks (each size +// 8x6), the chunks being arranged in a row major fashion. +// +// Note that the intent of this is that either ablocks or bblocks will be 1 +// - this construction allows the output loop to proceed in either order. + +namespace arm_gemm +{ +void a32_sgemm_8x6(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) +{ + const float *a_ptr = Apanel; + float *c_ptr = Cpanel; + + for(int yb = 0; yb < ablocks; yb++) + { + const float *a_ptr0 = a_ptr; + const float *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + int tails = (K & 3); + if(tails == 0) + { + tails = 4; + } + int k = ((K + 3) / 4) - 1; + + __asm __volatile( + "vmov.i32 q4, #0\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n" + "vmov.i32 q5, #0\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + "vmov.i32 q6, #0\n" ASM_PREFETCH("[%[a_ptr], #48]") "vmov.i32 q7, #0\n" ASM_PREFETCH("[%[b_ptr], #48]") "vmov.i32 q8, #0\n" ASM_PREFETCH("[%[a_ptr], #112]") "vmov.i32 q9, #0\n" + ASM_PREFETCH("[%[b_ptr], #112]") + "vmov.i32 q10, #0\n" + "vmov.i32 q11, #0\n" + "vmov.i32 q12, #0\n" + "vmov.i32 q13, #0\n" ASM_PREFETCH("[%[a_ptr], #176]") "vmov.i32 q14, #0\n" ASM_PREFETCH("[%[b_ptr], #176]") + "vmov.i32 q15, #0\n" + + "cmp %[k], #0\n" + "beq 6f\n" + + "1:\n" + // Unroll 0 + "vmla.f32 q4, q2, d0[0]\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n" + "vmla.f32 q5, q2, d0[1]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + "vmla.f32 q7, q2, d1[1]\n" + "vmla.f32 q8, q2, d2[0]\n" + "vmla.f32 q9, q2, d2[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + + "vmla.f32 q10, q3, d0[0]\n" + "vmla.f32 q11, q3, d0[1]\n" + "vmla.f32 q12, q3, d1[0]\n" + "vmla.f32 q13, q3, d1[1]\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n" + "vmla.f32 q14, q3, d2[0]\n" + "vmla.f32 q15, q3, d2[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + + // Unroll 1 + "vmla.f32 q4, q2, d3[0]\n" + "subs %[k], %[k], #1\n" + "vmla.f32 q5, q2, d3[1]\n" ASM_PREFETCH("[%[a_ptr], #208]") + "vmla.f32 q6, q2, d0[0]\n" + "vmla.f32 q7, q2, d0[1]\n" ASM_PREFETCH("[%[b_ptr], #192]") + "vmla.f32 q8, q2, d1[0]\n" + "vmla.f32 q9, q2, d1[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + + "vmla.f32 q10, q3, d3[0]\n" + "vmla.f32 q11, q3, d3[1]\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n" + "vmla.f32 q12, q3, d0[0]\n" + "vmla.f32 q13, q3, d0[1]\n" + "vmla.f32 q14, q3, d1[0]\n" + "vmla.f32 q15, q3, d1[1]\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n" + + // Unroll 2 + "vmla.f32 q4, q2, d2[0]\n" + "vmla.f32 q5, q2, d2[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + "vmla.f32 q6, q2, d3[0]\n" + "vmla.f32 q7, q2, d3[1]\n" ASM_PREFETCH("[%[a_ptr], #240]") + "vmla.f32 q8, q2, d0[0]\n" + "vmla.f32 q9, q2, d0[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + + "vmla.f32 q10, q3, d2[0]\n" + "vmla.f32 q11, q3, d2[1]\n" ASM_PREFETCH("[%[b_ptr], #208]") + "vmla.f32 q12, q3, d3[0]\n" + "vmla.f32 q13, q3, d3[1]\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n" + "vmla.f32 q14, q3, d0[0]\n" + "vmla.f32 q15, q3, d0[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + + // Unroll 3 + "vmla.f32 q4, q2, d1[0]\n" + "vmla.f32 q5, q2, d1[1]\n" + "vmla.f32 q6, q2, d2[0]\n" + "vmla.f32 q7, q2, d2[1]\n" + "vmla.f32 q8, q2, d3[0]\n" + "vmla.f32 q9, q2, d3[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + + "vmla.f32 q10, q3, d1[0]\n" + "vmla.f32 q11, q3, d1[1]\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n" + "vmla.f32 q12, q3, d2[0]\n" + "vmla.f32 q13, q3, d2[1]\n" + "vmla.f32 q14, q3, d3[0]\n" + "vmla.f32 q15, q3, d3[1]\n" + "bne 1b\n" + + // Branch here if we never execute main loop. + "6:\n" + + // "Tails" shows how many multiply blocks are needed at the + // end, must be 1-4 inclusive. Bail out to alternative tail + // immediately if it's 1. + "subs %[tails], %[tails], #1\n" + "beq 3f\n" + + // Detached final iteration + // Unroll 0 + "vmla.f32 q4, q2, d0[0]\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n" + "vmla.f32 q5, q2, d0[1]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + "vmla.f32 q7, q2, d1[1]\n" + "vmla.f32 q8, q2, d2[0]\n" + "subs %[tails], %[tails], #1\n" + "vmla.f32 q9, q2, d2[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + + "vmla.f32 q10, q3, d0[0]\n" + "vmla.f32 q11, q3, d0[1]\n" + "vmla.f32 q12, q3, d1[0]\n" + "vmla.f32 q13, q3, d1[1]\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n" + "vmla.f32 q14, q3, d2[0]\n" + "vmla.f32 q15, q3, d2[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + "beq 4f\n" + + // Unroll 1 + "vmla.f32 q4, q2, d3[0]\n" + "vmla.f32 q5, q2, d3[1]\n" + "subs %[tails], %[tails], #1\n" + "vmla.f32 q6, q2, d0[0]\n" + "vmla.f32 q7, q2, d0[1]\n" + "vmla.f32 q8, q2, d1[0]\n" + "vmla.f32 q9, q2, d1[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + + "vmla.f32 q10, q3, d3[0]\n" + "vmla.f32 q11, q3, d3[1]\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n" + "vmla.f32 q12, q3, d0[0]\n" + "vmla.f32 q13, q3, d0[1]\n" + "vmla.f32 q14, q3, d1[0]\n" + "vmla.f32 q15, q3, d1[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + "beq 5f\n" + + // Unroll 2 + "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n" + "vmla.f32 q4, q2, d2[0]\n" + "vmla.f32 q5, q2, d2[1]\n" + "vmla.f32 q6, q2, d3[0]\n" + "vmla.f32 q7, q2, d3[1]\n" + "vmla.f32 q8, q2, d0[0]\n" + "vmla.f32 q9, q2, d0[1]\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n" + + "vmla.f32 q10, q3, d2[0]\n" + "vmla.f32 q11, q3, d2[1]\n" + "vmla.f32 q12, q3, d3[0]\n" + "vmla.f32 q13, q3, d3[1]\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n" + "vmla.f32 q14, q3, d0[0]\n" + "vmla.f32 q15, q3, d0[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + + // Unroll 3 + "vmla.f32 q4, q2, d1[0]\n" + "vmla.f32 q10, q3, d1[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q5, q2, d1[1]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d1[1]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q6, q2, d2[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d2[0]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d2[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d2[1]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d3[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d3[0]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d3[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d3[1]\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + "b 2f\n" + + // tails==1 final tail + "3:\n" + "vmla.f32 q4, q2, d0[0]\n" + "vld1.32 {d2}, [%[a_ptr] :64]!\n" + "vmla.f32 q5, q2, d0[1]\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n" + "vmla.f32 q6, q2, d1[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q10, q3, d0[0]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d0[1]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d1[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d1[1]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d1[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d2[0]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d2[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d2[1]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d2[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + "b 2f\n" + + // tails==2 final tail + "4:\n" + "vmla.f32 q4, q2, d3[0]\n" + "vmla.f32 q10, q3, d3[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q5, q2, d3[1]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d3[1]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q6, q2, d0[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d0[0]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d0[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d0[1]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d1[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d1[0]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d1[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d1[1]\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + "b 2f\n" + + // tails==3 final tail + "5:\n" + "vmla.f32 q4, q2, d2[0]\n" + "vld1.32 {d0}, [%[a_ptr] :64]!\n" + "vmla.f32 q5, q2, d2[1]\n" + "vmla.f32 q6, q2, d3[0]\n" + "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n" + "vmla.f32 q10, q3, d2[0]\n" + "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n" + "vmla.f32 q11, q3, d2[1]\n" + "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n" + "vmla.f32 q12, q3, d3[0]\n" + "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n" + "vmla.f32 q7, q2, d3[1]\n" + "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n" + "vmla.f32 q13, q3, d3[1]\n" + "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n" + "vmla.f32 q8, q2, d0[0]\n" + "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n" + "vmla.f32 q14, q3, d0[0]\n" + "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n" + "vmla.f32 q9, q2, d0[1]\n" + "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n" + "vmla.f32 q15, q3, d0[1]\n" + "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n" + "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n" + + "2:\n" + "vst1.32 {d30-d31}, [%[c_ptr] :128]!\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), [k] "+r"(k), [tails] "+r"(tails) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp new file mode 100644 index 0000000000..387f899b20 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_gemm_s16_asimd_12x8(const int16_t *, const int16_t *, int32_t *, int, int, int); + +// 12x8 SGEMM "strategy" class. +// +// This describes the characteristics of a family of kernels, in terms of +// the required interleave properties and the output block size. +// +// All kernels in the family must share these characteristics. The actual +// kernel to be used can be chosen at runtime, based on the CPU_type +// structure. +class gemm_s16_12x8 +{ +public: + typedef int16_t operand_type; + typedef int32_t result_type; + + typedef void (*kern_type)(const int16_t *, const int16_t *, int32_t *, int, int, int); + + /* Describes the data layout for A input */ + static const int A_interleave = 8; + static const int A_block = 1; + static const int A_transpose = 0; + + /* Same for B input */ + static const int B_interleave = 12; + static const int B_block = 1; + static const int B_transpose = 1; + + /* Kernel blocking parameters */ + static const int out_width = 12; + static const int out_height = 8; + static const int k_unroll = 1; + + kern_type kernel = a64_gemm_s16_asimd_12x8; + + gemm_s16_12x8(const CPUInfo *ci) + { + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8/generic.cpp new file mode 100644 index 0000000000..b217dcf2cf --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8/generic.cpp @@ -0,0 +1,309 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +namespace arm_gemm +{ +void a64_gemm_s16_asimd_12x8(const int16_t *Apanel, const int16_t *Bpanel, int32_t *Cpanel, int ablocks, int bblocks, int K) +{ + const int16_t *a_ptr = Apanel; + int32_t *c_ptr = Cpanel; + + for(int yb = 0; yb < ablocks; yb++) + { + const int16_t *a_ptr0 = a_ptr; + const int16_t *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + const bool odd_k = K & 0x1; + int k = (K + 1) / 2 - 1; + + register int16x8_t aa asm("v0"); + register int16x8_t ab asm("v1"); + register int16x8_t b0 asm("v2"); + register int16x8_t b1 asm("v3"); + register int16x8_t b2 asm("v4"); + + __asm __volatile( + "ldr %d[aa], [%x[a_ptr]]\n" // Load A[A].lower + "movi v5.4s, #0\n" + "ldr x20, [%x[a_ptr], #0x08]\n" // Load A[A].upper + "movi v6.4s, #0\n" + "ldr %d[b0], [%x[b_ptr]]\n" // Load B[0].lower + "ins %[aa].d[1], x20\n" // Merge A[A].lower and upper + "movi v7.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #64]") + "movi v8.4s, #0\n" + "ldr x20, [%x[b_ptr], #0x08]\n" // Load B[0].upper + "movi v9.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #64]") + "movi v10.4s, #0\n" + "ldr %d[b1], [%x[b_ptr], #0x10]\n" // Load B[1].lower + "ins %[b0].d[1], x20\n" // Merge B[0].lower and upper + "movi v11.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #96]") + "movi v12.4s, #0\n" + "movi v13.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #96]") + "movi v14.4s, #0\n" + "movi v15.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #128]") + "movi v16.4s, #0\n" + "movi v17.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #128]") + "movi v18.4s, #0\n" + "movi v19.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #160]") + "movi v20.4s, #0\n" + "movi v21.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #160]") + "movi v22.4s, #0\n" + "movi v23.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #192]") + "movi v24.4s, #0\n" + "add %x[a_ptr], %x[a_ptr], #0x10\n" + "movi v25.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #192]") + "movi v26.4s, #0\n" + "add %x[b_ptr], %x[b_ptr], #0x18\n" + "movi v27.4s, #0\n" + "movi v28.4s, #0\n" + + "cbz %x[k], 2f\n" // Skip the loop if doing zero iterations. + + "1:\n" // Main loop + // First unroll + "smlal v5.4s, %[b0].4h, %[aa].h[0]\n" + "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper + "smlal v6.4s, %[b0].4h, %[aa].h[1]\n" + "smlal v7.4s, %[b0].4h, %[aa].h[2]\n" + "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower + "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper + "smlal v8.4s, %[b0].4h, %[aa].h[3]\n" + "smlal v9.4s, %[b0].4h, %[aa].h[4]\n" + "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper + "smlal v10.4s, %[b0].4h, %[aa].h[5]\n" + "smlal v11.4s, %[b0].4h, %[aa].h[6]\n" + "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower + "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper + "smlal v12.4s, %[b0].4h, %[aa].h[7]\n" + "smlal2 v13.4s, %[b0].8h, %[aa].h[0]\n" + "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper + "smlal2 v14.4s, %[b0].8h, %[aa].h[1]\n" + "smlal2 v15.4s, %[b0].8h, %[aa].h[2]\n" + "smlal2 v16.4s, %[b0].8h, %[aa].h[3]\n" + "smlal2 v17.4s, %[b0].8h, %[aa].h[4]\n" + "smlal2 v18.4s, %[b0].8h, %[aa].h[5]\n" + "smlal2 v19.4s, %[b0].8h, %[aa].h[6]\n" + "smlal2 v20.4s, %[b0].8h, %[aa].h[7]\n" + "ldr %d[b0], [%x[b_ptr], #0x18]\n" // Load B[0].lower + "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper + "smlal v21.4s, %[b1].4h, %[aa].h[0]\n" + "smlal v22.4s, %[b1].4h, %[aa].h[1]\n" + "ldr x20, [%x[b_ptr], #0x20]\n" // Load B[0].upper + "smlal v23.4s, %[b1].4h, %[aa].h[2]\n" + "smlal v24.4s, %[b1].4h, %[aa].h[3]\n" + "smlal v25.4s, %[b1].4h, %[aa].h[4]\n" + "smlal v26.4s, %[b1].4h, %[aa].h[5]\n" + "smlal v27.4s, %[b1].4h, %[aa].h[6]\n" + "smlal v28.4s, %[b1].4h, %[aa].h[7]\n" + + // Second unroll + "smlal2 v5.4s, %[b1].8h, %[ab].h[0]\n" + "ldr %d[aa], [%x[a_ptr], #0x10]\n" // Load A[A].lower + "ins %[b0].d[1], x20\n" // Merge B[0].lower and .upper + "smlal2 v6.4s, %[b1].8h, %[ab].h[1]\n" + "smlal2 v7.4s, %[b1].8h, %[ab].h[2]\n" + "ldr x20, [%x[a_ptr], #0x18]\n" // Load A[A].upper + "smlal2 v8.4s, %[b1].8h, %[ab].h[3]\n" + "smlal2 v9.4s, %[b1].8h, %[ab].h[4]\n" + "smlal2 v10.4s, %[b1].8h, %[ab].h[5]\n" + "smlal2 v11.4s, %[b1].8h, %[ab].h[6]\n" + "add %x[a_ptr], %x[a_ptr], #0x20\n" + "smlal2 v12.4s, %[b1].8h, %[ab].h[7]\n" + "smlal v13.4s, %[b2].4h, %[ab].h[0]\n" ASM_PREFETCH("[%[b_ptr], #320]") + "smlal v14.4s, %[b2].4h, %[ab].h[1]\n" + "smlal v15.4s, %[b2].4h, %[ab].h[2]\n" ASM_PREFETCH("[%[a_ptr], #320]") + "smlal v16.4s, %[b2].4h, %[ab].h[3]\n" + "smlal v17.4s, %[b2].4h, %[ab].h[4]\n" ASM_PREFETCH("[%[b_ptr], #448]") + "smlal v18.4s, %[b2].4h, %[ab].h[5]\n" + "smlal v19.4s, %[b2].4h, %[ab].h[6]\n" + "smlal v20.4s, %[b2].4h, %[ab].h[7]\n" + "smlal2 v21.4s, %[b2].8h, %[ab].h[0]\n" + "smlal2 v22.4s, %[b2].8h, %[ab].h[1]\n" + "subs %x[k], %x[k], #0x1\n" + "smlal2 v23.4s, %[b2].8h, %[ab].h[2]\n" + "smlal2 v24.4s, %[b2].8h, %[ab].h[3]\n" + "ldr %d[b1], [%x[b_ptr], #0x28]\n" // Load B[1].lower + "ins %[aa].d[1], x20\n" // Merge A[A].lower and .upper + "smlal2 v25.4s, %[b2].8h, %[ab].h[4]\n" + "smlal2 v26.4s, %[b2].8h, %[ab].h[5]\n" + "add %x[b_ptr], %x[b_ptr], #0x30\n" + "smlal2 v27.4s, %[b2].8h, %[ab].h[6]\n" + "smlal2 v28.4s, %[b2].8h, %[ab].h[7]\n" + "bne 1b\n" + + "2:\n" // Even tail + "cbnz %x[odd_k], 3f\n" + + "smlal v5.4s, %[b0].4h, %[aa].h[0]\n" + "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper + "smlal v6.4s, %[b0].4h, %[aa].h[1]\n" + "smlal v7.4s, %[b0].4h, %[aa].h[2]\n" + "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower + "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper + "smlal v8.4s, %[b0].4h, %[aa].h[3]\n" + "smlal v9.4s, %[b0].4h, %[aa].h[4]\n" + "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper + "smlal v10.4s, %[b0].4h, %[aa].h[5]\n" + "smlal v11.4s, %[b0].4h, %[aa].h[6]\n" + "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower + "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper + "smlal v12.4s, %[b0].4h, %[aa].h[7]\n" + "smlal2 v13.4s, %[b0].8h, %[aa].h[0]\n" + "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper + "smlal2 v14.4s, %[b0].8h, %[aa].h[1]\n" + "smlal2 v15.4s, %[b0].8h, %[aa].h[2]\n" + "smlal2 v16.4s, %[b0].8h, %[aa].h[3]\n" + "add %[a_ptr], %[a_ptr], #0x10\n" + "smlal2 v17.4s, %[b0].8h, %[aa].h[4]\n" + "add %[b_ptr], %[b_ptr], #0x18\n" + "smlal2 v18.4s, %[b0].8h, %[aa].h[5]\n" + "smlal2 v19.4s, %[b0].8h, %[aa].h[6]\n" + "smlal2 v20.4s, %[b0].8h, %[aa].h[7]\n" + "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper + "smlal v21.4s, %[b1].4h, %[aa].h[0]\n" + "smlal v22.4s, %[b1].4h, %[aa].h[1]\n" + "smlal v23.4s, %[b1].4h, %[aa].h[2]\n" + "smlal v24.4s, %[b1].4h, %[aa].h[3]\n" + "smlal v25.4s, %[b1].4h, %[aa].h[4]\n" + "smlal v26.4s, %[b1].4h, %[aa].h[5]\n" + "smlal v27.4s, %[b1].4h, %[aa].h[6]\n" + "smlal v28.4s, %[b1].4h, %[aa].h[7]\n" + + "smlal2 v5.4s, %[b1].8h, %[ab].h[0]\n" + "smlal v13.4s, %[b2].4h, %[ab].h[0]\n" + "smlal2 v21.4s, %[b2].8h, %[ab].h[0]\n" + "smlal2 v6.4s, %[b1].8h, %[ab].h[1]\n" + "smlal v14.4s, %[b2].4h, %[ab].h[1]\n" + "str q5, [%x[c_ptr]]\n" + "smlal2 v22.4s, %[b2].8h, %[ab].h[1]\n" + "str q13, [%x[c_ptr], #0x10]\n" + "smlal2 v7.4s, %[b1].8h, %[ab].h[2]\n" + "str q21, [%x[c_ptr], #0x20]\n" + "smlal v15.4s, %[b2].4h, %[ab].h[2]\n" + "str q6, [%x[c_ptr], #0x30]\n" + "smlal2 v23.4s, %[b2].8h, %[ab].h[2]\n" + "str q14, [%x[c_ptr], #0x40]\n" + "smlal2 v8.4s, %[b1].8h, %[ab].h[3]\n" + "str q22, [%x[c_ptr], #0x50]\n" + "smlal v16.4s, %[b2].4h, %[ab].h[3]\n" + "str q7, [%x[c_ptr], #0x60]\n" + "smlal2 v24.4s, %[b2].8h, %[ab].h[3]\n" + "str q15, [%x[c_ptr], #0x70]\n" + "smlal2 v9.4s, %[b1].8h, %[ab].h[4]\n" + "str q23, [%x[c_ptr], #0x80]\n" + "smlal v17.4s, %[b2].4h, %[ab].h[4]\n" + "str q8, [%x[c_ptr], #0x90]\n" + "smlal2 v25.4s, %[b2].8h, %[ab].h[4]\n" + "str q16, [%x[c_ptr], #0xa0]\n" + "smlal2 v10.4s, %[b1].8h, %[ab].h[5]\n" + "str q24, [%x[c_ptr], #0xb0]\n" + "smlal v18.4s, %[b2].4h, %[ab].h[5]\n" + "str q9, [%x[c_ptr], #0xc0]\n" + "smlal2 v26.4s, %[b2].8h, %[ab].h[5]\n" + "str q17, [%x[c_ptr], #0xd0]\n" + "smlal2 v11.4s, %[b1].8h, %[ab].h[6]\n" + "str q25, [%x[c_ptr], #0xe0]\n" + "smlal v19.4s, %[b2].4h, %[ab].h[6]\n" + "str q10, [%x[c_ptr], #0xf0]\n" + "smlal2 v27.4s, %[b2].8h, %[ab].h[6]\n" + "str q18, [%x[c_ptr], #0x100]\n" + "smlal2 v12.4s, %[b1].8h, %[ab].h[7]\n" + "str q26, [%x[c_ptr], #0x110]\n" + "smlal v20.4s, %[b2].4h, %[ab].h[7]\n" + "str q11, [%x[c_ptr], #0x120]\n" + "smlal2 v28.4s, %[b2].8h, %[ab].h[7]\n" + "str q19, [%x[c_ptr], #0x130]\n" + "b 4f\n" // Complete write out + + "3:\n" // Odd tail + "smlal v5.4s, %[b0].4h, %[aa].h[0]\n" + "smlal2 v13.4s, %[b0].8h, %[aa].h[0]\n" + "smlal v21.4s, %[b1].4h, %[aa].h[0]\n" + "smlal v6.4s, %[b0].4h, %[aa].h[1]\n" + "smlal2 v14.4s, %[b0].8h, %[aa].h[1]\n" + "smlal v22.4s, %[b1].4h, %[aa].h[1]\n" + "str q5, [%x[c_ptr]]\n" + "smlal v7.4s, %[b0].4h, %[aa].h[2]\n" + "str q13, [%x[c_ptr], #0x10]\n" + "smlal2 v15.4s, %[b0].8h, %[aa].h[2]\n" + "str q21, [%x[c_ptr], #0x20]\n" + "smlal v23.4s, %[b1].4h, %[aa].h[2]\n" + "str q6, [%x[c_ptr], #0x30]\n" + "smlal v8.4s, %[b0].4h, %[aa].h[3]\n" + "str q14, [%x[c_ptr], #0x40]\n" + "smlal2 v16.4s, %[b0].8h, %[aa].h[3]\n" + "str q22, [%x[c_ptr], #0x50]\n" + "smlal v24.4s, %[b1].4h, %[aa].h[3]\n" + "str q7, [%x[c_ptr], #0x60]\n" + "smlal v9.4s, %[b0].4h, %[aa].h[4]\n" + "str q15, [%x[c_ptr], #0x70]\n" + "smlal2 v17.4s, %[b0].8h, %[aa].h[4]\n" + "str q23, [%x[c_ptr], #0x80]\n" + "smlal v25.4s, %[b1].4h, %[aa].h[4]\n" + "str q8, [%x[c_ptr], #0x90]\n" + "smlal v10.4s, %[b0].4h, %[aa].h[5]\n" + "str q16, [%x[c_ptr], #0xa0]\n" + "smlal2 v18.4s, %[b0].8h, %[aa].h[5]\n" + "str q24, [%x[c_ptr], #0xb0]\n" + "smlal v26.4s, %[b1].4h, %[aa].h[5]\n" + "str q9, [%x[c_ptr], #0xc0]\n" + "smlal v11.4s, %[b0].4h, %[aa].h[6]\n" + "str q17, [%x[c_ptr], #0xd0]\n" + "smlal2 v19.4s, %[b0].8h, %[aa].h[6]\n" + "str q25, [%x[c_ptr], #0xe0]\n" + "smlal v27.4s, %[b1].4h, %[aa].h[6]\n" + "str q10, [%x[c_ptr], #0xf0]\n" + "smlal v12.4s, %[b0].4h, %[aa].h[7]\n" + "str q18, [%x[c_ptr], #0x100]\n" + "smlal2 v20.4s, %[b0].8h, %[aa].h[7]\n" + "str q26, [%x[c_ptr], #0x110]\n" + "smlal v28.4s, %[b1].4h, %[aa].h[7]\n" + "str q11, [%x[c_ptr], #0x120]\n" + + "4:\n" // End of function + "str q19, [%x[c_ptr], #0x130]\n" + "str q27, [%x[c_ptr], #0x140]\n" + "str q12, [%x[c_ptr], #0x150]\n" + "str q20, [%x[c_ptr], #0x160]\n" + "str q28, [%x[c_ptr], #0x170]\n" + "add %x[c_ptr], %x[c_ptr], #0x180\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), [k] "+r"(k), + [aa] "+w"(aa), [ab] "+w"(ab), [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2) + : [odd_k] "r"(odd_k) + : "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "x20", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp new file mode 100644 index 0000000000..08f90e16ed --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +#include "arm_gemm.hpp" + +namespace arm_gemm +{ +// Load the actual kernel +void a64_gemm_s8_12x8(const int8_t *, const int8_t *, int32_t *, int, int, int); +void a64_gemm_s8_12x8_a55r1(const int8_t *, const int8_t *, int32_t *, int, int, int); + +class gemm_s8_12x8 +{ +public: + typedef int8_t operand_type; + typedef int32_t result_type; + + typedef void (*kern_type)(const int8_t *, const int8_t *, int32_t *, int, int, int); + + /* Describes the data layout for A input */ + static const int A_interleave = 8; + static const int A_block = 4; + static const bool A_transpose = false; + + /* Same for B input */ + static const int B_interleave = 12; + static const int B_block = 4; + static const bool B_transpose = true; + + /* Kernel blocking parameters */ + static const int out_width = 12; + static const int out_height = 8; + static const int k_unroll = 4; + + kern_type kernel = a64_gemm_s8_12x8; + + gemm_s8_12x8(const CPUInfo *ci) + { + if(ci->get_cpu_model() == CPUModel::A55r1) + { + kernel = a64_gemm_s8_12x8_a55r1; + } + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp new file mode 100644 index 0000000000..ef2f29183c --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/a55r1.cpp @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +#ifdef NO_DOT_IN_TOOLCHAIN +#include "dot_toolchain_support.h" +#endif + +namespace arm_gemm +{ +void a64_gemm_s8_12x8_a55r1(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, const int ablocks, const int bblocks, const int K) +{ + const int8_t *a_ptr = Apanel; + int32_t *c_ptr = Cpanel; + + // We divide K by 4 because the sdot instruction processes 4 elements at a time. + const int W = K / 4; + + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + const int oddk = (W & 1); + const int k_iters = ((W + 1) / 2) - 1; + + for(int yb = 0; yb < ablocks; yb++) + { + const int8_t *a_ptr0 = a_ptr; + const int8_t *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + int k = k_iters; + + register int32x4_t a0 asm("v0"); + register int32x4_t a1 asm("v1"); + register int32x4_t b0 asm("v2"); + register int32x4_t b1 asm("v3"); + register int32x4_t b2 asm("v4"); + register int32x4_t a0a asm("v5"); + register int32x4_t a1a asm("v6"); + + __asm __volatile( +#ifdef NO_DOT_IN_TOOLCHAIN + _DECLARE_SDOT +#else + ".arch armv8.2-a+dotprod\n" +#endif + // Initialize result registers, load initial operands, prime prefetches. + "movi v8.4s, #0x0\n" + "ldr %q[a0], [%[a_ptr]]\n" + "movi v9.4s, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v10.4s, #0x0\n" + "ldr %q[a1], [%[a_ptr], #16]\n" + "movi v11.4s, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #192]") + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #384]") + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #448]") + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #384]") + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #512]") + + // The loop is offset by these two instructions which must + // always be executed. + "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "ldr %d[b2], [%[b_ptr], #32]\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + "1:\n" + "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n" + "ldr x20, [%[b_ptr], #40]\n" + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "subs %w[k], %w[k], #1\n" + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "ldr %d[a0a], [%[a_ptr], #32]\n" + + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "ins %[b2].d[1], x20\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "ldr x20, [%[a_ptr], #40]\n" + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "ldr %d[a1a], [%[a_ptr], #48]\n" + + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "ins %[a0a].d[1], x20\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "ldr x20, [%[a_ptr], #56]\n" + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "ldr %d[b0], [%[b_ptr], #48]\n" + + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "ins %[a1a].d[1], x20\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "ldr %d[b1], [%[b_ptr], #64]\n" + + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "ins %[b0].d[1], x20\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" + "ldr x20, [%[b_ptr], #72]\n" + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" ASM_PREFETCH("[%[a_ptr], #448]") + + "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #576]") + "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + + // Unroll 1 + "ldr %d[b2], [%[b_ptr], #80]\n" + + "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n" + "ins %[b1].d[1], x20\n" + "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n" + "ldr x20, [%[b_ptr], #88]\n" + "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n" + "ldr %d[a0], [%[a_ptr], #64]\n" + + "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "ins %[b2].d[1], x20\n" + "sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "ldr x20, [%[a_ptr], #72]\n" + "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "ldr %d[a1], [%[a_ptr], #80]\n" + + "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n" + "ins %[a0].d[1], x20\n" + "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n" + "ldr x20, [%[a_ptr], #88]\n" + "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "ldr %d[b0], [%[b_ptr], #96]\n" + + "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "ins %[a1].d[1], x20\n" + "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n" + "ldr x20, [%[b_ptr], #104]\n" + "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n" + "ldr %d[b1], [%[b_ptr], #112]\n" + + "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "ins %[b0].d[1], x20\n" + "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "ldr x20, [%[b_ptr], #120]\n" + "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "add %[a_ptr], %[a_ptr], #64\n" + + "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n" ASM_PREFETCH("[%[b_ptr], #640]") + "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "ins %[b1].d[1], x20\n" + "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + "ldr %d[b2], [%[b_ptr], #32]\n" + + "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "b.ne 1b\n" + + // Branch here if K=1 or 2. Do the right thing for odd/even at the end. + "4:\n" + + // Start final iteration - branch off to "odd" code before we load a0a. + "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n" + "ldr x20, [%[b_ptr], #40]\n" + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "cbnz %w[oddk], 2f\n" + + // Even K continuation + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "ldr %d[a0a], [%[a_ptr], #32]\n" + + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "ins %[b2].d[1], x20\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "ldr x20, [%[a_ptr], #40]\n" + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr]]") + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "ldr %d[a1a], [%[a_ptr], #48]\n" + + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "ins %[a0a].d[1], x20\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "ldr x20, [%[a_ptr], #56]\n" + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "ldr %d[b0], [%[b_ptr], #48]\n" + + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "ins %[a1a].d[1], x20\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #64]") + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" + + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #128]") + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "ldr %d[b1], [%[b_ptr], #64]\n" + + "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "ins %[b0].d[1], x20\n" + "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "ldr x20, [%[b_ptr], #72]\n" + "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #192]") + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "ldr %d[b2], [%[b_ptr], #80]\n" + + "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n" + "ins %[b1].d[1], x20\n" + "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n" + "ldr x20, [%[b_ptr], #88]\n" + "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "ins %[b2].d[1], x20\n" + + "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]") + "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #320]") + "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]") + "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n" + "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]") + "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") + "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") + "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #640]") + "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]") + "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + "b 3f\n" + + // Odd K continuation + "2:\n" + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" ASM_PREFETCHW("[%[c_ptr]]") + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "ins %[b2].d[1], x20\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #64]") + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #128]") + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #192]") + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]") + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #320]") + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]") + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]") + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" + ASM_PREFETCHWL2("[%[c_ptr], #640]") "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]") + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + + // Common tail + "3:\n" + "str q8, [%[c_ptr]]\n" + "str q16, [%[c_ptr], #16]\n" + "str q24, [%[c_ptr], #32]\n" + "str q9, [%[c_ptr], #48]\n" + "str q17, [%[c_ptr], #64]\n" + "str q25, [%[c_ptr], #80]\n" + "str q10, [%[c_ptr], #96]\n" + "str q18, [%[c_ptr], #112]\n" + "str q26, [%[c_ptr], #128]\n" + "str q11, [%[c_ptr], #144]\n" + "str q19, [%[c_ptr], #160]\n" + "str q27, [%[c_ptr], #176]\n" + "str q12, [%[c_ptr], #192]\n" + "str q20, [%[c_ptr], #208]\n" + "str q28, [%[c_ptr], #224]\n" + "str q13, [%[c_ptr], #240]\n" + "str q21, [%[c_ptr], #256]\n" + "str q29, [%[c_ptr], #272]\n" + "str q14, [%[c_ptr], #288]\n" + "str q22, [%[c_ptr], #304]\n" + "str q30, [%[c_ptr], #320]\n" + "str q15, [%[c_ptr], #336]\n" + "str q23, [%[c_ptr], #352]\n" + "str q31, [%[c_ptr], #368]\n" + "add %[c_ptr], %[c_ptr], #384\n" + +#ifdef NO_DOT_IN_TOOLCHAIN + ".purgem sdot\n" +#endif + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k) + : [oddk] "r"(oddk) + : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); + } + } +} + +} // namespace arm_gemm + +#endif
\ No newline at end of file diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/dot_toolchain_support.h b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/dot_toolchain_support.h new file mode 100644 index 0000000000..c76f99d776 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/dot_toolchain_support.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// Define a macro to assemble the UDOT instruction (in the absence of toolchain support) +#define _DECLARE_SDOT \ + ".altmacro\n" \ + ".macro sdot opd:req, opn:req, opm:req\n" \ + "local vd, vn, vm, h, l\n" \ + ".irp reg,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31\n" \ + ".ifeqs \"\\opd\",\"v\\reg\\.4s\"\n" \ + ".set vd,\\reg\n" \ + ".endif\n" \ + ".ifeqs \"\\opn\",\"v\\reg\\.16b\"\n" \ + ".set vn,\\reg\n" \ + ".endif\n" \ + ".irp idx,0,1,2,3\n" \ + ".ifeqs \"\\opm\",\"v\\reg\\.4b[\\idx\\]\"\n" \ + ".set vm,\\reg\n" \ + ".set h,\\idx / 2\n" \ + ".set l,\\idx %% 2\n" \ + ".endif\n" \ + ".endr\n" \ + ".endr\n" \ + ".ifndef vd\n" \ + ".error \"Bad operand \\opd\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vn\n" \ + ".error \"Bad operand \\opn\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vm\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef h\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef l\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".int 0x4f80e000 | vd | (vn << 5) | (vm << 16) | (l << 21) | (h << 11)\n" \ + ".endm\n" diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/generic.cpp new file mode 100644 index 0000000000..258ef5e224 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8/generic.cpp @@ -0,0 +1,343 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +#ifdef NO_DOT_IN_TOOLCHAIN +#include "dot_toolchain_support.h" +#endif + +namespace arm_gemm +{ +void a64_gemm_s8_12x8(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, int ablocks, int bblocks, int K) +{ + const int8_t *a_ptr = Apanel; + int32_t *c_ptr = Cpanel; + // We divide K by 4 because the sdot instruction processes 4 elements at a time. + const int W = K / 4; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + const int oddk = (W & 1); + const int init_value_k = ((W + 1) / 2) - 1; + for(int yb = 0; yb < ablocks; yb++) + { + const int8_t *a_ptr0 = a_ptr; + const int8_t *b_ptr = Bpanel; + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + int k = init_value_k; + register int32x4_t a0 asm("v0"); + register int32x4_t a1 asm("v1"); + register int32x4_t b0 asm("v2"); + register int32x4_t b1 asm("v3"); + register int32x4_t b2 asm("v4"); + register int32x4_t a0a asm("v5"); + register int32x4_t a1a asm("v6"); + __asm __volatile( +#ifdef NO_DOT_IN_TOOLCHAIN + _DECLARE_SDOT +#else + ".arch armv8.2-a+dotprod\n" +#endif + // Initialize result registers, load initial operands, prime prefetches. + "movi v8.4s, #0x0\n" + "ldr %q[a0], [%[a_ptr]]\n" + "movi v9.4s, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v10.4s, #0x0\n" + "ldr %q[a1], [%[a_ptr], #16]\n" + "movi v11.4s, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v18.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #192]") "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") "movi v20.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") "movi v21.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #384]") + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + // Loop proper + "1:\n" + "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n" + + "ldr %q[b2], [%[b_ptr], #32]\n" + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "ldr %q[a0a], [%[a_ptr], #32]\n" + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "ldr %q[a1a], [%[a_ptr], #48]\n" + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #48]\n" + + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" ASM_PREFETCH("[%[a_ptr], #320]") + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #64]\n" + + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #448]") + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "ldr %q[b2], [%[b_ptr], #80]\n" + + "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n" + "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n" + "ldr %q[a0], [%[a_ptr], #64]\n" + "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n" + "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "ldr %q[a1], [%[a_ptr], #80]\n" + "sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #96]\n" + + "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n" + "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #512]") + "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n" + "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #112]\n" + + "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n" + "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "subs %w[k], %w[k], #1\n" + "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + "bne 1b\n" + + // Target to use when K is 1 or 2 (i.e. zero iterations of main loop) + "4:\n" + + // Branch to alternative tail for odd K + "cbnz %w[oddk], 2f\n" + + // Detached final iteration (even K) + "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "ldr %q[a0a], [%[a_ptr], #32]\n" + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "ldr %q[a1a], [%[a_ptr], #48]\n" + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #48]\n" + + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #64]\n" + + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "ldr %q[b2], [%[b_ptr], #80]\n" + + "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n" + + "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n" + "str q8, [%[c_ptr], #0]\n" + "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n" + "str q16, [%[c_ptr], #16]\n" + "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "str q24, [%[c_ptr], #32]\n" + + "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "str q9, [%[c_ptr], #48]\n" + "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "str q17, [%[c_ptr], #64]\n" + "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n" + "str q25, [%[c_ptr], #80]\n" + "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n" + "str q10, [%[c_ptr], #96]\n" + + "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n" + "str q18, [%[c_ptr], #112]\n" + "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "str q26, [%[c_ptr], #128]\n" + "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "str q11, [%[c_ptr], #144]\n" + + "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "str q19, [%[c_ptr], #160]\n" + "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "str q27, [%[c_ptr], #176]\n" + "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n" + "str q12, [%[c_ptr], #192]\n" + + "sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "str q20, [%[c_ptr], #208]\n" + "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n" + "str q28, [%[c_ptr], #224]\n" + "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "str q13, [%[c_ptr], #240]\n" + + "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n" + "str q21, [%[c_ptr], #256]\n" + "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "str q29, [%[c_ptr], #272]\n" + "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "str q14, [%[c_ptr], #288]\n" + + "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "str q22, [%[c_ptr], #304]\n" + "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n" + "str q30, [%[c_ptr], #320]\n" + "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + "str q15, [%[c_ptr], #336]\n" + + "b 3f\n" + + // Detached final iteration (odd K) + "2:\n" + "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n" + "str q8, [%[c_ptr], #0]\n" + "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "str q16, [%[c_ptr], #16]\n" + "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "add %[a_ptr], %[a_ptr], #32\n" + "str q24, [%[c_ptr], #32]\n" + "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n" + "str q9, [%[c_ptr], #48]\n" + + "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "str q17, [%[c_ptr], #64]\n" + "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "str q25, [%[c_ptr], #80]\n" + "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "str q10, [%[c_ptr], #96]\n" + + "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "str q18, [%[c_ptr], #112]\n" + "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "str q26, [%[c_ptr], #128]\n" + "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "str q11, [%[c_ptr], #144]\n" + + "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "str q19, [%[c_ptr], #160]\n" + "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "str q27, [%[c_ptr], #176]\n" + "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "str q12, [%[c_ptr], #192]\n" + + "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "str q20, [%[c_ptr], #208]\n" + "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "str q28, [%[c_ptr], #224]\n" + "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "str q13, [%[c_ptr], #240]\n" + + "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "str q21, [%[c_ptr], #256]\n" + "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "str q29, [%[c_ptr], #272]\n" + "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "str q14, [%[c_ptr], #288]\n" + + "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "str q22, [%[c_ptr], #304]\n" + "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "str q30, [%[c_ptr], #320]\n" + "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "str q15, [%[c_ptr], #336]\n" + + // Common tail + "3:\n" + "str q23, [%[c_ptr], #352]\n" + "str q31, [%[c_ptr], #368]\n" + "add %[c_ptr], %[c_ptr], #384\n" + +#ifdef NO_DOT_IN_TOOLCHAIN + ".purgem sdot\n" +#endif + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k) + : [oddk] "r"(oddk) + : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp new file mode 100644 index 0000000000..2ec28f480c --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +namespace arm_gemm +{ +// Load the actual kernel +void a64_gemm_s8_4x4(const int8_t *, const int8_t *, int32_t *, int, int, int); + +#include "arm_gemm.hpp" + +class gemm_s8_4x4 +{ +public: + typedef int8_t operand_type; + typedef int32_t result_type; + + typedef void (*kern_type)(const int8_t *, const int8_t *, int32_t *, int, int, int); + + /* Describes the data layout for A input */ + static const int A_interleave = 4; + static const int A_block = 16; + static const bool A_transpose = false; + + /* Same for B input */ + static const int B_interleave = 4; + static const int B_block = 16; + static const bool B_transpose = true; + + /* Kernel blocking parameters */ + static const int out_width = 4; + static const int out_height = 4; + static const int k_unroll = 16; + + kern_type kernel = a64_gemm_s8_4x4; + + gemm_s8_4x4(const CPUInfo *ci) + { + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4/generic.cpp new file mode 100644 index 0000000000..243b94e25b --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4/generic.cpp @@ -0,0 +1,456 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +namespace arm_gemm +{ +void a64_gemm_s8_4x4(const int8_t *Apanel, const int8_t *Bpanel, int32_t *Cpanel, int ablocks, int bblocks, int K) +{ + const int8_t *a_ptr = Apanel; + int32_t *c_ptr = Cpanel; + + K /= 16; + int oddk = (K & 1); + + for(int yb = 0; yb < ablocks; yb++) + { + const int8_t *a_ptr0 = a_ptr; + const int8_t *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + + int k = ((K + 1) / 2) - 1; + + register int8x16_t b0 asm("v4"); + register int8x16_t b1 asm("v5"); + register int8x16_t b2 asm("v6"); + register int8x16_t b3 asm("v7"); + register int8x16_t b0a asm("v8"); + register int8x16_t b1a asm("v9"); + register int8x16_t b2a asm("v10"); + register int8x16_t b3a asm("v11"); + + __asm __volatile( + "movi v16.4s, #0x0\n" + "ldr q0, [%[a_ptr]]\n" + "movi v17.4s, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v18.4s, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v19.4s, #0x0\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "movi v20.4s, #0x0\n" + "ldr %q[b3], [%[b_ptr], #48]\n" + "movi v21.4s, #0x0\n" + "ldr q1, [%[a_ptr], #16]\n" + "movi v22.4s, #0x0\n" + "ldr q2, [%[a_ptr], #32]\n" + "movi v23.4s, #0x0\n" + "ldr q3, [%[a_ptr], #48]\n" + "movi v24.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v26.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v27.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") "movi v28.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v29.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #192]") "movi v30.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #256]") "movi v31.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") + + // Loop structure optimized for A57 (after r0). + + // Unavoidably, the multiply will "dribble" if + // dual issued with an add. + + // Minimize the effect of this by making sure + // there are 2 adds to run under the dribbled + // multiply. + + // Pipeline in blocks of 8 multiplies - combine + // this iteration's multiplies with adds from + // the previous iteration. + + // So the first block doesn't have any adds to + // do - but because all the adds are at the + // start of the block it's only the first couple + // of multiplies that need to be pulled out. + + // Start of unroll 0 (first iteration) + "smull v12.8h, v0.8b, %[b0].8b\n" + "smull v13.8h, v0.8b, %[b1].8b\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + // Unroll 0 continuation (branch target) + "1:\n" + "smull v14.8h, v0.8b, %[b2].8b\n" + "subs %w[k], %w[k], #1\n" + "smull v15.8h, v0.8b, %[b3].8b\n" + "ldr %q[b0a], [%[b_ptr], #64]\n" + "smlal2 v12.8h, v0.16b, %[b0].16b\n" + "smlal2 v13.8h, v0.16b, %[b1].16b\n" + "ldr %q[b1a], [%[b_ptr], #80]\n" + "smlal2 v14.8h, v0.16b, %[b2].16b\n" + "smlal2 v15.8h, v0.16b, %[b3].16b\n" + "ldr q0, [%[a_ptr], #64]\n" + + "sadalp v16.4s, v12.8h\n" + "smull v12.8h, v1.8b, %[b0].8b\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "smull v13.8h, v1.8b, %[b1].8b\n" + "sadalp v19.4s, v15.8h\n" + "smull v14.8h, v1.8b, %[b2].8b\n" + "ldr %q[b2a], [%[b_ptr], #96]\n" + "smull v15.8h, v1.8b, %[b3].8b\n" + "smlal2 v12.8h, v1.16b, %[b0].16b\n" + "ldr %q[b3a], [%[b_ptr], #112]\n" + "smlal2 v13.8h, v1.16b, %[b1].16b\n" + "add %[b_ptr], %[b_ptr], #128\n" + "smlal2 v14.8h, v1.16b, %[b2].16b\n" + "smlal2 v15.8h, v1.16b, %[b3].16b\n" + "ldr q1, [%[a_ptr], #80]\n" + + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v2.8b, %[b0].8b\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "smull v13.8h, v2.8b, %[b1].8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v14.8h, v2.8b, %[b2].8b\n" + "smull v15.8h, v2.8b, %[b3].8b\n" + "smlal2 v12.8h, v2.16b, %[b0].16b\n" ASM_PREFETCH("[%[b_ptr], #192]") + "smlal2 v13.8h, v2.16b, %[b1].16b\n" + "smlal2 v14.8h, v2.16b, %[b2].16b\n" ASM_PREFETCH("[%[a_ptr], #320]") + "smlal2 v15.8h, v2.16b, %[b3].16b\n" + "ldr q2, [%[a_ptr], #96]\n" + + "sadalp v24.4s, v12.8h\n" + "smull v12.8h, v3.8b, %[b0].8b\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "smull v13.8h, v3.8b, %[b1].8b\n" + "sadalp v27.4s, v15.8h\n" + "smull v14.8h, v3.8b, %[b2].8b\n" + "smull v15.8h, v3.8b, %[b3].8b\n" + "smlal2 v12.8h, v3.16b, %[b0].16b\n" + "ldr %q[b0], [%[b_ptr], #0]\n" + "smlal2 v13.8h, v3.16b, %[b1].16b\n" + "smlal2 v14.8h, v3.16b, %[b2].16b\n" + "smlal2 v15.8h, v3.16b, %[b3].16b\n" + "ldr q3, [%[a_ptr], #112]\n" + + // Unroll 1 + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, %[b0a].8b\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "smull v13.8h, v0.8b, %[b1a].8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v14.8h, v0.8b, %[b2a].8b\n" + "smull v15.8h, v0.8b, %[b3a].8b\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "smlal2 v12.8h, v0.16b, %[b0a].16b\n" + "smlal2 v13.8h, v0.16b, %[b1a].16b\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "smlal2 v14.8h, v0.16b, %[b2a].16b\n" + "smlal2 v15.8h, v0.16b, %[b3a].16b\n" + "ldr q0, [%[a_ptr], #128]\n" + + "sadalp v16.4s, v12.8h\n" + "smull v12.8h, v1.8b, %[b0a].8b\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "smull v13.8h, v1.8b, %[b1a].8b\n" + "sadalp v19.4s, v15.8h\n" + "add %[a_ptr], %[a_ptr], #128\n" + "smull v14.8h, v1.8b, %[b2a].8b\n" + "smull v15.8h, v1.8b, %[b3a].8b\n" + "ldr %q[b3], [%[b_ptr], #48]\n" + "smlal2 v12.8h, v1.16b, %[b0a].16b\n" + "smlal2 v13.8h, v1.16b, %[b1a].16b\n" + "smlal2 v14.8h, v1.16b, %[b2a].16b\n" + "smlal2 v15.8h, v1.16b, %[b3a].16b\n" + "ldr q1, [%[a_ptr], #16]\n" + + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v2.8b, %[b0a].8b\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "smull v13.8h, v2.8b, %[b1a].8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v14.8h, v2.8b, %[b2a].8b\n" + "smull v15.8h, v2.8b, %[b3a].8b\n" + "smlal2 v12.8h, v2.16b, %[b0a].16b\n" ASM_PREFETCH("[%[b_ptr], #256]") + "smlal2 v13.8h, v2.16b, %[b1a].16b\n" + "smlal2 v14.8h, v2.16b, %[b2a].16b\n" ASM_PREFETCH("[%[a_ptr], #256]") + "smlal2 v15.8h, v2.16b, %[b3a].16b\n" + "ldr q2, [%[a_ptr], #32]\n" + + "sadalp v24.4s, v12.8h\n" + "smull v12.8h, v3.8b, %[b0a].8b\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "smull v13.8h, v3.8b, %[b1a].8b\n" + "sadalp v27.4s, v15.8h\n" + "smull v14.8h, v3.8b, %[b2a].8b\n" + "smull v15.8h, v3.8b, %[b3a].8b\n" + "smlal2 v12.8h, v3.16b, %[b0a].16b\n" + "smlal2 v13.8h, v3.16b, %[b1a].16b\n" + "smlal2 v14.8h, v3.16b, %[b2a].16b\n" + "smlal2 v15.8h, v3.16b, %[b3a].16b\n" + "ldr q3, [%[a_ptr], #48]\n" + + // Start of unroll 0 for next iteration. + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, %[b0].8b\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "smull v13.8h, v0.8b, %[b1].8b\n" + "sadalp v31.4s, v15.8h\n" + "bne 1b\n" + + // Target to use when K=1 or 2 (i.e. zero iterations of main loop) + "4:\n" + + // Branch to alternative tail for odd K + "cbnz %w[oddk], 2f\n" + + // Detached final iteration (even K) + "smull v14.8h, v0.8b, %[b2].8b\n" + "smull v15.8h, v0.8b, %[b3].8b\n" + "ldr %q[b0a], [%[b_ptr], #64]\n" + "smlal2 v12.8h, v0.16b, %[b0].16b\n" + "smlal2 v13.8h, v0.16b, %[b1].16b\n" + "ldr %q[b1a], [%[b_ptr], #80]\n" + "smlal2 v14.8h, v0.16b, %[b2].16b\n" + "smlal2 v15.8h, v0.16b, %[b3].16b\n" + "ldr q0, [%[a_ptr], #64]\n" + + "sadalp v16.4s, v12.8h\n" + "smull v12.8h, v1.8b, %[b0].8b\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "smull v13.8h, v1.8b, %[b1].8b\n" + "sadalp v19.4s, v15.8h\n" + "smull v14.8h, v1.8b, %[b2].8b\n" + "ldr %q[b2a], [%[b_ptr], #96]\n" + "smull v15.8h, v1.8b, %[b3].8b\n" + "smlal2 v12.8h, v1.16b, %[b0].16b\n" + "ldr %q[b3a], [%[b_ptr], #112]\n" + "smlal2 v13.8h, v1.16b, %[b1].16b\n" + "add %[b_ptr], %[b_ptr], #128\n" + "smlal2 v14.8h, v1.16b, %[b2].16b\n" + "smlal2 v15.8h, v1.16b, %[b3].16b\n" + "ldr q1, [%[a_ptr], #80]\n" + + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v2.8b, %[b0].8b\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "smull v13.8h, v2.8b, %[b1].8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v14.8h, v2.8b, %[b2].8b\n" + "smull v15.8h, v2.8b, %[b3].8b\n" + "smlal2 v12.8h, v2.16b, %[b0].16b\n" + "smlal2 v13.8h, v2.16b, %[b1].16b\n" + "smlal2 v14.8h, v2.16b, %[b2].16b\n" + "smlal2 v15.8h, v2.16b, %[b3].16b\n" + "ldr q2, [%[a_ptr], #96]\n" + + "sadalp v24.4s, v12.8h\n" + "smull v12.8h, v3.8b, %[b0].8b\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "smull v13.8h, v3.8b, %[b1].8b\n" + "sadalp v27.4s, v15.8h\n" + "smull v14.8h, v3.8b, %[b2].8b\n" + "smull v15.8h, v3.8b, %[b3].8b\n" + "smlal2 v12.8h, v3.16b, %[b0].16b\n" + "smlal2 v13.8h, v3.16b, %[b1].16b\n" + "smlal2 v14.8h, v3.16b, %[b2].16b\n" + "smlal2 v15.8h, v3.16b, %[b3].16b\n" + "ldr q3, [%[a_ptr], #112]\n" + + // Unroll 1 + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, %[b0a].8b\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "smull v13.8h, v0.8b, %[b1a].8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v14.8h, v0.8b, %[b2a].8b\n" + "add %[a_ptr], %[a_ptr], #128\n" + "smull v15.8h, v0.8b, %[b3a].8b\n" + "smlal2 v12.8h, v0.16b, %[b0a].16b\n" + "smlal2 v13.8h, v0.16b, %[b1a].16b\n" + "smlal2 v14.8h, v0.16b, %[b2a].16b\n" + "smlal2 v15.8h, v0.16b, %[b3a].16b\n" + + "sadalp v16.4s, v12.8h\n" + "smull v12.8h, v1.8b, %[b0a].8b\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "smull v13.8h, v1.8b, %[b1a].8b\n" + "sadalp v19.4s, v15.8h\n" + "smull v14.8h, v1.8b, %[b2a].8b\n" + "smull v15.8h, v1.8b, %[b3a].8b\n" + "smlal2 v12.8h, v1.16b, %[b0a].16b\n" + "addp v16.4s, v16.4s, v17.4s\n" + "smlal2 v13.8h, v1.16b, %[b1a].16b\n" + "addp v17.4s, v18.4s, v19.4s\n" + "smlal2 v14.8h, v1.16b, %[b2a].16b\n" + "smlal2 v15.8h, v1.16b, %[b3a].16b\n" + + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v2.8b, %[b0a].8b\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "smull v13.8h, v2.8b, %[b1a].8b\n" + "sadalp v23.4s, v15.8h\n" + "addp v16.4s, v16.4s, v17.4s\n" + "smull v14.8h, v2.8b, %[b2a].8b\n" + "addp v18.4s, v20.4s, v21.4s\n" + "addp v19.4s, v22.4s, v23.4s\n" + "smull v15.8h, v2.8b, %[b3a].8b\n" + "smlal2 v12.8h, v2.16b, %[b0a].16b\n" + "str q16, [%[c_ptr]]\n" + "smlal2 v13.8h, v2.16b, %[b1a].16b\n" + "smlal2 v14.8h, v2.16b, %[b2a].16b\n" + "smlal2 v15.8h, v2.16b, %[b3a].16b\n" + + "sadalp v24.4s, v12.8h\n" + "smull v12.8h, v3.8b, %[b0a].8b\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "smull v13.8h, v3.8b, %[b1a].8b\n" + "sadalp v27.4s, v15.8h\n" + "addp v17.4s, v18.4s, v19.4s\n" + "smull v14.8h, v3.8b, %[b2a].8b\n" + "addp v20.4s, v24.4s, v25.4s\n" + "addp v21.4s, v26.4s, v27.4s\n" + "smull v15.8h, v3.8b, %[b3a].8b\n" + "smlal2 v12.8h, v3.16b, %[b0a].16b\n" + "str q17, [%[c_ptr], #16]\n" + "smlal2 v13.8h, v3.16b, %[b1a].16b\n" + "smlal2 v14.8h, v3.16b, %[b2a].16b\n" + "addp v18.4s, v20.4s, v21.4s\n" + "smlal2 v15.8h, v3.16b, %[b3a].16b\n" + "b 3f\n" + + // Detached final iteration (odd K) + "2:\n" + "smull v14.8h, v0.8b, %[b2].8b\n" + "add %[a_ptr], %[a_ptr], #64\n" + "smull v15.8h, v0.8b, %[b3].8b\n" + "add %[b_ptr], %[b_ptr], #64\n" + "smlal2 v12.8h, v0.16b, %[b0].16b\n" + "smlal2 v13.8h, v0.16b, %[b1].16b\n" + "smlal2 v14.8h, v0.16b, %[b2].16b\n" + "smlal2 v15.8h, v0.16b, %[b3].16b\n" + + "sadalp v16.4s, v12.8h\n" + "smull v12.8h, v1.8b, %[b0].8b\n" + "sadalp v17.4s, v13.8h\n" + "sadalp v18.4s, v14.8h\n" + "smull v13.8h, v1.8b, %[b1].8b\n" + "sadalp v19.4s, v15.8h\n" + "smull v14.8h, v1.8b, %[b2].8b\n" + "smull v15.8h, v1.8b, %[b3].8b\n" + "smlal2 v12.8h, v1.16b, %[b0].16b\n" + "addp v16.4s, v16.4s, v17.4s\n" + "smlal2 v13.8h, v1.16b, %[b1].16b\n" + "addp v17.4s, v18.4s, v19.4s\n" + "smlal2 v14.8h, v1.16b, %[b2].16b\n" + "smlal2 v15.8h, v1.16b, %[b3].16b\n" + + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v2.8b, %[b0].8b\n" + "sadalp v21.4s, v13.8h\n" + "sadalp v22.4s, v14.8h\n" + "smull v13.8h, v2.8b, %[b1].8b\n" + "sadalp v23.4s, v15.8h\n" + "addp v16.4s, v16.4s, v17.4s\n" + "smull v14.8h, v2.8b, %[b2].8b\n" + "addp v18.4s, v20.4s, v21.4s\n" + "addp v19.4s, v22.4s, v23.4s\n" + "smull v15.8h, v2.8b, %[b3].8b\n" + "smlal2 v12.8h, v2.16b, %[b0].16b\n" + "str q16, [%[c_ptr]]\n" + "smlal2 v13.8h, v2.16b, %[b1].16b\n" + "smlal2 v14.8h, v2.16b, %[b2].16b\n" + "smlal2 v15.8h, v2.16b, %[b3].16b\n" + + "sadalp v24.4s, v12.8h\n" + "smull v12.8h, v3.8b, %[b0].8b\n" + "sadalp v25.4s, v13.8h\n" + "sadalp v26.4s, v14.8h\n" + "smull v13.8h, v3.8b, %[b1].8b\n" + "sadalp v27.4s, v15.8h\n" + "addp v17.4s, v18.4s, v19.4s\n" + "smull v14.8h, v3.8b, %[b2].8b\n" + "addp v20.4s, v24.4s, v25.4s\n" + "addp v21.4s, v26.4s, v27.4s\n" + "smull v15.8h, v3.8b, %[b3].8b\n" + "smlal2 v12.8h, v3.16b, %[b0].16b\n" + "str q17, [%[c_ptr], #16]\n" + "smlal2 v13.8h, v3.16b, %[b1].16b\n" + "smlal2 v14.8h, v3.16b, %[b2].16b\n" + "addp v18.4s, v20.4s, v21.4s\n" + "smlal2 v15.8h, v3.16b, %[b3].16b\n" + + "3:\n" + + // Final additions + "sadalp v28.4s, v12.8h\n" + "str q18, [%[c_ptr], #32]\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + // Horizontal reduction, phase 1 + "addp v22.4s, v28.4s, v29.4s\n" + "addp v23.4s, v30.4s, v31.4s\n" + + // Horizontal reduction, phase 2 + "addp v19.4s, v22.4s, v23.4s\n" + "str q19, [%[c_ptr], #48]\n" + "add %[c_ptr], %[c_ptr], #64\n" + + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [b3] "+w"(b3), + [b0a] "+w"(b0a), [b1a] "+w"(b1a), [b2a] "+w"(b2a), [b3a] "+w"(b3a), + [k] "+r"(k) + : [oddk] "r"(oddk) + : "x20", "x21", "v0", "v1", "v2", "v3", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp new file mode 100644 index 0000000000..39757326f4 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_gemm_u16_asimd_12x8(const uint16_t *, const uint16_t *, uint32_t *, int, int, int); + +// 12x8 SGEMM "strategy" class. +// +// This describes the characteristics of a family of kernels, in terms of +// the required interleave properties and the output block size. +// +// All kernels in the family must share these characteristics. The actual +// kernel to be used can be chosen at runtime, based on the CPU_type +// structure. +class gemm_u16_12x8 +{ +public: + typedef uint16_t operand_type; + typedef uint32_t result_type; + + typedef void (*kern_type)(const uint16_t *, const uint16_t *, uint32_t *, int, int, int); + + /* Describes the data layout for A input */ + static const int A_interleave = 8; + static const int A_block = 1; + static const int A_transpose = 0; + + /* Same for B input */ + static const int B_interleave = 12; + static const int B_block = 1; + static const int B_transpose = 1; + + /* Kernel blocking parameters */ + static const int out_width = 12; + static const int out_height = 8; + static const int k_unroll = 1; + + kern_type kernel = a64_gemm_u16_asimd_12x8; + + gemm_u16_12x8(const CPUInfo *ci) + { + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8/generic.cpp new file mode 100644 index 0000000000..7903878301 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8/generic.cpp @@ -0,0 +1,309 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +namespace arm_gemm +{ +void a64_gemm_u16_asimd_12x8(const uint16_t *Apanel, const uint16_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K) +{ + const uint16_t *a_ptr = Apanel; + uint32_t *c_ptr = Cpanel; + + for(int yb = 0; yb < ablocks; yb++) + { + const uint16_t *a_ptr0 = a_ptr; + const uint16_t *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + const bool odd_k = K & 0x1; + int k = (K + 1) / 2 - 1; + + register uint16x8_t aa asm("v0"); + register uint16x8_t ab asm("v1"); + register uint16x8_t b0 asm("v2"); + register uint16x8_t b1 asm("v3"); + register uint16x8_t b2 asm("v4"); + + __asm __volatile( + "ldr %d[aa], [%x[a_ptr]]\n" // Load A[A].lower + "movi v5.4s, #0\n" + "ldr x20, [%x[a_ptr], #0x08]\n" // Load A[A].upper + "movi v6.4s, #0\n" + "ldr %d[b0], [%x[b_ptr]]\n" // Load B[0].lower + "ins %[aa].d[1], x20\n" // Merge A[A].lower and upper + "movi v7.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #64]") + "movi v8.4s, #0\n" + "ldr x20, [%x[b_ptr], #0x08]\n" // Load B[0].upper + "movi v9.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #64]") + "movi v10.4s, #0\n" + "ldr %d[b1], [%x[b_ptr], #0x10]\n" // Load B[1].lower + "ins %[b0].d[1], x20\n" // Merge B[0].lower and upper + "movi v11.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #96]") + "movi v12.4s, #0\n" + "movi v13.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #96]") + "movi v14.4s, #0\n" + "movi v15.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #128]") + "movi v16.4s, #0\n" + "movi v17.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #128]") + "movi v18.4s, #0\n" + "movi v19.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #160]") + "movi v20.4s, #0\n" + "movi v21.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #160]") + "movi v22.4s, #0\n" + "movi v23.4s, #0\n" ASM_PREFETCH("[%[a_ptr], #192]") + "movi v24.4s, #0\n" + "add %x[a_ptr], %x[a_ptr], #0x10\n" + "movi v25.4s, #0\n" ASM_PREFETCH("[%[b_ptr], #192]") + "movi v26.4s, #0\n" + "add %x[b_ptr], %x[b_ptr], #0x18\n" + "movi v27.4s, #0\n" + "movi v28.4s, #0\n" + + "cbz %x[k], 2f\n" // Skip the loop if doing zero iterations. + + "1:\n" // Main loop + // First unroll + "smlal v5.4s, %[b0].4h, %[aa].h[0]\n" + "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper + "umlal v6.4s, %[b0].4h, %[aa].h[1]\n" + "umlal v7.4s, %[b0].4h, %[aa].h[2]\n" + "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower + "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper + "umlal v8.4s, %[b0].4h, %[aa].h[3]\n" + "umlal v9.4s, %[b0].4h, %[aa].h[4]\n" + "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper + "umlal v10.4s, %[b0].4h, %[aa].h[5]\n" + "umlal v11.4s, %[b0].4h, %[aa].h[6]\n" + "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower + "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper + "umlal v12.4s, %[b0].4h, %[aa].h[7]\n" + "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n" + "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper + "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n" + "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n" + "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n" + "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n" + "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n" + "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n" + "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n" + "ldr %d[b0], [%x[b_ptr], #0x18]\n" // Load B[0].lower + "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper + "umlal v21.4s, %[b1].4h, %[aa].h[0]\n" + "umlal v22.4s, %[b1].4h, %[aa].h[1]\n" + "ldr x20, [%x[b_ptr], #0x20]\n" // Load B[0].upper + "umlal v23.4s, %[b1].4h, %[aa].h[2]\n" + "umlal v24.4s, %[b1].4h, %[aa].h[3]\n" + "umlal v25.4s, %[b1].4h, %[aa].h[4]\n" + "umlal v26.4s, %[b1].4h, %[aa].h[5]\n" + "umlal v27.4s, %[b1].4h, %[aa].h[6]\n" + "umlal v28.4s, %[b1].4h, %[aa].h[7]\n" + + // Second unroll + "umlal2 v5.4s, %[b1].8h, %[ab].h[0]\n" + "ldr %d[aa], [%x[a_ptr], #0x10]\n" // Load A[A].lower + "ins %[b0].d[1], x20\n" // Merge B[0].lower and .upper + "umlal2 v6.4s, %[b1].8h, %[ab].h[1]\n" + "umlal2 v7.4s, %[b1].8h, %[ab].h[2]\n" + "ldr x20, [%x[a_ptr], #0x18]\n" // Load A[A].upper + "umlal2 v8.4s, %[b1].8h, %[ab].h[3]\n" + "umlal2 v9.4s, %[b1].8h, %[ab].h[4]\n" + "umlal2 v10.4s, %[b1].8h, %[ab].h[5]\n" + "umlal2 v11.4s, %[b1].8h, %[ab].h[6]\n" + "add %x[a_ptr], %x[a_ptr], #0x20\n" + "umlal2 v12.4s, %[b1].8h, %[ab].h[7]\n" + "umlal v13.4s, %[b2].4h, %[ab].h[0]\n" ASM_PREFETCH("[%[b_ptr], #320]") + "umlal v14.4s, %[b2].4h, %[ab].h[1]\n" + "umlal v15.4s, %[b2].4h, %[ab].h[2]\n" ASM_PREFETCH("[%[a_ptr], #320]") + "umlal v16.4s, %[b2].4h, %[ab].h[3]\n" + "umlal v17.4s, %[b2].4h, %[ab].h[4]\n" ASM_PREFETCH("[%[b_ptr], #448]") + "umlal v18.4s, %[b2].4h, %[ab].h[5]\n" + "umlal v19.4s, %[b2].4h, %[ab].h[6]\n" + "umlal v20.4s, %[b2].4h, %[ab].h[7]\n" + "umlal2 v21.4s, %[b2].8h, %[ab].h[0]\n" + "umlal2 v22.4s, %[b2].8h, %[ab].h[1]\n" + "subs %x[k], %x[k], #0x1\n" + "umlal2 v23.4s, %[b2].8h, %[ab].h[2]\n" + "umlal2 v24.4s, %[b2].8h, %[ab].h[3]\n" + "ldr %d[b1], [%x[b_ptr], #0x28]\n" // Load B[1].lower + "ins %[aa].d[1], x20\n" // Merge A[A].lower and .upper + "umlal2 v25.4s, %[b2].8h, %[ab].h[4]\n" + "umlal2 v26.4s, %[b2].8h, %[ab].h[5]\n" + "add %x[b_ptr], %x[b_ptr], #0x30\n" + "umlal2 v27.4s, %[b2].8h, %[ab].h[6]\n" + "umlal2 v28.4s, %[b2].8h, %[ab].h[7]\n" + "bne 1b\n" + + "2:\n" // Even tail + "cbnz %x[odd_k], 3f\n" + + "umlal v5.4s, %[b0].4h, %[aa].h[0]\n" + "ldr x20, [%x[b_ptr]]\n" // Load B[1].upper + "umlal v6.4s, %[b0].4h, %[aa].h[1]\n" + "umlal v7.4s, %[b0].4h, %[aa].h[2]\n" + "ldr %d[ab], [%x[a_ptr]]\n" // Load A[B].lower + "ins %[b1].d[1], x20\n" // Merge B[1].lower and .upper + "umlal v8.4s, %[b0].4h, %[aa].h[3]\n" + "umlal v9.4s, %[b0].4h, %[aa].h[4]\n" + "ldr x20, [%x[a_ptr], #0x8]\n" // Load A[B].upper + "umlal v10.4s, %[b0].4h, %[aa].h[5]\n" + "umlal v11.4s, %[b0].4h, %[aa].h[6]\n" + "ldr %d[b2], [%x[b_ptr], #0x8]\n" // Load B[2].lower + "ins %[ab].d[1], x20\n" // Merge A[B].lower and .upper + "umlal v12.4s, %[b0].4h, %[aa].h[7]\n" + "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n" + "ldr x20, [%x[b_ptr], #0x10]\n" // Load B[2].upper + "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n" + "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n" + "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n" + "add %[a_ptr], %[a_ptr], #0x10\n" + "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n" + "add %[b_ptr], %[b_ptr], #0x18\n" + "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n" + "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n" + "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n" + "ins %[b2].d[1], x20\n" // Merge B[2].lower and .upper + "umlal v21.4s, %[b1].4h, %[aa].h[0]\n" + "umlal v22.4s, %[b1].4h, %[aa].h[1]\n" + "umlal v23.4s, %[b1].4h, %[aa].h[2]\n" + "umlal v24.4s, %[b1].4h, %[aa].h[3]\n" + "umlal v25.4s, %[b1].4h, %[aa].h[4]\n" + "umlal v26.4s, %[b1].4h, %[aa].h[5]\n" + "umlal v27.4s, %[b1].4h, %[aa].h[6]\n" + "umlal v28.4s, %[b1].4h, %[aa].h[7]\n" + + "umlal2 v5.4s, %[b1].8h, %[ab].h[0]\n" + "umlal v13.4s, %[b2].4h, %[ab].h[0]\n" + "umlal2 v21.4s, %[b2].8h, %[ab].h[0]\n" + "umlal2 v6.4s, %[b1].8h, %[ab].h[1]\n" + "umlal v14.4s, %[b2].4h, %[ab].h[1]\n" + "str q5, [%x[c_ptr]]\n" + "umlal2 v22.4s, %[b2].8h, %[ab].h[1]\n" + "str q13, [%x[c_ptr], #0x10]\n" + "umlal2 v7.4s, %[b1].8h, %[ab].h[2]\n" + "str q21, [%x[c_ptr], #0x20]\n" + "umlal v15.4s, %[b2].4h, %[ab].h[2]\n" + "str q6, [%x[c_ptr], #0x30]\n" + "umlal2 v23.4s, %[b2].8h, %[ab].h[2]\n" + "str q14, [%x[c_ptr], #0x40]\n" + "umlal2 v8.4s, %[b1].8h, %[ab].h[3]\n" + "str q22, [%x[c_ptr], #0x50]\n" + "umlal v16.4s, %[b2].4h, %[ab].h[3]\n" + "str q7, [%x[c_ptr], #0x60]\n" + "umlal2 v24.4s, %[b2].8h, %[ab].h[3]\n" + "str q15, [%x[c_ptr], #0x70]\n" + "umlal2 v9.4s, %[b1].8h, %[ab].h[4]\n" + "str q23, [%x[c_ptr], #0x80]\n" + "umlal v17.4s, %[b2].4h, %[ab].h[4]\n" + "str q8, [%x[c_ptr], #0x90]\n" + "umlal2 v25.4s, %[b2].8h, %[ab].h[4]\n" + "str q16, [%x[c_ptr], #0xa0]\n" + "umlal2 v10.4s, %[b1].8h, %[ab].h[5]\n" + "str q24, [%x[c_ptr], #0xb0]\n" + "umlal v18.4s, %[b2].4h, %[ab].h[5]\n" + "str q9, [%x[c_ptr], #0xc0]\n" + "umlal2 v26.4s, %[b2].8h, %[ab].h[5]\n" + "str q17, [%x[c_ptr], #0xd0]\n" + "umlal2 v11.4s, %[b1].8h, %[ab].h[6]\n" + "str q25, [%x[c_ptr], #0xe0]\n" + "umlal v19.4s, %[b2].4h, %[ab].h[6]\n" + "str q10, [%x[c_ptr], #0xf0]\n" + "umlal2 v27.4s, %[b2].8h, %[ab].h[6]\n" + "str q18, [%x[c_ptr], #0x100]\n" + "umlal2 v12.4s, %[b1].8h, %[ab].h[7]\n" + "str q26, [%x[c_ptr], #0x110]\n" + "umlal v20.4s, %[b2].4h, %[ab].h[7]\n" + "str q11, [%x[c_ptr], #0x120]\n" + "umlal2 v28.4s, %[b2].8h, %[ab].h[7]\n" + "str q19, [%x[c_ptr], #0x130]\n" + "b 4f\n" // Complete write out + + "3:\n" // Odd tail + "umlal v5.4s, %[b0].4h, %[aa].h[0]\n" + "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n" + "umlal v21.4s, %[b1].4h, %[aa].h[0]\n" + "umlal v6.4s, %[b0].4h, %[aa].h[1]\n" + "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n" + "umlal v22.4s, %[b1].4h, %[aa].h[1]\n" + "str q5, [%x[c_ptr]]\n" + "umlal v7.4s, %[b0].4h, %[aa].h[2]\n" + "str q13, [%x[c_ptr], #0x10]\n" + "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n" + "str q21, [%x[c_ptr], #0x20]\n" + "umlal v23.4s, %[b1].4h, %[aa].h[2]\n" + "str q6, [%x[c_ptr], #0x30]\n" + "umlal v8.4s, %[b0].4h, %[aa].h[3]\n" + "str q14, [%x[c_ptr], #0x40]\n" + "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n" + "str q22, [%x[c_ptr], #0x50]\n" + "umlal v24.4s, %[b1].4h, %[aa].h[3]\n" + "str q7, [%x[c_ptr], #0x60]\n" + "umlal v9.4s, %[b0].4h, %[aa].h[4]\n" + "str q15, [%x[c_ptr], #0x70]\n" + "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n" + "str q23, [%x[c_ptr], #0x80]\n" + "umlal v25.4s, %[b1].4h, %[aa].h[4]\n" + "str q8, [%x[c_ptr], #0x90]\n" + "umlal v10.4s, %[b0].4h, %[aa].h[5]\n" + "str q16, [%x[c_ptr], #0xa0]\n" + "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n" + "str q24, [%x[c_ptr], #0xb0]\n" + "umlal v26.4s, %[b1].4h, %[aa].h[5]\n" + "str q9, [%x[c_ptr], #0xc0]\n" + "umlal v11.4s, %[b0].4h, %[aa].h[6]\n" + "str q17, [%x[c_ptr], #0xd0]\n" + "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n" + "str q25, [%x[c_ptr], #0xe0]\n" + "umlal v27.4s, %[b1].4h, %[aa].h[6]\n" + "str q10, [%x[c_ptr], #0xf0]\n" + "umlal v12.4s, %[b0].4h, %[aa].h[7]\n" + "str q18, [%x[c_ptr], #0x100]\n" + "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n" + "str q26, [%x[c_ptr], #0x110]\n" + "umlal v28.4s, %[b1].4h, %[aa].h[7]\n" + "str q11, [%x[c_ptr], #0x120]\n" + + "4:\n" // End of function + "str q19, [%x[c_ptr], #0x130]\n" + "str q27, [%x[c_ptr], #0x140]\n" + "str q12, [%x[c_ptr], #0x150]\n" + "str q20, [%x[c_ptr], #0x160]\n" + "str q28, [%x[c_ptr], #0x170]\n" + "add %x[c_ptr], %x[c_ptr], #0x180\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), [k] "+r"(k), + [aa] "+w"(aa), [ab] "+w"(ab), [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2) + : [odd_k] "r"(odd_k) + : "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "x20", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp new file mode 100644 index 0000000000..26255b14bf --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +#include "arm_gemm.hpp" + +namespace arm_gemm +{ +// Load the actual kernel +void a64_gemm_u8_12x8(const uint8_t *, const uint8_t *, uint32_t *, int, int, int); +void a64_gemm_u8_12x8_a55r1(const uint8_t *, const uint8_t *, uint32_t *, int, int, int); + +class gemm_u8_12x8 +{ +public: + typedef uint8_t operand_type; + typedef uint32_t result_type; + + typedef void (*kern_type)(const uint8_t *, const uint8_t *, uint32_t *, int, int, int); + + /* Describes the data layout for A input */ + static const int A_interleave = 8; + static const int A_block = 4; + static const bool A_transpose = false; + + /* Same for B input */ + static const int B_interleave = 12; + static const int B_block = 4; + static const bool B_transpose = true; + + /* Kernel blocking parameters */ + static const int out_width = 12; + static const int out_height = 8; + static const int k_unroll = 4; + + kern_type kernel = a64_gemm_u8_12x8; + + gemm_u8_12x8(const CPUInfo *ci) + { + if(ci->get_cpu_model() == CPUModel::A55r1) + { + kernel = a64_gemm_u8_12x8_a55r1; + } + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp new file mode 100644 index 0000000000..f8fafbdf84 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/a55r1.cpp @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +#ifdef NO_DOT_IN_TOOLCHAIN +#include "dot_toolchain_support.h" +#endif + +namespace arm_gemm +{ +void a64_gemm_u8_12x8_a55r1(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, const int ablocks, const int bblocks, const int K) +{ + const uint8_t *a_ptr = Apanel; + uint32_t *c_ptr = Cpanel; + + // We divide K by 4 because the udot instruction processes 4 elements at a time. + const int W = K / 4; + + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + const int oddk = (W & 1); + const int k_iters = ((W + 1) / 2) - 1; + + for(int yb = 0; yb < ablocks; yb++) + { + const uint8_t *a_ptr0 = a_ptr; + const uint8_t *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + int k = k_iters; + + register int32x4_t a0 asm("v0"); + register int32x4_t a1 asm("v1"); + register int32x4_t b0 asm("v2"); + register int32x4_t b1 asm("v3"); + register int32x4_t b2 asm("v4"); + register int32x4_t a0a asm("v5"); + register int32x4_t a1a asm("v6"); + + __asm __volatile( +#ifdef NO_DOT_IN_TOOLCHAIN + _DECLARE_UDOT +#else + ".arch armv8.2-a+dotprod\n" +#endif + // Initialize result registers, load initial operands, prime prefetches. + "movi v8.4s, #0x0\n" + "ldr %q[a0], [%[a_ptr]]\n" + "movi v9.4s, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v10.4s, #0x0\n" + "ldr %q[a1], [%[a_ptr], #16]\n" + "movi v11.4s, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #192]") + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #384]") + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #448]") + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #384]") + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #512]") + + // The loop is offset by these two instructions which must + // always be executed. + "udot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "ldr %d[b2], [%[b_ptr], #32]\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + "1:\n" + "udot v9.4s , %[b0].16b, %[a0].4b[1]\n" + "ldr x20, [%[b_ptr], #40]\n" + "udot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "subs %w[k], %w[k], #1\n" + "udot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "ldr %d[a0a], [%[a_ptr], #32]\n" + + "udot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "ins %[b2].d[1], x20\n" + "udot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "ldr x20, [%[a_ptr], #40]\n" + "udot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "udot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "ldr %d[a1a], [%[a_ptr], #48]\n" + + "udot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "ins %[a0a].d[1], x20\n" + "udot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "ldr x20, [%[a_ptr], #56]\n" + "udot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "udot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "ldr %d[b0], [%[b_ptr], #48]\n" + + "udot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "ins %[a1a].d[1], x20\n" + "udot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "udot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "udot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "ldr %d[b1], [%[b_ptr], #64]\n" + + "udot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "ins %[b0].d[1], x20\n" + "udot v25.4s, %[b2].16b, %[a0].4b[1]\n" + "ldr x20, [%[b_ptr], #72]\n" + "udot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "udot v27.4s, %[b2].16b, %[a0].4b[3]\n" ASM_PREFETCH("[%[a_ptr], #448]") + + "udot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "udot v29.4s, %[b2].16b, %[a1].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #576]") + "udot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "udot v31.4s, %[b2].16b, %[a1].4b[3]\n" + + // Unroll 1 + "ldr %d[b2], [%[b_ptr], #80]\n" + + "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n" + "ins %[b1].d[1], x20\n" + "udot v9.4s , %[b0].16b, %[a0a].4b[1]\n" + "ldr x20, [%[b_ptr], #88]\n" + "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n" + "ldr %d[a0], [%[a_ptr], #64]\n" + + "udot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "ins %[b2].d[1], x20\n" + "udot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "ldr x20, [%[a_ptr], #72]\n" + "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n" + "udot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "ldr %d[a1], [%[a_ptr], #80]\n" + + "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n" + "ins %[a0].d[1], x20\n" + "udot v17.4s, %[b1].16b, %[a0a].4b[1]\n" + "ldr x20, [%[a_ptr], #88]\n" + "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n" + "udot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "ldr %d[b0], [%[b_ptr], #96]\n" + + "udot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "ins %[a1].d[1], x20\n" + "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n" + "ldr x20, [%[b_ptr], #104]\n" + "udot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n" + "ldr %d[b1], [%[b_ptr], #112]\n" + + "udot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "ins %[b0].d[1], x20\n" + "udot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "ldr x20, [%[b_ptr], #120]\n" + "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n" + "udot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "add %[a_ptr], %[a_ptr], #64\n" + + "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n" ASM_PREFETCH("[%[b_ptr], #640]") + "udot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "udot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "ins %[b1].d[1], x20\n" + "udot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + "ldr %d[b2], [%[b_ptr], #32]\n" + + "udot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "b.ne 1b\n" + + // Branch here if K=1 or 2. Do the right thing for odd/even at the end. + "4:\n" + + // Start final iteration - branch off to "odd" code before we load a0a. + "udot v9.4s , %[b0].16b, %[a0].4b[1]\n" + "ldr x20, [%[b_ptr], #40]\n" + "udot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "cbnz %w[oddk], 2f\n" + + // Even K continuation + "udot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "ldr %d[a0a], [%[a_ptr], #32]\n" + + "udot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "ins %[b2].d[1], x20\n" + "udot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "ldr x20, [%[a_ptr], #40]\n" + "udot v14.4s, %[b0].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr]]") + "udot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "ldr %d[a1a], [%[a_ptr], #48]\n" + + "udot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "ins %[a0a].d[1], x20\n" + "udot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "ldr x20, [%[a_ptr], #56]\n" + "udot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "udot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "ldr %d[b0], [%[b_ptr], #48]\n" + + "udot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "ins %[a1a].d[1], x20\n" + "udot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "udot v22.4s, %[b1].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #64]") + "udot v23.4s, %[b1].16b, %[a1].4b[3]\n" + + "udot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "udot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #128]") + "udot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "udot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "ldr %d[b1], [%[b_ptr], #64]\n" + + "udot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "ins %[b0].d[1], x20\n" + "udot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "ldr x20, [%[b_ptr], #72]\n" + "udot v30.4s, %[b2].16b, %[a1].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #192]") + "udot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "ldr %d[b2], [%[b_ptr], #80]\n" + + "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n" + "ins %[b1].d[1], x20\n" + "udot v9.4s , %[b0].16b, %[a0a].4b[1]\n" + "ldr x20, [%[b_ptr], #88]\n" + "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "ins %[b2].d[1], x20\n" + + "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]") + "udot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "udot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n" ASM_PREFETCHW("[%[c_ptr], #320]") + "udot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]") + "udot v17.4s, %[b1].16b, %[a0a].4b[1]\n" + "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]") + "udot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "udot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") + "udot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") + "udot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "udot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #640]") + "udot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]") + "udot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "udot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "udot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + "b 3f\n" + + // Odd K continuation + "2:\n" + "udot v11.4s, %[b0].16b, %[a0].4b[3]\n" ASM_PREFETCHW("[%[c_ptr]]") + "udot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "ins %[b2].d[1], x20\n" + "udot v13.4s, %[b0].16b, %[a1].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #64]") + "udot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "udot v15.4s, %[b0].16b, %[a1].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #128]") + "udot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "udot v17.4s, %[b1].16b, %[a0].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #192]") + "udot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "udot v19.4s, %[b1].16b, %[a0].4b[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]") + "udot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "udot v21.4s, %[b1].16b, %[a1].4b[1]\n" ASM_PREFETCHW("[%[c_ptr], #320]") + "udot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "udot v23.4s, %[b1].16b, %[a1].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]") + "udot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "udot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]") + "udot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "udot v27.4s, %[b2].16b, %[a0].4b[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") "udot v28.4s, %[b2].16b, %[a1].4b[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") "udot v29.4s, %[b2].16b, %[a1].4b[1]\n" + ASM_PREFETCHWL2("[%[c_ptr], #640]") "udot v30.4s, %[b2].16b, %[a1].4b[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]") + "udot v31.4s, %[b2].16b, %[a1].4b[3]\n" + + // Common tail + "3:\n" + "str q8, [%[c_ptr]]\n" + "str q16, [%[c_ptr], #16]\n" + "str q24, [%[c_ptr], #32]\n" + "str q9, [%[c_ptr], #48]\n" + "str q17, [%[c_ptr], #64]\n" + "str q25, [%[c_ptr], #80]\n" + "str q10, [%[c_ptr], #96]\n" + "str q18, [%[c_ptr], #112]\n" + "str q26, [%[c_ptr], #128]\n" + "str q11, [%[c_ptr], #144]\n" + "str q19, [%[c_ptr], #160]\n" + "str q27, [%[c_ptr], #176]\n" + "str q12, [%[c_ptr], #192]\n" + "str q20, [%[c_ptr], #208]\n" + "str q28, [%[c_ptr], #224]\n" + "str q13, [%[c_ptr], #240]\n" + "str q21, [%[c_ptr], #256]\n" + "str q29, [%[c_ptr], #272]\n" + "str q14, [%[c_ptr], #288]\n" + "str q22, [%[c_ptr], #304]\n" + "str q30, [%[c_ptr], #320]\n" + "str q15, [%[c_ptr], #336]\n" + "str q23, [%[c_ptr], #352]\n" + "str q31, [%[c_ptr], #368]\n" + "add %[c_ptr], %[c_ptr], #384\n" + +#ifdef NO_DOT_IN_TOOLCHAIN + ".purgem udot\n" +#endif + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k) + : [oddk] "r"(oddk) + : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); + } + } +} + +} // namespace arm_gemm + +#endif
\ No newline at end of file diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/dot_toolchain_support.h b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/dot_toolchain_support.h new file mode 100644 index 0000000000..5ee273bd74 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/dot_toolchain_support.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// Define a macro to assemble the UDOT instruction (in the absence of toolchain support) +#define _DECLARE_UDOT \ + ".altmacro\n" \ + ".macro udot opd:req, opn:req, opm:req\n" \ + "local vd, vn, vm, h, l\n" \ + ".irp reg,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31\n" \ + ".ifeqs \"\\opd\",\"v\\reg\\.4s\"\n" \ + ".set vd,\\reg\n" \ + ".endif\n" \ + ".ifeqs \"\\opn\",\"v\\reg\\.16b\"\n" \ + ".set vn,\\reg\n" \ + ".endif\n" \ + ".irp idx,0,1,2,3\n" \ + ".ifeqs \"\\opm\",\"v\\reg\\.4b[\\idx\\]\"\n" \ + ".set vm,\\reg\n" \ + ".set h,\\idx / 2\n" \ + ".set l,\\idx %% 2\n" \ + ".endif\n" \ + ".endr\n" \ + ".endr\n" \ + ".ifndef vd\n" \ + ".error \"Bad operand \\opd\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vn\n" \ + ".error \"Bad operand \\opn\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef vm\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef h\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".ifndef l\n" \ + ".error \"Bad operand \\opm\"\n" \ + ".exitm\n" \ + ".endif\n" \ + ".int 0x6f80e000 | vd | (vn << 5) | (vm << 16) | (l << 21) | (h << 11)\n" \ + ".endm\n" diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/generic.cpp new file mode 100644 index 0000000000..d026dc54f3 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8/generic.cpp @@ -0,0 +1,343 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +#ifdef NO_DOT_IN_TOOLCHAIN +#include "dot_toolchain_support.h" +#endif + +namespace arm_gemm +{ +void a64_gemm_u8_12x8(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K) +{ + const uint8_t *a_ptr = Apanel; + uint32_t *c_ptr = Cpanel; + // We divide K by 4 because the udot instruction processes 4 elements at a time. + const int W = K / 4; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + const int oddk = (W & 1); + const int init_value_k = ((W + 1) / 2) - 1; + for(int yb = 0; yb < ablocks; yb++) + { + const uint8_t *a_ptr0 = a_ptr; + const uint8_t *b_ptr = Bpanel; + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + int k = init_value_k; + register uint8x16_t a0 asm("v0"); + register uint8x16_t a1 asm("v1"); + register uint8x16_t b0 asm("v2"); + register uint8x16_t b1 asm("v3"); + register uint8x16_t b2 asm("v4"); + register uint8x16_t a0a asm("v5"); + register uint8x16_t a1a asm("v6"); + __asm __volatile( +#ifdef NO_DOT_IN_TOOLCHAIN + _DECLARE_UDOT +#else + ".arch armv8.2-a+dotprod\n" +#endif + // Initialize result registers, load initial operands, prime prefetches. + "movi v8.4s, #0x0\n" + "ldr %q[a0], [%[a_ptr]]\n" + "movi v9.4s, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v10.4s, #0x0\n" + "ldr %q[a1], [%[a_ptr], #16]\n" + "movi v11.4s, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v18.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #192]") "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") "movi v20.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") "movi v21.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #384]") + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + // Loop proper + "1:\n" + "udot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "udot v9.4s , %[b0].16b, %[a0].4b[1]\n" + + "ldr %q[b2], [%[b_ptr], #32]\n" + "udot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "udot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "ldr %q[a0a], [%[a_ptr], #32]\n" + "udot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "udot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "ldr %q[a1a], [%[a_ptr], #48]\n" + "udot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "udot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #48]\n" + + "udot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "udot v17.4s, %[b1].16b, %[a0].4b[1]\n" ASM_PREFETCH("[%[a_ptr], #320]") + "udot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "udot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "udot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "udot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "udot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "udot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #64]\n" + + "udot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "udot v25.4s, %[b2].16b, %[a0].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #448]") + "udot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "udot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "udot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "udot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "udot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "udot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "ldr %q[b2], [%[b_ptr], #80]\n" + + "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n" + "udot v9.4s , %[b0].16b, %[a0a].4b[1]\n" + "ldr %q[a0], [%[a_ptr], #64]\n" + "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n" + "udot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "ldr %q[a1], [%[a_ptr], #80]\n" + "udot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n" + "udot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #96]\n" + + "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n" + "udot v17.4s, %[b1].16b, %[a0a].4b[1]\n" ASM_PREFETCH("[%[b_ptr], #512]") + "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n" + "udot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "udot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n" + "udot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #112]\n" + + "udot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "udot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n" + "udot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n" + "udot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "subs %w[k], %w[k], #1\n" + "udot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "udot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + "bne 1b\n" + + // Target to use when K is 1 or 2 (i.e. zero iterations of main loop) + "4:\n" + + // Branch to alternative tail for odd K + "cbnz %w[oddk], 2f\n" + + // Detached final iteration (even K) + "udot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "udot v9.4s , %[b0].16b, %[a0].4b[1]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "udot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "udot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "ldr %q[a0a], [%[a_ptr], #32]\n" + "udot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "udot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "ldr %q[a1a], [%[a_ptr], #48]\n" + "udot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "udot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "ldr %q[b0], [%[b_ptr], #48]\n" + + "udot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "udot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "udot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "udot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "udot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "udot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "udot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "udot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "ldr %q[b1], [%[b_ptr], #64]\n" + + "udot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "udot v25.4s, %[b2].16b, %[a0].4b[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "udot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "udot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "udot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "udot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "udot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "udot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "ldr %q[b2], [%[b_ptr], #80]\n" + + "udot v8.4s , %[b0].16b, %[a0a].4b[0]\n" + + "udot v16.4s, %[b1].16b, %[a0a].4b[0]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "udot v9.4s , %[b0].16b, %[a0a].4b[1]\n" + "str q8, [%[c_ptr], #0]\n" + "udot v17.4s, %[b1].16b, %[a0a].4b[1]\n" + "str q16, [%[c_ptr], #16]\n" + "udot v24.4s, %[b2].16b, %[a0a].4b[0]\n" + "str q24, [%[c_ptr], #32]\n" + + "udot v25.4s, %[b2].16b, %[a0a].4b[1]\n" + "str q9, [%[c_ptr], #48]\n" + "udot v10.4s, %[b0].16b, %[a0a].4b[2]\n" + "str q17, [%[c_ptr], #64]\n" + "udot v18.4s, %[b1].16b, %[a0a].4b[2]\n" + "str q25, [%[c_ptr], #80]\n" + "udot v26.4s, %[b2].16b, %[a0a].4b[2]\n" + "str q10, [%[c_ptr], #96]\n" + + "udot v11.4s, %[b0].16b, %[a0a].4b[3]\n" + "str q18, [%[c_ptr], #112]\n" + "udot v19.4s, %[b1].16b, %[a0a].4b[3]\n" + "str q26, [%[c_ptr], #128]\n" + "udot v27.4s, %[b2].16b, %[a0a].4b[3]\n" + "str q11, [%[c_ptr], #144]\n" + + "udot v12.4s, %[b0].16b, %[a1a].4b[0]\n" + "str q19, [%[c_ptr], #160]\n" + "udot v20.4s, %[b1].16b, %[a1a].4b[0]\n" + "str q27, [%[c_ptr], #176]\n" + "udot v28.4s, %[b2].16b, %[a1a].4b[0]\n" + "str q12, [%[c_ptr], #192]\n" + + "udot v13.4s, %[b0].16b, %[a1a].4b[1]\n" + "str q20, [%[c_ptr], #208]\n" + "udot v21.4s, %[b1].16b, %[a1a].4b[1]\n" + "str q28, [%[c_ptr], #224]\n" + "udot v29.4s, %[b2].16b, %[a1a].4b[1]\n" + "str q13, [%[c_ptr], #240]\n" + + "udot v14.4s, %[b0].16b, %[a1a].4b[2]\n" + "str q21, [%[c_ptr], #256]\n" + "udot v22.4s, %[b1].16b, %[a1a].4b[2]\n" + "str q29, [%[c_ptr], #272]\n" + "udot v30.4s, %[b2].16b, %[a1a].4b[2]\n" + "str q14, [%[c_ptr], #288]\n" + + "udot v15.4s, %[b0].16b, %[a1a].4b[3]\n" + "str q22, [%[c_ptr], #304]\n" + "udot v23.4s, %[b1].16b, %[a1a].4b[3]\n" + "str q30, [%[c_ptr], #320]\n" + "udot v31.4s, %[b2].16b, %[a1a].4b[3]\n" + "str q15, [%[c_ptr], #336]\n" + + "b 3f\n" + + // Detached final iteration (odd K) + "2:\n" + "udot v8.4s , %[b0].16b, %[a0].4b[0]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "udot v16.4s, %[b1].16b, %[a0].4b[0]\n" + "udot v9.4s , %[b0].16b, %[a0].4b[1]\n" + "str q8, [%[c_ptr], #0]\n" + "udot v17.4s, %[b1].16b, %[a0].4b[1]\n" + "str q16, [%[c_ptr], #16]\n" + "udot v24.4s, %[b2].16b, %[a0].4b[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "add %[a_ptr], %[a_ptr], #32\n" + "str q24, [%[c_ptr], #32]\n" + "udot v25.4s, %[b2].16b, %[a0].4b[1]\n" + "str q9, [%[c_ptr], #48]\n" + + "udot v10.4s, %[b0].16b, %[a0].4b[2]\n" + "str q17, [%[c_ptr], #64]\n" + "udot v18.4s, %[b1].16b, %[a0].4b[2]\n" + "str q25, [%[c_ptr], #80]\n" + "udot v26.4s, %[b2].16b, %[a0].4b[2]\n" + "str q10, [%[c_ptr], #96]\n" + + "udot v11.4s, %[b0].16b, %[a0].4b[3]\n" + "str q18, [%[c_ptr], #112]\n" + "udot v19.4s, %[b1].16b, %[a0].4b[3]\n" + "str q26, [%[c_ptr], #128]\n" + "udot v27.4s, %[b2].16b, %[a0].4b[3]\n" + "str q11, [%[c_ptr], #144]\n" + + "udot v12.4s, %[b0].16b, %[a1].4b[0]\n" + "str q19, [%[c_ptr], #160]\n" + "udot v20.4s, %[b1].16b, %[a1].4b[0]\n" + "str q27, [%[c_ptr], #176]\n" + "udot v28.4s, %[b2].16b, %[a1].4b[0]\n" + "str q12, [%[c_ptr], #192]\n" + + "udot v13.4s, %[b0].16b, %[a1].4b[1]\n" + "str q20, [%[c_ptr], #208]\n" + "udot v21.4s, %[b1].16b, %[a1].4b[1]\n" + "str q28, [%[c_ptr], #224]\n" + "udot v29.4s, %[b2].16b, %[a1].4b[1]\n" + "str q13, [%[c_ptr], #240]\n" + + "udot v14.4s, %[b0].16b, %[a1].4b[2]\n" + "str q21, [%[c_ptr], #256]\n" + "udot v22.4s, %[b1].16b, %[a1].4b[2]\n" + "str q29, [%[c_ptr], #272]\n" + "udot v30.4s, %[b2].16b, %[a1].4b[2]\n" + "str q14, [%[c_ptr], #288]\n" + + "udot v15.4s, %[b0].16b, %[a1].4b[3]\n" + "str q22, [%[c_ptr], #304]\n" + "udot v23.4s, %[b1].16b, %[a1].4b[3]\n" + "str q30, [%[c_ptr], #320]\n" + "udot v31.4s, %[b2].16b, %[a1].4b[3]\n" + "str q15, [%[c_ptr], #336]\n" + + // Common tail + "3:\n" + "str q23, [%[c_ptr], #352]\n" + "str q31, [%[c_ptr], #368]\n" + "add %[c_ptr], %[c_ptr], #384\n" + +#ifdef NO_DOT_IN_TOOLCHAIN + ".purgem udot\n" +#endif + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k) + : [oddk] "r"(oddk) + : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp new file mode 100644 index 0000000000..5aa5291a29 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +namespace arm_gemm +{ +// Kernel definition +void a64_gemm_u8_4x4(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K); + +class gemm_u8_4x4 +{ +public: + typedef uint8_t operand_type; + typedef uint32_t result_type; + + typedef void (*kern_type)(const uint8_t *, const uint8_t *, uint32_t *, int, int, int); + + /* Describes the data layout for A input */ + static const int A_interleave = 4; + static const int A_block = 16; + static const bool A_transpose = false; + + /* Same for B input */ + static const int B_interleave = 4; + static const int B_block = 16; + static const bool B_transpose = true; + + /* Kernel blocking parameters */ + static const int out_width = 4; + static const int out_height = 4; + static const int k_unroll = 16; + + kern_type kernel = nullptr; + + gemm_u8_4x4(const CPUInfo *ci) + { + kernel = a64_gemm_u8_4x4; + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4/generic.cpp new file mode 100644 index 0000000000..0a881ffde3 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4/generic.cpp @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +namespace arm_gemm +{ +void a64_gemm_u8_4x4(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K) +{ + const uint8_t *a_ptr = Apanel; + uint32_t *c_ptr = Cpanel; + K /= 16; + + for(int yb = 0; yb < ablocks; yb++) + { + const uint8_t *a_ptr0 = a_ptr; + const uint8_t *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + + int k = K - 1; + + register uint8x16_t b0 asm("v4"); + register uint8x16_t b1 asm("v5"); + register uint8x16_t b2 asm("v6"); + register uint8x16_t b3 asm("v7"); + + __asm __volatile( + "movi v16.4s, #0x0\n" + "ldr q0, [%[a_ptr]]\n" + "movi v17.4s, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v18.4s, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v19.4s, #0x0\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "movi v20.4s, #0x0\n" + "ldr %q[b3], [%[b_ptr], #48]\n" + "movi v21.4s, #0x0\n" + "ldr q1, [%[a_ptr], #16]\n" + "movi v22.4s, #0x0\n" + "ldr q2, [%[a_ptr], #32]\n" + "movi v23.4s, #0x0\n" + "ldr q3, [%[a_ptr], #48]\n" + "movi v24.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v26.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v27.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") "movi v28.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v29.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #192]") "movi v30.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #256]") "movi v31.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") + + "umull v12.8h, v0.8b, %[b0].8b\n" + "add %[a_ptr], %[a_ptr], #64\n" + "umull v13.8h, v0.8b, %[b1].8b\n" + "umull v14.8h, v0.8b, %[b2].8b\n" + "add %[b_ptr], %[b_ptr], #64\n" + "umull v15.8h, v0.8b, %[b3].8b\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 2f\n" + + "1:\n" + "uadalp v16.4s, v12.8h\n" + "umull2 v12.8h, v0.16b, %[b0].16b\n" + "uadalp v17.4s, v13.8h\n" + "umull2 v13.8h, v0.16b, %[b1].16b\n" + "uadalp v18.4s, v14.8h\n" + "umull2 v14.8h, v0.16b, %[b2].16b\n" + "uadalp v19.4s, v15.8h\n" + "umull2 v15.8h, v0.16b, %[b3].16b\n" + "ldr q0, [%[a_ptr]]\n" + + "uadalp v16.4s, v12.8h\n" + "umull v12.8h, v1.8b, %[b0].8b\n" + "uadalp v17.4s, v13.8h\n" + "umull v13.8h, v1.8b, %[b1].8b\n" + "subs %w[k], %w[k], #1\n" + "uadalp v18.4s, v14.8h\n" + "umull v14.8h, v1.8b, %[b2].8b\n" + "uadalp v19.4s, v15.8h\n" + "umull v15.8h, v1.8b, %[b3].8b\n" + + "uadalp v20.4s, v12.8h\n" + "umull2 v12.8h, v1.16b, %[b0].16b\n" + "uadalp v21.4s, v13.8h\n" + "umull2 v13.8h, v1.16b, %[b1].16b\n" ASM_PREFETCH("[%[a_ptr], #256]") + "uadalp v22.4s, v14.8h\n" + "umull2 v14.8h, v1.16b, %[b2].16b\n" + "uadalp v23.4s, v15.8h\n" + "umull2 v15.8h, v1.16b, %[b3].16b\n" + "ldr q1, [%[a_ptr], #16]\n" + + "uadalp v20.4s, v12.8h\n" + "umull v12.8h, v2.8b, %[b0].8b\n" + "uadalp v21.4s, v13.8h\n" + "umull v13.8h, v2.8b, %[b1].8b\n" ASM_PREFETCH("[%[b_ptr], #256]") + "uadalp v22.4s, v14.8h\n" + "umull v14.8h, v2.8b, %[b2].8b\n" + "uadalp v23.4s, v15.8h\n" + "umull v15.8h, v2.8b, %[b3].8b\n" + + "uadalp v24.4s, v12.8h\n" + "umull2 v12.8h, v2.16b, %[b0].16b\n" + "uadalp v25.4s, v13.8h\n" + "umull2 v13.8h, v2.16b, %[b1].16b\n" + "uadalp v26.4s, v14.8h\n" + "umull2 v14.8h, v2.16b, %[b2].16b\n" + "uadalp v27.4s, v15.8h\n" + "umull2 v15.8h, v2.16b, %[b3].16b\n" + "ldr q2, [%[a_ptr], #32]\n" + + "uadalp v24.4s, v12.8h\n" + "umull v12.8h, v3.8b, %[b0].8b\n" + "uadalp v25.4s, v13.8h\n" + "umull v13.8h, v3.8b, %[b1].8b\n" + "uadalp v26.4s, v14.8h\n" + "umull v14.8h, v3.8b, %[b2].8b\n" + "uadalp v27.4s, v15.8h\n" + "umull v15.8h, v3.8b, %[b3].8b\n" + + "uadalp v28.4s, v12.8h\n" + "umull2 v12.8h, v3.16b, %[b0].16b\n" + "ldr %q[b0], [%[b_ptr]]\n" + "uadalp v29.4s, v13.8h\n" + "umull2 v13.8h, v3.16b, %[b1].16b\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "uadalp v30.4s, v14.8h\n" + "umull2 v14.8h, v3.16b, %[b2].16b\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "uadalp v31.4s, v15.8h\n" + "umull2 v15.8h, v3.16b, %[b3].16b\n" + "ldr %q[b3], [%[b_ptr], #48]\n" + + "uadalp v28.4s, v12.8h\n" + "umull v12.8h, v0.8b, %[b0].8b\n" + "add %[b_ptr], %[b_ptr], #64\n" + "uadalp v29.4s, v13.8h\n" + "umull v13.8h, v0.8b, %[b1].8b\n" + "ldr q3, [%[a_ptr], #48]\n" + "uadalp v30.4s, v14.8h\n" + "umull v14.8h, v0.8b, %[b2].8b\n" + "add %[a_ptr], %[a_ptr], #64\n" + "uadalp v31.4s, v15.8h\n" + "umull v15.8h, v0.8b, %[b3].8b\n" + "bne 1b\n" + + // Branch target + "2:\n" + "uadalp v16.4s, v12.8h\n" + "umull2 v12.8h, v0.16b, %[b0].16b\n" + "uadalp v17.4s, v13.8h\n" + "umull2 v13.8h, v0.16b, %[b1].16b\n" + "uadalp v18.4s, v14.8h\n" + "umull2 v14.8h, v0.16b, %[b2].16b\n" + "uadalp v19.4s, v15.8h\n" + "umull2 v15.8h, v0.16b, %[b3].16b\n" + + "uadalp v16.4s, v12.8h\n" + "umull v12.8h, v1.8b, %[b0].8b\n" + "uadalp v17.4s, v13.8h\n" + "umull v13.8h, v1.8b, %[b1].8b\n" + "uadalp v18.4s, v14.8h\n" + "umull v14.8h, v1.8b, %[b2].8b\n" + "uadalp v19.4s, v15.8h\n" + "umull v15.8h, v1.8b, %[b3].8b\n" + + "uadalp v20.4s, v12.8h\n" + "umull2 v12.8h, v1.16b, %[b0].16b\n" + "uadalp v21.4s, v13.8h\n" + "umull2 v13.8h, v1.16b, %[b1].16b\n" + "uadalp v22.4s, v14.8h\n" + "umull2 v14.8h, v1.16b, %[b2].16b\n" + "uadalp v23.4s, v15.8h\n" + "umull2 v15.8h, v1.16b, %[b3].16b\n" + + "uadalp v20.4s, v12.8h\n" + "umull v12.8h, v2.8b, %[b0].8b\n" + "uadalp v21.4s, v13.8h\n" + "umull v13.8h, v2.8b, %[b1].8b\n" + "uadalp v22.4s, v14.8h\n" + "umull v14.8h, v2.8b, %[b2].8b\n" + "uadalp v23.4s, v15.8h\n" + "umull v15.8h, v2.8b, %[b3].8b\n" + + "uadalp v24.4s, v12.8h\n" + "umull2 v12.8h, v2.16b, %[b0].16b\n" + "uadalp v25.4s, v13.8h\n" + "umull2 v13.8h, v2.16b, %[b1].16b\n" + "uadalp v26.4s, v14.8h\n" + "umull2 v14.8h, v2.16b, %[b2].16b\n" + "uadalp v27.4s, v15.8h\n" + "umull2 v15.8h, v2.16b, %[b3].16b\n" + + "uadalp v24.4s, v12.8h\n" + "umull v12.8h, v3.8b, %[b0].8b\n" + "uadalp v25.4s, v13.8h\n" + "umull v13.8h, v3.8b, %[b1].8b\n" + "uadalp v26.4s, v14.8h\n" + "umull v14.8h, v3.8b, %[b2].8b\n" + "uadalp v27.4s, v15.8h\n" + "umull v15.8h, v3.8b, %[b3].8b\n" + + "uadalp v28.4s, v12.8h\n" + "umull2 v12.8h, v3.16b, %[b0].16b\n" + "uadalp v29.4s, v13.8h\n" + "umull2 v13.8h, v3.16b, %[b1].16b\n" + "uadalp v30.4s, v14.8h\n" + "umull2 v14.8h, v3.16b, %[b2].16b\n" + "uadalp v31.4s, v15.8h\n" + "umull2 v15.8h, v3.16b, %[b3].16b\n" + + "uadalp v28.4s, v12.8h\n" + "uadalp v29.4s, v13.8h\n" + "uadalp v30.4s, v14.8h\n" + "uadalp v31.4s, v15.8h\n" + + "addp v16.4s, v16.4s, v17.4s\n" + "addp v17.4s, v18.4s, v19.4s\n" + "addp v18.4s, v20.4s, v21.4s\n" + "addp v19.4s, v22.4s, v23.4s\n" + "addp v20.4s, v24.4s, v25.4s\n" + "addp v21.4s, v26.4s, v27.4s\n" + "addp v22.4s, v28.4s, v29.4s\n" + "addp v23.4s, v30.4s, v31.4s\n" + + "addp v16.4s, v16.4s, v17.4s\n" + "addp v17.4s, v18.4s, v19.4s\n" + "addp v18.4s, v20.4s, v21.4s\n" + "addp v19.4s, v22.4s, v23.4s\n" + + "str q16, [%[c_ptr]]\n" + "str q17, [%[c_ptr], #16]\n" + "str q18, [%[c_ptr], #32]\n" + "str q19, [%[c_ptr], #48]\n" + "add %[c_ptr], %[c_ptr], #64\n" + + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [b3] "+w"(b3), + [k] "+r"(k) + : + : "x20", "x21", "v0", "v1", "v2", "v3", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp new file mode 100644 index 0000000000..77ec59aa35 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) + +#include "arm_gemm.hpp" + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_hgemm_asimd_24x8(const __fp16 *, const __fp16 *, __fp16 *, int, int, int); +void a64_hgemm_asimd_24x8_a55r1(const __fp16 *, const __fp16 *, __fp16 *, int, int, int); + +// 24x8 HGEMM "strategy" class. Describes the kernel properties. +// +// The generic "gemm_opt" function will instantiate one of these (allowing +// the constructor to pick a kernel implementation). +class hgemm_24x8 +{ +public: + typedef __fp16 operand_type; + typedef __fp16 result_type; + + typedef void (*kern_type)(const __fp16 *, const __fp16 *, __fp16 *, int, int, int); + + static const int A_block = 1; + static const int A_interleave = 8; + static const bool A_transpose = false; + + static const int B_block = 1; + static const int B_interleave = 24; + static const bool B_transpose = true; + + static const int out_width = 24; + static const int out_height = 8; + static const int k_unroll = 1; + + // Default to the generic kernel + kern_type kernel = a64_hgemm_asimd_24x8; + + hgemm_24x8(const CPUInfo *ci) + { + if(ci->get_cpu_model() == CPUModel::A55r1) + { + kernel = a64_hgemm_asimd_24x8_a55r1; + } + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp new file mode 100644 index 0000000000..d59618dd54 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +// Kernel implementation. +// +// Assume that "Apanel" points to a chunk of A blocks (each size 8xK) in read-order. +// Assume that "Bpanel" points to a chunk of B blocks (each size 12xK) in read-order. +// Assume that "Cpanel" points to a chunk of C output blocks (each size +// 12x8), the chunks being arranged in a row major fashion. +// +// Note that the intent of this is that either ablocks or bblocks will be 1 +// - this construction allows the output loop to proceed in either order. + +namespace arm_gemm +{ +void a64_hgemm_asimd_24x8_a55r1(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K) +{ + const __fp16 *a_ptr = Apanel; + __fp16 *c_ptr = Cpanel; + + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k_iters = ((K + 1) / 2) - 1; + + for(int yb = 0; yb < ablocks; yb++) + { + const __fp16 *a_ptr0 = a_ptr; + const __fp16 *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + int k = k_iters; + a_ptr = a_ptr0; + + // As A55 requires 64-bit loads anyway, just use 64 bits of the + // "A" operands to save on "ins" instructions. Since A55 is + // in-order, two sets of "A" operands and one set of "B" is + // sufficient. + register float16x8_t a0 asm("v0"); + register float16x8_t a1 asm("v1"); + register float16x8_t a0a asm("v2"); + register float16x8_t a1a asm("v3"); + register float16x8_t b0 asm("v4"); + register float16x8_t b1 asm("v5"); + register float16x8_t b2 asm("v6"); + + __asm __volatile( + // Enable FP16 extensions + ".arch armv8.2-a+fp16\n" + // Initialize result registers, load initial operands, prime prefetches. + "movi v8.8h, #0x0\n" + "ldr %d[a0], [%[a_ptr]]\n" + "movi v9.8h, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v10.8h, #0x0\n" + "ldr %d[a1], [%[a_ptr], #8]\n" + "movi v11.8h, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v12.8h, #0x0\n" + "movi v13.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") + "movi v14.8h, #0x0\n" + "movi v15.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") + "movi v16.8h, #0x0\n" + "movi v17.8h, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") + "movi v18.8h, #0x0\n" + "movi v19.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") + "movi v20.8h, #0x0\n" + "movi v21.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") + "movi v22.8h, #0x0\n" + "movi v23.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") + "movi v24.8h, #0x0\n" + "movi v25.8h, #0x0\n" + "movi v26.8h, #0x0\n" + "movi v27.8h, #0x0\n" + "movi v28.8h, #0x0\n" + "movi v29.8h, #0x0\n" + "movi v30.8h, #0x0\n" + "movi v31.8h, #0x0\n" + + // The loop is offset by these two instructions which must + // always be executed. + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "ldr %d[b2], [%[b_ptr], #32]\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr x20, [%[b_ptr], #40]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "subs %w[k], %w[k], #1\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %d[a0a], [%[a_ptr], #16]\n" + + "fmla v12.8h, %[b0].8h, %[a1].h[0]\n" + "ins %[b2].d[1], x20\n" + "fmla v13.8h, %[b0].8h, %[a1].h[1]\n" + "fmla v14.8h, %[b0].8h, %[a1].h[2]\n" + "fmla v15.8h, %[b0].8h, %[a1].h[3]\n" + "ldr %d[a1a], [%[a_ptr], #24]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "ldr %d[b0], [%[b_ptr], #48]\n" + + "fmla v20.8h, %[b1].8h, %[a1].h[0]\n" + "fmla v21.8h, %[b1].8h, %[a1].h[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "fmla v22.8h, %[b1].8h, %[a1].h[2]\n" + "fmla v23.8h, %[b1].8h, %[a1].h[3]\n" + "ldr %d[b1], [%[b_ptr], #64]\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "ins %[b0].d[1], x20\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "ldr x20, [%[b_ptr], #72]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" ASM_PREFETCH("[%[a_ptr], #128]") + + "fmla v28.8h, %[b2].8h, %[a1].h[0]\n" + "fmla v29.8h, %[b2].8h, %[a1].h[1]\n" ASM_PREFETCH("[%[b_ptr], #384]") + "fmla v30.8h, %[b2].8h, %[a1].h[2]\n" + "fmla v31.8h, %[b2].8h, %[a1].h[3]\n" + "ldr %d[b2], [%[b_ptr], #80]\n" + + // Unroll 1 + "fmla v8.8h , %[b0].8h, %[a0a].h[0]\n" + "ins %[b1].d[1], x20\n" + "fmla v9.8h , %[b0].8h, %[a0a].h[1]\n" + "ldr x20, [%[b_ptr], #88]\n" + "fmla v10.8h, %[b0].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0a].h[3]\n" + "ldr %d[a0], [%[a_ptr], #32]\n" + + "fmla v12.8h, %[b0].8h, %[a1a].h[0]\n" + "ins %[b2].d[1], x20\n" + "fmla v13.8h, %[b0].8h, %[a1a].h[1]\n" + "fmla v14.8h, %[b0].8h, %[a1a].h[2]\n" + "fmla v15.8h, %[b0].8h, %[a1a].h[3]\n" + "ldr %d[a1], [%[a_ptr], #40]\n" + + "fmla v16.8h, %[b1].8h, %[a0a].h[0]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "fmla v17.8h, %[b1].8h, %[a0a].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0a].h[3]\n" + "ldr %d[b0], [%[b_ptr], #96]\n" + + "fmla v20.8h, %[b1].8h, %[a1a].h[0]\n" + "fmla v21.8h, %[b1].8h, %[a1a].h[1]\n" + "ldr x20, [%[b_ptr], #104]\n" + "fmla v22.8h, %[b1].8h, %[a1a].h[2]\n" + "fmla v23.8h, %[b1].8h, %[a1a].h[3]\n" + "ldr %d[b1], [%[b_ptr], #112]\n" + + "fmla v24.8h, %[b2].8h, %[a0a].h[0]\n" + "ins %[b0].d[1], x20\n" + "fmla v25.8h, %[b2].8h, %[a0a].h[1]\n" + "ldr x20, [%[b_ptr], #120]\n" + "fmla v26.8h, %[b2].8h, %[a0a].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0a].h[3]\n" + + "fmla v28.8h, %[b2].8h, %[a1a].h[0]\n" ASM_PREFETCH("[%[b_ptr], #448]") + "fmla v29.8h, %[b2].8h, %[a1a].h[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v30.8h, %[b2].8h, %[a1a].h[2]\n" + "ins %[b1].d[1], x20\n" + "fmla v31.8h, %[b2].8h, %[a1a].h[3]\n" + "ldr %d[b2], [%[b_ptr], #32]\n" + + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "bne 1b\n" + + "4:\n" + + // Start final iteration - branch off to "odd" code before we load a0a + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr x20, [%[b_ptr], #40]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "cbnz %w[oddk], 2f\n" + + // Even K continuation + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %d[a0a], [%[a_ptr], #16]\n" + + "fmla v12.8h, %[b0].8h, %[a1].h[0]\n" + "ins %[b2].d[1], x20\n" + "fmla v13.8h, %[b0].8h, %[a1].h[1]\n" ASM_PREFETCHW("[%[c_ptr]]") + "fmla v14.8h, %[b0].8h, %[a1].h[2]\n" + "fmla v15.8h, %[b0].8h, %[a1].h[3]\n" + "ldr %d[a1a], [%[a_ptr], #24]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #64]") + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "ldr %d[b0], [%[b_ptr], #48]\n" + + "fmla v20.8h, %[b1].8h, %[a1].h[0]\n" + "fmla v21.8h, %[b1].8h, %[a1].h[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "fmla v22.8h, %[b1].8h, %[a1].h[2]\n" + "fmla v23.8h, %[b1].8h, %[a1].h[3]\n" + "ldr %d[b1], [%[b_ptr], #64]\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "ins %[b0].d[1], x20\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "ldr x20, [%[b_ptr], #72]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" ASM_PREFETCHW("[%[c_ptr], #128]") + + "fmla v28.8h, %[b2].8h, %[a1].h[0]\n" + "fmla v29.8h, %[b2].8h, %[a1].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #192]") + "fmla v30.8h, %[b2].8h, %[a1].h[2]\n" + "fmla v31.8h, %[b2].8h, %[a1].h[3]\n" + "ldr %d[b2], [%[b_ptr], #80]\n" + + "fmla v8.8h , %[b0].8h, %[a0a].h[0]\n" + "ins %[b1].d[1], x20\n" + "fmla v9.8h , %[b0].8h, %[a0a].h[1]\n" + "ldr x20, [%[b_ptr], #88]\n" + "fmla v10.8h, %[b0].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0a].h[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]") + + "fmla v12.8h, %[b0].8h, %[a1a].h[0]\n" + "ins %[b2].d[1], x20\n" + "fmla v13.8h, %[b0].8h, %[a1a].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #320]") + "fmla v14.8h, %[b0].8h, %[a1a].h[2]\n" + "fmla v15.8h, %[b0].8h, %[a1a].h[3]\n" + "ldr %d[a1], [%[a_ptr], #40]\n" + + "fmla v16.8h, %[b1].8h, %[a0a].h[0]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "fmla v17.8h, %[b1].8h, %[a0a].h[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]") + "fmla v18.8h, %[b1].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0a].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]") + + "fmla v20.8h, %[b1].8h, %[a1a].h[0]\n" + "fmla v21.8h, %[b1].8h, %[a1a].h[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") + "fmla v22.8h, %[b1].8h, %[a1a].h[2]\n" + "fmla v23.8h, %[b1].8h, %[a1a].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") + + "fmla v24.8h, %[b2].8h, %[a0a].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0a].h[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #640]") + "fmla v26.8h, %[b2].8h, %[a0a].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0a].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]") + + "fmla v28.8h, %[b2].8h, %[a1a].h[0]\n" + "fmla v29.8h, %[b2].8h, %[a1a].h[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v30.8h, %[b2].8h, %[a1a].h[2]\n" + "fmla v31.8h, %[b2].8h, %[a1a].h[3]\n" + "b 3f\n" + + "2:\n" + + // Odd tail + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" ASM_PREFETCHW("[%[c_ptr]]") + + "fmla v12.8h, %[b0].8h, %[a1].h[0]\n" + "ins %[b2].d[1], x20\n" + "fmla v13.8h, %[b0].8h, %[a1].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #64]") + "fmla v14.8h, %[b0].8h, %[a1].h[2]\n" + "add %[a_ptr], %[a_ptr], #16\n" + "fmla v15.8h, %[b0].8h, %[a1].h[3]\n" ASM_PREFETCHW("[%[c_ptr], #128]") + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #192]") + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]") + + "fmla v20.8h, %[b1].8h, %[a1].h[0]\n" + "fmla v21.8h, %[b1].8h, %[a1].h[1]\n" ASM_PREFETCHW("[%[c_ptr], #320]") + "fmla v22.8h, %[b1].8h, %[a1].h[2]\n" + "fmla v23.8h, %[b1].8h, %[a1].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]") + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]") + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]") + + "fmla v28.8h, %[b2].8h, %[a1].h[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") "fmla v29.8h, %[b2].8h, %[a1].h[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") "fmla v30.8h, %[b2].8h, %[a1].h[2]\n" + ASM_PREFETCHWL2("[%[c_ptr], #640]") "fmla v31.8h, %[b2].8h, %[a1].h[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]") + + // Common tail + // A55 won't dual issue these stores with anything else, so + // simplest to do them all in this common code. + "3:\n" + "str q8, [%[c_ptr]]\n" + "str q16, [%[c_ptr], #16]\n" + "str q24, [%[c_ptr], #32]\n" + "str q9, [%[c_ptr], #48]\n" + "str q17, [%[c_ptr], #64]\n" + "str q25, [%[c_ptr], #80]\n" + "str q10, [%[c_ptr], #96]\n" + "str q18, [%[c_ptr], #112]\n" + "str q26, [%[c_ptr], #128]\n" + "str q11, [%[c_ptr], #144]\n" + "str q19, [%[c_ptr], #160]\n" + "str q27, [%[c_ptr], #176]\n" + "str q12, [%[c_ptr], #192]\n" + "str q20, [%[c_ptr], #208]\n" + "str q28, [%[c_ptr], #224]\n" + "str q13, [%[c_ptr], #240]\n" + "str q21, [%[c_ptr], #256]\n" + "str q29, [%[c_ptr], #272]\n" + "str q14, [%[c_ptr], #288]\n" + "str q22, [%[c_ptr], #304]\n" + "str q30, [%[c_ptr], #320]\n" + "str q15, [%[c_ptr], #336]\n" + "str q23, [%[c_ptr], #352]\n" + "str q31, [%[c_ptr], #368]\n" + "5:\n" + "add %[c_ptr], %[c_ptr], #384\n" + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [a0] "=w"(a0), [a0a] "=w"(a0a), [a1] "=w"(a1), [a1a] "=w"(a1a), + [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [k] "+r"(k) + : [oddk] "r"(oddk) + : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); + } + } +} + +} // namespace arm_gemm + +#endif // __aarch64__ && __ARM_FEATURE_FP16_SCALAR_ARITHMETIC diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp new file mode 100644 index 0000000000..468d603484 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp @@ -0,0 +1,337 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +// Kernel implementation. +// +// Assume that "Apanel" points to a chunk of A blocks (each size 8xK) in read-order. +// Assume that "Bpanel" points to a chunk of B blocks (each size 12xK) in read-order. +// Assume that "Cpanel" points to a chunk of C output blocks (each size +// 12x8), the chunks being arranged in a row major fashion. +// +// Note that the intent of this is that either ablocks or bblocks will be 1 +// - this construction allows the output loop to proceed in either order. + +namespace arm_gemm +{ +void a64_hgemm_asimd_24x8(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K) +{ + const __fp16 *a_ptr = Apanel; + __fp16 *c_ptr = Cpanel; + + for(int yb = 0; yb < ablocks; yb++) + { + const __fp16 *a_ptr0 = a_ptr; + const __fp16 *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float16x8_t a0 asm("v0"); + register float16x8_t a0a asm("v1"); + register float16x8_t b0 asm("v2"); + register float16x8_t b1 asm("v3"); + register float16x8_t b2 asm("v4"); + register float16x8_t b0a asm("v5"); + register float16x8_t b1a asm("v6"); + register float16x8_t b2a asm("v7"); + + __asm __volatile( + ".arch armv8.2-a+fp16\n" + // Initialize result registers, load initial operands, prime prefetches. + "movi v8.8h, #0x0\n" + "ldr %q[a0], [%[a_ptr]]\n" + "movi v9.8h, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v10.8h, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v11.8h, #0x0\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "movi v12.8h, #0x0\n" + "ldr %q[b0a], [%[b_ptr], #48]\n" + "movi v13.8h, #0x0\n" + "ldr %q[b1a], [%[b_ptr], #64]\n" + "movi v14.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v15.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v16.8h, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v17.8h, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #192]") "movi v18.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v19.8h, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") + "movi v20.8h, #0x0\n" + "movi v21.8h, #0x0\n" + "movi v22.8h, #0x0\n" + "movi v23.8h, #0x0\n" + "movi v24.8h, #0x0\n" + "movi v25.8h, #0x0\n" + "movi v26.8h, #0x0\n" + "movi v27.8h, #0x0\n" + "movi v28.8h, #0x0\n" + "movi v29.8h, #0x0\n" + "movi v30.8h, #0x0\n" + "movi v31.8h, #0x0\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %q[a0a], [%[a_ptr], #16]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b2a], [%[b_ptr], #80]\n" + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + "ldr %q[b0], [%[b_ptr], #96]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" ASM_PREFETCH("[%[a_ptr], #128]") + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v20.8h, %[b1].8h, %[a0].h[4]\n" + "fmla v21.8h, %[b1].8h, %[a0].h[5]\n" + "fmla v22.8h, %[b1].8h, %[a0].h[6]\n" + "fmla v23.8h, %[b1].8h, %[a0].h[7]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" ASM_PREFETCH("[%[b_ptr], #288]") + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + "fmla v28.8h, %[b2].8h, %[a0].h[4]\n" + "fmla v29.8h, %[b2].8h, %[a0].h[5]\n" + "fmla v30.8h, %[b2].8h, %[a0].h[6]\n" + "fmla v31.8h, %[b2].8h, %[a0].h[7]\n" + "ldr %q[a0], [%[a_ptr], #32]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n" + "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n" + "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n" + "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n" + "ldr %q[b0a], [%[b_ptr], #48]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" ASM_PREFETCH("[%[b_ptr], #352]") + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + "fmla v20.8h, %[b1a].8h, %[a0a].h[4]\n" + "fmla v21.8h, %[b1a].8h, %[a0a].h[5]\n" + "fmla v22.8h, %[b1a].8h, %[a0a].h[6]\n" + "fmla v23.8h, %[b1a].8h, %[a0a].h[7]\n" + "ldr %q[b1a], [%[b_ptr], #64]\n" + + "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n" + "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n" + "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n" + "fmla v28.8h, %[b2a].8h, %[a0a].h[4]\n" + "fmla v29.8h, %[b2a].8h, %[a0a].h[5]\n" + "subs %w[k], %w[k], #1\n" + "fmla v30.8h, %[b2a].8h, %[a0a].h[6]\n" + "fmla v31.8h, %[b2a].8h, %[a0a].h[7]\n" + + "bne 1b\n" + "4:\n" + + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail. + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %q[a0a], [%[a_ptr], #16]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b2a], [%[b_ptr], #80]\n" + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "fmla v20.8h, %[b1].8h, %[a0].h[4]\n" + "fmla v21.8h, %[b1].8h, %[a0].h[5]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "fmla v22.8h, %[b1].8h, %[a0].h[6]\n" + "fmla v23.8h, %[b1].8h, %[a0].h[7]\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + "fmla v28.8h, %[b2].8h, %[a0].h[4]\n" + "fmla v29.8h, %[b2].8h, %[a0].h[5]\n" + "fmla v30.8h, %[b2].8h, %[a0].h[6]\n" + "fmla v31.8h, %[b2].8h, %[a0].h[7]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "str q8, [%[c_ptr]]\n" + "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n" + "str q16, [%[c_ptr], #16]\n" + + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "str q24, [%[c_ptr], #32]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "str q9, [%[c_ptr], #48]\n" + "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n" + "str q17, [%[c_ptr], #64]\n" + + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "str q25, [%[c_ptr], #80]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "str q10, [%[c_ptr], #96]\n" + "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n" + "str q18, [%[c_ptr], #112]\n" + + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "str q26, [%[c_ptr], #128]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + "str q11, [%[c_ptr], #144]\n" + "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n" + "str q19, [%[c_ptr], #160]\n" + + "fmla v12.8h, %[b0a].8h, %[a0a].h[4]\n" + "str q27, [%[c_ptr], #176]\n" + "fmla v20.8h, %[b1a].8h, %[a0a].h[4]\n" + "str q12, [%[c_ptr], #192]\n" + "fmla v28.8h, %[b2a].8h, %[a0a].h[4]\n" + "str q20, [%[c_ptr], #208]\n" + + "fmla v13.8h, %[b0a].8h, %[a0a].h[5]\n" + "str q28, [%[c_ptr], #224]\n" + "fmla v21.8h, %[b1a].8h, %[a0a].h[5]\n" + "str q13, [%[c_ptr], #240]\n" + "fmla v29.8h, %[b2a].8h, %[a0a].h[5]\n" + "str q21, [%[c_ptr], #256]\n" + + "fmla v14.8h, %[b0a].8h, %[a0a].h[6]\n" + "str q29, [%[c_ptr], #272]\n" + "fmla v22.8h, %[b1a].8h, %[a0a].h[6]\n" + "str q14, [%[c_ptr], #288]\n" + "fmla v30.8h, %[b2a].8h, %[a0a].h[6]\n" + "str q22, [%[c_ptr], #304]\n" + + "fmla v15.8h, %[b0a].8h, %[a0a].h[7]\n" + "str q30, [%[c_ptr], #320]\n" + "fmla v23.8h, %[b1a].8h, %[a0a].h[7]\n" + "str q15, [%[c_ptr], #336]\n" + "fmla v31.8h, %[b2a].8h, %[a0a].h[7]\n" + "b 3f\n" + + // Odd tail + "2:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "add %[a_ptr], %[a_ptr], #16\n" + "str q8, [%[c_ptr]]\n" + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "str q16, [%[c_ptr], #16]\n" + + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "str q24, [%[c_ptr], #32]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "str q9, [%[c_ptr], #48]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "str q17, [%[c_ptr], #64]\n" + + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "str q25, [%[c_ptr], #80]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "str q10, [%[c_ptr], #96]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "str q18, [%[c_ptr], #112]\n" + + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "str q26, [%[c_ptr], #128]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "str q11, [%[c_ptr], #144]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + "str q19, [%[c_ptr], #160]\n" + + "fmla v12.8h, %[b0].8h, %[a0].h[4]\n" + "str q27, [%[c_ptr], #176]\n" + "fmla v20.8h, %[b1].8h, %[a0].h[4]\n" + "str q12, [%[c_ptr], #192]\n" + "fmla v28.8h, %[b2].8h, %[a0].h[4]\n" + "str q20, [%[c_ptr], #208]\n" + + "fmla v13.8h, %[b0].8h, %[a0].h[5]\n" + "str q28, [%[c_ptr], #224]\n" + "fmla v21.8h, %[b1].8h, %[a0].h[5]\n" + "str q13, [%[c_ptr], #240]\n" + "fmla v29.8h, %[b2].8h, %[a0].h[5]\n" + "str q21, [%[c_ptr], #256]\n" + + "fmla v14.8h, %[b0].8h, %[a0].h[6]\n" + "str q29, [%[c_ptr], #272]\n" + "fmla v22.8h, %[b1].8h, %[a0].h[6]\n" + "str q14, [%[c_ptr], #288]\n" + "fmla v30.8h, %[b2].8h, %[a0].h[6]\n" + "str q22, [%[c_ptr], #304]\n" + + "fmla v15.8h, %[b0].8h, %[a0].h[7]\n" + "str q30, [%[c_ptr], #320]\n" + "fmla v23.8h, %[b1].8h, %[a0].h[7]\n" + "str q15, [%[c_ptr], #336]\n" + "fmla v31.8h, %[b2].8h, %[a0].h[7]\n" + + "3:\n" + "str q23, [%[c_ptr], #352]\n" + "str q31, [%[c_ptr], #368]\n" + "add %[c_ptr], %[c_ptr], #384\n" + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [a0] "+w"(a0), [a0a] "+w"(a0a), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), + [b0a] "+w"(b0a), [b1a] "+w"(b1a), [b2a] "+w"(b2a) + : [oddk] "r"(oddk) + : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif // __aarch64__ && __ARM_FEATURE_FP16_SCALAR_ARITHMETIC diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp new file mode 100644 index 0000000000..91a9e8de60 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_sgemm_asimd_12x8(const float *, const float *, float *, int, int, int); +void a64_sgemm_asimd_12x8_a53(const float *, const float *, float *, int, int, int); +void a64_sgemm_asimd_12x8_a55(const float *, const float *, float *, int, int, int); +void a64_sgemm_asimd_12x8_a55r1(const float *, const float *, float *, int, int, int); + +// 12x8 SGEMM "strategy" class. +// +// This describes the characteristics of a family of kernels, in terms of +// the required interleave properties and the output block size. +// +// All kernels in the family must share these characteristics. The actual +// kernel to be used can be chosen at runtime, based on the CPU_type +// structure. +class sgemm_12x8 +{ +public: + typedef float operand_type; + typedef float result_type; + + typedef void (*kern_type)(const float *, const float *, float *, int, int, int); + + /* Describes the data layout for A input */ + static const int A_interleave = 8; + static const int A_block = 1; + static const int A_transpose = 0; + + /* Same for B input */ + static const int B_interleave = 12; + static const int B_block = 1; + static const int B_transpose = 1; + + /* Kernel blocking parameters */ + static const int out_width = 12; + static const int out_height = 8; + static const int k_unroll = 1; + + kern_type kernel = a64_sgemm_asimd_12x8; + + sgemm_12x8(const CPUInfo *ci) + { + // Select specific kernel if available + switch(ci->get_cpu_model()) + { + case CPUModel::A53: + kernel = a64_sgemm_asimd_12x8_a53; + break; + + case CPUModel::A55r0: + kernel = a64_sgemm_asimd_12x8_a55; + break; + + case CPUModel::A55r1: + kernel = a64_sgemm_asimd_12x8_a55r1; + break; + + default: + /* Generic kernel is initialized by default. */ + break; + } + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp new file mode 100644 index 0000000000..618ebc733c --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a53.cpp @@ -0,0 +1,363 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +namespace arm_gemm +{ +void a64_sgemm_asimd_12x8_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) +{ + const float *a_ptr = Apanel; + float *c_ptr = Cpanel; + + for(int yb = 0; yb < ablocks; yb++) + { + const float *a_ptr0 = a_ptr; + const float *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float32x4_t a0 asm("v0"); + register float32x4_t a1 asm("v1"); + register float32x4_t b0 asm("v2"); + register float32x4_t b1 asm("v3"); + register float32x4_t b2 asm("v4"); + register float32x4_t a0a asm("v5"); + register float32x4_t a1a asm("v6"); + + __asm __volatile( + // Initialize result registers, load initial operands, prime prefetches. + "movi v8.4s, #0x0\n" + "ldr %q[a0], [%[a_ptr]]\n" + "movi v9.4s, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v10.4s, #0x0\n" + "ldr %q[a1], [%[a_ptr], #16]\n" + "movi v11.4s, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v18.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #192]") "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") "movi v20.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") "movi v21.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #384]") + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + "1:\n" + // Unroll 0 + "ldr %d[b2], [%[b_ptr], #32]\n" + "nop\n" + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + "ldr x20, [%[b_ptr], #40]\n" + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "subs %w[k], %w[k], #1\n" + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + + "ldr %d[a0a], [%[a_ptr], #32]\n" + "ins %[b2].d[1], x20\n" + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + "ldr x20, [%[a_ptr], #40]\n" + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + + "ldr %d[a1a], [%[a_ptr], #48]\n" + "ins %[a0a].d[1], x20\n" + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "ldr x20, [%[a_ptr], #56]\n" + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + + "ldr %d[b0], [%[b_ptr], #48]\n" + "ins %[a1a].d[1], x20\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + + ASM_PREFETCH("[%[a_ptr], #320]") + "ins %[b0].d[1], x20\n" + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + + ASM_PREFETCH("[%[b_ptr], #448]") + "nop\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" + + "ldr %d[b1], [%[b_ptr], #64]\n" + "nop\n" + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "ldr x20, [%[b_ptr], #72]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + + ASM_PREFETCH("[%[b_ptr], #512]") + "ins %[b1].d[1], x20\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + + // Unroll 1 + "ldr %d[b2], [%[b_ptr], #80]\n" + "nop\n" + "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n" + "ldr x20, [%[b_ptr], #88]\n" + "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n" + "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n" + + "ldr %d[a0], [%[a_ptr], #64]\n" + "ins %[b2].d[1], x20\n" + "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n" + "ldr x20, [%[a_ptr], #72]\n" + "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n" + + "ldr %d[a1], [%[a_ptr], #80]\n" + "ins %[a0].d[1], x20\n" + "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n" + "ldr x20, [%[a_ptr], #88]\n" + "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n" + "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n" + + "ldr %d[b0], [%[b_ptr], #96]\n" + "ins %[a1].d[1], x20\n" + "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n" + "ldr x20, [%[b_ptr], #104]\n" + "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n" + + "nop\n" + "ins %[b0].d[1], x20\n" + "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n" + "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n" + + "nop\n" + "nop\n" + "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n" + "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n" + + "ldr %d[b1], [%[b_ptr], #112]\n" + "nop\n" + "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n" + "ldr x20, [%[b_ptr], #120]\n" + "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n" + "add %[b_ptr], %[b_ptr], #96\n" + + "nop\n" + "ins %[b1].d[1], x20\n" + "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n" + "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n" + "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n" + + "bne 1b\n" + + // Branch here if K=1 or 2. Do the right thing for odd/even at the end. + "4:\n" + "cbnz %w[oddk], 2f\n" + + // Detached final iteration. (even K) + "ldr %d[b2], [%[b_ptr], #32]\n" + "nop\n" + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + "ldr x20, [%[b_ptr], #40]\n" + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "subs %w[k], %w[k], #1\n" + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + + "ldr %d[a0a], [%[a_ptr], #32]\n" + "ins %[b2].d[1], x20\n" + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + "ldr x20, [%[a_ptr], #40]\n" + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + + "ldr %d[a1a], [%[a_ptr], #48]\n" + "ins %[a0a].d[1], x20\n" + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "ldr x20, [%[a_ptr], #56]\n" + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + + "ldr %d[b0], [%[b_ptr], #48]\n" + "ins %[a1a].d[1], x20\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + + "ins %[b0].d[1], x20\n" + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + + "nop\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" + + "ldr %d[b1], [%[b_ptr], #64]\n" + "nop\n" + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "ldr x20, [%[b_ptr], #72]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + + "ins %[b1].d[1], x20\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + + "ldr %d[b2], [%[b_ptr], #80]\n" + "nop\n" + "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n" + "ldr x20, [%[b_ptr], #88]\n" + "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n" + "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n" + + "ins %[b2].d[1], x20\n" + "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n" + "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n" + "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n" + "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n" + "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n" + "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n" + "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n" + "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n" + "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n" + "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n" + "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n" + "b 3f\n" + + // Detached final iteration. (odd K) + "2:\n" + "ldr %d[b2], [%[b_ptr], #32]\n" + "nop\n" + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + "ldr x20, [%[b_ptr], #40]\n" + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + + "ins %[b2].d[1], x20\n" + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + + // Common tail + "3:\n" + "str q8, [%[c_ptr]]\n" + "str q16, [%[c_ptr], #16]\n" + "str q24, [%[c_ptr], #32]\n" + "str q9, [%[c_ptr], #48]\n" + "str q17, [%[c_ptr], #64]\n" + "str q25, [%[c_ptr], #80]\n" + "str q10, [%[c_ptr], #96]\n" + "str q18, [%[c_ptr], #112]\n" + "str q26, [%[c_ptr], #128]\n" + "str q11, [%[c_ptr], #144]\n" + "str q19, [%[c_ptr], #160]\n" + "str q27, [%[c_ptr], #176]\n" + "str q12, [%[c_ptr], #192]\n" + "str q20, [%[c_ptr], #208]\n" + "str q28, [%[c_ptr], #224]\n" + "str q13, [%[c_ptr], #240]\n" + "str q21, [%[c_ptr], #256]\n" + "str q29, [%[c_ptr], #272]\n" + "str q14, [%[c_ptr], #288]\n" + "str q22, [%[c_ptr], #304]\n" + "str q30, [%[c_ptr], #320]\n" + "str q15, [%[c_ptr], #336]\n" + "str q23, [%[c_ptr], #352]\n" + "str q31, [%[c_ptr], #368]\n" + "add %[c_ptr], %[c_ptr], #384\n" + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k) + : [oddk] "r"(oddk) + : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55.cpp new file mode 100644 index 0000000000..4ca25eb5ba --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55.cpp @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +namespace arm_gemm +{ +void a64_sgemm_asimd_12x8_a55(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) +{ + const float *a_ptr = Apanel; + float *c_ptr = Cpanel; + + for(int yb = 0; yb < ablocks; yb++) + { + const float *a_ptr0 = a_ptr; + const float *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float32x4_t a0 asm("v0"); + register float32x4_t a1 asm("v1"); + register float32x4_t b0 asm("v2"); + register float32x4_t b1 asm("v3"); + register float32x4_t b2 asm("v4"); + register float32x4_t a0a asm("v5"); + register float32x4_t a1a asm("v6"); + + __asm __volatile( + // Initialize result registers, load initial operands, prime prefetches. + "movi v8.4s, #0x0\n" + "ldr %q[a0], [%[a_ptr]]\n" + "movi v9.4s, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v10.4s, #0x0\n" + "ldr %q[a1], [%[a_ptr], #16]\n" + "movi v11.4s, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v18.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #192]") "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") "movi v20.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") "movi v21.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #384]") + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + "1:\n" + // Unroll 0 + "ldr %d[b2], [%[b_ptr], #32]\n" + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "ldr x20, [%[b_ptr], #40]\n" + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + "subs %w[k], %w[k], #1\n" + + "ldr %d[a0a], [%[a_ptr], #32]\n" + "ins %[b2].d[1], x20\n" + + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + "ldr x20, [%[a_ptr], #40]\n" + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + + "ldr %d[a1a], [%[a_ptr], #48]\n" + "ins %[a0a].d[1], x20\n" + + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" + "ldr x20, [%[a_ptr], #56]\n" + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + + "ldr %d[b0], [%[b_ptr], #48]\n" + "ins %[a1a].d[1], x20\n" ASM_PREFETCH("[%[a_ptr], #320]") + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + + "ldr %d[b1], [%[b_ptr], #64]\n" + "ins %[b0].d[1], x20\n" + + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" + "ldr x20, [%[b_ptr], #72]\n" + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" ASM_PREFETCH("[%[b_ptr], #448]") + + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" ASM_PREFETCH("[%[b_ptr], #512]") + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + + // Unroll 1 + "ldr %d[b2], [%[b_ptr], #80]\n" + "ins %[b1].d[1], x20\n" + + "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n" + "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n" + "ldr x20, [%[b_ptr], #88]\n" + "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n" + "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n" + + "ldr %d[a0], [%[a_ptr], #64]\n" + "ins %[b2].d[1], x20\n" + + "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n" + "ldr x20, [%[a_ptr], #72]\n" + "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n" + + "ldr %d[a1], [%[a_ptr], #80]\n" + "ins %[a0].d[1], x20\n" + + "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n" + "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n" + "ldr x20, [%[a_ptr], #88]\n" + "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n" + + "ldr %d[b0], [%[b_ptr], #96]\n" + "ins %[a1].d[1], x20\n" + + "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n" + "ldr x20, [%[b_ptr], #104]\n" + "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n" + + "ldr %d[b1], [%[b_ptr], #112]\n" + "ins %[b0].d[1], x20\n" + + "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n" + "ldr x20, [%[b_ptr], #120]\n" + "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n" + "add %[a_ptr], %[a_ptr], #64\n" + + "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n" + "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n" + + "ldr %d[b2], [%[b_ptr], #32]\n" + "ins %[b1].d[1], x20\n" + + "bne 1b\n" + + // Branch here if K=1 or 2. Do the right thing for odd/even at the end. + "4:\n" + "cbnz %w[oddk], 2f\n" + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + + // Detached final iteration. (even K) + "ldr x20, [%[b_ptr], #40]\n" + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "subs %w[k], %w[k], #1\n" + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + + "ldr %d[a0a], [%[a_ptr], #32]\n" + "ins %[b2].d[1], x20\n" + + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + "ldr x20, [%[a_ptr], #40]\n" + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + + "ldr %d[a1a], [%[a_ptr], #48]\n" + "ins %[a0a].d[1], x20\n" + + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" + "ldr x20, [%[a_ptr], #56]\n" + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + + "ldr %d[b0], [%[b_ptr], #48]\n" + "ins %[a1a].d[1], x20\n" + + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" + + "ldr %d[b1], [%[b_ptr], #64]\n" + "ins %[b0].d[1], x20\n" + + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" + "ldr x20, [%[b_ptr], #72]\n" + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + + "ldr %d[b2], [%[b_ptr], #80]\n" + "ins %[b1].d[1], x20\n" + + "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n" + "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n" + "ldr x20, [%[b_ptr], #88]\n" + "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n" + + "ins %[b2].d[1], x20\n" + "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n" + "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n" + "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n" + "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n" + "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n" + "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n" + "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n" + "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n" + "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n" + "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n" + "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n" + "b 3f\n" + + // Detached final iteration. (odd K) + "2:\n" + + "ldr %d[b2], [%[b_ptr], #32]\n" + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "ldr x20, [%[b_ptr], #40]\n" + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + "ins %[b2].d[1], x20\n" + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + + // Common tail + "3:\n" + "str q8, [%[c_ptr]]\n" + "str q16, [%[c_ptr], #16]\n" + "str q24, [%[c_ptr], #32]\n" + "str q9, [%[c_ptr], #48]\n" + "str q17, [%[c_ptr], #64]\n" + "str q25, [%[c_ptr], #80]\n" + "str q10, [%[c_ptr], #96]\n" + "str q18, [%[c_ptr], #112]\n" + "str q26, [%[c_ptr], #128]\n" + "str q11, [%[c_ptr], #144]\n" + "str q19, [%[c_ptr], #160]\n" + "str q27, [%[c_ptr], #176]\n" + "str q12, [%[c_ptr], #192]\n" + "str q20, [%[c_ptr], #208]\n" + "str q28, [%[c_ptr], #224]\n" + "str q13, [%[c_ptr], #240]\n" + "str q21, [%[c_ptr], #256]\n" + "str q29, [%[c_ptr], #272]\n" + "str q14, [%[c_ptr], #288]\n" + "str q22, [%[c_ptr], #304]\n" + "str q30, [%[c_ptr], #320]\n" + "str q15, [%[c_ptr], #336]\n" + "str q23, [%[c_ptr], #352]\n" + "str q31, [%[c_ptr], #368]\n" + "add %[c_ptr], %[c_ptr], #384\n" + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k) + : [oddk] "r"(oddk) + : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55r1.cpp new file mode 100644 index 0000000000..89fe6ac7ea --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/a55r1.cpp @@ -0,0 +1,342 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +namespace arm_gemm +{ +void a64_sgemm_asimd_12x8_a55r1(const float *Apanel, const float *Bpanel, float *Cpanel, const int ablocks, const int bblocks, const int K) +{ + const float *a_ptr = Apanel; + float *c_ptr = Cpanel; + + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k_iters = ((K + 1) / 2) - 1; + + for(int yb = 0; yb < ablocks; yb++) + { + const float *a_ptr0 = a_ptr; + const float *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + int k = k_iters; + + register float32x4_t a0 asm("v0"); + register float32x4_t a1 asm("v1"); + register float32x4_t b0 asm("v2"); + register float32x4_t b1 asm("v3"); + register float32x4_t b2 asm("v4"); + register float32x4_t a0a asm("v5"); + register float32x4_t a1a asm("v6"); + + __asm __volatile( + // Initialize result registers, load initial operands, prime prefetches. + "movi v8.4s, #0x0\n" + "ldr %q[a0], [%[a_ptr]]\n" + "movi v9.4s, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v10.4s, #0x0\n" + "ldr %q[a1], [%[a_ptr], #16]\n" + "movi v11.4s, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #192]") + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #384]") + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #448]") + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #384]") + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #512]") + + // The loop is offset by these two instructions which must + // always be executed. + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + "ldr %d[b2], [%[b_ptr], #32]\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + "1:\n" + // Unroll 0 + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "ldr x20, [%[b_ptr], #40]\n" + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + "subs %w[k], %w[k], #1\n" + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + "ldr %d[a0a], [%[a_ptr], #32]\n" + + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "ins %[b2].d[1], x20\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + "ldr x20, [%[a_ptr], #40]\n" + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + "ldr %d[a1a], [%[a_ptr], #48]\n" + + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + "ins %[a0a].d[1], x20\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" + "ldr x20, [%[a_ptr], #56]\n" + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + "ldr %d[b0], [%[b_ptr], #48]\n" + + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "ins %[a1a].d[1], x20\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + "ldr %d[b1], [%[b_ptr], #64]\n" + + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "ins %[b0].d[1], x20\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" + "ldr x20, [%[b_ptr], #72]\n" + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" ASM_PREFETCH("[%[a_ptr], #448]") + + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" ASM_PREFETCH("[%[b_ptr], #576]") + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + + // Unroll 1 + "ldr %d[b2], [%[b_ptr], #80]\n" + + "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n" + "ins %[b1].d[1], x20\n" + "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n" + "ldr x20, [%[b_ptr], #88]\n" + "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n" + "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n" + "ldr %d[a0], [%[a_ptr], #64]\n" + + "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n" + "ins %[b2].d[1], x20\n" + "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n" + "ldr x20, [%[a_ptr], #72]\n" + "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n" + "ldr %d[a1], [%[a_ptr], #80]\n" + + "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n" + "ins %[a0].d[1], x20\n" + "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n" + "ldr x20, [%[a_ptr], #88]\n" + "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n" + "ldr %d[b0], [%[b_ptr], #96]\n" + + "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n" + "ins %[a1].d[1], x20\n" + "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n" + "ldr x20, [%[b_ptr], #104]\n" + "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n" + "ldr %d[b1], [%[b_ptr], #112]\n" + + "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n" + "ins %[b0].d[1], x20\n" + "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n" + "ldr x20, [%[b_ptr], #120]\n" + "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n" + + "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n" + "add %[a_ptr], %[a_ptr], #64\n" + + "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n" ASM_PREFETCH("[%[b_ptr], #640]") + "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n" + "ins %[b1].d[1], x20\n" + "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n" + "ldr %d[b2], [%[b_ptr], #32]\n" + + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + "b.ne 1b\n" + + // Branch here if K=1 or 2. Do the right thing for odd/even at the end. + "4:\n" + + // Start final iteration - branch off to "odd" code before we load a0a. + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "ldr x20, [%[b_ptr], #40]\n" + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + "cbnz %w[oddk], 2f\n" + + // Even K continuation + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + "ldr %d[a0a], [%[a_ptr], #32]\n" + + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "ins %[b2].d[1], x20\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + "ldr x20, [%[a_ptr], #40]\n" + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" ASM_PREFETCHW("[%[c_ptr]]") + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + "ldr %d[a1a], [%[a_ptr], #48]\n" + + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + "ins %[a0a].d[1], x20\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" + "ldr x20, [%[a_ptr], #56]\n" + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + "ldr %d[b0], [%[b_ptr], #48]\n" + + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "ins %[a1a].d[1], x20\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "ldr x20, [%[b_ptr], #56]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" ASM_PREFETCHW("[%[c_ptr], #64]") + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" ASM_PREFETCHW("[%[c_ptr], #128]") + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" + "ldr %d[b1], [%[b_ptr], #64]\n" + + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + "ins %[b0].d[1], x20\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" + "ldr x20, [%[b_ptr], #72]\n" + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" ASM_PREFETCHW("[%[c_ptr], #192]") + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + "ldr %d[b2], [%[b_ptr], #80]\n" + + "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n" + "ins %[b1].d[1], x20\n" + "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n" + "ldr x20, [%[b_ptr], #88]\n" + "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n" + "ins %[b2].d[1], x20\n" + + "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]") + "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n" + "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n" ASM_PREFETCHW("[%[c_ptr], #320]") + "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n" + "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]") + "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n" + "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]") + "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n" + "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") + "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") + "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n" + "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #640]") + "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n" + "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]") + "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n" + "b 3f\n" + + // Odd K continuation + "2:\n" + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" ASM_PREFETCHW("[%[c_ptr]]") + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "ins %[b2].d[1], x20\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" ASM_PREFETCHW("[%[c_ptr], #64]") + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "add %[a_ptr], %[a_ptr], #32\n" + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" ASM_PREFETCHW("[%[c_ptr], #128]") + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" ASM_PREFETCHW("[%[c_ptr], #192]") + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" ASM_PREFETCHW("[%[c_ptr], #256]") + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" ASM_PREFETCHW("[%[c_ptr], #320]") + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #384]") + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" ASM_PREFETCHWL2("[%[c_ptr], #448]") + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" ASM_PREFETCHWL2("[%[c_ptr], #512]") "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" ASM_PREFETCHWL2("[%[c_ptr], #576]") "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" + ASM_PREFETCHWL2("[%[c_ptr], #640]") "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" ASM_PREFETCHWL2("[%[c_ptr], #704]") + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + + // Common tail + "3:\n" + "str q8, [%[c_ptr]]\n" + "str q16, [%[c_ptr], #16]\n" + "str q24, [%[c_ptr], #32]\n" + "str q9, [%[c_ptr], #48]\n" + "str q17, [%[c_ptr], #64]\n" + "str q25, [%[c_ptr], #80]\n" + "str q10, [%[c_ptr], #96]\n" + "str q18, [%[c_ptr], #112]\n" + "str q26, [%[c_ptr], #128]\n" + "str q11, [%[c_ptr], #144]\n" + "str q19, [%[c_ptr], #160]\n" + "str q27, [%[c_ptr], #176]\n" + "str q12, [%[c_ptr], #192]\n" + "str q20, [%[c_ptr], #208]\n" + "str q28, [%[c_ptr], #224]\n" + "str q13, [%[c_ptr], #240]\n" + "str q21, [%[c_ptr], #256]\n" + "str q29, [%[c_ptr], #272]\n" + "str q14, [%[c_ptr], #288]\n" + "str q22, [%[c_ptr], #304]\n" + "str q30, [%[c_ptr], #320]\n" + "str q15, [%[c_ptr], #336]\n" + "str q23, [%[c_ptr], #352]\n" + "str q31, [%[c_ptr], #368]\n" + "add %[c_ptr], %[c_ptr], #384\n" + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k) + : [oddk] "r"(oddk) + : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); + } + } +} + +} // namespace arm_gemm + +#endif diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/generic.cpp new file mode 100644 index 0000000000..42e870e814 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8/generic.cpp @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <arm_neon.h> + +#include "../../asmlib.hpp" + +// Kernel implementation. +// +// Assume that "Apanel" points to a chunk of A blocks (each size 8xK) in read-order. +// Assume that "Bpanel" points to a chunk of B blocks (each size 12xK) in read-order. +// Assume that "Cpanel" points to a chunk of C output blocks (each size +// 12x8), the chunks being arranged in a row major fashion. +// +// Note that the intent of this is that either ablocks or bblocks will be 1 +// - this construction allows the output loop to proceed in either order. + +namespace arm_gemm +{ +void a64_sgemm_asimd_12x8_jumps(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K, long int row_jump = 0, long int block_jump = 0) +{ + const float *a_ptr = Apanel; + float *c_ptr = Cpanel; + + for(int yb = 0; yb < ablocks; yb++) + { + const float *a_ptr0 = a_ptr; + const float *b_ptr = Bpanel; + + for(int xb = 0; xb < bblocks; xb++) + { + a_ptr = a_ptr0; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + int oddk = (K & 1); + int k = ((K + 1) / 2) - 1; + + register float32x4_t a0 asm("v0"); + register float32x4_t a1 asm("v1"); + register float32x4_t b0 asm("v2"); + register float32x4_t b1 asm("v3"); + register float32x4_t b2 asm("v4"); + register float32x4_t a0a asm("v5"); + register float32x4_t a1a asm("v6"); + + __asm __volatile( + // Initialize result registers, load initial operands, prime prefetches. + "movi v8.4s, #0x0\n" + "ldr %q[a0], [%[a_ptr]]\n" + "movi v9.4s, #0x0\n" + "ldr %q[b0], [%[b_ptr]]\n" + "movi v10.4s, #0x0\n" + "ldr %q[a1], [%[a_ptr], #16]\n" + "movi v11.4s, #0x0\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "movi v12.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #64]") "movi v13.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #64]") "movi v14.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #128]") "movi v15.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #128]") "movi v16.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #192]") "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #256]") "movi v18.4s, #0x0\n" + ASM_PREFETCH("[%[a_ptr], #192]") "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[b_ptr], #320]") "movi v20.4s, #0x0\n" ASM_PREFETCH("[%[a_ptr], #256]") "movi v21.4s, #0x0\n" + ASM_PREFETCH("[%[b_ptr], #384]") + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + + // Skip loop if we are doing zero iterations of it. + "cbz %w[k], 4f\n" + + // Loop proper + "1:\n" + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + "add %[b_ptr], %[b_ptr], %[row_jump]\n" + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + "ldr %q[a0a], [%[a_ptr], #32]\n" + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + "ldr %q[a1a], [%[a_ptr], #48]\n" + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + "ldr %q[b0], [%[b_ptr], #48]\n" + + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" ASM_PREFETCH("[%[a_ptr], #320]") + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + "ldr %q[b1], [%[b_ptr], #64]\n" + + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" ASM_PREFETCH("[%[b_ptr], #448]") + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + "ldr %q[b2], [%[b_ptr], #80]\n" + + "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n" + "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n" + "ldr %q[a0], [%[a_ptr], #64]\n" + "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n" + "add %[b_ptr], %[b_ptr], %[row_jump]\n" + "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n" + "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n" + "ldr %q[a1], [%[a_ptr], #80]\n" + "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n" + "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n" + "ldr %q[b0], [%[b_ptr], #96]\n" + + "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n" + "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n" ASM_PREFETCH("[%[b_ptr], #512]") + "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n" + "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n" + "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n" + "ldr %q[b1], [%[b_ptr], #112]\n" + + "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n" + "subs %w[k], %w[k], #1\n" + "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n" + "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n" + "bne 1b\n" + + // Target to use when K is 1 or 2 (i.e. zero iterations of main loop) + "4:\n" + + // Branch to alternative tail for odd K + "cbnz %w[oddk], 2f\n" + + // Detached final iteration (even K) + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + "add %[b_ptr], %[b_ptr], %[row_jump]\n" + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + "ldr %q[a0a], [%[a_ptr], #32]\n" + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + "ldr %q[a1a], [%[a_ptr], #48]\n" + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + "ldr %q[b0], [%[b_ptr], #48]\n" + + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + "ldr %q[b1], [%[b_ptr], #64]\n" + + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" + "add %[a_ptr], %[a_ptr], #64\n" + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + "ldr %q[b2], [%[b_ptr], #80]\n" + + "fmla v8.4s , %[b0].4s, %[a0a].s[0]\n" + "add %[b_ptr], %[b_ptr], %[block_jump]\n" + "fmla v16.4s, %[b1].4s, %[a0a].s[0]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v9.4s , %[b0].4s, %[a0a].s[1]\n" + "add %[b_ptr], %[b_ptr], %[row_jump]\n" + "str q8, [%[c_ptr], #0]\n" + "fmla v17.4s, %[b1].4s, %[a0a].s[1]\n" + "str q16, [%[c_ptr], #16]\n" + "fmla v24.4s, %[b2].4s, %[a0a].s[0]\n" + "str q24, [%[c_ptr], #32]\n" + + "fmla v25.4s, %[b2].4s, %[a0a].s[1]\n" + "str q9, [%[c_ptr], #48]\n" + "fmla v10.4s, %[b0].4s, %[a0a].s[2]\n" + "str q17, [%[c_ptr], #64]\n" + "fmla v18.4s, %[b1].4s, %[a0a].s[2]\n" + "str q25, [%[c_ptr], #80]\n" + "fmla v26.4s, %[b2].4s, %[a0a].s[2]\n" + "str q10, [%[c_ptr], #96]\n" + + "fmla v11.4s, %[b0].4s, %[a0a].s[3]\n" + "str q18, [%[c_ptr], #112]\n" + "fmla v19.4s, %[b1].4s, %[a0a].s[3]\n" + "str q26, [%[c_ptr], #128]\n" + "fmla v27.4s, %[b2].4s, %[a0a].s[3]\n" + "str q11, [%[c_ptr], #144]\n" + + "fmla v12.4s, %[b0].4s, %[a1a].s[0]\n" + "str q19, [%[c_ptr], #160]\n" + "fmla v20.4s, %[b1].4s, %[a1a].s[0]\n" + "str q27, [%[c_ptr], #176]\n" + "fmla v28.4s, %[b2].4s, %[a1a].s[0]\n" + "str q12, [%[c_ptr], #192]\n" + + "fmla v13.4s, %[b0].4s, %[a1a].s[1]\n" + "str q20, [%[c_ptr], #208]\n" + "fmla v21.4s, %[b1].4s, %[a1a].s[1]\n" + "str q28, [%[c_ptr], #224]\n" + "fmla v29.4s, %[b2].4s, %[a1a].s[1]\n" + "str q13, [%[c_ptr], #240]\n" + + "fmla v14.4s, %[b0].4s, %[a1a].s[2]\n" + "str q21, [%[c_ptr], #256]\n" + "fmla v22.4s, %[b1].4s, %[a1a].s[2]\n" + "str q29, [%[c_ptr], #272]\n" + "fmla v30.4s, %[b2].4s, %[a1a].s[2]\n" + "str q14, [%[c_ptr], #288]\n" + + "fmla v15.4s, %[b0].4s, %[a1a].s[3]\n" + "str q22, [%[c_ptr], #304]\n" + "fmla v23.4s, %[b1].4s, %[a1a].s[3]\n" + "str q30, [%[c_ptr], #320]\n" + "fmla v31.4s, %[b2].4s, %[a1a].s[3]\n" + "str q15, [%[c_ptr], #336]\n" + + "b 3f\n" + + // Detached final iteration (odd K) + "2:\n" + "fmla v8.4s , %[b0].4s, %[a0].s[0]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "fmla v16.4s, %[b1].4s, %[a0].s[0]\n" + "add %[b_ptr], %[b_ptr], %[row_jump]\n" + "fmla v9.4s , %[b0].4s, %[a0].s[1]\n" + "str q8, [%[c_ptr], #0]\n" + "fmla v17.4s, %[b1].4s, %[a0].s[1]\n" + "str q16, [%[c_ptr], #16]\n" + "fmla v24.4s, %[b2].4s, %[a0].s[0]\n" + "add %[b_ptr], %[b_ptr], #48\n" + "add %[a_ptr], %[a_ptr], #32\n" + "str q24, [%[c_ptr], #32]\n" + "fmla v25.4s, %[b2].4s, %[a0].s[1]\n" + "str q9, [%[c_ptr], #48]\n" + + "fmla v10.4s, %[b0].4s, %[a0].s[2]\n" + "str q17, [%[c_ptr], #64]\n" + "fmla v18.4s, %[b1].4s, %[a0].s[2]\n" + "str q25, [%[c_ptr], #80]\n" + "fmla v26.4s, %[b2].4s, %[a0].s[2]\n" + "str q10, [%[c_ptr], #96]\n" + + "fmla v11.4s, %[b0].4s, %[a0].s[3]\n" + "str q18, [%[c_ptr], #112]\n" + "fmla v19.4s, %[b1].4s, %[a0].s[3]\n" + "str q26, [%[c_ptr], #128]\n" + "fmla v27.4s, %[b2].4s, %[a0].s[3]\n" + "str q11, [%[c_ptr], #144]\n" + + "fmla v12.4s, %[b0].4s, %[a1].s[0]\n" + "str q19, [%[c_ptr], #160]\n" + "fmla v20.4s, %[b1].4s, %[a1].s[0]\n" + "str q27, [%[c_ptr], #176]\n" + "fmla v28.4s, %[b2].4s, %[a1].s[0]\n" + "str q12, [%[c_ptr], #192]\n" + + "fmla v13.4s, %[b0].4s, %[a1].s[1]\n" + "str q20, [%[c_ptr], #208]\n" + "fmla v21.4s, %[b1].4s, %[a1].s[1]\n" + "str q28, [%[c_ptr], #224]\n" + "fmla v29.4s, %[b2].4s, %[a1].s[1]\n" + "str q13, [%[c_ptr], #240]\n" + + "fmla v14.4s, %[b0].4s, %[a1].s[2]\n" + "str q21, [%[c_ptr], #256]\n" + "fmla v22.4s, %[b1].4s, %[a1].s[2]\n" + "str q29, [%[c_ptr], #272]\n" + "fmla v30.4s, %[b2].4s, %[a1].s[2]\n" + "str q14, [%[c_ptr], #288]\n" + + "fmla v15.4s, %[b0].4s, %[a1].s[3]\n" + "str q22, [%[c_ptr], #304]\n" + "fmla v23.4s, %[b1].4s, %[a1].s[3]\n" + "str q30, [%[c_ptr], #320]\n" + "fmla v31.4s, %[b2].4s, %[a1].s[3]\n" + "str q15, [%[c_ptr], #336]\n" + + // Common tail + "3:\n" + "str q23, [%[c_ptr], #352]\n" + "str q31, [%[c_ptr], #368]\n" + "add %[c_ptr], %[c_ptr], #384\n" + : + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr] "+r"(c_ptr), + [a0] "+w"(a0), [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), + [b0] "+w"(b0), [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k) + : [oddk] "r"(oddk), [row_jump] "r"(row_jump), [block_jump] "r"(block_jump) + : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); + } + } +} + +void a64_sgemm_asimd_12x8(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) +{ + a64_sgemm_asimd_12x8_jumps(Apanel, Bpanel, Cpanel, ablocks, bblocks, K, 0, 0); +} + +} // namespace arm_gemm + +#endif
\ No newline at end of file diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp new file mode 100644 index 0000000000..eceacc9031 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_sgemm_native_16x4(const float *, int, const float *, int, float *, int, float, int, int, int); + +// 12x8 SGEMM "strategy" class. +// +// This describes the characteristics of a family of kernels, in terms of +// the required interleave properties and the output block size. +// +// All kernels in the family must share these characteristics. The actual +// kernel to be used can be chosen at runtime, based on the CPU_type +// structure. +class sgemm_native_16x4 +{ +public: + typedef float operand_type; + typedef float result_type; + + typedef void (*kern_type)(const float *, int, const float *, int, float *, int, float, int, int, int); + + /* Kernel blocking parameters */ + static const int out_width = 16; + static const int out_height = 4; + static const int k_unroll = 1; + + // Default to the generic kernel + kern_type kernel = a64_sgemm_native_16x4; + + sgemm_native_16x4(const CPUInfo *ci) + { + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp new file mode 100644 index 0000000000..1b5787ce7c --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp @@ -0,0 +1,734 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <cstddef> + +#include <arm_neon.h> + +namespace arm_gemm +{ +void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, float *C, int ldc, float beta, int M, int N, int K) +{ + int oddk = (K % 8) ? 1 : 0; + int beta0 = (beta == 0.0f) ? 1 : 0; + + /* For now, very naive with no blocking */ + for(int y = 0; y < M; y += 4) + { + for(int x0 = 0; x0 < N; x0 += 16) + { + const float *a_ptr0 = A + (y * lda); + const float *a_ptr1 = a_ptr0 + lda; + const float *a_ptr2 = a_ptr1 + lda; + const float *a_ptr3 = a_ptr2 + lda; + + const float *b_ptr = B + x0; + + float *c_ptr0 = C + (y * ldc) + x0; + float *c_ptr1 = c_ptr0 + ldc; + float *c_ptr2 = c_ptr1 + ldc; + float *c_ptr3 = c_ptr2 + ldc; + + int loops = ((K + 4) / 8) - 1; + + size_t ldbb = ldb * sizeof(float); + + __asm __volatile( + "a0 .req v0\n" + "a1 .req v1\n" + "a2 .req v2\n" + "a3 .req v3\n" + "a0a .req v4\n" + "a1a .req v5\n" + "a2a .req v6\n" + "a3a .req v7\n" + "bb0 .req v8\n" + "bb1 .req v9\n" + "bb2 .req v10\n" + "bb3 .req v11\n" + "b0a .req v12\n" + "b1a .req v13\n" + "b2a .req v14\n" + "b3a .req v15\n" + + "a0q .req q0\n" + "a1q .req q1\n" + "a2q .req q2\n" + "a3q .req q3\n" + "a0aq .req q4\n" + "a1aq .req q5\n" + "a2aq .req q6\n" + "a3aq .req q7\n" + "b0q .req q8\n" + "b1q .req q9\n" + "b2q .req q10\n" + "b3q .req q11\n" + "b0aq .req q12\n" + "b1aq .req q13\n" + "b2aq .req q14\n" + "b3aq .req q15\n" + + "movi v16.4s, #0x0\n" + "ldr a0q, [%[a_ptr0]]\n" + "movi v17.4s, #0x0\n" + "ldr b0q, [%[b_ptr]]\n" + "movi v18.4s, #0x0\n" + "ldr b1q, [%[b_ptr], #16]\n" + "movi v19.4s, #0x0\n" + "ldr b2q, [%[b_ptr], #32]\n" + "movi v20.4s, #0x0\n" + "ldr b3q, [%[b_ptr], #48]\n" + "movi v21.4s, #0x0\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "ldr a1q, [%[a_ptr1]]\n" + "movi v22.4s, #0x0\n" + "ldr a2q, [%[a_ptr2]]\n" + "movi v23.4s, #0x0\n" + "ldr a3q, [%[a_ptr3]]\n" + "movi v24.4s, #0x0\n" + "ldr b0aq, [%[b_ptr]]\n" + "movi v25.4s, #0x0\n" + "ldr b1aq, [%[b_ptr], #16]\n" + "movi v26.4s, #0x0\n" + "ldr b2aq, [%[b_ptr], #32]\n" + "cbz %w[beta0], 5f\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + + // Skip if no complete loops. + "cbz %w[loops], 4f\n" + "b 1f\n" + + // If beta is non-zero, need to load and multiply by beta + "5:\n" + "ld1r {v4.4s}, [%[betaptr]]\n" + "ldr q16, [%[c_ptr0]]\n" + "ldr q17, [%[c_ptr0], #16]\n" + "ldr q18, [%[c_ptr0], #32]\n" + "ldr q19, [%[c_ptr0], #48]\n" + + "ldr q20, [%[c_ptr1]]\n" + "fmul v16.4s, v16.4s, v4.4s\n" + "ldr q21, [%[c_ptr1], #16]\n" + "fmul v17.4s, v17.4s, v4.4s\n" + "ldr q22, [%[c_ptr1], #32]\n" + "fmul v18.4s, v18.4s, v4.4s\n" + "ldr q23, [%[c_ptr1], #48]\n" + "fmul v19.4s, v19.4s, v4.4s\n" + + "ldr q24, [%[c_ptr2]]\n" + "fmul v20.4s, v20.4s, v4.4s\n" + "ldr q25, [%[c_ptr2], #16]\n" + "fmul v21.4s, v21.4s, v4.4s\n" + "ldr q26, [%[c_ptr2], #32]\n" + "fmul v22.4s, v22.4s, v4.4s\n" + "ldr q27, [%[c_ptr2], #48]\n" + "fmul v23.4s, v23.4s, v4.4s\n" + + "ldr q28, [%[c_ptr3]]\n" + "fmul v24.4s, v24.4s, v4.4s\n" + "ldr q29, [%[c_ptr3], #16]\n" + "fmul v25.4s, v25.4s, v4.4s\n" + "ldr q30, [%[c_ptr3], #32]\n" + "fmul v26.4s, v26.4s, v4.4s\n" + "ldr q31, [%[c_ptr3], #48]\n" + "fmul v27.4s, v27.4s, v4.4s\n" + + "fmul v28.4s, v28.4s, v4.4s\n" + "fmul v29.4s, v29.4s, v4.4s\n" + "fmul v30.4s, v30.4s, v4.4s\n" + "fmul v31.4s, v31.4s, v4.4s\n" + + "cbz %w[loops], 4f\n" + + "1:\n" + // Unroll 0 + "fmla v16.4s, bb0.4s, a0.s[0]\n" + "fmla v20.4s, bb0.4s, a1.s[0]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v24.4s, bb0.4s, a2.s[0]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v28.4s, bb0.4s, a3.s[0]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0.s[0]\n" + "fmla v21.4s, bb1.4s, a1.s[0]\n" + "ldr a0aq, [%[a_ptr0], #16]\n" + "fmla v25.4s, bb1.4s, a2.s[0]\n" + "fmla v29.4s, bb1.4s, a3.s[0]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0.s[0]\n" + "fmla v22.4s, bb2.4s, a1.s[0]\n" + "ldr a1aq, [%[a_ptr1], #16]\n" + "fmla v26.4s, bb2.4s, a2.s[0]\n" + "fmla v30.4s, bb2.4s, a3.s[0]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0.s[0]\n" + "fmla v23.4s, bb3.4s, a1.s[0]\n" + "ldr a2aq, [%[a_ptr2], #16]\n" + "fmla v27.4s, bb3.4s, a2.s[0]\n" + "fmla v31.4s, bb3.4s, a3.s[0]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + // Unroll 1 + "fmla v16.4s, b0a.4s, a0.s[1]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v20.4s, b0a.4s, a1.s[1]\n" + "ldr a3aq, [%[a_ptr3], #16]\n" + "fmla v24.4s, b0a.4s, a2.s[1]\n" + "fmla v28.4s, b0a.4s, a3.s[1]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0.s[1]\n" + "fmla v21.4s, b1a.4s, a1.s[1]\n" + "subs %w[loops], %w[loops], #1\n" + "fmla v25.4s, b1a.4s, a2.s[1]\n" + "fmla v29.4s, b1a.4s, a3.s[1]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0.s[1]\n" + "fmla v22.4s, b2a.4s, a1.s[1]\n" + "fmla v26.4s, b2a.4s, a2.s[1]\n" + "fmla v30.4s, b2a.4s, a3.s[1]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0.s[1]\n" + "fmla v23.4s, b3a.4s, a1.s[1]\n" + "fmla v27.4s, b3a.4s, a2.s[1]\n" + "fmla v31.4s, b3a.4s, a3.s[1]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + + // Unroll 2 + "fmla v16.4s, bb0.4s, a0.s[2]\n" + "fmla v20.4s, bb0.4s, a1.s[2]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2.s[2]\n" + "fmla v28.4s, bb0.4s, a3.s[2]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0.s[2]\n" + "add %[a_ptr0], %[a_ptr0], #32\n" + "fmla v21.4s, bb1.4s, a1.s[2]\n" + "add %[a_ptr1], %[a_ptr1], #32\n" + "fmla v25.4s, bb1.4s, a2.s[2]\n" + "add %[a_ptr2], %[a_ptr2], #32\n" + "fmla v29.4s, bb1.4s, a3.s[2]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0.s[2]\n" + "add %[a_ptr3], %[a_ptr3], #32\n" + "fmla v22.4s, bb2.4s, a1.s[2]\n" + "fmla v26.4s, bb2.4s, a2.s[2]\n" + "fmla v30.4s, bb2.4s, a3.s[2]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0.s[2]\n" + "fmla v23.4s, bb3.4s, a1.s[2]\n" + "fmla v27.4s, bb3.4s, a2.s[2]\n" + "fmla v31.4s, bb3.4s, a3.s[2]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + // Unroll 3 + "fmla v16.4s, b0a.4s, a0.s[3]\n" + "fmla v20.4s, b0a.4s, a1.s[3]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, b0a.4s, a2.s[3]\n" + "fmla v28.4s, b0a.4s, a3.s[3]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0.s[3]\n" + "fmla v21.4s, b1a.4s, a1.s[3]\n" + "fmla v25.4s, b1a.4s, a2.s[3]\n" + "fmla v29.4s, b1a.4s, a3.s[3]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0.s[3]\n" + "fmla v22.4s, b2a.4s, a1.s[3]\n" + "fmla v26.4s, b2a.4s, a2.s[3]\n" + "fmla v30.4s, b2a.4s, a3.s[3]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0.s[3]\n" + "fmla v23.4s, b3a.4s, a1.s[3]\n" + "ldr a0q, [%[a_ptr0]]\n" + "fmla v27.4s, b3a.4s, a2.s[3]\n" + "fmla v31.4s, b3a.4s, a3.s[3]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + + // Unroll 4 + "fmla v16.4s, bb0.4s, a0a.s[0]\n" + "fmla v20.4s, bb0.4s, a1a.s[0]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2a.s[0]\n" + "fmla v28.4s, bb0.4s, a3a.s[0]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0a.s[0]\n" + "fmla v21.4s, bb1.4s, a1a.s[0]\n" + "ldr a1q, [%[a_ptr1]]\n" + "fmla v25.4s, bb1.4s, a2a.s[0]\n" + "fmla v29.4s, bb1.4s, a3a.s[0]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0a.s[0]\n" + "fmla v22.4s, bb2.4s, a1a.s[0]\n" + "ldr a2q, [%[a_ptr2]]\n" + "fmla v26.4s, bb2.4s, a2a.s[0]\n" + "fmla v30.4s, bb2.4s, a3a.s[0]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0a.s[0]\n" + "fmla v23.4s, bb3.4s, a1a.s[0]\n" + "ldr a3q, [%[a_ptr3]]\n" + "fmla v27.4s, bb3.4s, a2a.s[0]\n" + "fmla v31.4s, bb3.4s, a3a.s[0]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + // Unroll 5 + "fmla v16.4s, b0a.4s, a0a.s[1]\n" + "fmla v20.4s, b0a.4s, a1a.s[1]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, b0a.4s, a2a.s[1]\n" + "fmla v28.4s, b0a.4s, a3a.s[1]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0a.s[1]\n" + "fmla v21.4s, b1a.4s, a1a.s[1]\n" + "fmla v25.4s, b1a.4s, a2a.s[1]\n" + "fmla v29.4s, b1a.4s, a3a.s[1]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0a.s[1]\n" + "fmla v22.4s, b2a.4s, a1a.s[1]\n" + "fmla v26.4s, b2a.4s, a2a.s[1]\n" + "fmla v30.4s, b2a.4s, a3a.s[1]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0a.s[1]\n" + "fmla v23.4s, b3a.4s, a1a.s[1]\n" + "fmla v27.4s, b3a.4s, a2a.s[1]\n" + "fmla v31.4s, b3a.4s, a3a.s[1]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + + // Unroll 6 + "fmla v16.4s, bb0.4s, a0a.s[2]\n" + "fmla v20.4s, bb0.4s, a1a.s[2]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2a.s[2]\n" + "fmla v28.4s, bb0.4s, a3a.s[2]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0a.s[2]\n" + "fmla v21.4s, bb1.4s, a1a.s[2]\n" + "fmla v25.4s, bb1.4s, a2a.s[2]\n" + "fmla v29.4s, bb1.4s, a3a.s[2]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0a.s[2]\n" + "fmla v22.4s, bb2.4s, a1a.s[2]\n" + "fmla v26.4s, bb2.4s, a2a.s[2]\n" + "fmla v30.4s, bb2.4s, a3a.s[2]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0a.s[2]\n" + "fmla v23.4s, bb3.4s, a1a.s[2]\n" + "fmla v27.4s, bb3.4s, a2a.s[2]\n" + "fmla v31.4s, bb3.4s, a3a.s[2]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + // Unroll 7 + "fmla v16.4s, b0a.4s, a0a.s[3]\n" + "fmla v20.4s, b0a.4s, a1a.s[3]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, b0a.4s, a2a.s[3]\n" + "fmla v28.4s, b0a.4s, a3a.s[3]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0a.s[3]\n" + "fmla v21.4s, b1a.4s, a1a.s[3]\n" + "fmla v25.4s, b1a.4s, a2a.s[3]\n" + "fmla v29.4s, b1a.4s, a3a.s[3]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0a.s[3]\n" + "fmla v22.4s, b2a.4s, a1a.s[3]\n" + "fmla v26.4s, b2a.4s, a2a.s[3]\n" + "fmla v30.4s, b2a.4s, a3a.s[3]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0a.s[3]\n" + "fmla v23.4s, b3a.4s, a1a.s[3]\n" + "fmla v27.4s, b3a.4s, a2a.s[3]\n" + "fmla v31.4s, b3a.4s, a3a.s[3]\n" + "bne 1b\n" + + // Skip to here + "4:\n" + + // Detached final iteration + // Unroll 0 + "fmla v16.4s, bb0.4s, a0.s[0]\n" + "fmla v20.4s, bb0.4s, a1.s[0]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + "fmla v24.4s, bb0.4s, a2.s[0]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v28.4s, bb0.4s, a3.s[0]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0.s[0]\n" + "cbnz %w[oddk], 2f\n" // Deal with odd K before we load a0a + "fmla v21.4s, bb1.4s, a1.s[0]\n" + "ldr a0aq, [%[a_ptr0], #16]\n" + "fmla v25.4s, bb1.4s, a2.s[0]\n" + "fmla v29.4s, bb1.4s, a3.s[0]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0.s[0]\n" + "fmla v22.4s, bb2.4s, a1.s[0]\n" + "ldr a1aq, [%[a_ptr1], #16]\n" + "fmla v26.4s, bb2.4s, a2.s[0]\n" + "fmla v30.4s, bb2.4s, a3.s[0]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0.s[0]\n" + "fmla v23.4s, bb3.4s, a1.s[0]\n" + "ldr a2aq, [%[a_ptr2], #16]\n" + "fmla v27.4s, bb3.4s, a2.s[0]\n" + "fmla v31.4s, bb3.4s, a3.s[0]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + // Unroll 1 + "fmla v16.4s, b0a.4s, a0.s[1]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v20.4s, b0a.4s, a1.s[1]\n" + "ldr a3aq, [%[a_ptr3], #16]\n" + "fmla v24.4s, b0a.4s, a2.s[1]\n" + "fmla v28.4s, b0a.4s, a3.s[1]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0.s[1]\n" + "fmla v21.4s, b1a.4s, a1.s[1]\n" + "subs %w[loops], %w[loops], #1\n" + "fmla v25.4s, b1a.4s, a2.s[1]\n" + "fmla v29.4s, b1a.4s, a3.s[1]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0.s[1]\n" + "fmla v22.4s, b2a.4s, a1.s[1]\n" + "fmla v26.4s, b2a.4s, a2.s[1]\n" + "fmla v30.4s, b2a.4s, a3.s[1]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0.s[1]\n" + "fmla v23.4s, b3a.4s, a1.s[1]\n" + "fmla v27.4s, b3a.4s, a2.s[1]\n" + "fmla v31.4s, b3a.4s, a3.s[1]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + + // Unroll 2 + "fmla v16.4s, bb0.4s, a0.s[2]\n" + "fmla v20.4s, bb0.4s, a1.s[2]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2.s[2]\n" + "fmla v28.4s, bb0.4s, a3.s[2]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0.s[2]\n" + "fmla v21.4s, bb1.4s, a1.s[2]\n" + "fmla v25.4s, bb1.4s, a2.s[2]\n" + "fmla v29.4s, bb1.4s, a3.s[2]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0.s[2]\n" + "fmla v22.4s, bb2.4s, a1.s[2]\n" + "fmla v26.4s, bb2.4s, a2.s[2]\n" + "fmla v30.4s, bb2.4s, a3.s[2]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0.s[2]\n" + "fmla v23.4s, bb3.4s, a1.s[2]\n" + "fmla v27.4s, bb3.4s, a2.s[2]\n" + "fmla v31.4s, bb3.4s, a3.s[2]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + // Unroll 3 + "fmla v16.4s, b0a.4s, a0.s[3]\n" + "fmla v20.4s, b0a.4s, a1.s[3]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, b0a.4s, a2.s[3]\n" + "fmla v28.4s, b0a.4s, a3.s[3]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0.s[3]\n" + "fmla v21.4s, b1a.4s, a1.s[3]\n" + "ldr a3aq, [%[a_ptr3], #16]\n" + "fmla v25.4s, b1a.4s, a2.s[3]\n" + "fmla v29.4s, b1a.4s, a3.s[3]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0.s[3]\n" + "fmla v22.4s, b2a.4s, a1.s[3]\n" + "fmla v26.4s, b2a.4s, a2.s[3]\n" + "fmla v30.4s, b2a.4s, a3.s[3]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0.s[3]\n" + "fmla v23.4s, b3a.4s, a1.s[3]\n" + "fmla v27.4s, b3a.4s, a2.s[3]\n" + "fmla v31.4s, b3a.4s, a3.s[3]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + + // Unroll 4 + "fmla v16.4s, bb0.4s, a0a.s[0]\n" + "fmla v20.4s, bb0.4s, a1a.s[0]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, bb0.4s, a2a.s[0]\n" + "fmla v28.4s, bb0.4s, a3a.s[0]\n" + "ldr b0q, [%[b_ptr]]\n" + + "fmla v17.4s, bb1.4s, a0a.s[0]\n" + "fmla v21.4s, bb1.4s, a1a.s[0]\n" + "fmla v25.4s, bb1.4s, a2a.s[0]\n" + "fmla v29.4s, bb1.4s, a3a.s[0]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0a.s[0]\n" + "fmla v22.4s, bb2.4s, a1a.s[0]\n" + "fmla v26.4s, bb2.4s, a2a.s[0]\n" + "fmla v30.4s, bb2.4s, a3a.s[0]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0a.s[0]\n" + "fmla v23.4s, bb3.4s, a1a.s[0]\n" + "fmla v27.4s, bb3.4s, a2a.s[0]\n" + "fmla v31.4s, bb3.4s, a3a.s[0]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + // Unroll 5 + "fmla v16.4s, b0a.4s, a0a.s[1]\n" + "fmla v20.4s, b0a.4s, a1a.s[1]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v24.4s, b0a.4s, a2a.s[1]\n" + "fmla v28.4s, b0a.4s, a3a.s[1]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0a.s[1]\n" + "fmla v21.4s, b1a.4s, a1a.s[1]\n" + "fmla v25.4s, b1a.4s, a2a.s[1]\n" + "fmla v29.4s, b1a.4s, a3a.s[1]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0a.s[1]\n" + "fmla v22.4s, b2a.4s, a1a.s[1]\n" + "fmla v26.4s, b2a.4s, a2a.s[1]\n" + "fmla v30.4s, b2a.4s, a3a.s[1]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0a.s[1]\n" + "fmla v23.4s, b3a.4s, a1a.s[1]\n" + "fmla v27.4s, b3a.4s, a2a.s[1]\n" + "fmla v31.4s, b3a.4s, a3a.s[1]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + + // Unroll 6 + "fmla v16.4s, bb0.4s, a0a.s[2]\n" + "fmla v20.4s, bb0.4s, a1a.s[2]\n" + "fmla v24.4s, bb0.4s, a2a.s[2]\n" + "fmla v28.4s, bb0.4s, a3a.s[2]\n" + + "fmla v17.4s, bb1.4s, a0a.s[2]\n" + "fmla v21.4s, bb1.4s, a1a.s[2]\n" + "fmla v25.4s, bb1.4s, a2a.s[2]\n" + "fmla v29.4s, bb1.4s, a3a.s[2]\n" + + "fmla v18.4s, bb2.4s, a0a.s[2]\n" + "fmla v22.4s, bb2.4s, a1a.s[2]\n" + "fmla v26.4s, bb2.4s, a2a.s[2]\n" + "fmla v30.4s, bb2.4s, a3a.s[2]\n" + + "fmla v19.4s, bb3.4s, a0a.s[2]\n" + "fmla v23.4s, bb3.4s, a1a.s[2]\n" + "fmla v27.4s, bb3.4s, a2a.s[2]\n" + "fmla v31.4s, bb3.4s, a3a.s[2]\n" + + // Unroll 7 + "fmla v16.4s, b0a.4s, a0a.s[3]\n" + "fmla v17.4s, b1a.4s, a0a.s[3]\n" + "fmla v18.4s, b2a.4s, a0a.s[3]\n" + "fmla v19.4s, b3a.4s, a0a.s[3]\n" + + "fmla v20.4s, b0a.4s, a1a.s[3]\n" + "str q16, [%[c_ptr0]]\n" + "fmla v21.4s, b1a.4s, a1a.s[3]\n" + "str q17, [%[c_ptr0], #16]\n" + "fmla v22.4s, b2a.4s, a1a.s[3]\n" + "str q18, [%[c_ptr0], #32]\n" + "fmla v23.4s, b3a.4s, a1a.s[3]\n" + "str q19, [%[c_ptr0], #48]\n" + + "fmla v24.4s, b0a.4s, a2a.s[3]\n" + "str q20, [%[c_ptr1]]\n" + "fmla v25.4s, b1a.4s, a2a.s[3]\n" + "str q21, [%[c_ptr1], #16]\n" + "fmla v26.4s, b2a.4s, a2a.s[3]\n" + "str q22, [%[c_ptr1], #32]\n" + "fmla v27.4s, b3a.4s, a2a.s[3]\n" + "str q23, [%[c_ptr1], #48]\n" + + "fmla v28.4s, b0a.4s, a3a.s[3]\n" + "str q24, [%[c_ptr2]]\n" + "fmla v29.4s, b1a.4s, a3a.s[3]\n" + "str q25, [%[c_ptr2], #16]\n" + "fmla v30.4s, b2a.4s, a3a.s[3]\n" + "str q26, [%[c_ptr2], #32]\n" + "fmla v31.4s, b3a.4s, a3a.s[3]\n" + "str q27, [%[c_ptr2], #48]\n" + "b 3f\n" + + // Odd K case: Just do 4 more. + "2:\n" + "fmla v21.4s, bb1.4s, a1.s[0]\n" + "fmla v25.4s, bb1.4s, a2.s[0]\n" + "fmla v29.4s, bb1.4s, a3.s[0]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v18.4s, bb2.4s, a0.s[0]\n" + "fmla v22.4s, bb2.4s, a1.s[0]\n" + "fmla v26.4s, bb2.4s, a2.s[0]\n" + "fmla v30.4s, bb2.4s, a3.s[0]\n" + "ldr b2q, [%[b_ptr], #32]\n" + + "fmla v19.4s, bb3.4s, a0.s[0]\n" + "fmla v23.4s, bb3.4s, a1.s[0]\n" + "fmla v27.4s, bb3.4s, a2.s[0]\n" + "fmla v31.4s, bb3.4s, a3.s[0]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + // Unroll 1 + "fmla v16.4s, b0a.4s, a0.s[1]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v20.4s, b0a.4s, a1.s[1]\n" + "fmla v24.4s, b0a.4s, a2.s[1]\n" + "fmla v28.4s, b0a.4s, a3.s[1]\n" + "ldr b0aq, [%[b_ptr]]\n" + + "fmla v17.4s, b1a.4s, a0.s[1]\n" + "fmla v21.4s, b1a.4s, a1.s[1]\n" + "subs %w[loops], %w[loops], #1\n" + "fmla v25.4s, b1a.4s, a2.s[1]\n" + "fmla v29.4s, b1a.4s, a3.s[1]\n" + "ldr b1aq, [%[b_ptr], #16]\n" + + "fmla v18.4s, b2a.4s, a0.s[1]\n" + "fmla v22.4s, b2a.4s, a1.s[1]\n" + "fmla v26.4s, b2a.4s, a2.s[1]\n" + "fmla v30.4s, b2a.4s, a3.s[1]\n" + "ldr b2aq, [%[b_ptr], #32]\n" + + "fmla v19.4s, b3a.4s, a0.s[1]\n" + "fmla v23.4s, b3a.4s, a1.s[1]\n" + "fmla v27.4s, b3a.4s, a2.s[1]\n" + "fmla v31.4s, b3a.4s, a3.s[1]\n" + "ldr b3aq, [%[b_ptr], #48]\n" + + // Unroll 2 + "fmla v16.4s, bb0.4s, a0.s[2]\n" + "fmla v20.4s, bb0.4s, a1.s[2]\n" + "fmla v24.4s, bb0.4s, a2.s[2]\n" + "fmla v28.4s, bb0.4s, a3.s[2]\n" + + "fmla v17.4s, bb1.4s, a0.s[2]\n" + "fmla v21.4s, bb1.4s, a1.s[2]\n" + "fmla v25.4s, bb1.4s, a2.s[2]\n" + "fmla v29.4s, bb1.4s, a3.s[2]\n" + + "fmla v18.4s, bb2.4s, a0.s[2]\n" + "fmla v22.4s, bb2.4s, a1.s[2]\n" + "fmla v26.4s, bb2.4s, a2.s[2]\n" + "fmla v30.4s, bb2.4s, a3.s[2]\n" + + "fmla v19.4s, bb3.4s, a0.s[2]\n" + "fmla v23.4s, bb3.4s, a1.s[2]\n" + "fmla v27.4s, bb3.4s, a2.s[2]\n" + "fmla v31.4s, bb3.4s, a3.s[2]\n" + + // Unroll 3 + "fmla v16.4s, b0a.4s, a0.s[3]\n" + "fmla v17.4s, b1a.4s, a0.s[3]\n" + "fmla v18.4s, b2a.4s, a0.s[3]\n" + "fmla v19.4s, b3a.4s, a0.s[3]\n" + + "fmla v20.4s, b0a.4s, a1.s[3]\n" + "str q16, [%[c_ptr0]]\n" + "fmla v21.4s, b1a.4s, a1.s[3]\n" + "str q17, [%[c_ptr0], #16]\n" + "fmla v22.4s, b2a.4s, a1.s[3]\n" + "str q18, [%[c_ptr0], #32]\n" + "fmla v23.4s, b3a.4s, a1.s[3]\n" + "str q19, [%[c_ptr0], #48]\n" + + "fmla v24.4s, b0a.4s, a2.s[3]\n" + "str q20, [%[c_ptr1]]\n" + "fmla v25.4s, b1a.4s, a2.s[3]\n" + "str q21, [%[c_ptr1], #16]\n" + "fmla v26.4s, b2a.4s, a2.s[3]\n" + "str q22, [%[c_ptr1], #32]\n" + "fmla v27.4s, b3a.4s, a2.s[3]\n" + "str q23, [%[c_ptr1], #48]\n" + + "fmla v28.4s, b0a.4s, a3.s[3]\n" + "str q24, [%[c_ptr2]]\n" + "fmla v29.4s, b1a.4s, a3.s[3]\n" + "str q25, [%[c_ptr2], #16]\n" + "fmla v30.4s, b2a.4s, a3.s[3]\n" + "str q26, [%[c_ptr2], #32]\n" + "fmla v31.4s, b3a.4s, a3.s[3]\n" + "str q27, [%[c_ptr2], #48]\n" + + "3:\n" + "str q28, [%[c_ptr3]]\n" + "str q29, [%[c_ptr3], #16]\n" + "str q30, [%[c_ptr3], #32]\n" + "str q31, [%[c_ptr3], #48]\n" + + : [a_ptr0] "+r"(a_ptr0), [a_ptr1] "+r"(a_ptr1), [a_ptr2] "+r"(a_ptr2), [a_ptr3] "+r"(a_ptr3), + [b_ptr] "+r"(b_ptr), [loops] "+r"(loops) + : [ldb] "r"(ldbb), [oddk] "r"(oddk), [beta0] "r"(beta0), [betaptr] "r"(&beta), + [c_ptr0] "r"(c_ptr0), [c_ptr1] "r"(c_ptr1), [c_ptr2] "r"(c_ptr2), [c_ptr3] "r"(c_ptr3) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "cc", "memory"); + } + } +} + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp new file mode 100644 index 0000000000..c89514f98e --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_sgemv_pretransposed(const float *, int, const float *, float *, float, int, int); + +// Pretransposed SGEMV strategy class. +class sgemv_pretransposed +{ +public: + typedef float operand_type; + typedef float result_type; + + typedef void (*kern_type)(const float *, int, const float *, float *, float, int, int); + + /* Describes the data layout for matrix (A) input */ + + /* Note that often GEMV is expressed as a GEMM with M=1, i.e. A is the + * (row) vector and B is the matrix, but the standard GEMV arrangement + * is matrix A times (column) vector X. "A_transpose" is expressed in + * terms of this standard arrangement, so if the A matrix is in fact the + * B matrix from a GEMM call, the sense of the transpose needs to be + * reversed. */ + static const int A_interleave = 32; + static const int A_block = 1; + static const bool A_transpose = false; + + /* Kernel blocking parameters */ + static const int out_width = 32; + static const int k_unroll = 1; + + kern_type kernel = a64_sgemv_pretransposed; + + sgemv_pretransposed(const CPUInfo *ci) + { + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp new file mode 100644 index 0000000000..290759822a --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp @@ -0,0 +1,794 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <algorithm> + +#include <arm_neon.h> + +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +namespace arm_gemm +{ +void a64_sgemv_pretransposed(const float *A, int lda, const float *X, float *Y, float beta, int M, int N) +{ + const bool beta0 = (beta == 0.0f); + const bool beta1 = (beta == 1.0f); + + for(int x = 0; x < N; x += 32) + { + float *y_ptr = Y + x; + + // How many elements are we processing in this loop? + int l = std::min(N - x, 32); + + register float32x4_t r0 asm("v24"); + register float32x4_t r1 asm("v25"); + register float32x4_t r2 asm("v26"); + register float32x4_t r3 asm("v27"); + register float32x4_t r4 asm("v28"); + register float32x4_t r5 asm("v29"); + register float32x4_t r6 asm("v30"); + register float32x4_t r7 asm("v31"); + + register float32x4_t x0 asm("v0"); + register float32x4_t x0a asm("v1"); + + const float *x_ptr = X; + const float *a_ptr = A + ((x / 32) * lda); + + if(beta0) + { + r0 = r1 = r2 = r3 = r4 = r5 = r6 = r7 = vdupq_n_f32(0.0f); + } + else + { + if(l == 32) + { + // Fastest path - load all 8 vectors + r0 = vld1q_f32(y_ptr); + r1 = vld1q_f32(y_ptr + 4); + r2 = vld1q_f32(y_ptr + 8); + r3 = vld1q_f32(y_ptr + 12); + r4 = vld1q_f32(y_ptr + 16); + r5 = vld1q_f32(y_ptr + 20); + r6 = vld1q_f32(y_ptr + 24); + r7 = vld1q_f32(y_ptr + 28); + } + else + { + // Slow case - leftovers. Note that we don't care about + // out-of-range vectors and lanes as we will throw them away at + // the end. + int vecs = l / 4; // How many leftover vectors? + int oddbits = l % 4; // And how many odd single values? + + if(oddbits) + { + // Load the outstanding odd values into a vector first + float32x4_t oddvec = vdupq_n_f32(0.0f); // This does not really need to be initialized, but the compiler has a hard time with that. + float *oddbase = y_ptr + l - oddbits; + + switch(oddbits) + { + case 3: + oddvec = vld1q_lane_f32(oddbase + 2, oddvec, 2); + // fall through + case 2: + oddvec = vld1q_lane_f32(oddbase + 1, oddvec, 1); + // fall through + case 1: + oddvec = vld1q_lane_f32(oddbase, oddvec, 0); + break; + + default: + UNREACHABLE("Impossible case in switch."); + } + + // Now load the whole vectors, putting the oddments in when we run out. + do + { + if(vecs == 0) + { + r0 = oddvec; + break; + } + + r0 = vld1q_f32(y_ptr); + if(--vecs == 0) + { + r1 = oddvec; + break; + } + + r1 = vld1q_f32(y_ptr + 4); + if(--vecs == 0) + { + r2 = oddvec; + break; + } + + r2 = vld1q_f32(y_ptr + 8); + if(--vecs == 0) + { + r3 = oddvec; + break; + } + + r3 = vld1q_f32(y_ptr + 12); + if(--vecs == 0) + { + r4 = oddvec; + break; + } + + r4 = vld1q_f32(y_ptr + 16); + if(--vecs == 0) + { + r5 = oddvec; + break; + } + + r5 = vld1q_f32(y_ptr + 20); + if(--vecs == 0) + { + r6 = oddvec; + break; + } + + r6 = vld1q_f32(y_ptr + 24); + r7 = oddvec; + } + while(0); + } + else + { + // Slightly less slow path - just load the whole vectors + do + { + // It can't be the case that oddbits==0 AND vecs==0 or we wouldn't be here. + if(vecs == 0) + { + UNREACHABLE("Impossible lack of work to do"); + } + + r0 = vld1q_f32(y_ptr); + if(--vecs == 0) + { + break; + } + + r1 = vld1q_f32(y_ptr + 4); + if(--vecs == 0) + { + break; + } + + r2 = vld1q_f32(y_ptr + 8); + if(--vecs == 0) + { + break; + } + + r3 = vld1q_f32(y_ptr + 12); + if(--vecs == 0) + { + break; + } + + r4 = vld1q_f32(y_ptr + 16); + if(--vecs == 0) + { + break; + } + + r5 = vld1q_f32(y_ptr + 20); + if(--vecs == 0) + { + break; + } + + r6 = vld1q_f32(y_ptr + 24); + } + while(0); + } + } + + if(!beta1) + { + const float32x4_t vb = vdupq_n_f32(beta); + + r0 = vmulq_f32(r0, vb); + r1 = vmulq_f32(r1, vb); + r2 = vmulq_f32(r2, vb); + r3 = vmulq_f32(r3, vb); + r4 = vmulq_f32(r4, vb); + r5 = vmulq_f32(r5, vb); + r6 = vmulq_f32(r6, vb); + r7 = vmulq_f32(r7, vb); + } + } + + if(M >= 8) + { + int k = (M / 8) - 1; + x0 = vld1q_f32(x_ptr); + + __asm __volatile( + "ldr q2, [%[a_ptr], #0]\n" + "ldr q3, [%[a_ptr], #16]\n" + "ldr q4, [%[a_ptr], #32]\n" + "ldr q5, [%[a_ptr], #48]\n" + "ldr q6, [%[a_ptr], #64]\n" + "ldr q7, [%[a_ptr], #80]\n" + "ldr q8, [%[a_ptr], #96]\n" + "ldr q9, [%[a_ptr], #112]\n" + "ldr q10, [%[a_ptr], #128]\n" + "ldr q11, [%[a_ptr], #144]\n" + "ldr q12, [%[a_ptr], #160]\n" + "ldr q13, [%[a_ptr], #176]\n" + "ldr q14, [%[a_ptr], #192]\n" + "ldr q15, [%[a_ptr], #208]\n" + "ldr q16, [%[a_ptr], #224]\n" + "ldr q17, [%[a_ptr], #240]\n" + "ldr q18, [%[a_ptr], #256]\n" + "ldr q19, [%[a_ptr], #272]\n" + "ldr q20, [%[a_ptr], #288]\n" + "ldr q21, [%[a_ptr], #304]\n" + "ldr q22, [%[a_ptr], #320]\n" + "ldr q23, [%[a_ptr], #336]\n" ASM_PREFETCH("[%[a_ptr], #384]") + ASM_PREFETCH("[%[a_ptr], #448]") + ASM_PREFETCH("[%[a_ptr], #512]") + ASM_PREFETCH("[%[a_ptr], #576]") + ASM_PREFETCH("[%[a_ptr], #640]") + ASM_PREFETCH("[%[a_ptr], #704]") + ASM_PREFETCH("[%[a_ptr], #768]") + ASM_PREFETCH("[%[a_ptr], #832]") + ASM_PREFETCH("[%[a_ptr], #896]") + ASM_PREFETCH("[%[a_ptr], #960]") + ASM_PREFETCH("[%[a_ptr], #1024]") + ASM_PREFETCH("[%[a_ptr], #1088]") + ASM_PREFETCH("[%[a_ptr], #1152]") + ASM_PREFETCH("[%[a_ptr], #1216]") + ASM_PREFETCH("[%[a_ptr], #1280]") + ASM_PREFETCH("[%[a_ptr], #1344]") + ASM_PREFETCH("[%[a_ptr], #1408]") + ASM_PREFETCH("[%[a_ptr], #1472]") + ASM_PREFETCH("[%[a_ptr], #1536]") + ASM_PREFETCH("[%[a_ptr], #1600]") + ASM_PREFETCH("[%[a_ptr], #1664]") + ASM_PREFETCH("[%[a_ptr], #1728]") + ASM_PREFETCH("[%[a_ptr], #1792]") + ASM_PREFETCH("[%[a_ptr], #1856]") + ASM_PREFETCH("[%[a_ptr], #1920]") + ASM_PREFETCH("[%[a_ptr], #1984]") + "add %[a_ptr], %[a_ptr], #352\n" + + "cbz %w[k], 2f\n" + + "1:\n" + // Unroll 0 + "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" + "ldr %q[x0a], [%[x_ptr], #16]\n" + "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" + "ldr q3, [%[a_ptr], #0]\n" + "subs %w[k], %w[k], #1\n" + "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" + "ldr q4, [%[a_ptr], #16]\n" + "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" + "ldr q5, [%[a_ptr], #32]\n" + "add %[x_ptr], %[x_ptr], #32\n" ASM_PREFETCH("[%[a_ptr], #1664]") + "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" + "ldr q6, [%[a_ptr], #48]\n" + "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" + "ldr q7, [%[a_ptr], #64]\n" + "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" + "ldr q8, [%[a_ptr], #80]\n" + "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" + "ldr q9, [%[a_ptr], #96]\n" ASM_PREFETCH("[%[a_ptr], #1728]") + + // Unroll 1 + "fmla %[r0].4s, v10.4s, %[x0].s[1]\n" + "ldr q10, [%[a_ptr], #112]\n" + "fmla %[r1].4s, v11.4s, %[x0].s[1]\n" + "ldr q11, [%[a_ptr], #128]\n" + "fmla %[r2].4s, v12.4s, %[x0].s[1]\n" + "ldr q12, [%[a_ptr], #144]\n" + "fmla %[r3].4s, v13.4s, %[x0].s[1]\n" + "ldr q13, [%[a_ptr], #160]\n" ASM_PREFETCH("[%[a_ptr], #1792]") + "fmla %[r4].4s, v14.4s, %[x0].s[1]\n" + "ldr q14, [%[a_ptr], #176]\n" + "fmla %[r5].4s, v15.4s, %[x0].s[1]\n" + "ldr q15, [%[a_ptr], #192]\n" + "fmla %[r6].4s, v16.4s, %[x0].s[1]\n" + "ldr q16, [%[a_ptr], #208]\n" + "fmla %[r7].4s, v17.4s, %[x0].s[1]\n" + "ldr q17, [%[a_ptr], #224]\n" ASM_PREFETCH("[%[a_ptr], #1856]") + + // Unroll 2 + "fmla %[r0].4s, v18.4s, %[x0].s[2]\n" + "ldr q18, [%[a_ptr], #240]\n" + "fmla %[r1].4s, v19.4s, %[x0].s[2]\n" + "ldr q19, [%[a_ptr], #256]\n" + "fmla %[r2].4s, v20.4s, %[x0].s[2]\n" + "ldr q20, [%[a_ptr], #272]\n" + "fmla %[r3].4s, v21.4s, %[x0].s[2]\n" + "ldr q21, [%[a_ptr], #288]\n" ASM_PREFETCH("[%[a_ptr], #1920]") + "fmla %[r4].4s, v22.4s, %[x0].s[2]\n" + "ldr q22, [%[a_ptr], #304]\n" + "fmla %[r5].4s, v23.4s, %[x0].s[2]\n" + "ldr q23, [%[a_ptr], #320]\n" + "fmla %[r6].4s, v3.4s, %[x0].s[2]\n" + "ldr q2, [%[a_ptr], #336]\n" + "ldr q3, [%[a_ptr], #352]\n" + "fmla %[r7].4s, v4.4s, %[x0].s[2]\n" + "ldr q4, [%[a_ptr], #368]\n" ASM_PREFETCH("[%[a_ptr], #1984]") + + // Unroll 3 + "fmla %[r0].4s, v5.4s, %[x0].s[3]\n" + "ldr q5, [%[a_ptr], #384]\n" + "fmla %[r1].4s, v6.4s, %[x0].s[3]\n" + "ldr q6, [%[a_ptr], #400]\n" + "fmla %[r2].4s, v7.4s, %[x0].s[3]\n" + "ldr q7, [%[a_ptr], #416]\n" + "fmla %[r3].4s, v8.4s, %[x0].s[3]\n" ASM_PREFETCH("[%[a_ptr], #2048]") + "ldr q8, [%[a_ptr], #432]\n" + "fmla %[r4].4s, v9.4s, %[x0].s[3]\n" + "ldr q9, [%[a_ptr], #448]\n" + "fmla %[r5].4s, v10.4s, %[x0].s[3]\n" + "ldr q10, [%[a_ptr], #464]\n" + "fmla %[r6].4s, v11.4s, %[x0].s[3]\n" + "ldr q11, [%[a_ptr], #480]\n" + "fmla %[r7].4s, v12.4s, %[x0].s[3]\n" + "ldr q12, [%[a_ptr], #496]\n" ASM_PREFETCH("[%[a_ptr], #2112]") + + // Unroll 4 + "fmla %[r0].4s, v13.4s, %[x0a].s[0]\n" + "ldr %q[x0], [%[x_ptr]]\n" + "fmla %[r1].4s, v14.4s, %[x0a].s[0]\n" + "ldr q14, [%[a_ptr], #512]\n" + "fmla %[r2].4s, v15.4s, %[x0a].s[0]\n" + "ldr q15, [%[a_ptr], #528]\n" + "fmla %[r3].4s, v16.4s, %[x0a].s[0]\n" ASM_PREFETCH("[%[a_ptr], #2176]") + "ldr q16, [%[a_ptr], #544]\n" + "fmla %[r4].4s, v17.4s, %[x0a].s[0]\n" + "ldr q17, [%[a_ptr], #560]\n" + "fmla %[r5].4s, v18.4s, %[x0a].s[0]\n" + "ldr q18, [%[a_ptr], #576]\n" + "fmla %[r6].4s, v19.4s, %[x0a].s[0]\n" + "ldr q19, [%[a_ptr], #592]\n" + "fmla %[r7].4s, v20.4s, %[x0a].s[0]\n" + "ldr q20, [%[a_ptr], #608]\n" ASM_PREFETCH("[%[a_ptr], #2240]") + + // Unroll 5 + "fmla %[r0].4s, v21.4s, %[x0a].s[1]\n" + "ldr q21, [%[a_ptr], #624]\n" + "fmla %[r1].4s, v22.4s, %[x0a].s[1]\n" + "ldr q22, [%[a_ptr], #640]\n" + "fmla %[r2].4s, v23.4s, %[x0a].s[1]\n" + "ldr q23, [%[a_ptr], #656]\n" + "fmla %[r3].4s, v2.4s, %[x0a].s[1]\n" + "ldr q2, [%[a_ptr], #672]\n" ASM_PREFETCH("[%[a_ptr], #2304]") + "fmla %[r4].4s, v3.4s, %[x0a].s[1]\n" + "ldr q3, [%[a_ptr], #688]\n" + "fmla %[r5].4s, v4.4s, %[x0a].s[1]\n" + "ldr q4, [%[a_ptr], #704]\n" + "fmla %[r6].4s, v5.4s, %[x0a].s[1]\n" + "ldr q5, [%[a_ptr], #720]\n" + "fmla %[r7].4s, v6.4s, %[x0a].s[1]\n" + "ldr q6, [%[a_ptr], #736]\n" ASM_PREFETCH("[%[a_ptr], #2368]") + + // Unroll 6 + "fmla %[r0].4s, v7.4s, %[x0a].s[2]\n" + "ldr q7, [%[a_ptr], #752]\n" + "fmla %[r1].4s, v8.4s, %[x0a].s[2]\n" + "ldr q8, [%[a_ptr], #768]\n" + "fmla %[r2].4s, v9.4s, %[x0a].s[2]\n" + "ldr q9, [%[a_ptr], #784]\n" + "fmla %[r3].4s, v10.4s, %[x0a].s[2]\n" + "ldr q10, [%[a_ptr], #800]\n" ASM_PREFETCH("[%[a_ptr], #2432]") + "fmla %[r4].4s, v11.4s, %[x0a].s[2]\n" + "ldr q11, [%[a_ptr], #816]\n" + "fmla %[r5].4s, v12.4s, %[x0a].s[2]\n" + "ldr q12, [%[a_ptr], #832]\n" + "fmla %[r6].4s, v14.4s, %[x0a].s[2]\n" + "ldr q13, [%[a_ptr], #848]\n" + "ldr q14, [%[a_ptr], #864]\n" + "fmla %[r7].4s, v15.4s, %[x0a].s[2]\n" + "ldr q15, [%[a_ptr], #880]\n" ASM_PREFETCH("[%[a_ptr], #2496]") + + // Unroll 7 + "fmla %[r0].4s, v16.4s, %[x0a].s[3]\n" + "ldr q16, [%[a_ptr], #896]\n" + "fmla %[r1].4s, v17.4s, %[x0a].s[3]\n" + "ldr q17, [%[a_ptr], #912]\n" + "fmla %[r2].4s, v18.4s, %[x0a].s[3]\n" + "ldr q18, [%[a_ptr], #928]\n" + "fmla %[r3].4s, v19.4s, %[x0a].s[3]\n" ASM_PREFETCH("[%[a_ptr], #2560]") + "ldr q19, [%[a_ptr], #944]\n" + "fmla %[r4].4s, v20.4s, %[x0a].s[3]\n" + "ldr q20, [%[a_ptr], #960]\n" + "fmla %[r5].4s, v21.4s, %[x0a].s[3]\n" + "ldr q21, [%[a_ptr], #976]\n" + "add %[a_ptr], %[a_ptr], #1024\n" + "fmla %[r6].4s, v22.4s, %[x0a].s[3]\n" + "ldr q22, [%[a_ptr], #-32]\n" + "fmla %[r7].4s, v23.4s, %[x0a].s[3]\n" + "ldr q23, [%[a_ptr], #-16]\n" ASM_PREFETCH("[%[a_ptr], #1600]") + "bne 1b\n" + + // Detached final iteration + "2:\n" + + // Unroll 0 + "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" + "ldr %q[x0a], [%[x_ptr], #16]\n" + "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" + "ldr q3, [%[a_ptr], #0]\n" + "subs %w[k], %w[k], #1\n" + "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" + "ldr q4, [%[a_ptr], #16]\n" + "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" + "ldr q5, [%[a_ptr], #32]\n" + "add %[x_ptr], %[x_ptr], #32\n" + "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" + "ldr q6, [%[a_ptr], #48]\n" + "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" + "ldr q7, [%[a_ptr], #64]\n" + "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" + "ldr q8, [%[a_ptr], #80]\n" + "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" + "ldr q9, [%[a_ptr], #96]\n" + + // Unroll 1 + "fmla %[r0].4s, v10.4s, %[x0].s[1]\n" + "ldr q10, [%[a_ptr], #112]\n" + "fmla %[r1].4s, v11.4s, %[x0].s[1]\n" + "ldr q11, [%[a_ptr], #128]\n" + "fmla %[r2].4s, v12.4s, %[x0].s[1]\n" + "ldr q12, [%[a_ptr], #144]\n" + "fmla %[r3].4s, v13.4s, %[x0].s[1]\n" + "ldr q13, [%[a_ptr], #160]\n" + "fmla %[r4].4s, v14.4s, %[x0].s[1]\n" + "ldr q14, [%[a_ptr], #176]\n" + "fmla %[r5].4s, v15.4s, %[x0].s[1]\n" + "ldr q15, [%[a_ptr], #192]\n" + "fmla %[r6].4s, v16.4s, %[x0].s[1]\n" + "ldr q16, [%[a_ptr], #208]\n" + "fmla %[r7].4s, v17.4s, %[x0].s[1]\n" + "ldr q17, [%[a_ptr], #224]\n" + + // Unroll 2 + "fmla %[r0].4s, v18.4s, %[x0].s[2]\n" + "ldr q18, [%[a_ptr], #240]\n" + "fmla %[r1].4s, v19.4s, %[x0].s[2]\n" + "ldr q19, [%[a_ptr], #256]\n" + "fmla %[r2].4s, v20.4s, %[x0].s[2]\n" + "ldr q20, [%[a_ptr], #272]\n" + "fmla %[r3].4s, v21.4s, %[x0].s[2]\n" + "ldr q21, [%[a_ptr], #288]\n" + "fmla %[r4].4s, v22.4s, %[x0].s[2]\n" + "ldr q22, [%[a_ptr], #304]\n" + "fmla %[r5].4s, v23.4s, %[x0].s[2]\n" + "ldr q23, [%[a_ptr], #320]\n" + "fmla %[r6].4s, v3.4s, %[x0].s[2]\n" + "ldr q2, [%[a_ptr], #336]\n" + "ldr q3, [%[a_ptr], #352]\n" + "fmla %[r7].4s, v4.4s, %[x0].s[2]\n" + "ldr q4, [%[a_ptr], #368]\n" + + // Unroll 3 + "fmla %[r0].4s, v5.4s, %[x0].s[3]\n" + "ldr q5, [%[a_ptr], #384]\n" + "fmla %[r1].4s, v6.4s, %[x0].s[3]\n" + "ldr q6, [%[a_ptr], #400]\n" + "fmla %[r2].4s, v7.4s, %[x0].s[3]\n" + "ldr q7, [%[a_ptr], #416]\n" + "fmla %[r3].4s, v8.4s, %[x0].s[3]\n" + "ldr q8, [%[a_ptr], #432]\n" + "fmla %[r4].4s, v9.4s, %[x0].s[3]\n" + "ldr q9, [%[a_ptr], #448]\n" + "fmla %[r5].4s, v10.4s, %[x0].s[3]\n" + "ldr q10, [%[a_ptr], #464]\n" + "fmla %[r6].4s, v11.4s, %[x0].s[3]\n" + "ldr q11, [%[a_ptr], #480]\n" + "fmla %[r7].4s, v12.4s, %[x0].s[3]\n" + "ldr q12, [%[a_ptr], #496]\n" + + // Unroll 4 + "fmla %[r0].4s, v13.4s, %[x0a].s[0]\n" + "fmla %[r1].4s, v14.4s, %[x0a].s[0]\n" + "ldr q14, [%[a_ptr], #512]\n" + "fmla %[r2].4s, v15.4s, %[x0a].s[0]\n" + "ldr q15, [%[a_ptr], #528]\n" + "fmla %[r3].4s, v16.4s, %[x0a].s[0]\n" + "ldr q16, [%[a_ptr], #544]\n" + "fmla %[r4].4s, v17.4s, %[x0a].s[0]\n" + "ldr q17, [%[a_ptr], #560]\n" + "fmla %[r5].4s, v18.4s, %[x0a].s[0]\n" + "ldr q18, [%[a_ptr], #576]\n" + "fmla %[r6].4s, v19.4s, %[x0a].s[0]\n" + "ldr q19, [%[a_ptr], #592]\n" + "fmla %[r7].4s, v20.4s, %[x0a].s[0]\n" + "ldr q20, [%[a_ptr], #608]\n" + + // Unroll 5 + "fmla %[r0].4s, v21.4s, %[x0a].s[1]\n" + "ldr q21, [%[a_ptr], #624]\n" + "fmla %[r1].4s, v22.4s, %[x0a].s[1]\n" + "ldr q22, [%[a_ptr], #640]\n" + "fmla %[r2].4s, v23.4s, %[x0a].s[1]\n" + "ldr q23, [%[a_ptr], #656]\n" + "fmla %[r3].4s, v2.4s, %[x0a].s[1]\n" + "add %[a_ptr], %[a_ptr], #672\n" + "fmla %[r4].4s, v3.4s, %[x0a].s[1]\n" + "fmla %[r5].4s, v4.4s, %[x0a].s[1]\n" + "fmla %[r6].4s, v5.4s, %[x0a].s[1]\n" + "fmla %[r7].4s, v6.4s, %[x0a].s[1]\n" + + // Unroll 6 + "fmla %[r0].4s, v7.4s, %[x0a].s[2]\n" + "fmla %[r1].4s, v8.4s, %[x0a].s[2]\n" + "fmla %[r2].4s, v9.4s, %[x0a].s[2]\n" + "fmla %[r3].4s, v10.4s, %[x0a].s[2]\n" + "fmla %[r4].4s, v11.4s, %[x0a].s[2]\n" + "fmla %[r5].4s, v12.4s, %[x0a].s[2]\n" + "fmla %[r6].4s, v14.4s, %[x0a].s[2]\n" + "fmla %[r7].4s, v15.4s, %[x0a].s[2]\n" + + // Unroll 7 + "fmla %[r0].4s, v16.4s, %[x0a].s[3]\n" + "fmla %[r1].4s, v17.4s, %[x0a].s[3]\n" + "fmla %[r2].4s, v18.4s, %[x0a].s[3]\n" + "fmla %[r3].4s, v19.4s, %[x0a].s[3]\n" + "fmla %[r4].4s, v20.4s, %[x0a].s[3]\n" + "fmla %[r5].4s, v21.4s, %[x0a].s[3]\n" + "fmla %[r6].4s, v22.4s, %[x0a].s[3]\n" + "fmla %[r7].4s, v23.4s, %[x0a].s[3]\n" + : + [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), + [x0] "+w"(x0), [x0a] "+w"(x0a), [k] "+r"(k), + [r0] "+w"(r0), [r1] "+w"(r1), [r2] "+w"(r2), [r3] "+w"(r3), + [r4] "+w"(r4), [r5] "+w"(r5), [r6] "+w"(r6), [r7] "+w"(r7) + : + : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "cc", "memory"); + } + + // Deal with ragged M + if(M % 8) + { + int l = (M % 8) - 1; + + __asm __volatile( + "ldr q2, [%[a_ptr], #0]\n" + "ldr q3, [%[a_ptr], #16]\n" + "ldr q4, [%[a_ptr], #32]\n" + "ldr q5, [%[a_ptr], #48]\n" + "ldr q6, [%[a_ptr], #64]\n" + "ldr q7, [%[a_ptr], #80]\n" + "ldr q8, [%[a_ptr], #96]\n" + "ldr q9, [%[a_ptr], #112]\n" + "ldr %s[x0], [%[x_ptr]]\n" + "add %[a_ptr], %[a_ptr], #128\n" + "add %[x_ptr], %[x_ptr], #4\n" + + "cbz %w[l], 2f\n" + + "1:\n" + "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" + "ldr q2, [%[a_ptr], #0]\n" + "subs %w[l], %w[l], #1\n" + "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" + "ldr q3, [%[a_ptr], #16]\n" + "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" + "ldr q4, [%[a_ptr], #32]\n" + "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" + "ldr q5, [%[a_ptr], #48]\n" + "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" + "ldr q6, [%[a_ptr], #64]\n" + "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" + "ldr q7, [%[a_ptr], #80]\n" + "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" + "ldr q8, [%[a_ptr], #96]\n" + "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" + "ldr q9, [%[a_ptr], #112]\n" + "ldr %s[x0], [%[x_ptr]]\n" + "add %[a_ptr], %[a_ptr], #128\n" + "add %[x_ptr], %[x_ptr], #4\n" + "bne 1b\n" + + "2:\n" + + "fmla %[r0].4s, v2.4s, %[x0].s[0]\n" + "fmla %[r1].4s, v3.4s, %[x0].s[0]\n" + "fmla %[r2].4s, v4.4s, %[x0].s[0]\n" + "fmla %[r3].4s, v5.4s, %[x0].s[0]\n" + "fmla %[r4].4s, v6.4s, %[x0].s[0]\n" + "fmla %[r5].4s, v7.4s, %[x0].s[0]\n" + "fmla %[r6].4s, v8.4s, %[x0].s[0]\n" + "fmla %[r7].4s, v9.4s, %[x0].s[0]\n" + : + [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), + [x0] "+w"(x0), [l] "+r"(l), + [r0] "+w"(r0), [r1] "+w"(r1), [r2] "+w"(r2), [r3] "+w"(r3), + [r4] "+w"(r4), [r5] "+w"(r5), [r6] "+w"(r6), [r7] "+w"(r7) + : + : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", "memory"); + } + + if(l == 32) + { + // Fast path + vst1q_f32(y_ptr, r0); + vst1q_f32(y_ptr + 4, r1); + vst1q_f32(y_ptr + 8, r2); + vst1q_f32(y_ptr + 12, r3); + vst1q_f32(y_ptr + 16, r4); + vst1q_f32(y_ptr + 20, r5); + vst1q_f32(y_ptr + 24, r6); + vst1q_f32(y_ptr + 28, r7); + } + else + { + int vecs = l / 4; + int oddbits = l % 4; + + if(oddbits) + { + // As above - slowest path deals with vectors plus odd bits + float32x4_t oddvec; + + do + { + if(vecs == 0) + { + oddvec = r0; + break; + } + + vst1q_f32(y_ptr, r0); + if(--vecs == 0) + { + oddvec = r1; + break; + } + + vst1q_f32(y_ptr + 4, r1); + if(--vecs == 0) + { + oddvec = r2; + break; + } + + vst1q_f32(y_ptr + 8, r2); + if(--vecs == 0) + { + oddvec = r3; + break; + } + + vst1q_f32(y_ptr + 12, r3); + if(--vecs == 0) + { + oddvec = r4; + break; + } + + vst1q_f32(y_ptr + 16, r4); + if(--vecs == 0) + { + oddvec = r5; + break; + } + + vst1q_f32(y_ptr + 20, r5); + if(--vecs == 0) + { + oddvec = r6; + break; + } + + vst1q_f32(y_ptr + 24, r6); + oddvec = r7; + } + while(0); + + float *oddbase = y_ptr + l - oddbits; + + switch(oddbits) + { + case 3: + vst1q_lane_f32(oddbase + 2, oddvec, 2); + // fall through + case 2: + vst1q_lane_f32(oddbase + 1, oddvec, 1); + // fall through + case 1: + vst1q_lane_f32(oddbase, oddvec, 0); + break; + + default: + // oddbits must be 1, 2 or 3. + UNREACHABLE("Impossible case in switch."); + } + } + else + { + // As above - medium path deals with vectors only + do + { + if(vecs == 0) + { + UNREACHABLE("vecs and oddbits can't both be 0"); + } + + vst1q_f32(y_ptr, r0); + if(--vecs == 0) + { + break; + } + + vst1q_f32(y_ptr + 4, r1); + if(--vecs == 0) + { + break; + } + + vst1q_f32(y_ptr + 8, r2); + if(--vecs == 0) + { + break; + } + + vst1q_f32(y_ptr + 12, r3); + if(--vecs == 0) + { + break; + } + + vst1q_f32(y_ptr + 16, r4); + if(--vecs == 0) + { + break; + } + + vst1q_f32(y_ptr + 20, r5); + if(--vecs == 0) + { + break; + } + + vst1q_f32(y_ptr + 24, r6); + } + while(0); + } + } + } +} + +} // namespace arm_gemm + +#endif // aarch64 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp new file mode 100644 index 0000000000..5b9bd72c89 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans.hpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_sgemv_trans(const float *, const float *, float *, float, int, int, int); + +// Transposed SGEMV strategy class. +class sgemv_trans +{ +public: + typedef float operand_type; + typedef float result_type; + + typedef void (*kern_type)(const float *, const float *, float *, float, int, int, int); + + /* Kernel blocking parameters */ + static const int out_width = 96; + static const int k_unroll = 1; + + kern_type kernel = a64_sgemv_trans; + + sgemv_trans(const CPUInfo *ci) + { + } +}; + +} // namespace arm_gemm + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp new file mode 100644 index 0000000000..3309baff3a --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp @@ -0,0 +1,913 @@ +/* + * Copyright (c) 2017-2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include <cstddef> + +#include <arm_neon.h> + +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +// Kernel implementation - transposed GEMV +// +// The kernel will process "M" rows of A (= steps of dot product) and "N" +// columns (= dot products total) +// +// General plan is to do as many columns simultaneously as possible - a +// reasonable limit is half the NEON regfile = 64 total accumulators. +// +// It's possible that messing around with sub-blocking M and N can yield +// higher performance, but that's left to the outer loop. In this kernel we +// process all of M at the same time. + +// How far ahead to prefetch for the first and subsequent prefetches. +// These values work for A72 on JunoR2... + +#define FIRST_PFD 9 +#define PFD 6 + +namespace arm_gemm +{ +void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float alpha, int lda, int M, int N) +{ + const float *a_ptr_base = Astart; + float *y_ptr = Ystart; + + register const float32x4_t va asm("v1") = vdupq_n_f32(alpha); + + int firstpfd = FIRST_PFD; + if(firstpfd > M) + { + firstpfd = (M - 1); + } + + int pfd = PFD; + if(pfd > M) + { + pfd = (M - 1); + } + + ptrdiff_t jump = lda * sizeof(int); + + for(; N >= 96; N -= 96) + { + int k = M - 1; + + const float *a_ptr = a_ptr_base; + const float *x_ptr = Xstart; + const float *pf_ptr = a_ptr; + const float *firstpf_ptr = a_ptr; + const float *pf_limit = a_ptr + (M * lda); + + for(int i = 0; i < firstpfd; i++) + { + prefetch_1x(firstpf_ptr); + firstpf_ptr += lda; + } + + for(int i = 0; i < pfd; i++) + { + prefetch_5x(pf_ptr + 16); + pf_ptr += lda; + } + + a_ptr_base += 96; + + __asm __volatile( + "movi v8.4s,#0x0\n" + "ldr w0, [%[x_ptr]]\n" + "movi v9.4s,#0x0\n" + "ldr q2, [%[a_ptr], #0]\n" + "movi v10.4s,#0x0\n" + "ldr q3, [%[a_ptr], #0x10]\n" + "movi v11.4s,#0x0\n" + "ldr q4, [%[a_ptr], #0x20]\n" + "movi v12.4s,#0x0\n" + "ldr q5, [%[a_ptr], #0x30]\n" + "movi v13.4s,#0x0\n" + "ldr q6, [%[a_ptr], #0x40]\n" + "movi v14.4s,#0x0\n" + "ldr q7, [%[a_ptr], #0x50]\n" + "movi v15.4s,#0x0\n" ASM_PREFETCH("[%[firstpf_ptr]]") + "movi v16.4s, #0x0\n" + "movi v17.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #64]") + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #128]") + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #192]") + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #256]") + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" ASM_PREFETCH("[%[pf_ptr], #320]") + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "add %[pf_ptr], %[pf_ptr], %[jump]\n" + "movi v28.4s, #0x0\n" + "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + + // Skip everything if there are no iterations of the main loop to do. + "cbz %w[k], 10f\n" + + // Loop with all prefetches. Exit this loop when firstpf_ptr + // hits pf_limit. + "1:\n" + "dup v0.4s, w0\n" + "ldr w0, [%[x_ptr], #4]\n" + "add %[x_ptr], %[x_ptr], #0x4\n" + "fmla v8.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x60]\n" + "fmla v9.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x70]\n" ASM_PREFETCH("[%[firstpf_ptr]]") + "fmla v10.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x80]\n" + "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n" + "fmla v11.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x90]\n" + "sub %w[k], %w[k], #1\n" ASM_PREFETCH("[%[x_ptr], #128]") + "fmla v12.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0xa0]\n" + "fmla v13.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0xb0]\n" ASM_PREFETCH("[%[pf_ptr], #0x40]") + "fmla v14.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0xc0]\n" + "fmla v15.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0xd0]\n" + "fmla v16.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0xe0]\n" + "fmla v17.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0xf0]\n" ASM_PREFETCH("[%[pf_ptr], #0x80]") + "fmla v18.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x100]\n" + "fmla v19.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x110]\n" + "fmla v20.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x120]\n" + "fmla v21.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x130]\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]") + "fmla v22.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x140]\n" + "fmla v23.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x150]\n" + "fmla v24.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x160]\n" + "fmla v25.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x170]\n" ASM_PREFETCH("[%[pf_ptr], #0x100]") + "add %[a_ptr], %[a_ptr], %[jump]\n" + "fmla v26.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x00]\n" + "fmla v27.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x10]\n" + "fmla v28.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x20]\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x30]\n" ASM_PREFETCH("[%[pf_ptr], #0x140]") + "fmla v30.4s, v6.4s, v0.4s\n" + "add %[pf_ptr], %[pf_ptr], %[jump]\n" + "ldr q6, [%[a_ptr], #0x40]\n" + "fmla v31.4s, v7.4s, v0.4s\n" + "cmp %[firstpf_ptr], %[pf_limit]\n" + "ldr q7, [%[a_ptr], #0x50]\n" + "blt 1b\n" + + // Check that there are still "main" prefetches to do. + "cmp %[pf_ptr], %[pf_limit]\n" + "bge 9f\n" + + // Just the main prefetches, exit this loop when pf_ptr hits pf_limit. + "8:\n" + "dup v0.4s, w0\n" + "ldr w0, [%[x_ptr], #4]\n" + "add %[x_ptr], %[x_ptr], #0x4\n" + "fmla v8.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x60]\n" + "fmla v9.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x70]\n" + "fmla v10.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x80]\n" + "fmla v11.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x90]\n" + "sub %w[k], %w[k], #1\n" ASM_PREFETCH("[%[x_ptr], #128]") + "fmla v12.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0xa0]\n" + "fmla v13.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0xb0]\n" ASM_PREFETCH("[%[pf_ptr], #0x40]") + "fmla v14.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0xc0]\n" + "fmla v15.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0xd0]\n" + "fmla v16.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0xe0]\n" + "fmla v17.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0xf0]\n" ASM_PREFETCH("[%[pf_ptr], #0x80]") + "fmla v18.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x100]\n" + "fmla v19.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x110]\n" + "fmla v20.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x120]\n" + "fmla v21.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x130]\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]") + "fmla v22.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x140]\n" + "fmla v23.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x150]\n" + "fmla v24.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x160]\n" + "fmla v25.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x170]\n" ASM_PREFETCH("[%[pf_ptr], #0x100]") + "add %[a_ptr], %[a_ptr], %[jump]\n" + "fmla v26.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x00]\n" + "fmla v27.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x10]\n" + "fmla v28.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x20]\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x30]\n" ASM_PREFETCH("[%[pf_ptr], #0x140]") + "fmla v30.4s, v6.4s, v0.4s\n" + "add %[pf_ptr], %[pf_ptr], %[jump]\n" + "ldr q6, [%[a_ptr], #0x40]\n" + "fmla v31.4s, v7.4s, v0.4s\n" + "cmp %[pf_ptr], %[pf_limit]\n" + "ldr q7, [%[a_ptr], #0x50]\n" + "blt 8b\n" + + // Check that there is still work to do. + "9:\n" + "cmp %w[k], #0\n" + "beq 10f\n" + + // Loop without prefetches, exit when k hits 0. + "2:\n" + "dup v0.4s, w0\n" + "ldr w0, [%[x_ptr], #4]\n" + "add %[x_ptr], %[x_ptr], #0x4\n" + "fmla v8.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x60]\n" + "fmla v9.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x70]\n" + "fmla v10.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x80]\n" + "fmla v11.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x90]\n" + "subs %w[k], %w[k], #1\n" + "fmla v12.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0xa0]\n" + "fmla v13.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0xb0]\n" + "fmla v14.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0xc0]\n" + "fmla v15.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0xd0]\n" + "fmla v16.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0xe0]\n" + "fmla v17.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0xf0]\n" + "fmla v18.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x100]\n" + "fmla v19.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x110]\n" + "fmla v20.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x120]\n" + "fmla v21.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x130]\n" + "fmla v22.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x140]\n" + "fmla v23.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x150]\n" + "fmla v24.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x160]\n" + "fmla v25.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x170]\n" + "add %[a_ptr], %[a_ptr], %[jump]\n" + "fmla v26.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x00]\n" + "fmla v27.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x10]\n" + "fmla v28.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x20]\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x30]\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x40]\n" + "fmla v31.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x50]\n" + "bne 2b\n" + + "10:\n" + + // Final iteration + "dup v0.4s, w0\n" + "fmla v8.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x60]\n" + "fmla v9.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x70]\n" + "fmla v10.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x80]\n" + "fmla v11.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x90]\n" + "fmla v12.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0xa0]\n" + "fmla v13.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0xb0]\n" + "fmla v14.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0xc0]\n" + "fmla v15.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0xd0]\n" + "fmla v16.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0xe0]\n" + "fmla v17.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0xf0]\n" + "fmla v18.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x100]\n" + "fmla v19.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x110]\n" + "fmla v20.4s, v2.4s, v0.4s\n" + "ldr q2, [%[a_ptr], #0x120]\n" + "fmla v21.4s, v3.4s, v0.4s\n" + "ldr q3, [%[a_ptr], #0x130]\n" + "fmla v22.4s, v4.4s, v0.4s\n" + "ldr q4, [%[a_ptr], #0x140]\n" + "fmla v23.4s, v5.4s, v0.4s\n" + "ldr q5, [%[a_ptr], #0x150]\n" + "fmla v24.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x160]\n" + "fmla v25.4s, v7.4s, v0.4s\n" + "ldr q7, [%[a_ptr], #0x170]\n" + "fmla v26.4s, v2.4s, v0.4s\n" + "ldr q2, [%[y_ptr]]\n" + "fmla v27.4s, v3.4s, v0.4s\n" + "ldr q3, [%[y_ptr], #0x10]\n" + "fmla v28.4s, v4.4s, v0.4s\n" + "ldr q4, [%[y_ptr], #0x20]\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "ldr q5, [%[y_ptr], #0x30]\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "ldr q6, [%[y_ptr], #0x40]\n" + "fmla v31.4s, v7.4s, v0.4s\n" + "ldr q7, [%[y_ptr], #0x50]\n" + + "fmla v2.4s, v8.4s, %[va].4s\n" + "ldr q8, [%[y_ptr], #0x60]\n" + "fmla v3.4s, v9.4s, %[va].4s\n" + "ldr q9, [%[y_ptr], #0x70]\n" + "fmla v4.4s, v10.4s, %[va].4s\n" + "ldr q10, [%[y_ptr], #0x80]\n" + "fmla v5.4s, v11.4s, %[va].4s\n" + "ldr q11, [%[y_ptr], #0x90]\n" + "fmla v6.4s, v12.4s, %[va].4s\n" + "ldr q12, [%[y_ptr], #0xa0]\n" + "str q2, [%[y_ptr], #0x00]\n" + "fmla v7.4s, v13.4s, %[va].4s\n" + "ldr q13, [%[y_ptr], #0xb0]\n" + "str q3, [%[y_ptr], #0x10]\n" + "fmla v8.4s, v14.4s, %[va].4s\n" + "ldr q14, [%[y_ptr], #0xc0]\n" + "str q4, [%[y_ptr], #0x20]\n" + "fmla v9.4s, v15.4s, %[va].4s\n" + "ldr q15, [%[y_ptr], #0xd0]\n" + "str q5, [%[y_ptr], #0x30]\n" + "fmla v10.4s, v16.4s, %[va].4s\n" + "ldr q16, [%[y_ptr], #0xe0]\n" + "str q6, [%[y_ptr], #0x40]\n" + "fmla v11.4s, v17.4s, %[va].4s\n" + "ldr q17, [%[y_ptr], #0xf0]\n" + "str q7, [%[y_ptr], #0x50]\n" + "fmla v12.4s, v18.4s, %[va].4s\n" + "ldr q18, [%[y_ptr], #0x100]\n" + "str q8, [%[y_ptr], #0x60]\n" + "fmla v13.4s, v19.4s, %[va].4s\n" + "ldr q19, [%[y_ptr], #0x110]\n" + "str q9, [%[y_ptr], #0x70]\n" + "fmla v14.4s, v20.4s, %[va].4s\n" + "ldr q20, [%[y_ptr], #0x120]\n" + "str q10, [%[y_ptr], #0x80]\n" + "fmla v15.4s, v21.4s, %[va].4s\n" + "ldr q21, [%[y_ptr], #0x130]\n" + "str q11, [%[y_ptr], #0x90]\n" + "fmla v16.4s, v22.4s, %[va].4s\n" + "ldr q22, [%[y_ptr], #0x140]\n" + "str q12, [%[y_ptr], #0xa0]\n" + "fmla v17.4s, v23.4s, %[va].4s\n" + "ldr q23, [%[y_ptr], #0x150]\n" + "str q13, [%[y_ptr], #0xb0]\n" + "fmla v18.4s, v24.4s, %[va].4s\n" + "ldr q24, [%[y_ptr], #0x160]\n" + "str q14, [%[y_ptr], #0xc0]\n" + "fmla v19.4s, v25.4s, %[va].4s\n" + "ldr q25, [%[y_ptr], #0x170]\n" + "str q15, [%[y_ptr], #0xd0]\n" + "fmla v20.4s, v26.4s, %[va].4s\n" + "str q16, [%[y_ptr], #0xe0]\n" + "fmla v21.4s, v27.4s, %[va].4s\n" + "str q17, [%[y_ptr], #0xf0]\n" + "fmla v22.4s, v28.4s, %[va].4s\n" + "str q18, [%[y_ptr], #0x100]\n" + "fmla v23.4s, v29.4s, %[va].4s\n" + "str q19, [%[y_ptr], #0x110]\n" + "fmla v24.4s, v30.4s, %[va].4s\n" + "str q20, [%[y_ptr], #0x120]\n" + "fmla v25.4s, v31.4s, %[va].4s\n" + "str q21, [%[y_ptr], #0x130]\n" + + "stp q22, q23, [%[y_ptr], #0x140]\n" + "stp q24, q25, [%[y_ptr], #0x160]\n" + "add %[y_ptr], %[y_ptr], #0x180\n" + + : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k), [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr) + : [jump] "r"(jump), [va] "w"(va), [pf_limit] "r"(pf_limit) + : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "cc"); + } + + if(N > 0) + { + // Handle N tail - up to 95 stragglers. + // This is 0-23 vectors, plus optionally an 64-bit vector and/or a + // single value for the remainder. + + // Independent pointers into the matrix for the odd 2 and odd 1. + // Double up as flag to indicate whether they are needed. + const float *odd2_aptr = NULL; + const float *odd1_aptr = NULL; + + // Figure out how much work we need to do. + int numvecs = N / 4; + int rem = N % 4; + int k = M; + + // Set up pointers for the odd 2/1 if needed. + if(rem >= 2) + { + odd2_aptr = a_ptr_base + (numvecs * 4); + } + + if(rem & 1) + { + odd1_aptr = a_ptr_base + (numvecs * 4) + (odd2_aptr == NULL ? 0 : 2); + } + + const float *a_ptr = a_ptr_base; + const float *firstpf_ptr = a_ptr_base; + const float *pf_ptr = a_ptr_base; + const float *pf_limit = a_ptr + (M * lda); + + const float *x_ptr = Xstart; + int vecs = 0; // Working variable to count how many vectors to work on. + int dopf = 1; // Track whether we are doing prefetches. + + // Figure out how many cache lines we need to prefetch each time. + int numpfs = (N + 15) / 16; + + // Do initial prefetches + for(int i = 0; i < firstpfd + 1; i++) + { + prefetch_1x(firstpf_ptr); + firstpf_ptr += lda; + } + + // Do "main" prefetches - adapt number to the number we actually need. + if(numpfs > 1) + { + for(int i = 0; i < pfd + 1; i++) + { + switch(numpfs) + { + case 2: + prefetch_1x(pf_ptr + 16); + break; + + case 3: + prefetch_2x(pf_ptr + 16); + break; + + case 4: + prefetch_3x(pf_ptr + 16); + break; + + case 5: + prefetch_4x(pf_ptr + 16); + break; + + case 6: + prefetch_5x(pf_ptr + 16); + break; + + default: + UNREACHABLE("Impossible."); + } + pf_ptr += lda; + } + } + else + { + // Just disable additional prefetches + dopf = 0; + } + + // Do the real work + __asm __volatile( + // Initialize all the vectors - not worth skipping this if only + // some are needed. + "movi v8.4s,#0x0\n" + "ldr w0, [%[x_ptr]]\n" + "movi v9.4s,#0x0\n" + "movi v10.4s,#0x0\n" + "movi v11.4s,#0x0\n" + "movi v12.4s,#0x0\n" + "movi v13.4s,#0x0\n" + "movi v14.4s,#0x0\n" + "movi v15.4s,#0x0\n" + "movi v16.4s, #0x0\n" + "movi v17.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v6.2s, #0x0\n" + "movi v5.2s, #0x0\n" + + "1:\n" ASM_PREFETCH("[%[firstpf_ptr]]\n") + "11:\n" + "dup v0.4s, w0\n" + "ldr w0, [%[x_ptr], #4]\n" + "add %[x_ptr], %[x_ptr], #4\n" + + "cbz %w[numvecs], 2f\n" + "mov %w[vecs], %w[numvecs]\n" + + // Vector 0 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x00]\n" + "fmla v8.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 1 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x10]\n" + "fmla v9.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 2 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x20]\n" + "fmla v10.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 3 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x30]\n" + "fmla v11.4s, v7.4s, v0.4s\n" + // Prefetch + "cbz %w[dopf], 3f\n" ASM_PREFETCH("[%[pf_ptr], #0x40]") + "3:\n" + "beq 2f\n" + + // Vector 4 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x40]\n" + "fmla v12.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 5 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x50]\n" + "fmla v13.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 6 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x60]\n" + "fmla v14.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 7 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x70]\n" + "fmla v15.4s, v7.4s, v0.4s\n" + // Prefetch + "cbz %w[dopf], 4f\n" ASM_PREFETCH("[%[pf_ptr], #0x80]") + "4:\n" + "beq 2f\n" + + // Vector 8 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x80]\n" + "fmla v16.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 9 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x90]\n" + "fmla v17.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 10 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xa0]\n" + "fmla v18.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 11 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xb0]\n" + "fmla v19.4s, v7.4s, v0.4s\n" + // Prefetch + "cbz %w[dopf], 5f\n" ASM_PREFETCH("[%[pf_ptr], #0xc0]") + "5:\n" + "beq 2f\n" + + // Vector 12 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xc0]\n" + "fmla v20.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 13 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xd0]\n" + "fmla v21.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 14 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xe0]\n" + "fmla v22.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 15 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0xf0]\n" + "fmla v23.4s, v7.4s, v0.4s\n" + // Prefetch + "cbz %w[dopf], 6f\n" ASM_PREFETCH("[%[pf_ptr], #0x100]") + "6:\n" + "beq 2f\n" + + // Vector 16 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x100]\n" + "fmla v24.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 17 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x110]\n" + "fmla v25.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 18 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x120]\n" + "fmla v26.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 19 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x130]\n" + "fmla v27.4s, v7.4s, v0.4s\n" + // Prefetch + "cbz %w[dopf], 7f\n" ASM_PREFETCH("[%[pf_ptr], #0x140]") + "7:\n" + "beq 2f\n" + + // Vector 20 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x140]\n" + "fmla v28.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 21 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x150]\n" + "fmla v29.4s, v7.4s, v0.4s\n" + "beq 2f\n" + // Vector 22 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7,[%[a_ptr], #0x160]\n" + "fmla v30.4s, v7.4s, v0.4s\n" + + "2:\n" + "add %[a_ptr], %[a_ptr], %[jump]\n" + + // Do the odd 2-vector, if needed + "cbz %[odd2_aptr], 8f\n" + "ldr d7, [%[odd2_aptr]]\n" + "fmla v6.2s, v7.2s, v0.2s\n" + "add %[odd2_aptr], %[odd2_aptr], %[jump]\n" + + "8:\n" + // Do the odd 1-vector, if needed + "cbz %[odd1_aptr], 9f\n" + "ldr s7, [%[odd1_aptr]]\n" + "fmla v5.2s, v7.2s, v0.2s\n" + "add %[odd1_aptr], %[odd1_aptr], %[jump]\n" + + // Get out if needed. + "9:\n" + "subs %w[k], %w[k], #1\n" + "beq 10f\n" + + // Update the "main" prefetch pointer, if it strays beyond the limit turn off "dopf" + "add %[pf_ptr], %[pf_ptr], %[jump]\n" + "cmp %[pf_ptr], %[pf_limit]\n" + "csel %w[dopf], %w[dopf], WZR, LT\n" + + // Update the "leading" prefetch pointer, don't do the first + // instruction of the loop if it's over the limit. + "add %[firstpf_ptr], %[firstpf_ptr], %[jump]\n" + "cmp %[firstpf_ptr], %[pf_limit]\n" + "blt 1b\n" + "b 11b\n" + + // Now write out the outputs + "10:\n" + "cbz %w[numvecs], 12f\n" + "mov %w[vecs], %w[numvecs]\n" + + // Vector 0 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v8.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 1 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v9.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 2 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v10.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 3 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v11.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 4 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v12.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 5 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v13.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 6 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v14.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 7 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v15.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 8 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v16.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 9 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v17.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 10 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v18.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 11 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v19.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 12 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v20.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 13 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v21.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 14 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v22.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 15 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v23.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 16 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v24.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 17 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v25.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 18 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v26.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 19 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v27.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 20 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v28.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 21 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v29.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + "beq 12f\n" + // Vector 22 + "subs %w[vecs], %w[vecs], #1\n" + "ldr q7, [%[y_ptr]]\n" + "fmla v7.4s, v30.4s, %[va].4s\n" + "str q7, [%[y_ptr]], #0x10\n" + + // Odd 2 + "12:\n" + "cbz %[odd2_aptr], 13f\n" + "ldr d7, [%[y_ptr]]\n" + "fmla v7.2s, v6.2s, %[va].2s\n" + "str d7, [%[y_ptr]], #0x8\n" + + // Odd 1 + "13:\n" + "cbz %[odd1_aptr], 14f\n" + "ldr s7, [%[y_ptr]]\n" + "fmla v7.2s, v5.2s, %[va].2s\n" + "str s7, [%[y_ptr]]\n" + + "14:\n" + : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k), + [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr), + [odd1_aptr] "+r"(odd1_aptr), [odd2_aptr] "+r"(odd2_aptr), + [dopf] "+r"(dopf), [vecs] "+r"(vecs) + : [jump] "r"(jump), [va] "w"(va), [pf_limit] "r"(pf_limit), [numvecs] "r"(numvecs) + : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "cc"); + } +} + +} // namespace arm_gemm + +#endif // __aarch64__ |