diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/transforms')
85 files changed, 18955 insertions, 5291 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp deleted file mode 100644 index 543664bb0e..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright (c) 2017-2019 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __arm__ - -#include <arm_neon.h> - -#include "../asmlib.hpp" - -template<> -template<typename T> -inline void TransformImpl<6, 1, false, 4, 4, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) { - uint32_t *outptr = reinterpret_cast<uint32_t *>(out); - const uint32_t *inptr = reinterpret_cast<const uint32_t *>(in); - bool first = true; - - uint32_t zerobuff[16] = { 0 }; // 8 for asm loop plus up to 7 for overflow loop - - for (int y=y0; y<ymax; y+=6) { - const uint32_t *inptr0 = inptr + y * ldin + k0; - const uint32_t *inptr1 = inptr0 + ldin; - const uint32_t *inptr2 = inptr1 + ldin; - const uint32_t *inptr3 = inptr2 + ldin; - const uint32_t *inptr4 = inptr3 + ldin; - const uint32_t *inptr5 = inptr4 + ldin; - - //prefetch_2x(inptr0); - //prefetch_2x(inptr1); - //prefetch_2x(inptr2); - //prefetch_2x(inptr3); - //prefetch_2x(inptr4); - //prefetch_2x(inptr5); - - int x=(kmax-k0); - for (;(x>7) || first;x-=8) { - /* Cope with ragged cases by copying from a buffer of zeroes instead */ - /* 'first' forces this to always run at least once, needed if the total size is <=7. */ - if ((y + 5) >= ymax) { - switch ((y + 5) - ymax) { - /* Everything falls through in here */ - case 4: - inptr1 = zerobuff; - // fall through - case 3: - inptr2 = zerobuff; - // fall through - case 2: - inptr3 = zerobuff; - // fall through - case 1: - inptr4 = zerobuff; - // fall through - case 0: - inptr5 = zerobuff; - break; - - default: - UNREACHABLE("Impossible."); - } - } - - if (first) { - if (x<=7) { - break; - } - - first = false; - } - - __asm __volatile ( - // Load up 8 elements (2 vectors) from each of 8 sources. - "VLD1.32 {d0-d3}, [%[inptr0]]!\n" // q0=A0A1A2A3 - "VLD1.32 {d4-d7}, [%[inptr1]]!\n" // q2=B0B1B2B3 - "VLD1.32 {d8-d11}, [%[inptr2]]!\n" // q4=C0C1C2C3 - "VZIP.32 q0, q4\n" // q0=A0C0A1C1, q4 = A2C2A3C3 - "VLD1.32 {d12-d15}, [%[inptr3]]!\n" // q6=D0D1D2D3 - "VZIP.32 q2, q6\n" // q2=B0D0B1D1, q6 = B2D2B3D3 - "VLD1.32 {d16-d19}, [%[inptr4]]!\n" - "VLD1.32 {d20-d23}, [%[inptr5]]!\n" - "VZIP.32 q8, q10\n" // q8=E0F0E1F1, q10 = E2F2E3F3 - ASM_PREFETCH("[%[inptr0], #128]") - "VZIP.32 q0, q2\n" // q0 = A0B0C0D0, q2 = A1B1C1D1 - - // Store first elements - "VST1.32 {d0-d1}, [%[outptr]]!\n" - "VST1.32 {d16}, [%[outptr]]!\n" - - "VZIP.32 q4, q6\n" // q4 = A2B2C2D2, q6 = A3B3C3D3 - - // Store second elements - "VST1.32 {d4-d5}, [%[outptr]]!\n" - "VZIP.32 q1, q5\n" - ASM_PREFETCH("[%[inptr1], #128]") - "VST1.32 {d17}, [%[outptr]]!\n" - "VZIP.32 q3, q7\n" - - // Store third elements - "VZIP.32 q9, q11\n" - "VST1.32 {d8-d9}, [%[outptr]]!\n" - "VZIP.32 q1, q3\n" - ASM_PREFETCH("[%[inptr2], #128]") - "VST1.32 {d20}, [%[outptr]]!\n" - - // Store fourth elements - "VZIP.32 q5, q7\n" - "VST1.32 {d12-d13}, [%[outptr]]!\n" - ASM_PREFETCH("[%[inptr3], #128]") - "VST1.32 {d21}, [%[outptr]]!\n" - - // Fifth - "VST1.32 {d2-d3}, [%[outptr]]!\n" - ASM_PREFETCH("[%[inptr4], #128]") - "VST1.32 {d18}, [%[outptr]]!\n" - - // Sixth - "VST1.32 {d6-d7}, [%[outptr]]!\n" - ASM_PREFETCH("[%[inptr5], #128]") - "VST1.32 {d19}, [%[outptr]]!\n" - - // Seventh - "VST1.32 {d10-d11}, [%[outptr]]!\n" - "VST1.32 {d22}, [%[outptr]]!\n" - - // Eighth - "VST1.32 {d14-d15}, [%[outptr]]!\n" - "VST1.32 {d23}, [%[outptr]]!\n" - - : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), - [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [outptr] "+r" (outptr) - : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "memory" - ); - } - - for (;x>0;x--) { - *outptr++ = *inptr0++; - *outptr++ = *inptr1++; - *outptr++ = *inptr2++; - *outptr++ = *inptr3++; - *outptr++ = *inptr4++; - *outptr++ = *inptr5++; - } - } -} - -#endif // __arm__ diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp index 587bec366a..b50c240a3a 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -30,22 +30,22 @@ // Generic unblocked transposed 8x32-bit sized specialisation template <> template <typename T> -inline void TransformImpl<8, 1, true, 4, 4, false>::Transform( +void TransformImpl<8, 1, true, 4, 4, VLType::None>::Transform( T* out, const T* const in, const int stride, const int x0, const int xmax, const int k0, const int kmax ) { // Redirect to a 16x uint16_t specialisation - TransformImpl<16, 1, true, 2, 2, false>::Transform( + TransformImpl<16, 1, true, 2, 2, VLType::None>::Transform( reinterpret_cast<uint16_t *>(out), reinterpret_cast<const uint16_t *>(in), stride*2, x0*2, xmax*2, k0, kmax ); } -// Generic 12x16-bit sized specialisation +// Generic 16x16-bit sized specialisation template <> template <typename T> -inline void TransformImpl<16, 1, true, 2, 2, false>::Transform( +void TransformImpl<16, 1, true, 2, 2, VLType::None>::Transform( T* out, const T* const in, const int stride, const int x0, const int xmax, const int k0, const int kmax ) { @@ -59,7 +59,7 @@ inline void TransformImpl<16, 1, true, 2, 2, false>::Transform( // Specialised 16 x uint16_t version template <> -inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) { +void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) { __asm volatile ( "VLD1.32 {d0-d3}, [%[in0]]!\n" "VST1.32 {d0-d3}, [%[out]]\n" @@ -72,7 +72,7 @@ inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x1(con } template <> -inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out) { +void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out) { __asm volatile ( "VLD1.32 {d0-d3}, [%[in0]]!\n" "VST1.32 {d0-d3}, [%[out]]!\n" @@ -90,7 +90,7 @@ inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x2(con } template <> -inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) { +void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) { __asm __volatile ( "VLD1.32 {d0-d3}, [%[in0]]!\n" "VST1.32 {d0-d3}, [%[out]]!\n" @@ -117,7 +117,7 @@ inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x4(con template <> template <> -inline void TransformImpl<16, 1, true, 2, 2, false>::Transform( +void TransformImpl<16, 1, true, 2, 2, VLType::None>::Transform( uint16_t* out, const uint16_t* const in, const int stride, const int x0, const int xmax, const int k0, const int kmax ) { diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp deleted file mode 100644 index 6b742c8776..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright (c) 2017-2018 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ - -#include <arm_neon.h> - -#include "../asmlib.hpp" -#include "../utils.hpp" - -template<> -template<typename T> -void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) { - uint8_t *outptr = (uint8_t *)out; - const uint8_t *inptr = (uint8_t *)in; - - uint8_t zerobuff[16] = { 0 }; - - for (uint64_t y = y0 ; y < static_cast<uint64_t>(ymax) ; y+=4) { - const uint8_t *inptr0 = inptr + y * ldin + k0; - const uint8_t *inptr1 = inptr0 + ldin; - const uint8_t *inptr2 = inptr1 + ldin; - const uint8_t *inptr3 = inptr2 + ldin; - - prefetch_2x(inptr0); - prefetch_2x(inptr1); - prefetch_2x(inptr2); - prefetch_2x(inptr3); - - int x=(kmax-k0); - for (;x>15;x-=16) { - /* Cope with ragged cases by copying from a buffer of zeroes instead */ - if ((y + 3) >= static_cast<uint64_t>(ymax)) { - switch ((y + 3) - ymax) { - /* Everything falls through in here */ - case 2: - inptr1 = zerobuff; - // fall through - case 1: - inptr2 = zerobuff; - // fall through - case 0: - inptr3 = zerobuff; - break; - - default: - UNREACHABLE("Impossible."); - } - } - - __asm __volatile ( - "LDR q0, [%[inptr0]], #16\n" - ASM_PREFETCH("[%[inptr0], #176]") - "LDR q1, [%[inptr1]], #16\n" - ASM_PREFETCH("[%[inptr1], #176]") - "STP q0, q1, [%[outptr]], #32\n" - "LDR q0, [%[inptr2]], #16\n" - ASM_PREFETCH("[%[inptr2], #176]") - "LDR q1, [%[inptr3]], #16\n" - ASM_PREFETCH("[%[inptr3], #176]") - "STP q0, q1, [%[outptr]], #32\n" - : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), - [outptr] "+r" (outptr) - : - : "v0", "v1" - ); - } - - if (x>0) { - /* Need to duplicate this here, in case we didn't run the main loop. */ - if ((y + 3) >= static_cast<uint64_t>(ymax)) { - switch ((y + 3) - ymax) { - /* Everything falls through in here */ - case 2: - inptr1 = zerobuff; - // fall through - case 1: - inptr2 = zerobuff; - // fall through - case 0: - inptr3 = zerobuff; - break; - - default: - UNREACHABLE("Impossible."); - } - } - - /* We have to write out 16 values, copy as many legal values as there are and pad with 0 */ - auto f = [&outptr, x](const uint8_t *&p) { - for (int i=0; i<16; i++) { - if (i < x) { - *outptr++ = *p++; - } else { - *outptr++ = 0; - } - } - }; - - f(inptr0); - f(inptr1); - f(inptr2); - f(inptr3); - } - } -} - -#endif // __aarch64__
\ No newline at end of file diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp deleted file mode 100644 index 80dd6c5e25..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Copyright (c) 2017-2019 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ - -#include <arm_neon.h> - -#include "../asmlib.hpp" - -template<> -template<typename T> -void TransformImpl<8, 1, false, 2, 2, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) { - uint16_t *outptr = (uint16_t *)out; - const uint16_t *inptr = (const uint16_t *)in; - bool first=true; - - uint16_t zerobuff[16] = { 0 }; // 8 for asm loop plus up to 7 for overflow loop - - for (int y=y0; y<ymax; y+=8) { - const uint16_t *inptr0 = inptr + y * ldin + k0; - const uint16_t *inptr1 = inptr0 + ldin; - const uint16_t *inptr2 = inptr1 + ldin; - const uint16_t *inptr3 = inptr2 + ldin; - const uint16_t *inptr4 = inptr3 + ldin; - const uint16_t *inptr5 = inptr4 + ldin; - const uint16_t *inptr6 = inptr5 + ldin; - const uint16_t *inptr7 = inptr6 + ldin; - - prefetch_2x(inptr0); - prefetch_2x(inptr1); - prefetch_2x(inptr2); - prefetch_2x(inptr3); - prefetch_2x(inptr4); - prefetch_2x(inptr5); - prefetch_2x(inptr6); - prefetch_2x(inptr7); - - int x=(kmax-k0); - for (;(x>7) || first;x-=8) { - /* Cope with ragged cases by copying from a buffer of zeroes instead */ - /* 'first' forces this to always run at least once, needed if the total size is <=7. */ - if ((y + 7) >= ymax) { - switch ((y + 7) - ymax) { - /* Everything falls through in here */ - case 6: - inptr1 = zerobuff; - // fall through - case 5: - inptr2 = zerobuff; - // fall through - case 4: - inptr3 = zerobuff; - // fall through - case 3: - inptr4 = zerobuff; - // fall through - case 2: - inptr5 = zerobuff; - // fall through - case 1: - inptr6 = zerobuff; - // fall through - case 0: - inptr7 = zerobuff; - break; - - default: - UNREACHABLE("Impossible."); - } - } - - if (first) { - if (x <= 7) { - break; - } - - first = false; - } - - int skippf = (x & 31); - __asm __volatile ( - // Load up 8 elements (1 vector) from each of 8 sources. - "CBNZ %w[skippf], 1f\n" - ASM_PREFETCH("[%[inptr0], #128]") - ASM_PREFETCH("[%[inptr1], #128]") - ASM_PREFETCH("[%[inptr2], #128]") - ASM_PREFETCH("[%[inptr3], #128]") - "1:\n" - - "LDR q0, [%[inptr0]], #16\n" // q0=A0A1A2A3A4A5A6A7 - "LDR q4, [%[inptr4]], #16\n" // q8=E0E1E2E3E4E5E6E7 - "LDR q2, [%[inptr2]], #16\n" // q4=C0C1C2C3... - "LDR q6, [%[inptr6]], #16\n" - "ZIP1 v8.8h, v0.8h, v4.8h\n" // q8=A0E0A1E1A2E2A3E3 - "ZIP2 v16.8h, v0.8h, v4.8h\n" // q16=A4E4A5E5A6E6A7E7 - "ZIP1 v9.8h, v2.8h, v6.8h\n" // q9=C0G0C1G1C2G2C3G3 - "ZIP2 v17.8h, v2.8h, v6.8h\n" // q17=C4G4C5G5C6G6C7G7 - "LDR q1, [%[inptr1]], #16\n" // q1=B0B1B2B3B4B5B6B7 - "LDR q5, [%[inptr5]], #16\n" - "LDR q3, [%[inptr3]], #16\n" // q3=D0D1D2D3.... - "LDR q7, [%[inptr7]], #16\n" - "ZIP1 v10.8h, v1.8h, v5.8h\n" // q18=B0F0B1F1B2F2B3F3 - "ZIP2 v18.8h, v1.8h, v5.8h\n" // q18=B4F4B5F5B6F6B7F7 - "ZIP1 v11.8h, v3.8h, v7.8h\n" // q19=D0H0D1H1D2H2D3H3 - "ZIP2 v19.8h, v3.8h, v7.8h\n" // q19=D4H4D5H5D6H6D7H7 - - "ZIP1 v12.8h, v8.8h, v9.8h\n" // q20=A0C0E0G0A1C1E1G1 - "ZIP2 v20.8h, v8.8h, v9.8h\n" - "ZIP1 v13.8h, v10.8h, v11.8h\n" // q21=B0D0F0H0B1I1F1H1 - "ZIP2 v21.8h, v10.8h, v11.8h\n" - - "CBNZ %w[skippf], 2f\n" - ASM_PREFETCH("[%[inptr4], #112]") - ASM_PREFETCH("[%[inptr5], #112]") - ASM_PREFETCH("[%[inptr6], #112]") - ASM_PREFETCH("[%[inptr7], #112]") - "2:\n" - - "ZIP1 v22.8h, v16.8h, v17.8h\n" - "ZIP2 v30.8h, v16.8h, v17.8h\n" - "ZIP1 v23.8h, v18.8h, v19.8h\n" - "ZIP2 v31.8h, v18.8h, v19.8h\n" - - "ZIP1 v14.8h, v12.8h, v13.8h\n" // q22=A0B0C0D0E0F0G0H0 - "ZIP2 v15.8h, v12.8h, v13.8h\n" // q23=A1B1C1D1E1F1G1H1 - "STP q14, q15, [%[outptr]], #32\n" // Write back first two elements - - "ZIP1 v0.8h, v20.8h, v21.8h\n" - "ZIP2 v1.8h, v20.8h, v21.8h\n" - "STP q0, q1, [%[outptr]], #32\n" // Write back next two elements - - "ZIP1 v2.8h, v22.8h, v23.8h\n" - "ZIP2 v3.8h, v22.8h, v23.8h\n" - "STP q2, q3, [%[outptr]], #32\n" // Write back next two elements - - "ZIP1 v4.8h, v30.8h, v31.8h\n" - "ZIP2 v5.8h, v30.8h, v31.8h\n" - "STP q4, q5, [%[outptr]], #32\n" // Write back last two elements - : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), - [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr) - : [skippf] "r" (skippf) - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", - "v25", "v26", "v27", "v28", "v29", "v30", "v31", "memory" - ); - } - - for (;x>0;x--) { - *outptr++ = *inptr0++; - *outptr++ = *inptr1++; - *outptr++ = *inptr2++; - *outptr++ = *inptr3++; - *outptr++ = *inptr4++; - *outptr++ = *inptr5++; - *outptr++ = *inptr6++; - *outptr++ = *inptr7++; - } - } -} - -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp deleted file mode 100644 index 9dfc1346e6..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (c) 2017-2019 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#if defined(__aarch64__) && !defined(__ARM_FEATURE_SVE) - -#include <arm_neon.h> - -#include "../asmlib.hpp" - -template<> -template<typename T> -inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) { - uint32_t *outptr = (uint32_t *)out; - const uint32_t *inptr = (uint32_t *)in; - bool first = true; - - uint32_t zerobuff[16] = { 0 }; // 8 for asm loop plus up to 7 for overflow loop - - for (int y=y0; y<ymax; y+=8) { - const uint32_t *inptr0 = inptr + y * ldin + k0; - const uint32_t *inptr1 = inptr0 + ldin; - const uint32_t *inptr2 = inptr1 + ldin; - const uint32_t *inptr3 = inptr2 + ldin; - const uint32_t *inptr4 = inptr3 + ldin; - const uint32_t *inptr5 = inptr4 + ldin; - const uint32_t *inptr6 = inptr5 + ldin; - const uint32_t *inptr7 = inptr6 + ldin; - - prefetch_2x(inptr0); - prefetch_2x(inptr1); - prefetch_2x(inptr2); - prefetch_2x(inptr3); - prefetch_2x(inptr4); - prefetch_2x(inptr5); - prefetch_2x(inptr6); - prefetch_2x(inptr7); - - int x=(kmax-k0); - for (;(x>7) || first;x-=8) { - /* Cope with ragged cases by copying from a buffer of zeroes instead */ - /* 'first' forces this to always run at least once, needed if the total size is <=7. */ - if ((y + 7) >= ymax) { - switch ((y + 7) - ymax) { - /* Everything falls through in here */ - case 6: - inptr1 = zerobuff; - // fall through - case 5: - inptr2 = zerobuff; - // fall through - case 4: - inptr3 = zerobuff; - // fall through - case 3: - inptr4 = zerobuff; - // fall through - case 2: - inptr5 = zerobuff; - // fall through - case 1: - inptr6 = zerobuff; - // fall through - case 0: - inptr7 = zerobuff; - break; - - default: - UNREACHABLE("Impossible."); - } - } - - if (first) { - if (x<=7) { - break; - } - - first = false; - } - - __asm __volatile ( - // Load up 8 elements (2 vectors) from each of 8 sources. - "LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3 - "LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3 - "LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3 - "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 - ASM_PREFETCH("[%[inptr0], #128]") - "LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3 - "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 - "LDP q8, q9, [%[inptr4]], #32\n" - "LDP q10, q11, [%[inptr5]], #32\n" - "LDP q12, q13, [%[inptr6]], #32\n" - "ZIP1 v18.4s, v8.4s, v12.4s\n" - ASM_PREFETCH("[%[inptr1], #128]") - "LDP q14, q15, [%[inptr7]], #32\n" - "ZIP1 v19.4s, v10.4s, v14.4s\n" - - "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 - ASM_PREFETCH("[%[inptr2], #128]") - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - - "ZIP2 v16.4s, v0.4s, v4.4s\n" - ASM_PREFETCH("[%[inptr3], #128]") - "ZIP2 v17.4s, v2.4s, v6.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Write back the first element of each source - - "ZIP2 v18.4s, v8.4s, v12.4s\n" - "ZIP2 v19.4s, v10.4s, v14.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Write back the second element of each source - - "ZIP1 v20.4s, v16.4s, v17.4s\n" - ASM_PREFETCH("[%[inptr4], #128]") - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - - "ZIP1 v16.4s, v1.4s, v5.4s\n" - ASM_PREFETCH("[%[inptr5], #128]") - "ZIP1 v17.4s, v3.4s, v7.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Third element - - "ZIP1 v18.4s, v9.4s, v13.4s\n" - "ZIP1 v19.4s, v11.4s, v15.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Fourth element - - "ZIP1 v20.4s, v16.4s, v17.4s\n" - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" - ASM_PREFETCH("[%[inptr6], #128]") - "ZIP2 v23.4s, v18.4s, v19.4s\n" - - "ZIP2 v16.4s, v1.4s, v5.4s\n" - "ZIP2 v17.4s, v3.4s, v7.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Fifth element - - "ZIP2 v18.4s, v9.4s, v13.4s\n" - ASM_PREFETCH("[%[inptr7], #128]") - "ZIP2 v19.4s, v11.4s, v15.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Sixth element - - "ZIP1 v20.4s, v16.4s, v17.4s\n" - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Seventh element - - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Eighth element - : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), - [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "memory" - ); - } - - for (;x>0;x--) { - *outptr++ = *inptr0++; - *outptr++ = *inptr1++; - *outptr++ = *inptr2++; - *outptr++ = *inptr3++; - *outptr++ = *inptr4++; - *outptr++ = *inptr5++; - *outptr++ = *inptr6++; - *outptr++ = *inptr7++; - } - } -} - -#endif // __aarch64__ && !__ARM_FEATURE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_block4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_block4_8bit.hpp deleted file mode 100644 index 2bc7801b15..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_block4_8bit.hpp +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Copyright (c) 2017-2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#if defined(__aarch64__) && !defined(__ARM_FEATURE_SVE) - -#include <arm_neon.h> - -#include "../asmlib.hpp" - -template<> -template<typename T> -inline void TransformImpl<8, 4, false, 1, 1, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) { - uint8_t *outptr = reinterpret_cast<uint8_t *>(out); - const uint8_t *inptr = reinterpret_cast<const uint8_t *>(in); - bool first = true; - - /* Helper functions to copy blocks about used for odd case. */ - class t { - public: - static inline void copy_4_inc(uint8_t *&out, const uint8_t *&in) { - uint32_t *out_word = reinterpret_cast<uint32_t *>(out); - const uint32_t *in_word = reinterpret_cast<const uint32_t *>(in); - - *out_word++ = *in_word++; - - out = reinterpret_cast<uint8_t *>(out_word); - in = reinterpret_cast<const uint8_t *>(in_word); - } - - static inline void copy_pad(uint8_t *&out, const uint8_t *&in, size_t count) { - for (unsigned int i=0; i<4; i++) { - if (i < count) { - *out++ = *in++; - } else { - *out++ = 0; - } - } - } - }; - - uint8_t zerobuff[64]; // 32 for asm loop plus up to 31 for overflow loop - - for (int y=y0; y<ymax; y+=8) { - const uint8_t *inptr0 = inptr + y * ldin + k0; - const uint8_t *inptr1 = inptr0 + ldin; - const uint8_t *inptr2 = inptr1 + ldin; - const uint8_t *inptr3 = inptr2 + ldin; - const uint8_t *inptr4 = inptr3 + ldin; - const uint8_t *inptr5 = inptr4 + ldin; - const uint8_t *inptr6 = inptr5 + ldin; - const uint8_t *inptr7 = inptr6 + ldin; - - prefetch_2x(inptr0); - prefetch_2x(inptr1); - prefetch_2x(inptr2); - prefetch_2x(inptr3); - prefetch_2x(inptr4); - prefetch_2x(inptr5); - prefetch_2x(inptr6); - prefetch_2x(inptr7); - - int x=(kmax-k0); - for (;(x>31) || first;x-=32) { - /* Cope with ragged cases by copying from a buffer of zeroes instead */ - /* 'first' forces this to always run at least once, needed if the total size is <=32. */ - if ((y + 7) >= ymax) { - switch ((y + 7) - ymax) { - /* Everything falls through in here */ - case 6: - inptr1 = zerobuff; - // fall through - case 5: - inptr2 = zerobuff; - // fall through - case 4: - inptr3 = zerobuff; - // fall through - case 3: - inptr4 = zerobuff; - // fall through - case 2: - inptr5 = zerobuff; - // fall through - case 1: - inptr6 = zerobuff; - // fall through - case 0: - inptr7 = zerobuff; - break; - - default: - UNREACHABLE("Impossible."); - } - } - - if (first) { - if (x<=31) { - break; - } - - first = false; - } - - __asm __volatile ( - // Load up 8 elements (2 vectors) from each of 8 sources. - "LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3 - "LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3 - "LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3 - "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 - ASM_PREFETCH("[%[inptr0], #128]") - "LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3 - "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 - "LDP q8, q9, [%[inptr4]], #32\n" - "LDP q10, q11, [%[inptr5]], #32\n" - "LDP q12, q13, [%[inptr6]], #32\n" - "ZIP1 v18.4s, v8.4s, v12.4s\n" - ASM_PREFETCH("[%[inptr1], #128]") - "LDP q14, q15, [%[inptr7]], #32\n" - "ZIP1 v19.4s, v10.4s, v14.4s\n" - - "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 - ASM_PREFETCH("[%[inptr2], #128]") - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - - "ZIP2 v16.4s, v0.4s, v4.4s\n" - ASM_PREFETCH("[%[inptr3], #128]") - "ZIP2 v17.4s, v2.4s, v6.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Write back the first element of each source - - "ZIP2 v18.4s, v8.4s, v12.4s\n" - "ZIP2 v19.4s, v10.4s, v14.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Write back the second element of each source - - "ZIP1 v20.4s, v16.4s, v17.4s\n" - ASM_PREFETCH("[%[inptr4], #128]") - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - - "ZIP1 v16.4s, v1.4s, v5.4s\n" - ASM_PREFETCH("[%[inptr5], #128]") - "ZIP1 v17.4s, v3.4s, v7.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Third element - - "ZIP1 v18.4s, v9.4s, v13.4s\n" - "ZIP1 v19.4s, v11.4s, v15.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Fourth element - - "ZIP1 v20.4s, v16.4s, v17.4s\n" - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" - ASM_PREFETCH("[%[inptr6], #128]") - "ZIP2 v23.4s, v18.4s, v19.4s\n" - - "ZIP2 v16.4s, v1.4s, v5.4s\n" - "ZIP2 v17.4s, v3.4s, v7.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Fifth element - - "ZIP2 v18.4s, v9.4s, v13.4s\n" - ASM_PREFETCH("[%[inptr7], #128]") - "ZIP2 v19.4s, v11.4s, v15.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Sixth element - - "ZIP1 v20.4s, v16.4s, v17.4s\n" - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Seventh element - - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Eighth element - : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), - [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "memory" - ); - } - - // Copy any leftover blocks of 4 a complete block at a time. - for (;x>4;x-=4) { - t::copy_4_inc(outptr, inptr0); - t::copy_4_inc(outptr, inptr1); - t::copy_4_inc(outptr, inptr2); - t::copy_4_inc(outptr, inptr3); - t::copy_4_inc(outptr, inptr4); - t::copy_4_inc(outptr, inptr5); - t::copy_4_inc(outptr, inptr6); - t::copy_4_inc(outptr, inptr7); - } - - // Final block with padding, if any. - if (x > 0) { - t::copy_pad(outptr, inptr0, x); - t::copy_pad(outptr, inptr1, x); - t::copy_pad(outptr, inptr2, x); - t::copy_pad(outptr, inptr3, x); - t::copy_pad(outptr, inptr4, x); - t::copy_pad(outptr, inptr5, x); - t::copy_pad(outptr, inptr6, x); - t::copy_pad(outptr, inptr7, x); - } - } -} - -#endif // __aarch64__ && !__ARM_FEATURE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp deleted file mode 100644 index bde3274926..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Copyright (c) 2017-2019 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) - -#include <arm_neon.h> - -#include "../asmlib.hpp" - -template<> -template<> -inline void TransformImpl<8, 1, false, 4, 2, false>::Transform(float *out, const __fp16 *in, int ldin, int y0, int ymax, int k0, int kmax) { - float *outptr = out; - const __fp16 *inptr = in; - bool first = true; - - __fp16 zerobuff[16] = { 0 }; // 8 for asm loop plus up to 7 for overflow loop - - for (int y=y0; y<ymax; y+=8) { - const __fp16 *inptr0 = inptr + y * ldin + k0; - const __fp16 *inptr1 = inptr0 + ldin; - const __fp16 *inptr2 = inptr1 + ldin; - const __fp16 *inptr3 = inptr2 + ldin; - const __fp16 *inptr4 = inptr3 + ldin; - const __fp16 *inptr5 = inptr4 + ldin; - const __fp16 *inptr6 = inptr5 + ldin; - const __fp16 *inptr7 = inptr6 + ldin; - - prefetch_2x(inptr0); - prefetch_2x(inptr1); - prefetch_2x(inptr2); - prefetch_2x(inptr3); - prefetch_2x(inptr4); - prefetch_2x(inptr5); - prefetch_2x(inptr6); - prefetch_2x(inptr7); - - int x=(kmax-k0); - for (;(x>7) || first;x-=8) { - /* Cope with ragged cases by copying from a buffer of zeroes instead */ - /* 'first' forces this to always run at least once, needed if the total size is <=7. */ - if ((y + 7) >= ymax) { - switch ((y + 7) - ymax) { - /* Everything falls through in here */ - case 6: - inptr1 = zerobuff; - // fall through - case 5: - inptr2 = zerobuff; - // fall through - case 4: - inptr3 = zerobuff; - // fall through - case 3: - inptr4 = zerobuff; - // fall through - case 2: - inptr5 = zerobuff; - // fall through - case 1: - inptr6 = zerobuff; - // fall through - case 0: - inptr7 = zerobuff; - break; - - default: - UNREACHABLE("Impossible."); - } - } - - if (first) { - if (x<=7) { - break; - } - - first = false; - } - - __asm __volatile ( - // Load up 8 elements (2 vectors) from each of 8 sources. - "LDR q0, [%[inptr0]], #16\n" - "LDR q2, [%[inptr1]], #16\n" - "FCVTL2 v1.4s, v0.8h\n" - "FCVTL v0.4s, v0.4h\n" - "LDR q4, [%[inptr2]], #16\n" // q4=C0C1C2C3 - "FCVTL2 v3.4s, v2.8h\n" - "FCVTL v2.4s, v2.4h\n" - "FCVTL2 v5.4s, v4.8h\n" - "FCVTL v4.4s, v4.4h\n" - "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 - ASM_PREFETCH("[%[inptr0], #128]") - "LDR q6, [%[inptr3]], #16\n" // q6=D0D1D2D3 - "FCVTL2 v7.4s, v6.8h\n" - "FCVTL v6.4s, v6.4h\n" - "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 - "LDR q8, [%[inptr4]], #16\n" - "LDR q10, [%[inptr5]], #16\n" - "FCVTL2 v9.4s, v8.8h\n" - "FCVTL v8.4s, v8.4h\n" - ASM_PREFETCH("[%[inptr1], #128]") - "LDR q12, [%[inptr6]], #16\n" - "FCVTL2 v11.4s, v10.8h\n" - "FCVTL v10.4s, v10.4h\n" - "FCVTL2 v13.4s, v12.8h\n" - "FCVTL v12.4s, v12.4h\n" - "ZIP1 v18.4s, v8.4s, v12.4s\n" - "LDR q14, [%[inptr7]], #16\n" - "FCVTL2 v15.4s, v14.8h\n" - "FCVTL v14.4s, v14.4h\n" - "ZIP1 v19.4s, v10.4s, v14.4s\n" - - ASM_PREFETCH("[%[inptr2], #128]") - "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - ASM_PREFETCH("[%[inptr3], #128]") - - "ZIP2 v16.4s, v0.4s, v4.4s\n" - "ZIP2 v17.4s, v2.4s, v6.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Write back the first element of each source - - "ZIP2 v18.4s, v8.4s, v12.4s\n" - ASM_PREFETCH("[%[inptr4], #128]") - "ZIP2 v19.4s, v10.4s, v14.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Write back the second element of each source - - "ZIP1 v20.4s, v16.4s, v17.4s\n" - "ZIP1 v21.4s, v18.4s, v19.4s\n" - ASM_PREFETCH("[%[inptr5], #128]") - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - - "ZIP1 v16.4s, v1.4s, v5.4s\n" - "ZIP1 v17.4s, v3.4s, v7.4s\n" - ASM_PREFETCH("[%[inptr6], #128]") - "STP q20, q21, [%[outptr]], #32\n" // Third element - - "ZIP1 v18.4s, v9.4s, v13.4s\n" - "ZIP1 v19.4s, v11.4s, v15.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Fourth element - ASM_PREFETCH("[%[inptr7], #128]") - - "ZIP1 v20.4s, v16.4s, v17.4s\n" - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - - "ZIP2 v16.4s, v1.4s, v5.4s\n" - "ZIP2 v17.4s, v3.4s, v7.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Fifth element - - "ZIP2 v18.4s, v9.4s, v13.4s\n" - "ZIP2 v19.4s, v11.4s, v15.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Sixth element - - "ZIP1 v20.4s, v16.4s, v17.4s\n" - "ZIP1 v21.4s, v18.4s, v19.4s\n" - "STP q20, q21, [%[outptr]], #32\n" // Seventh element - - "ZIP2 v22.4s, v16.4s, v17.4s\n" - "ZIP2 v23.4s, v18.4s, v19.4s\n" - "STP q22, q23, [%[outptr]], #32\n" // Eighth element - : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), - [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "memory" - ); - } - - for (;x>0;x--) { - *outptr++ = *inptr0++; - *outptr++ = *inptr1++; - *outptr++ = *inptr2++; - *outptr++ = *inptr3++; - *outptr++ = *inptr4++; - *outptr++ = *inptr5++; - *outptr++ = *inptr6++; - *outptr++ = *inptr7++; - } - } -} - -#endif // __aarch64__ && __ARM_FP16_ARGS diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_128.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_128.hpp new file mode 100644 index 0000000000..8574d89226 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_128.hpp @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_128(uint32_t *out, const uint32_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 32 * height * sizeof(uint32_t); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x20\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Column loop + "ldr q15, [x25], #0x10\n" + "ldr q14, [x23], #0x10\n" + "sub x24, x24, #0x20\n" + "cmp x24, #0x20\n" + "ldr q13, [x22], #0x10\n" + "ldr q12, [x20], #0x10\n" + "ldr q11, [x25], #0x10\n" + "ldr q10, [x23], #0x10\n" + "ldr q9, [x22], #0x10\n" + "ldr q8, [x20], #0x10\n" + "ldr q7, [x25], #0x10\n" + "ldr q6, [x23], #0x10\n" + "ldr q5, [x22], #0x10\n" + "ldr q4, [x20], #0x10\n" + "ldr q3, [x25], #0x10\n" + "ldr q2, [x23], #0x10\n" + "ldr q1, [x22], #0x10\n" + "ldr q0, [x20], #0x10\n" + "ldr q31, [x25], #0x10\n" + "ldr q30, [x23], #0x10\n" + "ldr q29, [x22], #0x10\n" + "ldr q28, [x20], #0x10\n" + "ldr q27, [x25], #0x10\n" + "ldr q26, [x23], #0x10\n" + "ldr q25, [x22], #0x10\n" + "ldr q24, [x20], #0x10\n" + "ldr q23, [x25], #0x10\n" + "ldr q22, [x23], #0x10\n" + "ldr q21, [x22], #0x10\n" + "ldr q20, [x20], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q15, [x21, #0x0]\n" + "str q11, [x21, #0x10]\n" + "str q7, [x21, #0x20]\n" + "str q3, [x21, #0x30]\n" + "str q31, [x21, #0x40]\n" + "str q27, [x21, #0x50]\n" + "str q23, [x21, #0x60]\n" + "str q19, [x21, #0x70]\n" + "str q14, [x21, #0x80]\n" + "str q10, [x21, #0x90]\n" + "str q6, [x21, #0xa0]\n" + "str q2, [x21, #0xb0]\n" + "str q30, [x21, #0xc0]\n" + "str q26, [x21, #0xd0]\n" + "str q22, [x21, #0xe0]\n" + "str q18, [x21, #0xf0]\n" + "str q13, [x21, #0x100]\n" + "str q9, [x21, #0x110]\n" + "str q5, [x21, #0x120]\n" + "str q1, [x21, #0x130]\n" + "str q29, [x21, #0x140]\n" + "str q25, [x21, #0x150]\n" + "str q21, [x21, #0x160]\n" + "str q17, [x21, #0x170]\n" + "str q12, [x21, #0x180]\n" + "str q8, [x21, #0x190]\n" + "str q4, [x21, #0x1a0]\n" + "str q0, [x21, #0x1b0]\n" + "str q28, [x21, #0x1c0]\n" + "str q24, [x21, #0x1d0]\n" + "str q20, [x21, #0x1e0]\n" + "str q16, [x21, #0x1f0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Column loop skip + "cmp x24, #0x10\n" + "blt 5f\n" + "4:" // Main row loop: width 16 loop: loop + "ldr q31, [x25], #0x10\n" + "ldr q30, [x23], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q29, [x22], #0x10\n" + "ldr q28, [x20], #0x10\n" + "ldr q27, [x25], #0x10\n" + "ldr q26, [x23], #0x10\n" + "ldr q25, [x22], #0x10\n" + "ldr q24, [x20], #0x10\n" + "ldr q23, [x25], #0x10\n" + "ldr q22, [x23], #0x10\n" + "ldr q21, [x22], #0x10\n" + "ldr q20, [x20], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q31, [x21, #0x0]\n" + "str q27, [x21, #0x10]\n" + "str q23, [x21, #0x20]\n" + "str q19, [x21, #0x30]\n" + "str q30, [x21, #0x80]\n" + "str q26, [x21, #0x90]\n" + "str q22, [x21, #0xa0]\n" + "str q18, [x21, #0xb0]\n" + "str q29, [x21, #0x100]\n" + "str q25, [x21, #0x110]\n" + "str q21, [x21, #0x120]\n" + "str q17, [x21, #0x130]\n" + "str q28, [x21, #0x180]\n" + "str q24, [x21, #0x190]\n" + "str q20, [x21, #0x1a0]\n" + "str q16, [x21, #0x1b0]\n" + "add x21, x21, #0x40\n" + "bge 4b\n" + "5:" // Main row loop: width 16 loop: skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q19, [x21, #0x0]\n" + "str q18, [x21, #0x80]\n" + "str q17, [x21, #0x100]\n" + "str q16, [x21, #0x180]\n" + "add x21, x21, #0x10\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr s19, [x25], #0x4\n" + "ldr s18, [x23], #0x4\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "str s19, [x21, #0x0]\n" + "str s18, [x21, #0x80]\n" + "str s17, [x21, #0x100]\n" + "str s16, [x21, #0x180]\n" + "add x21, x21, #0x4\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0x200\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "cmp x20, #0x20\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 13f\n" + "12:" // Tail row loop: Column loop + "ldr q23, [x25], #0x10\n" + "ldr q22, [x25], #0x10\n" + "sub x20, x20, #0x20\n" + "cmp x20, #0x20\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x25], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x25], #0x10\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x25], #0x10\n" + "str q23, [x21, #0x0]\n" + "str q22, [x21, #0x10]\n" + "str q21, [x21, #0x20]\n" + "str q20, [x21, #0x30]\n" + "str q19, [x21, #0x40]\n" + "str q18, [x21, #0x50]\n" + "str q17, [x21, #0x60]\n" + "str q16, [x21, #0x70]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Column loop skip + "cmp x20, #0x10\n" + "blt 15f\n" + "14:" // Tail row loop: width 16 loop: loop + "ldr q19, [x25], #0x10\n" + "ldr q18, [x25], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x25], #0x10\n" + "str q19, [x21, #0x0]\n" + "str q18, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, #0x40\n" + "bge 14b\n" + "15:" // Tail row loop: width 16 loop: skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr q16, [x25], #0x10\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, #0x10\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr s16, [x25], #0x4\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x80\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace + +template<> +void Transform<32, 1, true, VLType::None>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_128( + reinterpret_cast<uint32_t *>(out), + reinterpret_cast<const uint32_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 4, + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_1x4.hpp new file mode 100644 index 0000000000..cdf1f98608 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_1x4.hpp @@ -0,0 +1,431 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_12_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 12 * roundup<size_t>(height, 4) * sizeof(uint8_t); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x30\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "sub x24, x24, #0x30\n" + "cmp x24, #0x30\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v31.16b, v21.16b, v17.16b\n" + "zip1 v22.16b, v20.16b, v16.16b\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v14.16b, v21.16b, v17.16b\n" + "zip2 v13.16b, v20.16b, v16.16b\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v30.16b, v19.16b, v17.16b\n" + "zip1 v29.16b, v18.16b, v16.16b\n" + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v12.16b, v19.16b, v17.16b\n" + "zip2 v11.16b, v18.16b, v16.16b\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v10.16b, v21.16b, v17.16b\n" + "zip1 v9.16b, v20.16b, v16.16b\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v8.16b, v21.16b, v17.16b\n" + "zip2 v7.16b, v20.16b, v16.16b\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v6.16b, v19.16b, v17.16b\n" + "zip1 v5.16b, v18.16b, v16.16b\n" + "ldr q28, [x9], #0x10\n" + "ldr q27, [x28], #0x10\n" + "zip2 v4.16b, v19.16b, v17.16b\n" + "zip2 v3.16b, v18.16b, v16.16b\n" + "ldr q26, [x27], #0x10\n" + "ldr q25, [x26], #0x10\n" + "zip1 v2.16b, v28.16b, v26.16b\n" + "zip1 v1.16b, v27.16b, v25.16b\n" + "ldr q24, [x25], #0x10\n" + "ldr q23, [x23], #0x10\n" + "zip1 v16.16b, v31.16b, v22.16b\n" + "zip2 v22.16b, v31.16b, v22.16b\n" + "ldr q21, [x22], #0x10\n" + "ldr q20, [x20], #0x10\n" + "zip1 v0.16b, v24.16b, v21.16b\n" + "zip1 v31.16b, v23.16b, v20.16b\n" + "zip1 v19.16b, v14.16b, v13.16b\n" + "zip1 v18.16b, v30.16b, v29.16b\n" + "str q16, [x21, #0x0]\n" + "zip2 v16.16b, v30.16b, v29.16b\n" + "zip1 v17.16b, v12.16b, v11.16b\n" + "str q22, [x21, #0x10]\n" + "str q19, [x21, #0x20]\n" + "zip2 v30.16b, v28.16b, v26.16b\n" + "zip2 v29.16b, v27.16b, v25.16b\n" + "str q18, [x21, #0x30]\n" + "zip2 v28.16b, v24.16b, v21.16b\n" + "zip2 v27.16b, v23.16b, v20.16b\n" + "str q16, [x21, #0x40]\n" + "zip2 v21.16b, v14.16b, v13.16b\n" + "zip1 v16.16b, v10.16b, v9.16b\n" + "str q17, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "zip2 v20.16b, v10.16b, v9.16b\n" + "zip2 v19.16b, v12.16b, v11.16b\n" + "zip1 v18.16b, v6.16b, v5.16b\n" + "zip2 v17.16b, v6.16b, v5.16b\n" + "str q21, [x21, #0x0]\n" + "str q16, [x21, #0x10]\n" + "zip1 v16.16b, v8.16b, v7.16b\n" + "zip2 v26.16b, v8.16b, v7.16b\n" + "str q20, [x21, #0x20]\n" + "zip1 v25.16b, v2.16b, v1.16b\n" + "zip1 v24.16b, v4.16b, v3.16b\n" + "str q19, [x21, #0x30]\n" + "zip2 v23.16b, v4.16b, v3.16b\n" + "zip1 v22.16b, v0.16b, v31.16b\n" + "str q18, [x21, #0x40]\n" + "zip2 v21.16b, v2.16b, v1.16b\n" + "zip1 v20.16b, v30.16b, v29.16b\n" + "str q17, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "zip2 v19.16b, v30.16b, v29.16b\n" + "zip2 v18.16b, v0.16b, v31.16b\n" + "str q16, [x21, #0x0]\n" + "zip1 v17.16b, v28.16b, v27.16b\n" + "zip2 v16.16b, v28.16b, v27.16b\n" + "str q26, [x21, #0x10]\n" + "str q25, [x21, #0x20]\n" + "str q24, [x21, #0x30]\n" + "str q23, [x21, #0x40]\n" + "str q22, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "str q21, [x21, #0x0]\n" + "str q20, [x21, #0x10]\n" + "str q19, [x21, #0x20]\n" + "str q18, [x21, #0x30]\n" + "str q17, [x21, #0x40]\n" + "str q16, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0xc\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr d23, [x9], #0x8\n" + "ldr d22, [x28], #0x8\n" + "sub x24, x24, #0xc\n" + "cmp x24, #0xc\n" + "ldr d19, [x27], #0x8\n" + "ldr d18, [x26], #0x8\n" + "ldr d21, [x25], #0x8\n" + "ldr d25, [x23], #0x8\n" + "ldr d20, [x22], #0x8\n" + "ldr d17, [x20], #0x8\n" + "ld1 { v23.s }[2], [x9], #0x4\n" + "ld1 { v22.s }[2], [x28], #0x4\n" + "ld1 { v19.s }[2], [x27], #0x4\n" + "ld1 { v18.s }[2], [x26], #0x4\n" + "zip1 v24.16b, v23.16b, v19.16b\n" + "zip1 v16.16b, v22.16b, v18.16b\n" + "ld1 { v21.s }[2], [x25], #0x4\n" + "ld1 { v25.s }[2], [x23], #0x4\n" + "zip2 v19.16b, v23.16b, v19.16b\n" + "zip2 v18.16b, v22.16b, v18.16b\n" + "ld1 { v20.s }[2], [x22], #0x4\n" + "ld1 { v17.s }[2], [x20], #0x4\n" + "zip1 v23.16b, v21.16b, v20.16b\n" + "zip1 v22.16b, v25.16b, v17.16b\n" + "zip2 v21.16b, v21.16b, v20.16b\n" + "zip2 v20.16b, v25.16b, v17.16b\n" + "zip1 v17.16b, v24.16b, v16.16b\n" + "zip2 v16.16b, v24.16b, v16.16b\n" + "str q17, [x21, #0x0]\n" + "zip1 v19.16b, v19.16b, v18.16b\n" + "zip1 v18.16b, v23.16b, v22.16b\n" + "str q16, [x21, #0x10]\n" + "zip2 v17.16b, v23.16b, v22.16b\n" + "zip1 v16.16b, v21.16b, v20.16b\n" + "str q19, [x21, #0x20]\n" + "str q18, [x21, #0x30]\n" + "str q17, [x21, #0x40]\n" + "str q16, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr s20, [x25], #0x4\n" + "ldr s19, [x23], #0x4\n" + "zip1 v18.16b, v17.16b, v16.16b\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v17.16b, v20.16b, v17.16b\n" + "zip1 v16.16b, v19.16b, v16.16b\n" + "str q18, [x21, #0x0]\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, #0x10\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr b19, [x9], #0x1\n" + "ldr b18, [x28], #0x1\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr b17, [x27], #0x1\n" + "ldr b16, [x26], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr b20, [x25], #0x1\n" + "ldr b19, [x23], #0x1\n" + "zip1 v18.16b, v17.16b, v16.16b\n" + "ldr b17, [x22], #0x1\n" + "ldr b16, [x20], #0x1\n" + "zip1 v17.16b, v20.16b, v17.16b\n" + "zip1 v16.16b, v19.16b, v16.16b\n" + "str s18, [x21, #0x0]\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str s16, [x21, #0x30]\n" + "add x21, x21, #0x4\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x60\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "mov x20, %x[width]\n" + "add x26, x27, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x26, %x[in_stride]\n" + "csel x26, x26, %x[pad_row], GT\n" + "csel x27, x27, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x28, x28, %x[pad_row], GT\n" + "cmp x20, #0x30\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: Unroll column loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "sub x20, x20, #0x30\n" + "cmp x20, #0x30\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v31.16b, v21.16b, v17.16b\n" + "zip1 v30.16b, v20.16b, v16.16b\n" + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v29.16b, v21.16b, v17.16b\n" + "zip2 v28.16b, v20.16b, v16.16b\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v27.16b, v19.16b, v17.16b\n" + "zip1 v26.16b, v18.16b, v16.16b\n" + "ldr q22, [x9], #0x10\n" + "ldr q21, [x28], #0x10\n" + "zip2 v25.16b, v19.16b, v17.16b\n" + "zip2 v20.16b, v18.16b, v16.16b\n" + "ldr q19, [x27], #0x10\n" + "ldr q18, [x26], #0x10\n" + "zip1 v24.16b, v22.16b, v19.16b\n" + "zip1 v23.16b, v21.16b, v18.16b\n" + "zip1 v16.16b, v31.16b, v30.16b\n" + "zip2 v17.16b, v31.16b, v30.16b\n" + "str q16, [x21, #0x0]\n" + "zip1 v16.16b, v29.16b, v28.16b\n" + "str q17, [x21, #0x10]\n" + "zip2 v22.16b, v22.16b, v19.16b\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "zip2 v21.16b, v21.16b, v18.16b\n" + "zip2 v18.16b, v29.16b, v28.16b\n" + "zip1 v16.16b, v27.16b, v26.16b\n" + "zip2 v17.16b, v27.16b, v26.16b\n" + "str q18, [x21, #0x0]\n" + "str q16, [x21, #0x10]\n" + "zip1 v16.16b, v25.16b, v20.16b\n" + "zip2 v20.16b, v25.16b, v20.16b\n" + "str q17, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "zip1 v19.16b, v24.16b, v23.16b\n" + "zip2 v18.16b, v24.16b, v23.16b\n" + "str q16, [x21, #0x0]\n" + "zip1 v17.16b, v22.16b, v21.16b\n" + "zip2 v16.16b, v22.16b, v21.16b\n" + "str q20, [x21, #0x10]\n" + "str q19, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "str q18, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Unroll column loop skip + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr d19, [x9], #0x8\n" + "ldr d21, [x28], #0x8\n" + "sub x20, x20, #0xc\n" + "cmp x20, #0xc\n" + "ldr d18, [x27], #0x8\n" + "ldr d16, [x26], #0x8\n" + "ld1 { v19.s }[2], [x9], #0x4\n" + "ld1 { v21.s }[2], [x28], #0x4\n" + "ld1 { v18.s }[2], [x27], #0x4\n" + "ld1 { v16.s }[2], [x26], #0x4\n" + "zip1 v20.16b, v19.16b, v18.16b\n" + "zip1 v17.16b, v21.16b, v16.16b\n" + "zip2 v19.16b, v19.16b, v18.16b\n" + "zip2 v18.16b, v21.16b, v16.16b\n" + "zip1 v16.16b, v20.16b, v17.16b\n" + "zip2 v17.16b, v20.16b, v17.16b\n" + "str q16, [x21, #0x0]\n" + "zip1 v16.16b, v19.16b, v18.16b\n" + "str q17, [x21, #0x10]\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, #0x10\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr b19, [x9], #0x1\n" + "ldr b18, [x28], #0x1\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "ldr b17, [x27], #0x1\n" + "ldr b16, [x26], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x30\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<12, 4, true, VLType::None>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_12_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<12, 4, true, VLType::None>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_12_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_1x8.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_1x8.hpp new file mode 100644 index 0000000000..da0809d4d6 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_1x8.hpp @@ -0,0 +1,335 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_12_1x8(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 8) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 12 * roundup<size_t>(height, 8) * sizeof(uint8_t); + + __asm__ __volatile__( + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "cmp %x[height], #0x7\n" + "add %x[in], x22, %x[in_stride]\n" + "csel x22, x22, %x[pad_row], GT\n" + "csel x23, x23, %x[pad_row], GE\n" + "cmp %x[height], #0x5\n" + "mov x21, %x[width]\n" + "csel x24, x24, %x[pad_row], GT\n" + "csel x25, x25, %x[pad_row], GE\n" + "cmp %x[height], #0x3\n" + "csel x26, x26, %x[pad_row], GT\n" + "csel x27, x27, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x28, x28, %x[pad_row], GT\n" + "cmp x21, #0x30\n" + "mov x20, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q21, [x9], #0x10\n" + "ldr q25, [x28], #0x10\n" + "sub x21, x21, #0x30\n" + "cmp x21, #0x30\n" + "ldr q20, [x27], #0x10\n" + "ldr q24, [x26], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x24], #0x10\n" + "zip1 v7.16b, v21.16b, v19.16b\n" + "zip1 v6.16b, v25.16b, v18.16b\n" + "ldr q17, [x23], #0x10\n" + "ldr q16, [x22], #0x10\n" + "zip1 v28.16b, v20.16b, v17.16b\n" + "zip1 v27.16b, v24.16b, v16.16b\n" + "ldr q23, [x9], #0x10\n" + "ldr q22, [x28], #0x10\n" + "zip2 v5.16b, v21.16b, v19.16b\n" + "zip2 v4.16b, v20.16b, v17.16b\n" + "ldr q21, [x27], #0x10\n" + "ldr q20, [x26], #0x10\n" + "zip2 v3.16b, v25.16b, v18.16b\n" + "zip2 v2.16b, v24.16b, v16.16b\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x24], #0x10\n" + "zip1 v1.16b, v23.16b, v19.16b\n" + "zip1 v15.16b, v22.16b, v18.16b\n" + "ldr q17, [x23], #0x10\n" + "ldr q16, [x22], #0x10\n" + "zip1 v0.16b, v21.16b, v17.16b\n" + "zip1 v31.16b, v20.16b, v16.16b\n" + "ldr q26, [x9], #0x10\n" + "ldr q30, [x28], #0x10\n" + "zip2 v14.16b, v23.16b, v19.16b\n" + "zip2 v13.16b, v21.16b, v17.16b\n" + "ldr q25, [x27], #0x10\n" + "ldr q24, [x26], #0x10\n" + "zip2 v12.16b, v22.16b, v18.16b\n" + "zip2 v11.16b, v20.16b, v16.16b\n" + "ldr q23, [x25], #0x10\n" + "ldr q22, [x24], #0x10\n" + "zip1 v10.16b, v26.16b, v23.16b\n" + "zip1 v9.16b, v30.16b, v22.16b\n" + "ldr q21, [x23], #0x10\n" + "ldr q17, [x22], #0x10\n" + "zip1 v29.16b, v25.16b, v21.16b\n" + "zip1 v8.16b, v24.16b, v17.16b\n" + "zip1 v19.16b, v7.16b, v28.16b\n" + "zip1 v16.16b, v6.16b, v27.16b\n" + "zip2 v28.16b, v7.16b, v28.16b\n" + "zip2 v18.16b, v6.16b, v27.16b\n" + "zip1 v27.16b, v5.16b, v4.16b\n" + "zip1 v20.16b, v3.16b, v2.16b\n" + "zip2 v7.16b, v26.16b, v23.16b\n" + "zip2 v26.16b, v25.16b, v21.16b\n" + "zip2 v6.16b, v30.16b, v22.16b\n" + "zip2 v25.16b, v24.16b, v17.16b\n" + "zip2 v5.16b, v5.16b, v4.16b\n" + "zip2 v4.16b, v3.16b, v2.16b\n" + "zip1 v3.16b, v1.16b, v0.16b\n" + "zip1 v2.16b, v15.16b, v31.16b\n" + "zip2 v1.16b, v1.16b, v0.16b\n" + "zip2 v0.16b, v15.16b, v31.16b\n" + "zip1 v31.16b, v14.16b, v13.16b\n" + "zip1 v30.16b, v12.16b, v11.16b\n" + "zip2 v24.16b, v14.16b, v13.16b\n" + "zip2 v23.16b, v12.16b, v11.16b\n" + "zip1 v22.16b, v10.16b, v29.16b\n" + "zip1 v21.16b, v9.16b, v8.16b\n" + "zip1 v17.16b, v19.16b, v16.16b\n" + "zip2 v16.16b, v19.16b, v16.16b\n" + "str q17, [x20, #0x0]\n" + "zip1 v19.16b, v28.16b, v18.16b\n" + "zip2 v18.16b, v28.16b, v18.16b\n" + "str q16, [x20, #0x10]\n" + "zip1 v17.16b, v27.16b, v20.16b\n" + "zip2 v16.16b, v27.16b, v20.16b\n" + "str q19, [x20, #0x20]\n" + "str q18, [x20, #0x30]\n" + "zip2 v29.16b, v10.16b, v29.16b\n" + "zip2 v20.16b, v9.16b, v8.16b\n" + "str q17, [x20, #0x40]\n" + "zip1 v28.16b, v7.16b, v26.16b\n" + "zip1 v27.16b, v6.16b, v25.16b\n" + "str q16, [x20, #0x50]\n" + "add x20, x20, %x[out_stride]\n" + "zip2 v26.16b, v7.16b, v26.16b\n" + "zip2 v25.16b, v6.16b, v25.16b\n" + "zip1 v17.16b, v5.16b, v4.16b\n" + "zip2 v16.16b, v5.16b, v4.16b\n" + "str q17, [x20, #0x0]\n" + "zip1 v18.16b, v3.16b, v2.16b\n" + "zip2 v17.16b, v3.16b, v2.16b\n" + "str q16, [x20, #0x10]\n" + "zip1 v16.16b, v1.16b, v0.16b\n" + "zip2 v19.16b, v1.16b, v0.16b\n" + "str q18, [x20, #0x20]\n" + "str q17, [x20, #0x30]\n" + "zip1 v18.16b, v31.16b, v30.16b\n" + "zip2 v17.16b, v31.16b, v30.16b\n" + "str q16, [x20, #0x40]\n" + "zip1 v16.16b, v24.16b, v23.16b\n" + "zip2 v24.16b, v24.16b, v23.16b\n" + "str q19, [x20, #0x50]\n" + "add x20, x20, %x[out_stride]\n" + "zip1 v23.16b, v22.16b, v21.16b\n" + "zip2 v22.16b, v22.16b, v21.16b\n" + "str q18, [x20, #0x0]\n" + "zip1 v21.16b, v29.16b, v20.16b\n" + "zip2 v20.16b, v29.16b, v20.16b\n" + "str q17, [x20, #0x10]\n" + "zip1 v19.16b, v28.16b, v27.16b\n" + "zip2 v18.16b, v28.16b, v27.16b\n" + "str q16, [x20, #0x20]\n" + "zip1 v17.16b, v26.16b, v25.16b\n" + "zip2 v16.16b, v26.16b, v25.16b\n" + "str q24, [x20, #0x30]\n" + "str q23, [x20, #0x40]\n" + "str q22, [x20, #0x50]\n" + "add x20, x20, %x[out_stride]\n" + "str q21, [x20, #0x0]\n" + "str q20, [x20, #0x10]\n" + "str q19, [x20, #0x20]\n" + "str q18, [x20, #0x30]\n" + "str q17, [x20, #0x40]\n" + "str q16, [x20, #0x50]\n" + "add x20, x20, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x21, #0xc\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr d23, [x9], #0x8\n" + "ldr d27, [x28], #0x8\n" + "sub x21, x21, #0xc\n" + "cmp x21, #0xc\n" + "ldr d21, [x27], #0x8\n" + "ldr d26, [x26], #0x8\n" + "ldr d20, [x25], #0x8\n" + "ldr d19, [x24], #0x8\n" + "ldr d17, [x23], #0x8\n" + "ldr d16, [x22], #0x8\n" + "ld1 { v23.s }[2], [x9], #0x4\n" + "ld1 { v27.s }[2], [x28], #0x4\n" + "ld1 { v21.s }[2], [x27], #0x4\n" + "ld1 { v26.s }[2], [x26], #0x4\n" + "ld1 { v20.s }[2], [x25], #0x4\n" + "ld1 { v19.s }[2], [x24], #0x4\n" + "zip1 v25.16b, v23.16b, v20.16b\n" + "zip1 v24.16b, v27.16b, v19.16b\n" + "ld1 { v17.s }[2], [x23], #0x4\n" + "ld1 { v16.s }[2], [x22], #0x4\n" + "zip1 v22.16b, v21.16b, v17.16b\n" + "zip1 v18.16b, v26.16b, v16.16b\n" + "zip2 v23.16b, v23.16b, v20.16b\n" + "zip2 v21.16b, v21.16b, v17.16b\n" + "zip2 v20.16b, v27.16b, v19.16b\n" + "zip2 v17.16b, v26.16b, v16.16b\n" + "zip1 v19.16b, v25.16b, v22.16b\n" + "zip1 v16.16b, v24.16b, v18.16b\n" + "zip2 v22.16b, v25.16b, v22.16b\n" + "zip2 v18.16b, v24.16b, v18.16b\n" + "zip1 v21.16b, v23.16b, v21.16b\n" + "zip1 v20.16b, v20.16b, v17.16b\n" + "zip1 v17.16b, v19.16b, v16.16b\n" + "zip2 v16.16b, v19.16b, v16.16b\n" + "str q17, [x20, #0x0]\n" + "zip1 v19.16b, v22.16b, v18.16b\n" + "zip2 v18.16b, v22.16b, v18.16b\n" + "str q16, [x20, #0x10]\n" + "zip1 v17.16b, v21.16b, v20.16b\n" + "zip2 v16.16b, v21.16b, v20.16b\n" + "str q19, [x20, #0x20]\n" + "str q18, [x20, #0x30]\n" + "str q17, [x20, #0x40]\n" + "str q16, [x20, #0x50]\n" + "add x20, x20, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x21, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr s18, [x9], #0x4\n" + "ldr s19, [x28], #0x4\n" + "sub x21, x21, #0x4\n" + "cmp x21, #0x4\n" + "ldr s21, [x27], #0x4\n" + "ldr s20, [x26], #0x4\n" + "ldr s17, [x25], #0x4\n" + "ldr s16, [x24], #0x4\n" + "zip1 v18.16b, v18.16b, v17.16b\n" + "zip1 v19.16b, v19.16b, v16.16b\n" + "ldr s17, [x23], #0x4\n" + "ldr s16, [x22], #0x4\n" + "zip1 v17.16b, v21.16b, v17.16b\n" + "zip1 v16.16b, v20.16b, v16.16b\n" + "zip1 v18.16b, v18.16b, v17.16b\n" + "zip1 v16.16b, v19.16b, v16.16b\n" + "zip1 v17.16b, v18.16b, v16.16b\n" + "zip2 v16.16b, v18.16b, v16.16b\n" + "str q17, [x20, #0x0]\n" + "str q16, [x20, #0x10]\n" + "add x20, x20, #0x20\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x21, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr b19, [x9], #0x1\n" + "ldr b18, [x28], #0x1\n" + "sub x21, x21, #0x1\n" + "cmp x21, #0x1\n" + "ldr b21, [x27], #0x1\n" + "ldr b20, [x26], #0x1\n" + "ldr b17, [x25], #0x1\n" + "ldr b16, [x24], #0x1\n" + "zip1 v19.16b, v19.16b, v17.16b\n" + "zip1 v18.16b, v18.16b, v16.16b\n" + "ldr b17, [x23], #0x1\n" + "ldr b16, [x22], #0x1\n" + "zip1 v17.16b, v21.16b, v17.16b\n" + "zip1 v16.16b, v20.16b, v16.16b\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str d16, [x20, #0x0]\n" + "add x20, x20, #0x8\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x60\n" + "bge 1b\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<12, 8, true, VLType::None>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_12_1x8( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<12, 8, true, VLType::None>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_12_1x8( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_2x2.hpp new file mode 100644 index 0000000000..cef468e9cc --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_2x2.hpp @@ -0,0 +1,343 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_12_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 12 * roundup<size_t>(height, 2) * sizeof(uint16_t); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "sub x24, x24, #0x18\n" + "zip1 v10.8h, v19.8h, v18.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip2 v9.8h, v19.8h, v18.8h\n" + "zip1 v8.8h, v17.8h, v16.8h\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v7.8h, v17.8h, v16.8h\n" + "zip1 v6.8h, v19.8h, v18.8h\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip2 v5.8h, v19.8h, v18.8h\n" + "zip1 v4.8h, v17.8h, v16.8h\n" + "ldr q21, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip1 v3.8h, v21.8h, v18.8h\n" + "zip2 v2.8h, v17.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v1.8h, v17.8h, v16.8h\n" + "cmp x24, #0x18\n" + "ldr q20, [x25], #0x10\n" + "ldr q19, [x23], #0x10\n" + "zip1 v0.8h, v20.8h, v19.8h\n" + "zip2 v31.8h, v21.8h, v18.8h\n" + "ldr q30, [x22], #0x10\n" + "ldr q29, [x20], #0x10\n" + "zip1 v28.8h, v30.8h, v29.8h\n" + "zip2 v27.8h, v17.8h, v16.8h\n" + "ldr q17, [x9], #0x10\n" + "ldr q16, [x28], #0x10\n" + "zip1 v26.8h, v17.8h, v16.8h\n" + "zip2 v25.8h, v17.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v24.8h, v17.8h, v16.8h\n" + "zip2 v23.8h, v17.8h, v16.8h\n" + "ldr q18, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "zip2 v22.8h, v20.8h, v19.8h\n" + "zip1 v21.8h, v18.8h, v17.8h\n" + "ldr q20, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q10, [x21, #0x0]\n" + "zip2 v19.8h, v18.8h, v17.8h\n" + "str q9, [x21, #0x10]\n" + "zip2 v18.8h, v30.8h, v29.8h\n" + "zip1 v17.8h, v20.8h, v16.8h\n" + "str q3, [x21, #0x20]\n" + "zip2 v16.8h, v20.8h, v16.8h\n" + "str q8, [x21, #0x30]\n" + "str q7, [x21, #0x40]\n" + "str q1, [x21, #0x50]\n" + "str q6, [x21, #0x60]\n" + "str q5, [x21, #0x70]\n" + "str q0, [x21, #0x80]\n" + "str q4, [x21, #0x90]\n" + "str q2, [x21, #0xa0]\n" + "str q28, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "str q31, [x21, #0x0]\n" + "str q26, [x21, #0x10]\n" + "str q25, [x21, #0x20]\n" + "str q27, [x21, #0x30]\n" + "str q24, [x21, #0x40]\n" + "str q23, [x21, #0x50]\n" + "str q22, [x21, #0x60]\n" + "str q21, [x21, #0x70]\n" + "str q19, [x21, #0x80]\n" + "str q18, [x21, #0x90]\n" + "str q17, [x21, #0xa0]\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0xc\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr q17, [x9], #0x10\n" + "ldr q16, [x28], #0x10\n" + "sub x24, x24, #0xc\n" + "cmp x24, #0xc\n" + "ldr q19, [x27], #0x10\n" + "ldr q18, [x26], #0x10\n" + "zip1 v28.8h, v17.8h, v16.8h\n" + "zip2 v27.8h, v17.8h, v16.8h\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v26.8h, v19.8h, v18.8h\n" + "zip2 v25.8h, v19.8h, v18.8h\n" + "ldr q19, [x22], #0x10\n" + "ldr q18, [x20], #0x10\n" + "zip1 v24.8h, v17.8h, v16.8h\n" + "zip2 v23.8h, v17.8h, v16.8h\n" + "ldr d17, [x9], #0x8\n" + "ldr d16, [x28], #0x8\n" + "zip1 v22.8h, v17.8h, v16.8h\n" + "zip1 v21.8h, v19.8h, v18.8h\n" + "ldr d17, [x27], #0x8\n" + "ldr d16, [x26], #0x8\n" + "zip1 v20.8h, v17.8h, v16.8h\n" + "zip2 v19.8h, v19.8h, v18.8h\n" + "ldr d17, [x25], #0x8\n" + "ldr d16, [x23], #0x8\n" + "zip1 v18.8h, v17.8h, v16.8h\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str q28, [x21, #0x0]\n" + "str q27, [x21, #0x10]\n" + "str q22, [x21, #0x20]\n" + "str q26, [x21, #0x30]\n" + "str q25, [x21, #0x40]\n" + "str q20, [x21, #0x50]\n" + "str q24, [x21, #0x60]\n" + "str q23, [x21, #0x70]\n" + "str q18, [x21, #0x80]\n" + "str q21, [x21, #0x90]\n" + "str q19, [x21, #0xa0]\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr d19, [x9], #0x8\n" + "ldr d18, [x28], #0x8\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr d17, [x27], #0x8\n" + "ldr d16, [x26], #0x8\n" + "zip1 v20.8h, v19.8h, v18.8h\n" + "zip1 v19.8h, v17.8h, v16.8h\n" + "ldr d17, [x25], #0x8\n" + "ldr d16, [x23], #0x8\n" + "zip1 v18.8h, v17.8h, v16.8h\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "str q20, [x21, #0x0]\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str q19, [x21, #0x30]\n" + "str q18, [x21, #0x60]\n" + "str q16, [x21, #0x90]\n" + "add x21, x21, #0x10\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr h19, [x9], #0x2\n" + "ldr h18, [x28], #0x2\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr h17, [x27], #0x2\n" + "ldr h16, [x26], #0x2\n" + "zip1 v20.8h, v19.8h, v18.8h\n" + "zip1 v19.8h, v17.8h, v16.8h\n" + "ldr h17, [x25], #0x2\n" + "ldr h16, [x23], #0x2\n" + "zip1 v18.8h, v17.8h, v16.8h\n" + "ldr h17, [x22], #0x2\n" + "ldr h16, [x20], #0x2\n" + "str s20, [x21, #0x0]\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str s19, [x21, #0x30]\n" + "str s18, [x21, #0x60]\n" + "str s16, [x21, #0x90]\n" + "add x21, x21, #0x4\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0xc0\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x9, %x[in]\n" + "mov x20, %x[width]\n" + "add x28, x9, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x28, %x[in_stride]\n" + "csel x28, x28, %x[pad_row], GT\n" + "cmp x20, #0x18\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 13f\n" + "12:" // Tail row loop: Unroll column loop + "ldr q17, [x9], #0x10\n" + "ldr q16, [x28], #0x10\n" + "sub x20, x20, #0x18\n" + "zip1 v22.8h, v17.8h, v16.8h\n" + "ldr q21, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v17.8h, v17.8h, v16.8h\n" + "zip1 v20.8h, v21.8h, v18.8h\n" + "ldr q19, [x9], #0x10\n" + "ldr q16, [x28], #0x10\n" + "str q22, [x21, #0x0]\n" + "cmp x20, #0x18\n" + "str q17, [x21, #0x10]\n" + "zip2 v18.8h, v21.8h, v18.8h\n" + "zip1 v17.8h, v19.8h, v16.8h\n" + "str q20, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "zip2 v16.8h, v19.8h, v16.8h\n" + "str q18, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Unroll column loop skip + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q20, [x9], #0x10\n" + "ldr q17, [x28], #0x10\n" + "sub x20, x20, #0xc\n" + "cmp x20, #0xc\n" + "ldr d19, [x9], #0x8\n" + "ldr d16, [x28], #0x8\n" + "zip1 v18.8h, v20.8h, v17.8h\n" + "zip2 v17.8h, v20.8h, v17.8h\n" + "zip1 v16.8h, v19.8h, v16.8h\n" + "str q18, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr d17, [x9], #0x8\n" + "ldr d16, [x28], #0x8\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, #0x10\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr h17, [x9], #0x2\n" + "ldr h16, [x28], #0x2\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x30\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<12, 2, true, VLType::None>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_12_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_2x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_2x4.hpp new file mode 100644 index 0000000000..4c02d0534d --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_2x4.hpp @@ -0,0 +1,444 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_12_2x4(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 12 * roundup<size_t>(height, 4) * sizeof(uint16_t); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "sub x24, x24, #0x18\n" + "cmp x24, #0x18\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v13.8h, v21.8h, v17.8h\n" + "zip1 v12.8h, v20.8h, v16.8h\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v11.8h, v21.8h, v17.8h\n" + "zip2 v10.8h, v20.8h, v16.8h\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v9.8h, v19.8h, v17.8h\n" + "zip1 v8.8h, v18.8h, v16.8h\n" + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v7.8h, v19.8h, v17.8h\n" + "zip2 v6.8h, v18.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v27.8h, v21.8h, v17.8h\n" + "zip1 v22.8h, v20.8h, v16.8h\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v5.8h, v21.8h, v17.8h\n" + "zip2 v4.8h, v20.8h, v16.8h\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v26.8h, v19.8h, v17.8h\n" + "zip1 v25.8h, v18.8h, v16.8h\n" + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v3.8h, v19.8h, v17.8h\n" + "zip2 v2.8h, v18.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v24.8h, v21.8h, v17.8h\n" + "zip1 v23.8h, v20.8h, v16.8h\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v1.8h, v21.8h, v17.8h\n" + "zip2 v0.8h, v20.8h, v16.8h\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v31.8h, v19.8h, v17.8h\n" + "zip1 v30.8h, v18.8h, v16.8h\n" + "zip2 v29.8h, v19.8h, v17.8h\n" + "zip2 v28.8h, v18.8h, v16.8h\n" + "zip1 v17.8h, v13.8h, v12.8h\n" + "zip2 v16.8h, v13.8h, v12.8h\n" + "str q17, [x21, #0x0]\n" + "zip1 v18.8h, v11.8h, v10.8h\n" + "zip2 v17.8h, v11.8h, v10.8h\n" + "str q16, [x21, #0x10]\n" + "zip1 v16.8h, v27.8h, v22.8h\n" + "zip2 v22.8h, v27.8h, v22.8h\n" + "str q18, [x21, #0x20]\n" + "zip1 v21.8h, v9.8h, v8.8h\n" + "zip2 v20.8h, v9.8h, v8.8h\n" + "str q17, [x21, #0x30]\n" + "zip1 v19.8h, v7.8h, v6.8h\n" + "zip2 v18.8h, v7.8h, v6.8h\n" + "str q16, [x21, #0x40]\n" + "zip1 v17.8h, v26.8h, v25.8h\n" + "zip2 v16.8h, v26.8h, v25.8h\n" + "str q22, [x21, #0x50]\n" + "str q21, [x21, #0x60]\n" + "zip1 v27.8h, v5.8h, v4.8h\n" + "zip2 v26.8h, v5.8h, v4.8h\n" + "str q20, [x21, #0x70]\n" + "zip1 v25.8h, v24.8h, v23.8h\n" + "zip2 v24.8h, v24.8h, v23.8h\n" + "str q19, [x21, #0x80]\n" + "zip1 v23.8h, v1.8h, v0.8h\n" + "zip2 v22.8h, v1.8h, v0.8h\n" + "str q18, [x21, #0x90]\n" + "zip1 v21.8h, v3.8h, v2.8h\n" + "zip2 v20.8h, v3.8h, v2.8h\n" + "str q17, [x21, #0xa0]\n" + "zip1 v19.8h, v31.8h, v30.8h\n" + "zip2 v18.8h, v31.8h, v30.8h\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "zip1 v17.8h, v29.8h, v28.8h\n" + "zip2 v16.8h, v29.8h, v28.8h\n" + "str q27, [x21, #0x0]\n" + "str q26, [x21, #0x10]\n" + "str q25, [x21, #0x20]\n" + "str q24, [x21, #0x30]\n" + "str q23, [x21, #0x40]\n" + "str q22, [x21, #0x50]\n" + "str q21, [x21, #0x60]\n" + "str q20, [x21, #0x70]\n" + "str q19, [x21, #0x80]\n" + "str q18, [x21, #0x90]\n" + "str q17, [x21, #0xa0]\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0xc\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "sub x24, x24, #0xc\n" + "cmp x24, #0xc\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v25.8h, v19.8h, v17.8h\n" + "zip1 v24.8h, v18.8h, v16.8h\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x23], #0x10\n" + "zip2 v31.8h, v19.8h, v17.8h\n" + "zip2 v23.8h, v18.8h, v16.8h\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v30.8h, v21.8h, v17.8h\n" + "zip1 v29.8h, v20.8h, v16.8h\n" + "ldr d19, [x9], #0x8\n" + "ldr d18, [x28], #0x8\n" + "zip2 v28.8h, v21.8h, v17.8h\n" + "zip2 v27.8h, v20.8h, v16.8h\n" + "ldr d17, [x27], #0x8\n" + "ldr d16, [x26], #0x8\n" + "zip1 v26.8h, v19.8h, v17.8h\n" + "zip1 v22.8h, v18.8h, v16.8h\n" + "ldr d21, [x25], #0x8\n" + "ldr d20, [x23], #0x8\n" + "zip1 v19.8h, v25.8h, v24.8h\n" + "zip2 v18.8h, v25.8h, v24.8h\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "zip1 v25.8h, v21.8h, v17.8h\n" + "zip1 v24.8h, v20.8h, v16.8h\n" + "zip1 v17.8h, v31.8h, v23.8h\n" + "zip2 v16.8h, v31.8h, v23.8h\n" + "str q19, [x21, #0x0]\n" + "zip1 v23.8h, v26.8h, v22.8h\n" + "zip2 v22.8h, v26.8h, v22.8h\n" + "str q18, [x21, #0x10]\n" + "zip1 v21.8h, v30.8h, v29.8h\n" + "zip2 v20.8h, v30.8h, v29.8h\n" + "str q17, [x21, #0x20]\n" + "zip1 v19.8h, v28.8h, v27.8h\n" + "zip2 v18.8h, v28.8h, v27.8h\n" + "str q16, [x21, #0x30]\n" + "zip1 v17.8h, v25.8h, v24.8h\n" + "zip2 v16.8h, v25.8h, v24.8h\n" + "str q23, [x21, #0x40]\n" + "str q22, [x21, #0x50]\n" + "str q21, [x21, #0x60]\n" + "str q20, [x21, #0x70]\n" + "str q19, [x21, #0x80]\n" + "str q18, [x21, #0x90]\n" + "str q17, [x21, #0xa0]\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr d19, [x9], #0x8\n" + "ldr d18, [x28], #0x8\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr d17, [x27], #0x8\n" + "ldr d16, [x26], #0x8\n" + "zip1 v17.8h, v19.8h, v17.8h\n" + "zip1 v16.8h, v18.8h, v16.8h\n" + "ldr d18, [x25], #0x8\n" + "ldr d21, [x23], #0x8\n" + "zip1 v20.8h, v17.8h, v16.8h\n" + "zip2 v19.8h, v17.8h, v16.8h\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "zip1 v18.8h, v18.8h, v17.8h\n" + "zip1 v16.8h, v21.8h, v16.8h\n" + "str q20, [x21, #0x0]\n" + "zip1 v17.8h, v18.8h, v16.8h\n" + "zip2 v16.8h, v18.8h, v16.8h\n" + "str q19, [x21, #0x10]\n" + "str q17, [x21, #0x60]\n" + "str q16, [x21, #0x70]\n" + "add x21, x21, #0x20\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr h19, [x9], #0x2\n" + "ldr h18, [x28], #0x2\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr h17, [x27], #0x2\n" + "ldr h16, [x26], #0x2\n" + "zip1 v17.8h, v19.8h, v17.8h\n" + "zip1 v16.8h, v18.8h, v16.8h\n" + "ldr h20, [x25], #0x2\n" + "ldr h19, [x23], #0x2\n" + "zip1 v18.8h, v17.8h, v16.8h\n" + "ldr h17, [x22], #0x2\n" + "ldr h16, [x20], #0x2\n" + "zip1 v17.8h, v20.8h, v17.8h\n" + "zip1 v16.8h, v19.8h, v16.8h\n" + "str d18, [x21, #0x0]\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str d16, [x21, #0x60]\n" + "add x21, x21, #0x8\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0xc0\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "mov x20, %x[width]\n" + "add x26, x27, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x26, %x[in_stride]\n" + "csel x26, x26, %x[pad_row], GT\n" + "csel x27, x27, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x28, x28, %x[pad_row], GT\n" + "cmp x20, #0x18\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: Unroll column loop + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "sub x20, x20, #0x18\n" + "cmp x20, #0x18\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v31.8h, v19.8h, v17.8h\n" + "zip1 v30.8h, v18.8h, v16.8h\n" + "ldr q22, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v29.8h, v19.8h, v17.8h\n" + "zip2 v28.8h, v18.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v27.8h, v22.8h, v17.8h\n" + "zip1 v21.8h, v20.8h, v16.8h\n" + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v26.8h, v22.8h, v17.8h\n" + "zip2 v20.8h, v20.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v25.8h, v19.8h, v17.8h\n" + "zip1 v24.8h, v18.8h, v16.8h\n" + "zip2 v23.8h, v19.8h, v17.8h\n" + "zip2 v22.8h, v18.8h, v16.8h\n" + "zip1 v17.8h, v31.8h, v30.8h\n" + "zip2 v16.8h, v31.8h, v30.8h\n" + "str q17, [x21, #0x0]\n" + "zip1 v19.8h, v29.8h, v28.8h\n" + "zip2 v18.8h, v29.8h, v28.8h\n" + "str q16, [x21, #0x10]\n" + "zip1 v17.8h, v27.8h, v21.8h\n" + "zip2 v16.8h, v27.8h, v21.8h\n" + "str q19, [x21, #0x20]\n" + "str q18, [x21, #0x30]\n" + "zip1 v21.8h, v26.8h, v20.8h\n" + "zip2 v20.8h, v26.8h, v20.8h\n" + "str q17, [x21, #0x40]\n" + "zip1 v19.8h, v25.8h, v24.8h\n" + "zip2 v18.8h, v25.8h, v24.8h\n" + "str q16, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "zip1 v17.8h, v23.8h, v22.8h\n" + "zip2 v16.8h, v23.8h, v22.8h\n" + "str q21, [x21, #0x0]\n" + "str q20, [x21, #0x10]\n" + "str q19, [x21, #0x20]\n" + "str q18, [x21, #0x30]\n" + "str q17, [x21, #0x40]\n" + "str q16, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Unroll column loop skip + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q21, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "sub x20, x20, #0xc\n" + "cmp x20, #0xc\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v24.8h, v21.8h, v17.8h\n" + "zip1 v23.8h, v18.8h, v16.8h\n" + "ldr d20, [x9], #0x8\n" + "ldr d19, [x28], #0x8\n" + "zip2 v22.8h, v21.8h, v17.8h\n" + "zip2 v18.8h, v18.8h, v16.8h\n" + "ldr d17, [x27], #0x8\n" + "ldr d16, [x26], #0x8\n" + "zip1 v21.8h, v20.8h, v17.8h\n" + "zip1 v20.8h, v19.8h, v16.8h\n" + "zip1 v17.8h, v24.8h, v23.8h\n" + "zip2 v16.8h, v24.8h, v23.8h\n" + "str q17, [x21, #0x0]\n" + "zip1 v19.8h, v22.8h, v18.8h\n" + "zip2 v18.8h, v22.8h, v18.8h\n" + "str q16, [x21, #0x10]\n" + "zip1 v17.8h, v21.8h, v20.8h\n" + "zip2 v16.8h, v21.8h, v20.8h\n" + "str q19, [x21, #0x20]\n" + "str q18, [x21, #0x30]\n" + "str q17, [x21, #0x40]\n" + "str q16, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr d18, [x9], #0x8\n" + "ldr d19, [x28], #0x8\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "ldr d17, [x27], #0x8\n" + "ldr d16, [x26], #0x8\n" + "zip1 v18.8h, v18.8h, v17.8h\n" + "zip1 v16.8h, v19.8h, v16.8h\n" + "zip1 v17.8h, v18.8h, v16.8h\n" + "zip2 v16.8h, v18.8h, v16.8h\n" + "str q17, [x21, #0x0]\n" + "str q16, [x21, #0x10]\n" + "add x21, x21, #0x20\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr h19, [x9], #0x2\n" + "ldr h18, [x28], #0x2\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "ldr h17, [x27], #0x2\n" + "ldr h16, [x26], #0x2\n" + "zip1 v17.8h, v19.8h, v17.8h\n" + "zip1 v16.8h, v18.8h, v16.8h\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str d16, [x21, #0x0]\n" + "add x21, x21, #0x8\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x60\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<12, 4, true, VLType::None>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_12_2x4( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_2x4_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_2x4_fp32bf16.hpp new file mode 100644 index 0000000000..2a3208d18d --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_2x4_fp32bf16.hpp @@ -0,0 +1,734 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_12_2x4_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 12 * roundup<size_t>(height, 4) * sizeof(bfloat16); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q15, [x9], #0x10\n" + "ldr q17, [x28], #0x10\n" + "sub x24, x24, #0x18\n" + "cmp x24, #0x18\n" + "ldr q16, [x27], #0x10\n" + "ldr q20, [x26], #0x10\n" + "zip1 v6.4s, v15.4s, v16.4s\n" + "zip1 v11.4s, v17.4s, v20.4s\n" + "ldr q2, [x25], #0x10\n" + "ldr q4, [x23], #0x10\n" + "zip2 v22.4s, v15.4s, v16.4s\n" + "zip2 v18.4s, v17.4s, v20.4s\n" + "ldr q17, [x22], #0x10\n" + "ldr q26, [x20], #0x10\n" + "zip1 v9.4s, v2.4s, v17.4s\n" + "zip1 v10.4s, v4.4s, v26.4s\n" + "ldr q16, [x9], #0x10\n" + "ldr q27, [x28], #0x10\n" + "zip2 v3.4s, v2.4s, v17.4s\n" + "zip2 v30.4s, v4.4s, v26.4s\n" + "ldr q13, [x27], #0x10\n" + "ldr q1, [x26], #0x10\n" + "zip1 v23.4s, v16.4s, v13.4s\n" + "zip1 v5.4s, v27.4s, v1.4s\n" + "ldr q26, [x25], #0x10\n" + "ldr q14, [x23], #0x10\n" + "zip2 v0.4s, v16.4s, v13.4s\n" + "zip2 v2.4s, v27.4s, v1.4s\n" + "ldr q15, [x22], #0x10\n" + "ldr q8, [x20], #0x10\n" + "zip1 v31.4s, v26.4s, v15.4s\n" + "zip1 v4.4s, v14.4s, v8.4s\n" + "ldr q28, [x9], #0x10\n" + "ldr q19, [x28], #0x10\n" + "zip2 v21.4s, v26.4s, v15.4s\n" + "zip2 v16.4s, v14.4s, v8.4s\n" + "ldr q15, [x27], #0x10\n" + "ldr q1, [x26], #0x10\n" + "zip1 v17.4s, v28.4s, v15.4s\n" + "zip1 v8.4s, v19.4s, v1.4s\n" + "ldr q27, [x25], #0x10\n" + "ldr q20, [x23], #0x10\n" + "zip2 v7.4s, v28.4s, v15.4s\n" + "zip2 v15.4s, v19.4s, v1.4s\n" + "ldr q12, [x22], #0x10\n" + "ldr q25, [x20], #0x10\n" + "zip1 v14.4s, v27.4s, v12.4s\n" + "zip1 v26.4s, v20.4s, v25.4s\n" + "ldr q13, [x9], #0x10\n" + "ldr q29, [x28], #0x10\n" + "zip2 v28.4s, v27.4s, v12.4s\n" + "zip2 v12.4s, v20.4s, v25.4s\n" + "ldr q27, [x27], #0x10\n" + "ldr q20, [x26], #0x10\n" + "zip1 v19.4s, v13.4s, v27.4s\n" + "zip1 v25.4s, v29.4s, v20.4s\n" + "ldr q24, [x25], #0x10\n" + "ldr q1, [x23], #0x10\n" + "zip2 v27.4s, v13.4s, v27.4s\n" + "zip2 v13.4s, v29.4s, v20.4s\n" + "ldr q20, [x22], #0x10\n" + "zip1 v29.4s, v24.4s, v20.4s\n" + "zip2 v20.4s, v24.4s, v20.4s\n" + "zip1 v24.4s, v6.4s, v11.4s\n" + ".inst 0x0ea16b18 // bfcvtn v24.4h, v24.4s\n" + "zip2 v11.4s, v6.4s, v11.4s\n" + "ldr q6, [x20], #0x10\n" + ".inst 0x4ea16978 // bfcvtn2 v24.8h, v11.4s\n" + "zip1 v11.4s, v1.4s, v6.4s\n" + "zip2 v6.4s, v1.4s, v6.4s\n" + "zip1 v1.4s, v22.4s, v18.4s\n" + ".inst 0x0ea16821 // bfcvtn v1.4h, v1.4s\n" + "zip2 v18.4s, v22.4s, v18.4s\n" + "ldr q22, [x9], #0x10\n" + ".inst 0x4ea16a41 // bfcvtn2 v1.8h, v18.4s\n" + "zip1 v18.4s, v23.4s, v5.4s\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + "zip2 v5.4s, v23.4s, v5.4s\n" + "ldr q23, [x28], #0x10\n" + ".inst 0x4ea168b2 // bfcvtn2 v18.8h, v5.4s\n" + "zip1 v5.4s, v0.4s, v2.4s\n" + ".inst 0x0ea168a5 // bfcvtn v5.4h, v5.4s\n" + "zip2 v0.4s, v0.4s, v2.4s\n" + "ldr q2, [x27], #0x10\n" + ".inst 0x4ea16805 // bfcvtn2 v5.8h, v0.4s\n" + "zip1 v0.4s, v22.4s, v2.4s\n" + "zip2 v2.4s, v22.4s, v2.4s\n" + "zip1 v22.4s, v17.4s, v8.4s\n" + ".inst 0x0ea16ad6 // bfcvtn v22.4h, v22.4s\n" + "zip2 v8.4s, v17.4s, v8.4s\n" + "ldr q17, [x26], #0x10\n" + ".inst 0x4ea16916 // bfcvtn2 v22.8h, v8.4s\n" + "zip1 v8.4s, v23.4s, v17.4s\n" + "zip2 v23.4s, v23.4s, v17.4s\n" + "zip1 v17.4s, v7.4s, v15.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + "zip2 v7.4s, v7.4s, v15.4s\n" + "ldr q15, [x25], #0x10\n" + ".inst 0x4ea168f1 // bfcvtn2 v17.8h, v7.4s\n" + "zip1 v7.4s, v9.4s, v10.4s\n" + ".inst 0x0ea168e7 // bfcvtn v7.4h, v7.4s\n" + "zip2 v10.4s, v9.4s, v10.4s\n" + "ldr q9, [x23], #0x10\n" + ".inst 0x4ea16947 // bfcvtn2 v7.8h, v10.4s\n" + "zip1 v10.4s, v3.4s, v30.4s\n" + ".inst 0x0ea1694a // bfcvtn v10.4h, v10.4s\n" + "zip2 v30.4s, v3.4s, v30.4s\n" + "ldr q3, [x22], #0x10\n" + ".inst 0x4ea16bca // bfcvtn2 v10.8h, v30.4s\n" + "zip1 v30.4s, v15.4s, v3.4s\n" + "zip2 v15.4s, v15.4s, v3.4s\n" + "zip1 v3.4s, v31.4s, v4.4s\n" + ".inst 0x0ea16863 // bfcvtn v3.4h, v3.4s\n" + "zip2 v31.4s, v31.4s, v4.4s\n" + "ldr q4, [x20], #0x10\n" + ".inst 0x4ea16be3 // bfcvtn2 v3.8h, v31.4s\n" + "zip1 v31.4s, v9.4s, v4.4s\n" + "zip2 v4.4s, v9.4s, v4.4s\n" + "zip1 v9.4s, v21.4s, v16.4s\n" + ".inst 0x0ea16929 // bfcvtn v9.4h, v9.4s\n" + "zip2 v16.4s, v21.4s, v16.4s\n" + "ldr q21, [x9], #0x10\n" + ".inst 0x4ea16a09 // bfcvtn2 v9.8h, v16.4s\n" + "zip1 v16.4s, v14.4s, v26.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "zip2 v14.4s, v14.4s, v26.4s\n" + "ldr q26, [x28], #0x10\n" + ".inst 0x4ea169d0 // bfcvtn2 v16.8h, v14.4s\n" + "zip1 v14.4s, v28.4s, v12.4s\n" + ".inst 0x0ea169ce // bfcvtn v14.4h, v14.4s\n" + "zip2 v12.4s, v28.4s, v12.4s\n" + "ldr q28, [x27], #0x10\n" + ".inst 0x4ea1698e // bfcvtn2 v14.8h, v12.4s\n" + "zip1 v12.4s, v21.4s, v28.4s\n" + "zip2 v28.4s, v21.4s, v28.4s\n" + "zip1 v21.4s, v19.4s, v25.4s\n" + ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n" + "zip2 v19.4s, v19.4s, v25.4s\n" + "ldr q25, [x26], #0x10\n" + ".inst 0x4ea16a75 // bfcvtn2 v21.8h, v19.4s\n" + "zip1 v19.4s, v26.4s, v25.4s\n" + "zip2 v25.4s, v26.4s, v25.4s\n" + "zip1 v26.4s, v27.4s, v13.4s\n" + ".inst 0x0ea16b5a // bfcvtn v26.4h, v26.4s\n" + "zip2 v13.4s, v27.4s, v13.4s\n" + "ldr q27, [x25], #0x10\n" + ".inst 0x4ea169ba // bfcvtn2 v26.8h, v13.4s\n" + "zip1 v13.4s, v0.4s, v8.4s\n" + ".inst 0x0ea169ad // bfcvtn v13.4h, v13.4s\n" + "zip2 v8.4s, v0.4s, v8.4s\n" + "ldr q0, [x23], #0x10\n" + ".inst 0x4ea1690d // bfcvtn2 v13.8h, v8.4s\n" + "zip1 v8.4s, v2.4s, v23.4s\n" + ".inst 0x0ea16908 // bfcvtn v8.4h, v8.4s\n" + "zip2 v23.4s, v2.4s, v23.4s\n" + "ldr q2, [x22], #0x10\n" + ".inst 0x4ea16ae8 // bfcvtn2 v8.8h, v23.4s\n" + "ldr q23, [x20], #0x10\n" + "str q24, [x21, #0x0]\n" + "zip1 v24.4s, v27.4s, v2.4s\n" + "zip2 v27.4s, v27.4s, v2.4s\n" + "zip1 v2.4s, v0.4s, v23.4s\n" + "zip2 v23.4s, v0.4s, v23.4s\n" + "str q1, [x21, #0x10]\n" + "zip1 v0.4s, v12.4s, v19.4s\n" + "zip1 v1.4s, v28.4s, v25.4s\n" + "str q18, [x21, #0x20]\n" + "zip1 v18.4s, v29.4s, v11.4s\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "str q5, [x21, #0x30]\n" + "zip1 v5.4s, v20.4s, v6.4s\n" + "zip2 v19.4s, v12.4s, v19.4s\n" + "str q22, [x21, #0x40]\n" + "zip1 v12.4s, v30.4s, v31.4s\n" + "zip1 v22.4s, v15.4s, v4.4s\n" + "str q17, [x21, #0x50]\n" + "zip1 v17.4s, v24.4s, v2.4s\n" + ".inst 0x0ea16821 // bfcvtn v1.4h, v1.4s\n" + "str q7, [x21, #0x60]\n" + "zip1 v7.4s, v27.4s, v23.4s\n" + "zip2 v25.4s, v28.4s, v25.4s\n" + "str q10, [x21, #0x70]\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + "zip2 v29.4s, v29.4s, v11.4s\n" + "str q3, [x21, #0x80]\n" + ".inst 0x0ea168ab // bfcvtn v11.4h, v5.4s\n" + "zip2 v10.4s, v20.4s, v6.4s\n" + "str q9, [x21, #0x90]\n" + ".inst 0x0ea16986 // bfcvtn v6.4h, v12.4s\n" + "zip2 v12.4s, v30.4s, v31.4s\n" + "str q16, [x21, #0xa0]\n" + ".inst 0x0ea16ac5 // bfcvtn v5.4h, v22.4s\n" + "zip2 v4.4s, v15.4s, v4.4s\n" + "str q14, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + ".inst 0x0ea16a2f // bfcvtn v15.4h, v17.4s\n" + "zip2 v20.4s, v24.4s, v2.4s\n" + "str q21, [x21, #0x0]\n" + ".inst 0x0ea168fc // bfcvtn v28.4h, v7.4s\n" + "zip2 v30.4s, v27.4s, v23.4s\n" + "str q26, [x21, #0x10]\n" + ".inst 0x4ea16a60 // bfcvtn2 v0.8h, v19.4s\n" + ".inst 0x4ea16b21 // bfcvtn2 v1.8h, v25.4s\n" + "str q13, [x21, #0x20]\n" + ".inst 0x4ea16bb2 // bfcvtn2 v18.8h, v29.4s\n" + ".inst 0x4ea1694b // bfcvtn2 v11.8h, v10.4s\n" + "str q8, [x21, #0x30]\n" + ".inst 0x4ea16986 // bfcvtn2 v6.8h, v12.4s\n" + ".inst 0x4ea16885 // bfcvtn2 v5.8h, v4.4s\n" + "str q0, [x21, #0x40]\n" + ".inst 0x4ea16a8f // bfcvtn2 v15.8h, v20.4s\n" + ".inst 0x4ea16bdc // bfcvtn2 v28.8h, v30.4s\n" + "str q1, [x21, #0x50]\n" + "str q18, [x21, #0x60]\n" + "str q11, [x21, #0x70]\n" + "str q6, [x21, #0x80]\n" + "str q5, [x21, #0x90]\n" + "str q15, [x21, #0xa0]\n" + "str q28, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0xc\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr q20, [x9], #0x10\n" + "ldr q9, [x28], #0x10\n" + "sub x24, x24, #0xc\n" + "cmp x24, #0xc\n" + "ldr q8, [x27], #0x10\n" + "ldr q1, [x26], #0x10\n" + "zip1 v7.4s, v20.4s, v8.4s\n" + "zip1 v19.4s, v9.4s, v1.4s\n" + "ldr q6, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip2 v5.4s, v20.4s, v8.4s\n" + "zip2 v18.4s, v9.4s, v1.4s\n" + "ldr q27, [x22], #0x10\n" + "ldr q14, [x20], #0x10\n" + "zip1 v26.4s, v6.4s, v27.4s\n" + "zip1 v15.4s, v16.4s, v14.4s\n" + "ldr q1, [x9], #0x10\n" + "ldr q30, [x28], #0x10\n" + "zip2 v24.4s, v6.4s, v27.4s\n" + "zip2 v25.4s, v16.4s, v14.4s\n" + "ldr q13, [x27], #0x10\n" + "ldr q17, [x26], #0x10\n" + "zip1 v10.4s, v1.4s, v13.4s\n" + "zip1 v16.4s, v30.4s, v17.4s\n" + "ldr q4, [x25], #0x10\n" + "ldr q11, [x23], #0x10\n" + "zip2 v0.4s, v1.4s, v13.4s\n" + "zip2 v27.4s, v30.4s, v17.4s\n" + "ldr q28, [x22], #0x10\n" + "ldr q12, [x20], #0x10\n" + "zip1 v22.4s, v4.4s, v28.4s\n" + "zip1 v13.4s, v11.4s, v12.4s\n" + "ldr q31, [x9], #0x10\n" + "ldr q17, [x28], #0x10\n" + "zip2 v14.4s, v4.4s, v28.4s\n" + "zip2 v12.4s, v11.4s, v12.4s\n" + "ldr q2, [x27], #0x10\n" + "ldr q3, [x26], #0x10\n" + "zip1 v8.4s, v31.4s, v2.4s\n" + "zip1 v4.4s, v17.4s, v3.4s\n" + "ldr q23, [x25], #0x10\n" + "ldr q1, [x23], #0x10\n" + "zip2 v28.4s, v31.4s, v2.4s\n" + "zip2 v29.4s, v17.4s, v3.4s\n" + "ldr q11, [x22], #0x10\n" + "ldr q17, [x20], #0x10\n" + "zip1 v9.4s, v23.4s, v11.4s\n" + "zip1 v21.4s, v1.4s, v17.4s\n" + "zip2 v11.4s, v23.4s, v11.4s\n" + "zip2 v17.4s, v1.4s, v17.4s\n" + "zip1 v2.4s, v7.4s, v19.4s\n" + "zip1 v31.4s, v5.4s, v18.4s\n" + "zip1 v3.4s, v10.4s, v16.4s\n" + "zip1 v6.4s, v0.4s, v27.4s\n" + "zip1 v1.4s, v8.4s, v4.4s\n" + "zip1 v30.4s, v28.4s, v29.4s\n" + "zip1 v20.4s, v26.4s, v15.4s\n" + "zip1 v23.4s, v24.4s, v25.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "zip2 v7.4s, v7.4s, v19.4s\n" + "zip1 v19.4s, v22.4s, v13.4s\n" + ".inst 0x0ea16bff // bfcvtn v31.4h, v31.4s\n" + "zip2 v18.4s, v5.4s, v18.4s\n" + "zip1 v5.4s, v14.4s, v12.4s\n" + ".inst 0x0ea16863 // bfcvtn v3.4h, v3.4s\n" + "zip2 v16.4s, v10.4s, v16.4s\n" + "zip1 v10.4s, v9.4s, v21.4s\n" + ".inst 0x0ea168c6 // bfcvtn v6.4h, v6.4s\n" + "zip2 v0.4s, v0.4s, v27.4s\n" + "zip1 v27.4s, v11.4s, v17.4s\n" + ".inst 0x0ea16821 // bfcvtn v1.4h, v1.4s\n" + "zip2 v4.4s, v8.4s, v4.4s\n" + ".inst 0x0ea16bde // bfcvtn v30.4h, v30.4s\n" + "zip2 v29.4s, v28.4s, v29.4s\n" + ".inst 0x0ea16a9c // bfcvtn v28.4h, v20.4s\n" + "zip2 v15.4s, v26.4s, v15.4s\n" + ".inst 0x0ea16ae8 // bfcvtn v8.4h, v23.4s\n" + "zip2 v26.4s, v24.4s, v25.4s\n" + ".inst 0x0ea16a79 // bfcvtn v25.4h, v19.4s\n" + "zip2 v24.4s, v22.4s, v13.4s\n" + ".inst 0x0ea168b7 // bfcvtn v23.4h, v5.4s\n" + "zip2 v22.4s, v14.4s, v12.4s\n" + ".inst 0x0ea16945 // bfcvtn v5.4h, v10.4s\n" + "zip2 v20.4s, v9.4s, v21.4s\n" + ".inst 0x0ea16b73 // bfcvtn v19.4h, v27.4s\n" + "zip2 v17.4s, v11.4s, v17.4s\n" + ".inst 0x4ea168e2 // bfcvtn2 v2.8h, v7.4s\n" + ".inst 0x4ea16a5f // bfcvtn2 v31.8h, v18.4s\n" + "str q2, [x21, #0x0]\n" + ".inst 0x4ea16a03 // bfcvtn2 v3.8h, v16.4s\n" + ".inst 0x4ea16806 // bfcvtn2 v6.8h, v0.4s\n" + "str q31, [x21, #0x10]\n" + ".inst 0x4ea16881 // bfcvtn2 v1.8h, v4.4s\n" + ".inst 0x4ea16bbe // bfcvtn2 v30.8h, v29.4s\n" + "str q3, [x21, #0x20]\n" + ".inst 0x4ea169fc // bfcvtn2 v28.8h, v15.4s\n" + ".inst 0x4ea16b48 // bfcvtn2 v8.8h, v26.4s\n" + "str q6, [x21, #0x30]\n" + ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + "str q1, [x21, #0x40]\n" + ".inst 0x4ea16a85 // bfcvtn2 v5.8h, v20.4s\n" + ".inst 0x4ea16a33 // bfcvtn2 v19.8h, v17.4s\n" + "str q30, [x21, #0x50]\n" + "str q28, [x21, #0x60]\n" + "str q8, [x21, #0x70]\n" + "str q25, [x21, #0x80]\n" + "str q23, [x21, #0x90]\n" + "str q5, [x21, #0xa0]\n" + "str q19, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr q23, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v22.4s, v23.4s, v17.4s\n" + "zip1 v21.4s, v20.4s, v16.4s\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v28.4s, v23.4s, v17.4s\n" + "zip2 v20.4s, v20.4s, v16.4s\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v27.4s, v19.4s, v17.4s\n" + "zip1 v26.4s, v18.4s, v16.4s\n" + "zip2 v25.4s, v19.4s, v17.4s\n" + "zip2 v24.4s, v18.4s, v16.4s\n" + "zip1 v19.4s, v22.4s, v21.4s\n" + "zip1 v18.4s, v28.4s, v20.4s\n" + "zip1 v17.4s, v27.4s, v26.4s\n" + "zip1 v16.4s, v25.4s, v24.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v22.4s, v21.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v28.4s, v20.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v27.4s, v26.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v25.4s, v24.4s\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + "str q23, [x21, #0x0]\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q21, [x21, #0x10]\n" + "str q19, [x21, #0x60]\n" + "str q17, [x21, #0x70]\n" + "add x21, x21, #0x20\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "ldr s20, [x25], #0x4\n" + "ldr s19, [x23], #0x4\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a12 // bfcvtn v18.4h, v16.4s\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v17.4s, v20.4s, v17.4s\n" + "zip1 v16.4s, v19.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d18, [x21, #0x0]\n" + "str d16, [x21, #0x60]\n" + "add x21, x21, #0x8\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0xc0\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "mov x20, %x[width]\n" + "add x26, x27, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x26, %x[in_stride]\n" + "csel x26, x26, %x[pad_row], GT\n" + "csel x27, x27, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x28, x28, %x[pad_row], GT\n" + "cmp x20, #0x18\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: Unroll column loop + "ldr q22, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "sub x20, x20, #0x18\n" + "cmp x20, #0x18\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v19.4s, v22.4s, v17.4s\n" + "zip1 v21.4s, v18.4s, v16.4s\n" + "ldr q24, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v10.4s, v22.4s, v17.4s\n" + "zip2 v2.4s, v18.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v22.4s, v24.4s, v17.4s\n" + "zip1 v4.4s, v20.4s, v16.4s\n" + "ldr q23, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v29.4s, v24.4s, v17.4s\n" + "zip2 v1.4s, v20.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v30.4s, v23.4s, v17.4s\n" + "zip1 v31.4s, v18.4s, v16.4s\n" + "ldr q24, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v23.4s, v23.4s, v17.4s\n" + "zip2 v28.4s, v18.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v25.4s, v24.4s, v17.4s\n" + "zip1 v26.4s, v20.4s, v16.4s\n" + "ldr q14, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v24.4s, v24.4s, v17.4s\n" + "zip2 v15.4s, v20.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v12.4s, v14.4s, v17.4s\n" + "zip1 v13.4s, v18.4s, v16.4s\n" + "ldr q7, [x9], #0x10\n" + "ldr q3, [x28], #0x10\n" + "zip2 v0.4s, v14.4s, v17.4s\n" + "zip2 v9.4s, v18.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v14.4s, v7.4s, v17.4s\n" + "zip1 v8.4s, v3.4s, v16.4s\n" + "zip2 v7.4s, v7.4s, v17.4s\n" + "zip2 v11.4s, v3.4s, v16.4s\n" + "zip1 v18.4s, v19.4s, v21.4s\n" + "zip1 v6.4s, v10.4s, v2.4s\n" + "zip1 v5.4s, v22.4s, v4.4s\n" + "zip1 v16.4s, v29.4s, v1.4s\n" + "zip1 v27.4s, v30.4s, v31.4s\n" + "zip1 v3.4s, v23.4s, v28.4s\n" + "zip1 v17.4s, v25.4s, v26.4s\n" + "zip1 v20.4s, v24.4s, v15.4s\n" + ".inst 0x0ea16a52 // bfcvtn v18.4h, v18.4s\n" + "zip2 v19.4s, v19.4s, v21.4s\n" + "zip1 v21.4s, v12.4s, v13.4s\n" + ".inst 0x0ea168c6 // bfcvtn v6.4h, v6.4s\n" + "zip2 v10.4s, v10.4s, v2.4s\n" + "zip1 v2.4s, v0.4s, v9.4s\n" + ".inst 0x0ea168a5 // bfcvtn v5.4h, v5.4s\n" + "zip2 v4.4s, v22.4s, v4.4s\n" + "zip1 v22.4s, v14.4s, v8.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "zip2 v1.4s, v29.4s, v1.4s\n" + "zip1 v29.4s, v7.4s, v11.4s\n" + ".inst 0x0ea16b7b // bfcvtn v27.4h, v27.4s\n" + "zip2 v30.4s, v30.4s, v31.4s\n" + ".inst 0x0ea1687f // bfcvtn v31.4h, v3.4s\n" + "zip2 v23.4s, v23.4s, v28.4s\n" + ".inst 0x0ea16a23 // bfcvtn v3.4h, v17.4s\n" + "zip2 v28.4s, v25.4s, v26.4s\n" + ".inst 0x0ea16a9a // bfcvtn v26.4h, v20.4s\n" + "zip2 v25.4s, v24.4s, v15.4s\n" + ".inst 0x0ea16ab8 // bfcvtn v24.4h, v21.4s\n" + "zip2 v12.4s, v12.4s, v13.4s\n" + ".inst 0x0ea16855 // bfcvtn v21.4h, v2.4s\n" + "zip2 v13.4s, v0.4s, v9.4s\n" + ".inst 0x0ea16ac2 // bfcvtn v2.4h, v22.4s\n" + "zip2 v0.4s, v14.4s, v8.4s\n" + ".inst 0x0ea16ba9 // bfcvtn v9.4h, v29.4s\n" + "zip2 v17.4s, v7.4s, v11.4s\n" + ".inst 0x4ea16a72 // bfcvtn2 v18.8h, v19.4s\n" + ".inst 0x4ea16946 // bfcvtn2 v6.8h, v10.4s\n" + "str q18, [x21, #0x0]\n" + ".inst 0x4ea16885 // bfcvtn2 v5.8h, v4.4s\n" + ".inst 0x4ea16830 // bfcvtn2 v16.8h, v1.4s\n" + "str q6, [x21, #0x10]\n" + ".inst 0x4ea16bdb // bfcvtn2 v27.8h, v30.4s\n" + ".inst 0x4ea16aff // bfcvtn2 v31.8h, v23.4s\n" + "str q5, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + ".inst 0x4ea16b83 // bfcvtn2 v3.8h, v28.4s\n" + ".inst 0x4ea16b3a // bfcvtn2 v26.8h, v25.4s\n" + "str q27, [x21, #0x40]\n" + ".inst 0x4ea16998 // bfcvtn2 v24.8h, v12.4s\n" + ".inst 0x4ea169b5 // bfcvtn2 v21.8h, v13.4s\n" + "str q31, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + ".inst 0x4ea16802 // bfcvtn2 v2.8h, v0.4s\n" + ".inst 0x4ea16a29 // bfcvtn2 v9.8h, v17.4s\n" + "str q3, [x21, #0x0]\n" + "str q26, [x21, #0x10]\n" + "str q24, [x21, #0x20]\n" + "str q21, [x21, #0x30]\n" + "str q2, [x21, #0x40]\n" + "str q9, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Unroll column loop skip + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "sub x20, x20, #0xc\n" + "cmp x20, #0xc\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v26.4s, v19.4s, v17.4s\n" + "zip1 v25.4s, v18.4s, v16.4s\n" + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v24.4s, v19.4s, v17.4s\n" + "zip2 v23.4s, v18.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v2.4s, v21.4s, v17.4s\n" + "zip1 v22.4s, v20.4s, v16.4s\n" + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v1.4s, v21.4s, v17.4s\n" + "zip2 v0.4s, v20.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v31.4s, v19.4s, v17.4s\n" + "zip1 v30.4s, v18.4s, v16.4s\n" + "zip2 v29.4s, v19.4s, v17.4s\n" + "zip2 v28.4s, v18.4s, v16.4s\n" + "zip1 v21.4s, v26.4s, v25.4s\n" + "zip1 v20.4s, v24.4s, v23.4s\n" + "zip1 v19.4s, v2.4s, v22.4s\n" + "zip1 v18.4s, v1.4s, v0.4s\n" + "zip1 v17.4s, v31.4s, v30.4s\n" + "zip1 v16.4s, v29.4s, v28.4s\n" + ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" + "zip2 v26.4s, v26.4s, v25.4s\n" + ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n" + "zip2 v24.4s, v24.4s, v23.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v2.4s, v22.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v1.4s, v0.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v31.4s, v30.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v29.4s, v28.4s\n" + ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n" + ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" + "str q27, [x21, #0x0]\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + "str q25, [x21, #0x10]\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q23, [x21, #0x20]\n" + "str q21, [x21, #0x30]\n" + "str q19, [x21, #0x40]\n" + "str q17, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr q20, [x9], #0x10\n" + "ldr q19, [x28], #0x10\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v22.4s, v20.4s, v17.4s\n" + "zip1 v18.4s, v19.4s, v16.4s\n" + "zip2 v21.4s, v20.4s, v17.4s\n" + "zip2 v20.4s, v19.4s, v16.4s\n" + "zip1 v17.4s, v22.4s, v18.4s\n" + "zip1 v16.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v22.4s, v18.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v21.4s, v20.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q19, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "add x21, x21, #0x20\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d16, [x21, #0x0]\n" + "add x21, x21, #0x8\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x60\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace +template<> +void Transform<12, 4, true, VLType::None>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_12_2x4_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_s8s16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_s8s16.hpp new file mode 100644 index 0000000000..4d9d5e7f43 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_s8s16.hpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_12_s8s16(int16_t *out, const int8_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 12 * height * sizeof(int16_t); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q3, [x25], #0x10\n" + "ldr q21, [x23], #0x10\n" + "sshll2 v20.8h, v3.16b, #0x0\n" + "sshll v2.8h, v21.8b, #0x0\n" + "ldr q1, [x22], #0x10\n" + "ldr q19, [x20], #0x10\n" + "sshll2 v18.8h, v1.16b, #0x0\n" + "sshll v0.8h, v19.8b, #0x0\n" + "ldr d17, [x25], #0x8\n" + "ldr d16, [x22], #0x8\n" + "sshll v31.8h, v17.8b, #0x0\n" + "sshll v30.8h, v16.8b, #0x0\n" + "ldr d29, [x23], #0x8\n" + "ldr d28, [x20], #0x8\n" + "sshll2 v27.8h, v21.16b, #0x0\n" + "sshll2 v26.8h, v19.16b, #0x0\n" + "dup v25.2d, v20.d[0]\n" + "dup v24.2d, v2.d[1]\n" + "sub x24, x24, #0x18\n" + "cmp x24, #0x18\n" + "dup v23.2d, v18.d[0]\n" + "dup v22.2d, v0.d[1]\n" + "dup v21.2d, v20.d[1]\n" + "dup v20.2d, v31.d[1]\n" + "dup v19.2d, v18.d[1]\n" + "dup v18.2d, v30.d[1]\n" + "sshll v17.8h, v3.8b, #0x0\n" + "sshll v16.8h, v1.8b, #0x0\n" + "str q17, [x21, #0x0]\n" + "mov v25.d[1], v2.d[0]\n" + "mov v24.d[1], v27.d[0]\n" + "str q25, [x21, #0x10]\n" + "mov v23.d[1], v0.d[0]\n" + "mov v22.d[1], v26.d[0]\n" + "str q24, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "sshll v17.8h, v29.8b, #0x0\n" + "sshll v16.8h, v28.8b, #0x0\n" + "str q23, [x21, #0x40]\n" + "mov v21.d[1], v31.d[0]\n" + "mov v20.d[1], v27.d[1]\n" + "str q22, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "mov v19.d[1], v30.d[0]\n" + "mov v18.d[1], v26.d[1]\n" + "str q21, [x21, #0x0]\n" + "str q20, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q19, [x21, #0x30]\n" + "str q18, [x21, #0x40]\n" + "str q16, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0xc\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr d19, [x23], #0x8\n" + "ldr d18, [x20], #0x8\n" + "sub x24, x24, #0xc\n" + "cmp x24, #0xc\n" + "ld1 { v19.s }[2], [x23], #0x4\n" + "ld1 { v18.s }[2], [x20], #0x4\n" + "sshll v25.8h, v19.8b, #0x0\n" + "sshll v24.8h, v18.8b, #0x0\n" + "ldr d17, [x25], #0x8\n" + "ldr d16, [x22], #0x8\n" + "sshll2 v23.8h, v19.16b, #0x0\n" + "sshll2 v22.8h, v18.16b, #0x0\n" + "ld1 { v17.s }[2], [x25], #0x4\n" + "ld1 { v16.s }[2], [x22], #0x4\n" + "sshll2 v21.8h, v17.16b, #0x0\n" + "sshll2 v20.8h, v16.16b, #0x0\n" + "dup v19.2d, v25.d[1]\n" + "dup v18.2d, v24.d[1]\n" + "sshll v17.8h, v17.8b, #0x0\n" + "sshll v16.8h, v16.8b, #0x0\n" + "str q17, [x21, #0x0]\n" + "mov v21.d[1], v25.d[0]\n" + "mov v19.d[1], v23.d[0]\n" + "str q21, [x21, #0x10]\n" + "mov v20.d[1], v24.d[0]\n" + "mov v18.d[1], v22.d[0]\n" + "str q19, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "str q20, [x21, #0x40]\n" + "str q18, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr s19, [x25], #0x4\n" + "ldr s18, [x23], #0x4\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "sshll v19.8h, v19.8b, #0x0\n" + "sshll v18.8h, v18.8b, #0x0\n" + "sshll v17.8h, v17.8b, #0x0\n" + "sshll v16.8h, v16.8b, #0x0\n" + "str d19, [x21, #0x0]\n" + "str d18, [x21, #0x18]\n" + "str d17, [x21, #0x30]\n" + "str d16, [x21, #0x48]\n" + "add x21, x21, #0x8\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr b19, [x25], #0x1\n" + "ldr b18, [x23], #0x1\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr b17, [x22], #0x1\n" + "ldr b16, [x20], #0x1\n" + "sshll v19.8h, v19.8b, #0x0\n" + "sshll v18.8h, v18.8b, #0x0\n" + "sshll v17.8h, v17.8b, #0x0\n" + "sshll v16.8h, v16.8b, #0x0\n" + "str h19, [x21, #0x0]\n" + "str h18, [x21, #0x18]\n" + "str h17, [x21, #0x30]\n" + "str h16, [x21, #0x48]\n" + "add x21, x21, #0x2\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0x60\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "cmp x20, #0x18\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 13f\n" + "12:" // Tail row loop: Unroll column loop + "ldr q20, [x25], #0x10\n" + "ldr d16, [x25], #0x8\n" + "sshll2 v19.8h, v20.16b, #0x0\n" + "sshll v18.8h, v16.8b, #0x0\n" + "dup v17.2d, v19.d[1]\n" + "sub x20, x20, #0x18\n" + "sshll v16.8h, v20.8b, #0x0\n" + "str q16, [x21, #0x0]\n" + "dup v16.2d, v19.d[0]\n" + "str d16, [x21, #0x10]\n" + "add x21, x21, %x[out_stride]\n" + "cmp x20, #0x18\n" + "mov v17.d[1], v18.d[0]\n" + "dup v16.2d, v18.d[1]\n" + "str q17, [x21, #0x0]\n" + "str d16, [x21, #0x10]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Unroll column loop skip + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr d16, [x25], #0x8\n" + "ld1 { v16.s }[2], [x25], #0x4\n" + "sub x20, x20, #0xc\n" + "cmp x20, #0xc\n" + "sshll v17.8h, v16.8b, #0x0\n" + "sshll2 v16.8h, v16.16b, #0x0\n" + "str q17, [x21, #0x0]\n" + "str d16, [x21, #0x10]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr s16, [x25], #0x4\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "sshll v16.8h, v16.8b, #0x0\n" + "str d16, [x21, #0x0]\n" + "add x21, x21, #0x8\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr b16, [x25], #0x1\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "sshll v16.8h, v16.8b, #0x0\n" + "str h16, [x21, #0x0]\n" + "add x21, x21, #0x2\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x18\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace +template<> +void Transform<12, 1, true, VLType::None>( + int16_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_12_s8s16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_u8u16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_u8u16.hpp new file mode 100644 index 0000000000..b0cd7e4ef7 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12_u8u16.hpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_12_u8u16(uint16_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 12 * height * sizeof(uint16_t); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q3, [x25], #0x10\n" + "ldr q21, [x23], #0x10\n" + "ushll2 v20.8h, v3.16b, #0x0\n" + "ushll v2.8h, v21.8b, #0x0\n" + "ldr q1, [x22], #0x10\n" + "ldr q19, [x20], #0x10\n" + "ushll2 v18.8h, v1.16b, #0x0\n" + "ushll v0.8h, v19.8b, #0x0\n" + "ldr d17, [x25], #0x8\n" + "ldr d16, [x22], #0x8\n" + "ushll v31.8h, v17.8b, #0x0\n" + "ushll v30.8h, v16.8b, #0x0\n" + "ldr d29, [x23], #0x8\n" + "ldr d28, [x20], #0x8\n" + "ushll2 v27.8h, v21.16b, #0x0\n" + "ushll2 v26.8h, v19.16b, #0x0\n" + "dup v25.2d, v20.d[0]\n" + "dup v24.2d, v2.d[1]\n" + "sub x24, x24, #0x18\n" + "cmp x24, #0x18\n" + "dup v23.2d, v18.d[0]\n" + "dup v22.2d, v0.d[1]\n" + "dup v21.2d, v20.d[1]\n" + "dup v20.2d, v31.d[1]\n" + "dup v19.2d, v18.d[1]\n" + "dup v18.2d, v30.d[1]\n" + "ushll v17.8h, v3.8b, #0x0\n" + "ushll v16.8h, v1.8b, #0x0\n" + "str q17, [x21, #0x0]\n" + "mov v25.d[1], v2.d[0]\n" + "mov v24.d[1], v27.d[0]\n" + "str q25, [x21, #0x10]\n" + "mov v23.d[1], v0.d[0]\n" + "mov v22.d[1], v26.d[0]\n" + "str q24, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "ushll v17.8h, v29.8b, #0x0\n" + "ushll v16.8h, v28.8b, #0x0\n" + "str q23, [x21, #0x40]\n" + "mov v21.d[1], v31.d[0]\n" + "mov v20.d[1], v27.d[1]\n" + "str q22, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "mov v19.d[1], v30.d[0]\n" + "mov v18.d[1], v26.d[1]\n" + "str q21, [x21, #0x0]\n" + "str q20, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q19, [x21, #0x30]\n" + "str q18, [x21, #0x40]\n" + "str q16, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0xc\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr d19, [x23], #0x8\n" + "ldr d18, [x20], #0x8\n" + "sub x24, x24, #0xc\n" + "cmp x24, #0xc\n" + "ld1 { v19.s }[2], [x23], #0x4\n" + "ld1 { v18.s }[2], [x20], #0x4\n" + "ushll v25.8h, v19.8b, #0x0\n" + "ushll v24.8h, v18.8b, #0x0\n" + "ldr d17, [x25], #0x8\n" + "ldr d16, [x22], #0x8\n" + "ushll2 v23.8h, v19.16b, #0x0\n" + "ushll2 v22.8h, v18.16b, #0x0\n" + "ld1 { v17.s }[2], [x25], #0x4\n" + "ld1 { v16.s }[2], [x22], #0x4\n" + "ushll2 v21.8h, v17.16b, #0x0\n" + "ushll2 v20.8h, v16.16b, #0x0\n" + "dup v19.2d, v25.d[1]\n" + "dup v18.2d, v24.d[1]\n" + "ushll v17.8h, v17.8b, #0x0\n" + "ushll v16.8h, v16.8b, #0x0\n" + "str q17, [x21, #0x0]\n" + "mov v21.d[1], v25.d[0]\n" + "mov v19.d[1], v23.d[0]\n" + "str q21, [x21, #0x10]\n" + "mov v20.d[1], v24.d[0]\n" + "mov v18.d[1], v22.d[0]\n" + "str q19, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "str q20, [x21, #0x40]\n" + "str q18, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr s19, [x25], #0x4\n" + "ldr s18, [x23], #0x4\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "ushll v19.8h, v19.8b, #0x0\n" + "ushll v18.8h, v18.8b, #0x0\n" + "ushll v17.8h, v17.8b, #0x0\n" + "ushll v16.8h, v16.8b, #0x0\n" + "str d19, [x21, #0x0]\n" + "str d18, [x21, #0x18]\n" + "str d17, [x21, #0x30]\n" + "str d16, [x21, #0x48]\n" + "add x21, x21, #0x8\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr b19, [x25], #0x1\n" + "ldr b18, [x23], #0x1\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr b17, [x22], #0x1\n" + "ldr b16, [x20], #0x1\n" + "ushll v19.8h, v19.8b, #0x0\n" + "ushll v18.8h, v18.8b, #0x0\n" + "ushll v17.8h, v17.8b, #0x0\n" + "ushll v16.8h, v16.8b, #0x0\n" + "str h19, [x21, #0x0]\n" + "str h18, [x21, #0x18]\n" + "str h17, [x21, #0x30]\n" + "str h16, [x21, #0x48]\n" + "add x21, x21, #0x2\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0x60\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "cmp x20, #0x18\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 13f\n" + "12:" // Tail row loop: Unroll column loop + "ldr q20, [x25], #0x10\n" + "ldr d16, [x25], #0x8\n" + "ushll2 v19.8h, v20.16b, #0x0\n" + "ushll v18.8h, v16.8b, #0x0\n" + "dup v17.2d, v19.d[1]\n" + "sub x20, x20, #0x18\n" + "ushll v16.8h, v20.8b, #0x0\n" + "str q16, [x21, #0x0]\n" + "dup v16.2d, v19.d[0]\n" + "str d16, [x21, #0x10]\n" + "add x21, x21, %x[out_stride]\n" + "cmp x20, #0x18\n" + "mov v17.d[1], v18.d[0]\n" + "dup v16.2d, v18.d[1]\n" + "str q17, [x21, #0x0]\n" + "str d16, [x21, #0x10]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Unroll column loop skip + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr d16, [x25], #0x8\n" + "ld1 { v16.s }[2], [x25], #0x4\n" + "sub x20, x20, #0xc\n" + "cmp x20, #0xc\n" + "ushll v17.8h, v16.8b, #0x0\n" + "ushll2 v16.8h, v16.16b, #0x0\n" + "str q17, [x21, #0x0]\n" + "str d16, [x21, #0x10]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr s16, [x25], #0x4\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "ushll v16.8h, v16.8b, #0x0\n" + "str d16, [x21, #0x0]\n" + "add x21, x21, #0x8\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr b16, [x25], #0x1\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "ushll v16.8h, v16.8b, #0x0\n" + "str h16, [x21, #0x0]\n" + "add x21, x21, #0x2\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x18\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace +template<> +void Transform<12, 1, true, VLType::None>( + uint16_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_12_u8u16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp deleted file mode 100644 index ec54ce00ea..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Copyright (c) 2017-2018 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ - -#include "transpose_interleave_common.hpp" - -// Generic unblocked transposed 6x32-bit sized specialisation -template <> -template <typename T> -inline void TransformImpl<6, 1, true, 4, 4, false>::Transform( - T* out, const T* const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax -) { - // Redirect to a 12 x uint16_t specialisation - TransformImpl<12, 1, true, 2, 2, false>::Transform( - reinterpret_cast<uint16_t *>(out), - reinterpret_cast<const uint16_t *>(in), - stride*2, x0*2, xmax*2, k0, kmax - ); -} - -// Generic 12x16-bit sized specialisation -template <> -template <typename T> -inline void TransformImpl<12, 1, true, 2, 2, false>::Transform( - T* out, const T* const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax -) { - // Redirect to a uint16_t specialisation - Transform( - reinterpret_cast<uint16_t *>(out), - reinterpret_cast<const uint16_t *>(in), - stride, x0, xmax, k0, kmax - ); -} - -// Specialised 12 x uint16_t version -template <> -inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) { - __asm volatile ( - "LDR q0, [%[in0]]\n" - "STR q0, [%[out]]\n" - "LDR d1, [%[in0], #0x10]\n" - "STR d1, [%[out], #0x10]\n" - "ADD %x[in0], %x[in0], #0x18\n" - ASM_PREFETCH("[%[in0], #192]") - : [in0] "+r" (in0), - [out] "+r" (out) - : - : "v0", "v1", "memory" - ); -} - -template <> -inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *out) { - __asm volatile ( - "LDR q0, [%[in0]]\n" - "LDR d1, [%[in0], #0x10]\n" - "ADD %x[in0], %x[in0], #0x18\n" - ASM_PREFETCH("[%[in0], #192]") - - "LDR x21, [%[in1]]\n" - "LDR q2, [%[in1], #0x08]\n" - "INS v1.d[1], x21\n" - "ADD %x[in1], %x[in1], #0x18\n" - "STP q0, q1, [%[out]]\n" - "STR q2, [%x[out], #0x20]\n" - ASM_PREFETCH("[%[in1], #192]") - : [in0] "+r" (in0), - [in1] "+r" (in1), - [out] "+r" (out) - : - : "x21", "v0", "v1", "v2", "memory" - ); -} - -template <> -inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) { - __asm __volatile ( - "LDR q0, [%x[in0]], #0x10\n" - "STR q0, [%x[out]]\n" - "LDR d1, [%x[in0]], #0x08\n" - ASM_PREFETCH("[%[in0], #192]") - "STR d1, [%x[out], #0x10]\n" - - "LDR q0, [%x[in1]], #0x10\n" - "STR q0, [%x[out], #0x18]\n" - "LDR d1, [%x[in1]], #0x08\n" - ASM_PREFETCH("[%[in1], #192]") - "STR d1, [%x[out], #0x28]\n" - - "LDR q0, [%x[in2]], #0x10\n" - "STR q0, [%x[out], #0x30]\n" - "LDR d1, [%x[in2]], #0x08\n" - ASM_PREFETCH("[%[in2], #192]") - "STR d1, [%x[out], #0x40]\n" - - "LDR q0, [%x[in3]], #0x10\n" - "STR q0, [%x[out], #0x48]\n" - "LDR d1, [%x[in3]], #0x08\n" - ASM_PREFETCH("[%[in3], #192]") - "STR d1, [%x[out], #0x58]\n" - : [in0] "+r" (in0), - [in1] "+r" (in1), - [in2] "+r" (in2), - [in3] "+r" (in3), - [out] "+r" (out) - : - : "v0", "v1", "memory" - ); -} - -template <> -template <> -inline void TransformImpl<12, 1, true, 2, 2, false>::Transform( - uint16_t* out, const uint16_t* const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax -) { - TransposeInterleaveCommon<12, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax); -} - -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp deleted file mode 100644 index 8992c1010d..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright (c) 2017-2019 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) - -#include "transpose_interleave_common.hpp" - -template <> -inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x1(const __fp16 *&in0, float *out) { - __asm __volatile ( - "LDR q0, [%[in0]], #16\n" - "FCVTL2 v1.4s, v0.8h\n" - "FCVTL v0.4s, v0.4h\n" - "STP q0, q1, [%[out]]\n" - ASM_PREFETCH("[%[in0], #192]") - "LDR d2, [%[in0]], #8\n" - "FCVTL v2.4s, v2.4h\n" - "STR q2, [%[out], #32]\n" - : [in0] "+r" (in0), [out] "+r" (out) - : - : "v0", "v1", "v2", "memory" - ); -} - -template <> -inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x2(const __fp16 *&in0, const __fp16 *&in1, float *out) { - __asm __volatile ( - "LDR q0, [%[in0]], #16\n" - "FCVTL2 v1.4s, v0.8h\n" - "FCVTL v0.4s, v0.4h\n" - "STP q0, q1, [%[out]]\n" - ASM_PREFETCH("[%[in0], #192]") - "LDR d2, [%[in0]], #8\n" - "FCVTL v2.4s, v2.4h\n" - "LDR q3, [%[in1]], #16\n" - "FCVTL2 v4.4s, v3.8h\n" - "FCVTL v3.4s, v3.4h\n" - "STP q2, q3, [%[out], #32]\n" - ASM_PREFETCH("[%[in1], #192]") - "LDR d5, [%[in1]], #8\n" - "FCVTL v5.4s, v5.4h\n" - "STP q4, q5, [%[out], #64]\n" - : [in0] "+r" (in0), [in1] "+r" (in1), [out] "+r" (out) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "memory" - ); -} - -template <> -inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x4(const __fp16 *&in0, const __fp16 *&in1, const __fp16 *&in2, const __fp16 *&in3, float *out) { - __asm __volatile ( - "LDR q0, [%[in0]], #16\n" - "FCVTL2 v1.4s, v0.8h\n" - "FCVTL v0.4s, v0.4h\n" - "STP q0, q1, [%[out]]\n" - "LDR d2, [%[in0]], #8\n" - ASM_PREFETCH("[%[in0], #192]") - "FCVTL v2.4s, v2.4h\n" - "LDR q3, [%[in1]], #16\n" - "FCVTL2 v4.4s, v3.8h\n" - "FCVTL v3.4s, v3.4h\n" - "STP q2, q3, [%[out], #32]\n" - "LDR d5, [%[in1]], #8\n" - "FCVTL v5.4s, v5.4h\n" - ASM_PREFETCH("[%[in1], #192]") - "STP q4, q5, [%[out], #64]\n" - "LDR q6, [%[in2]], #16\n" - "FCVTL2 v7.4s, v6.8h\n" - "FCVTL v6.4s, v6.4h\n" - "STP q6, q7, [%[out], #96]\n" - "LDR d8, [%[in2]], #8\n" - "FCVTL v8.4s, v8.4h\n" - ASM_PREFETCH("[%[in2], #192]") - "LDR q9, [%[in3]], #16\n" - "FCVTL2 v10.4s, v9.8h\n" - "FCVTL v9.4s, v9.4h\n" - "STP q8, q9, [%[out], #128]\n" - "LDR d11, [%[in3]], #8\n" - "FCVTL v11.4s, v11.4h\n" - "STP q10, q11, [%[out], #160]\n" - ASM_PREFETCH("[%[in3], #192]") - - : [in0] "+r" (in0), [in1] "+r" (in1), [in2] "+r" (in2), [in3] "+r" (in3), [out] "+r" (out) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory" - ); -} - -template <> -template <> -inline void TransformImpl<12, 1, true, 4, 2, false>::Transform( - float* out, const __fp16* const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax -) { - TransposeInterleaveCommon<12, __fp16, float>::Transform(out, in, stride, x0, xmax, k0, kmax); -} - -#endif // __aarch64__ && __ARM_FP16_ARGS diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16.hpp new file mode 100644 index 0000000000..0399f8becc --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16.hpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2021-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_16(uint32_t *out, const uint32_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 4 * height * sizeof(uint32_t); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "mov x23, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "add x22, x25, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "cmp x24, #0x4\n" + "add %x[in], x20, %x[in_stride]\n" + "blt 3f\n" + "2:" // Main row loop: Column loop + "ldr q19, [x25], #0x10\n" + "ldr q18, [x22], #0x10\n" + "sub x24, x24, #0x4\n" + "ldr q17, [x21], #0x10\n" + "ldr q16, [x20], #0x10\n" + "cmp x24, #0x4\n" + "str q19, [x23, #0x0]\n" + "str q18, [x23, #0x10]\n" + "str q17, [x23, #0x20]\n" + "str q16, [x23, #0x30]\n" + "add x23, x23, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Column loop skip + "cbz x24, 5f\n" + "movi v16.4s, #0x0\n" + "str q16, [x23, #0x0]\n" + "str q16, [x23, #0x10]\n" + "str q16, [x23, #0x20]\n" + "str q16, [x23, #0x30]\n" + "4:" // Main row loop: width 1 loop: loop + "ldr s19, [x25], #0x4\n" + "ldr s18, [x22], #0x4\n" + "sub x24, x24, #0x1\n" + "ldr s17, [x21], #0x4\n" + "ldr s16, [x20], #0x4\n" + "cmp x24, #0x1\n" + "str s19, [x23, #0x0]\n" + "str s18, [x23, #0x10]\n" + "str s17, [x23, #0x20]\n" + "str s16, [x23, #0x30]\n" + "add x23, x23, #0x4\n" + "bge 4b\n" + "5:" // Main row loop: odd col skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0x40\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "mov x23, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "cmp x20, #0x4\n" + "add %x[in], x25, %x[in_stride]\n" + "blt 9f\n" + "8:" // Tail row loop: Column loop + "ldr q16, [x25], #0x10\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "str q16, [x23, #0x0]\n" + "add x23, x23, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Column loop skip + "cbz x20, 11f\n" + "movi v16.4s, #0x0\n" + "str q16, [x23, #0x0]\n" + "10:" // Tail row loop: width 1 loop: loop + "ldr s16, [x25], #0x4\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "str s16, [x23, #0x0]\n" + "add x23, x23, #0x4\n" + "bge 10b\n" + "11:" // Tail row loop: odd col skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x10\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "v16", "v17", "v18", "v19", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace + +template<> +void Transform<4, 1, true, VLType::None>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_16( + reinterpret_cast<uint32_t *>(out), + reinterpret_cast<const uint32_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 4, + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_1x4.hpp new file mode 100644 index 0000000000..f3a1dde73f --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_1x4.hpp @@ -0,0 +1,331 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_16_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 16 * roundup<size_t>(height, 4) * sizeof(uint8_t); + + __asm__ __volatile__( + "cmp %x[height], #0x10\n" + "blt 8f\n" + "1:" // Main row loop: Head + "mov x17, %x[in]\n" + "add x16, x17, %x[in_stride]\n" + "add x15, x16, %x[in_stride]\n" + "add x14, x15, %x[in_stride]\n" + "add x13, x14, %x[in_stride]\n" + "add x12, x13, %x[in_stride]\n" + "add x11, x12, %x[in_stride]\n" + "add x10, x11, %x[in_stride]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x10\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x10\n" + "blt 3f\n" + "2:" // Main row loop: Column loop + "ldr q21, [x17], #0x10\n" + "ldr q20, [x16], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q17, [x15], #0x10\n" + "ldr q16, [x14], #0x10\n" + "zip1 v3.16b, v21.16b, v17.16b\n" + "zip1 v2.16b, v20.16b, v16.16b\n" + "ldr q19, [x13], #0x10\n" + "ldr q18, [x12], #0x10\n" + "zip2 v1.16b, v21.16b, v17.16b\n" + "zip2 v0.16b, v20.16b, v16.16b\n" + "ldr q17, [x11], #0x10\n" + "ldr q16, [x10], #0x10\n" + "zip1 v31.16b, v19.16b, v17.16b\n" + "zip1 v30.16b, v18.16b, v16.16b\n" + "ldr q25, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v24.16b, v19.16b, v17.16b\n" + "zip2 v23.16b, v18.16b, v16.16b\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v22.16b, v25.16b, v17.16b\n" + "zip1 v21.16b, v20.16b, v16.16b\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v29.16b, v25.16b, v17.16b\n" + "zip2 v20.16b, v20.16b, v16.16b\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v28.16b, v19.16b, v17.16b\n" + "zip1 v27.16b, v18.16b, v16.16b\n" + "zip2 v26.16b, v19.16b, v17.16b\n" + "zip2 v25.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v3.16b, v2.16b\n" + "zip2 v17.16b, v3.16b, v2.16b\n" + "str q16, [x21, #0x0]\n" + "zip1 v16.16b, v1.16b, v0.16b\n" + "zip2 v19.16b, v1.16b, v0.16b\n" + "str q17, [x21, #0x10]\n" + "zip1 v18.16b, v31.16b, v30.16b\n" + "zip2 v17.16b, v31.16b, v30.16b\n" + "str q16, [x21, #0x20]\n" + "zip1 v16.16b, v24.16b, v23.16b\n" + "zip2 v24.16b, v24.16b, v23.16b\n" + "str q19, [x21, #0x30]\n" + "zip1 v23.16b, v22.16b, v21.16b\n" + "zip2 v22.16b, v22.16b, v21.16b\n" + "str q18, [x21, #0x40]\n" + "zip1 v21.16b, v29.16b, v20.16b\n" + "zip2 v20.16b, v29.16b, v20.16b\n" + "str q17, [x21, #0x50]\n" + "zip1 v19.16b, v28.16b, v27.16b\n" + "zip2 v18.16b, v28.16b, v27.16b\n" + "str q16, [x21, #0x60]\n" + "zip1 v17.16b, v26.16b, v25.16b\n" + "zip2 v16.16b, v26.16b, v25.16b\n" + "str q24, [x21, #0x70]\n" + "str q23, [x21, #0x80]\n" + "str q22, [x21, #0x90]\n" + "str q21, [x21, #0xa0]\n" + "str q20, [x21, #0xb0]\n" + "str q19, [x21, #0xc0]\n" + "str q18, [x21, #0xd0]\n" + "str q17, [x21, #0xe0]\n" + "str q16, [x21, #0xf0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 5f\n" + "4:" // Main row loop: width 4 loop: loop + "ldr s19, [x17], #0x4\n" + "ldr s18, [x16], #0x4\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr s17, [x15], #0x4\n" + "ldr s16, [x14], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr s19, [x13], #0x4\n" + "ldr s18, [x12], #0x4\n" + "zip1 v22.16b, v17.16b, v16.16b\n" + "ldr s17, [x11], #0x4\n" + "ldr s16, [x10], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "zip1 v21.16b, v17.16b, v16.16b\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr s20, [x25], #0x4\n" + "ldr s19, [x23], #0x4\n" + "zip1 v18.16b, v17.16b, v16.16b\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v17.16b, v20.16b, v17.16b\n" + "zip1 v16.16b, v19.16b, v16.16b\n" + "str q22, [x21, #0x0]\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str q21, [x21, #0x40]\n" + "str q18, [x21, #0x80]\n" + "str q16, [x21, #0xc0]\n" + "add x21, x21, #0x10\n" + "bge 4b\n" + "5:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 7f\n" + "6:" // Main row loop: width 1 loop: loop + "ldr b19, [x17], #0x1\n" + "ldr b18, [x16], #0x1\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr b17, [x15], #0x1\n" + "ldr b16, [x14], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr b19, [x13], #0x1\n" + "ldr b18, [x12], #0x1\n" + "zip1 v22.16b, v17.16b, v16.16b\n" + "ldr b17, [x11], #0x1\n" + "ldr b16, [x10], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr b19, [x9], #0x1\n" + "ldr b18, [x28], #0x1\n" + "zip1 v21.16b, v17.16b, v16.16b\n" + "ldr b17, [x27], #0x1\n" + "ldr b16, [x26], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr b20, [x25], #0x1\n" + "ldr b19, [x23], #0x1\n" + "zip1 v18.16b, v17.16b, v16.16b\n" + "ldr b17, [x22], #0x1\n" + "ldr b16, [x20], #0x1\n" + "zip1 v17.16b, v20.16b, v17.16b\n" + "zip1 v16.16b, v19.16b, v16.16b\n" + "str s22, [x21, #0x0]\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str s21, [x21, #0x40]\n" + "str s18, [x21, #0x80]\n" + "str s16, [x21, #0xc0]\n" + "add x21, x21, #0x4\n" + "bge 6b\n" + "7:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x10\n" + "add %x[out], %x[out], #0x100\n" + "bge 1b\n" + "cbz %x[height], 16f\n" + "8:" // Main loop skip + "9:" // Tail row loop: Head + "mov x17, %x[in]\n" + "add x16, x17, %x[in_stride]\n" + "add x15, x16, %x[in_stride]\n" + "mov x20, %x[width]\n" + "add x14, x15, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x14, %x[in_stride]\n" + "csel x14, x14, %x[pad_row], GT\n" + "csel x15, x15, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x16, x16, %x[pad_row], GT\n" + "cmp x20, #0x10\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 11f\n" + "10:" // Tail row loop: Column loop + "ldr q20, [x17], #0x10\n" + "ldr q21, [x16], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "ldr q19, [x15], #0x10\n" + "ldr q16, [x14], #0x10\n" + "zip1 v18.16b, v20.16b, v19.16b\n" + "zip1 v17.16b, v21.16b, v16.16b\n" + "zip2 v20.16b, v20.16b, v19.16b\n" + "zip2 v19.16b, v21.16b, v16.16b\n" + "zip1 v16.16b, v18.16b, v17.16b\n" + "zip2 v18.16b, v18.16b, v17.16b\n" + "str q16, [x21, #0x0]\n" + "zip1 v17.16b, v20.16b, v19.16b\n" + "zip2 v16.16b, v20.16b, v19.16b\n" + "str q18, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, %x[out_stride]\n" + "bge 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: width 4 loop: loop + "ldr s19, [x17], #0x4\n" + "ldr s18, [x16], #0x4\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "ldr s17, [x15], #0x4\n" + "ldr s16, [x14], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, #0x10\n" + "bge 12b\n" + "13:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 15f\n" + "14:" // Tail row loop: width 1 loop: loop + "ldr b19, [x17], #0x1\n" + "ldr b18, [x16], #0x1\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "ldr b17, [x15], #0x1\n" + "ldr b16, [x14], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 14b\n" + "15:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x40\n" + "bge 9b\n" + "16:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<16, 4, true, VLType::None>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_16_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<16, 4, true, VLType::None>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_16_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_1x8.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_1x8.hpp new file mode 100644 index 0000000000..7c7e91e666 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_1x8.hpp @@ -0,0 +1,291 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_16_1x8(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 8) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 16 * roundup<size_t>(height, 8) * sizeof(uint8_t); + + __asm__ __volatile__( + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "cmp %x[height], #0x7\n" + "add %x[in], x22, %x[in_stride]\n" + "csel x22, x22, %x[pad_row], GT\n" + "csel x23, x23, %x[pad_row], GE\n" + "cmp %x[height], #0x5\n" + "mov x21, %x[width]\n" + "csel x24, x24, %x[pad_row], GT\n" + "csel x25, x25, %x[pad_row], GE\n" + "cmp %x[height], #0x3\n" + "csel x26, x26, %x[pad_row], GT\n" + "csel x27, x27, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x28, x28, %x[pad_row], GT\n" + "cmp x21, #0x20\n" + "mov x20, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q23, [x9], #0x10\n" + "ldr q22, [x28], #0x10\n" + "sub x21, x21, #0x20\n" + "cmp x21, #0x20\n" + "ldr q20, [x27], #0x10\n" + "ldr q21, [x26], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x24], #0x10\n" + "zip1 v5.16b, v23.16b, v19.16b\n" + "zip1 v4.16b, v22.16b, v18.16b\n" + "ldr q17, [x23], #0x10\n" + "ldr q16, [x22], #0x10\n" + "zip1 v3.16b, v20.16b, v17.16b\n" + "zip1 v31.16b, v21.16b, v16.16b\n" + "ldr q25, [x9], #0x10\n" + "ldr q24, [x28], #0x10\n" + "zip2 v2.16b, v23.16b, v19.16b\n" + "zip2 v30.16b, v20.16b, v17.16b\n" + "ldr q23, [x27], #0x10\n" + "ldr q20, [x26], #0x10\n" + "zip2 v22.16b, v22.16b, v18.16b\n" + "zip2 v21.16b, v21.16b, v16.16b\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x24], #0x10\n" + "zip1 v29.16b, v25.16b, v19.16b\n" + "zip1 v28.16b, v24.16b, v18.16b\n" + "ldr q17, [x23], #0x10\n" + "ldr q16, [x22], #0x10\n" + "zip1 v27.16b, v23.16b, v17.16b\n" + "zip1 v26.16b, v20.16b, v16.16b\n" + "zip2 v1.16b, v25.16b, v19.16b\n" + "zip2 v25.16b, v23.16b, v17.16b\n" + "zip2 v24.16b, v24.16b, v18.16b\n" + "zip2 v16.16b, v20.16b, v16.16b\n" + "zip1 v0.16b, v5.16b, v3.16b\n" + "zip1 v17.16b, v4.16b, v31.16b\n" + "zip2 v20.16b, v5.16b, v3.16b\n" + "zip2 v19.16b, v4.16b, v31.16b\n" + "zip1 v31.16b, v2.16b, v30.16b\n" + "zip1 v18.16b, v22.16b, v21.16b\n" + "zip2 v30.16b, v2.16b, v30.16b\n" + "zip2 v23.16b, v22.16b, v21.16b\n" + "zip1 v22.16b, v29.16b, v27.16b\n" + "zip1 v21.16b, v28.16b, v26.16b\n" + "zip2 v29.16b, v29.16b, v27.16b\n" + "zip2 v28.16b, v28.16b, v26.16b\n" + "zip1 v27.16b, v1.16b, v25.16b\n" + "zip1 v26.16b, v24.16b, v16.16b\n" + "zip2 v25.16b, v1.16b, v25.16b\n" + "zip2 v24.16b, v24.16b, v16.16b\n" + "zip1 v16.16b, v0.16b, v17.16b\n" + "zip2 v17.16b, v0.16b, v17.16b\n" + "str q16, [x20, #0x0]\n" + "zip1 v16.16b, v20.16b, v19.16b\n" + "zip2 v20.16b, v20.16b, v19.16b\n" + "str q17, [x20, #0x10]\n" + "zip1 v19.16b, v31.16b, v18.16b\n" + "zip2 v18.16b, v31.16b, v18.16b\n" + "str q16, [x20, #0x20]\n" + "zip1 v17.16b, v30.16b, v23.16b\n" + "zip2 v16.16b, v30.16b, v23.16b\n" + "str q20, [x20, #0x30]\n" + "str q19, [x20, #0x40]\n" + "zip1 v23.16b, v22.16b, v21.16b\n" + "zip2 v22.16b, v22.16b, v21.16b\n" + "str q18, [x20, #0x50]\n" + "zip1 v21.16b, v29.16b, v28.16b\n" + "zip2 v20.16b, v29.16b, v28.16b\n" + "str q17, [x20, #0x60]\n" + "zip1 v19.16b, v27.16b, v26.16b\n" + "zip2 v18.16b, v27.16b, v26.16b\n" + "str q16, [x20, #0x70]\n" + "add x20, x20, %x[out_stride]\n" + "zip1 v17.16b, v25.16b, v24.16b\n" + "zip2 v16.16b, v25.16b, v24.16b\n" + "str q23, [x20, #0x0]\n" + "str q22, [x20, #0x10]\n" + "str q21, [x20, #0x20]\n" + "str q20, [x20, #0x30]\n" + "str q19, [x20, #0x40]\n" + "str q18, [x20, #0x50]\n" + "str q17, [x20, #0x60]\n" + "str q16, [x20, #0x70]\n" + "add x20, x20, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x21, #0x10\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr q25, [x9], #0x10\n" + "ldr q27, [x28], #0x10\n" + "sub x21, x21, #0x10\n" + "cmp x21, #0x10\n" + "ldr q26, [x27], #0x10\n" + "ldr q24, [x26], #0x10\n" + "ldr q22, [x25], #0x10\n" + "ldr q21, [x24], #0x10\n" + "zip1 v20.16b, v25.16b, v22.16b\n" + "zip1 v23.16b, v27.16b, v21.16b\n" + "ldr q17, [x23], #0x10\n" + "ldr q16, [x22], #0x10\n" + "zip1 v19.16b, v26.16b, v17.16b\n" + "zip1 v18.16b, v24.16b, v16.16b\n" + "zip2 v25.16b, v25.16b, v22.16b\n" + "zip2 v22.16b, v26.16b, v17.16b\n" + "zip2 v21.16b, v27.16b, v21.16b\n" + "zip2 v16.16b, v24.16b, v16.16b\n" + "zip1 v24.16b, v20.16b, v19.16b\n" + "zip1 v17.16b, v23.16b, v18.16b\n" + "zip2 v20.16b, v20.16b, v19.16b\n" + "zip2 v19.16b, v23.16b, v18.16b\n" + "zip1 v23.16b, v25.16b, v22.16b\n" + "zip1 v18.16b, v21.16b, v16.16b\n" + "zip2 v22.16b, v25.16b, v22.16b\n" + "zip2 v21.16b, v21.16b, v16.16b\n" + "zip1 v16.16b, v24.16b, v17.16b\n" + "zip2 v17.16b, v24.16b, v17.16b\n" + "str q16, [x20, #0x0]\n" + "zip1 v16.16b, v20.16b, v19.16b\n" + "zip2 v20.16b, v20.16b, v19.16b\n" + "str q17, [x20, #0x10]\n" + "zip1 v19.16b, v23.16b, v18.16b\n" + "zip2 v18.16b, v23.16b, v18.16b\n" + "str q16, [x20, #0x20]\n" + "zip1 v17.16b, v22.16b, v21.16b\n" + "zip2 v16.16b, v22.16b, v21.16b\n" + "str q20, [x20, #0x30]\n" + "str q19, [x20, #0x40]\n" + "str q18, [x20, #0x50]\n" + "str q17, [x20, #0x60]\n" + "str q16, [x20, #0x70]\n" + "add x20, x20, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x21, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr s18, [x9], #0x4\n" + "ldr s19, [x28], #0x4\n" + "sub x21, x21, #0x4\n" + "cmp x21, #0x4\n" + "ldr s21, [x27], #0x4\n" + "ldr s20, [x26], #0x4\n" + "ldr s17, [x25], #0x4\n" + "ldr s16, [x24], #0x4\n" + "zip1 v18.16b, v18.16b, v17.16b\n" + "zip1 v19.16b, v19.16b, v16.16b\n" + "ldr s17, [x23], #0x4\n" + "ldr s16, [x22], #0x4\n" + "zip1 v17.16b, v21.16b, v17.16b\n" + "zip1 v16.16b, v20.16b, v16.16b\n" + "zip1 v18.16b, v18.16b, v17.16b\n" + "zip1 v16.16b, v19.16b, v16.16b\n" + "zip1 v17.16b, v18.16b, v16.16b\n" + "zip2 v16.16b, v18.16b, v16.16b\n" + "str q17, [x20, #0x0]\n" + "str q16, [x20, #0x10]\n" + "add x20, x20, #0x20\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x21, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr b19, [x9], #0x1\n" + "ldr b18, [x28], #0x1\n" + "sub x21, x21, #0x1\n" + "cmp x21, #0x1\n" + "ldr b21, [x27], #0x1\n" + "ldr b20, [x26], #0x1\n" + "ldr b17, [x25], #0x1\n" + "ldr b16, [x24], #0x1\n" + "zip1 v19.16b, v19.16b, v17.16b\n" + "zip1 v18.16b, v18.16b, v16.16b\n" + "ldr b17, [x23], #0x1\n" + "ldr b16, [x22], #0x1\n" + "zip1 v17.16b, v21.16b, v17.16b\n" + "zip1 v16.16b, v20.16b, v16.16b\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str d16, [x20, #0x0]\n" + "add x20, x20, #0x8\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x80\n" + "bge 1b\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<16, 8, true, VLType::None>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_16_1x8( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<16, 8, true, VLType::None>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_16_1x8( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_2x2.hpp new file mode 100644 index 0000000000..b4515cbfd4 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_2x2.hpp @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_16_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 16 * roundup<size_t>(height, 2) * sizeof(uint16_t); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "blt 8f\n" + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x10\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Column loop + "ldr q17, [x9], #0x10\n" + "ldr q16, [x28], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q19, [x27], #0x10\n" + "ldr q18, [x26], #0x10\n" + "zip1 v1.8h, v17.8h, v16.8h\n" + "zip2 v0.8h, v17.8h, v16.8h\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v31.8h, v19.8h, v18.8h\n" + "zip2 v30.8h, v19.8h, v18.8h\n" + "ldr q29, [x22], #0x10\n" + "ldr q18, [x20], #0x10\n" + "zip1 v28.8h, v17.8h, v16.8h\n" + "zip2 v27.8h, v17.8h, v16.8h\n" + "ldr q17, [x9], #0x10\n" + "ldr q16, [x28], #0x10\n" + "zip1 v26.8h, v17.8h, v16.8h\n" + "zip2 v25.8h, v17.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v24.8h, v17.8h, v16.8h\n" + "zip2 v23.8h, v17.8h, v16.8h\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v22.8h, v17.8h, v16.8h\n" + "zip2 v21.8h, v17.8h, v16.8h\n" + "ldr q20, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v19.8h, v29.8h, v18.8h\n" + "zip2 v18.8h, v29.8h, v18.8h\n" + "zip1 v17.8h, v20.8h, v16.8h\n" + "zip2 v16.8h, v20.8h, v16.8h\n" + "str q1, [x21, #0x0]\n" + "str q0, [x21, #0x10]\n" + "str q26, [x21, #0x20]\n" + "str q25, [x21, #0x30]\n" + "str q31, [x21, #0x40]\n" + "str q30, [x21, #0x50]\n" + "str q24, [x21, #0x60]\n" + "str q23, [x21, #0x70]\n" + "str q28, [x21, #0x80]\n" + "str q27, [x21, #0x90]\n" + "str q22, [x21, #0xa0]\n" + "str q21, [x21, #0xb0]\n" + "str q19, [x21, #0xc0]\n" + "str q18, [x21, #0xd0]\n" + "str q17, [x21, #0xe0]\n" + "str q16, [x21, #0xf0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 5f\n" + "4:" // Main row loop: width 4 loop: loop + "ldr d19, [x9], #0x8\n" + "ldr d18, [x28], #0x8\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr d17, [x27], #0x8\n" + "ldr d16, [x26], #0x8\n" + "zip1 v20.8h, v19.8h, v18.8h\n" + "zip1 v19.8h, v17.8h, v16.8h\n" + "ldr d17, [x25], #0x8\n" + "ldr d16, [x23], #0x8\n" + "zip1 v18.8h, v17.8h, v16.8h\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "str q20, [x21, #0x0]\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str q19, [x21, #0x40]\n" + "str q18, [x21, #0x80]\n" + "str q16, [x21, #0xc0]\n" + "add x21, x21, #0x10\n" + "bge 4b\n" + "5:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 7f\n" + "6:" // Main row loop: width 1 loop: loop + "ldr h19, [x9], #0x2\n" + "ldr h18, [x28], #0x2\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr h17, [x27], #0x2\n" + "ldr h16, [x26], #0x2\n" + "zip1 v20.8h, v19.8h, v18.8h\n" + "zip1 v19.8h, v17.8h, v16.8h\n" + "ldr h17, [x25], #0x2\n" + "ldr h16, [x23], #0x2\n" + "zip1 v18.8h, v17.8h, v16.8h\n" + "ldr h17, [x22], #0x2\n" + "ldr h16, [x20], #0x2\n" + "str s20, [x21, #0x0]\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str s19, [x21, #0x40]\n" + "str s18, [x21, #0x80]\n" + "str s16, [x21, #0xc0]\n" + "add x21, x21, #0x4\n" + "bge 6b\n" + "7:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x100\n" + "bge 1b\n" + "cbz %x[height], 16f\n" + "8:" // Main loop skip + "9:" // Tail row loop: Head + "mov x9, %x[in]\n" + "mov x20, %x[width]\n" + "add x28, x9, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x28, %x[in_stride]\n" + "csel x28, x28, %x[pad_row], GT\n" + "cmp x20, #0x10\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 11f\n" + "10:" // Tail row loop: Column loop + "ldr q18, [x9], #0x10\n" + "ldr q17, [x28], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "ldr q20, [x9], #0x10\n" + "ldr q16, [x28], #0x10\n" + "zip1 v19.8h, v18.8h, v17.8h\n" + "zip2 v18.8h, v18.8h, v17.8h\n" + "zip1 v17.8h, v20.8h, v16.8h\n" + "zip2 v16.8h, v20.8h, v16.8h\n" + "str q19, [x21, #0x0]\n" + "str q18, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, %x[out_stride]\n" + "bge 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: width 4 loop: loop + "ldr d17, [x9], #0x8\n" + "ldr d16, [x28], #0x8\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, #0x10\n" + "bge 12b\n" + "13:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 15f\n" + "14:" // Tail row loop: width 1 loop: loop + "ldr h17, [x9], #0x2\n" + "ldr h16, [x28], #0x2\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 14b\n" + "15:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x40\n" + "bge 9b\n" + "16:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<16, 2, true, VLType::None>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_16_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_2x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_2x4.hpp new file mode 100644 index 0000000000..ac67467240 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_2x4.hpp @@ -0,0 +1,510 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_16_2x4(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 16 * roundup<size_t>(height, 4) * sizeof(uint16_t); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x20\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q23, [x9], #0x10\n" + "ldr q29, [x28], #0x10\n" + "sub x24, x24, #0x20\n" + "cmp x24, #0x20\n" + "ldr q13, [x27], #0x10\n" + "ldr q12, [x26], #0x10\n" + "zip1 v20.8h, v23.8h, v13.8h\n" + "zip1 v28.8h, v29.8h, v12.8h\n" + "ldr q18, [x25], #0x10\n" + "ldr q9, [x23], #0x10\n" + "zip2 v22.8h, v23.8h, v13.8h\n" + "zip2 v1.8h, v29.8h, v12.8h\n" + "ldr q27, [x22], #0x10\n" + "ldr q3, [x20], #0x10\n" + "zip1 v4.8h, v18.8h, v27.8h\n" + "zip1 v26.8h, v9.8h, v3.8h\n" + "ldr q17, [x9], #0x10\n" + "ldr q2, [x28], #0x10\n" + "zip2 v15.8h, v18.8h, v27.8h\n" + "zip2 v12.8h, v9.8h, v3.8h\n" + "ldr q23, [x27], #0x10\n" + "ldr q14, [x26], #0x10\n" + "zip1 v19.8h, v17.8h, v23.8h\n" + "zip1 v21.8h, v2.8h, v14.8h\n" + "ldr q6, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v27.8h, v17.8h, v23.8h\n" + "zip2 v17.8h, v2.8h, v14.8h\n" + "ldr q0, [x22], #0x10\n" + "ldr q3, [x20], #0x10\n" + "zip1 v16.8h, v6.8h, v0.8h\n" + "zip1 v30.8h, v18.8h, v3.8h\n" + "ldr q2, [x9], #0x10\n" + "ldr q13, [x28], #0x10\n" + "zip2 v31.8h, v6.8h, v0.8h\n" + "zip2 v8.8h, v18.8h, v3.8h\n" + "ldr q14, [x27], #0x10\n" + "ldr q3, [x26], #0x10\n" + "zip1 v11.8h, v2.8h, v14.8h\n" + "zip1 v29.8h, v13.8h, v3.8h\n" + "ldr q25, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v23.8h, v2.8h, v14.8h\n" + "zip2 v10.8h, v13.8h, v3.8h\n" + "ldr q7, [x22], #0x10\n" + "ldr q6, [x20], #0x10\n" + "zip1 v14.8h, v25.8h, v7.8h\n" + "zip1 v13.8h, v18.8h, v6.8h\n" + "ldr q2, [x9], #0x10\n" + "ldr q5, [x28], #0x10\n" + "zip2 v9.8h, v25.8h, v7.8h\n" + "zip2 v7.8h, v18.8h, v6.8h\n" + "ldr q6, [x27], #0x10\n" + "ldr q24, [x26], #0x10\n" + "zip1 v25.8h, v2.8h, v6.8h\n" + "zip1 v3.8h, v5.8h, v24.8h\n" + "ldr q0, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v2.8h, v2.8h, v6.8h\n" + "zip2 v24.8h, v5.8h, v24.8h\n" + "ldr q5, [x22], #0x10\n" + "zip1 v6.8h, v0.8h, v5.8h\n" + "zip2 v5.8h, v0.8h, v5.8h\n" + "zip1 v0.8h, v20.8h, v28.8h\n" + "zip2 v28.8h, v20.8h, v28.8h\n" + "ldr q20, [x20], #0x10\n" + "str q0, [x21, #0x0]\n" + "zip1 v0.8h, v18.8h, v20.8h\n" + "zip2 v20.8h, v18.8h, v20.8h\n" + "str q28, [x21, #0x10]\n" + "zip1 v18.8h, v22.8h, v1.8h\n" + "zip2 v28.8h, v22.8h, v1.8h\n" + "str q18, [x21, #0x20]\n" + "zip1 v22.8h, v19.8h, v21.8h\n" + "zip2 v19.8h, v19.8h, v21.8h\n" + "str q28, [x21, #0x30]\n" + "zip1 v18.8h, v27.8h, v17.8h\n" + "zip2 v17.8h, v27.8h, v17.8h\n" + "str q22, [x21, #0x40]\n" + "zip1 v27.8h, v4.8h, v26.8h\n" + "zip2 v26.8h, v4.8h, v26.8h\n" + "str q19, [x21, #0x50]\n" + "zip1 v22.8h, v15.8h, v12.8h\n" + "zip2 v21.8h, v15.8h, v12.8h\n" + "str q18, [x21, #0x60]\n" + "zip1 v19.8h, v16.8h, v30.8h\n" + "zip2 v18.8h, v16.8h, v30.8h\n" + "str q17, [x21, #0x70]\n" + "zip1 v17.8h, v31.8h, v8.8h\n" + "zip2 v16.8h, v31.8h, v8.8h\n" + "str q27, [x21, #0x80]\n" + "str q26, [x21, #0x90]\n" + "zip1 v31.8h, v11.8h, v29.8h\n" + "zip2 v30.8h, v11.8h, v29.8h\n" + "str q22, [x21, #0xa0]\n" + "zip1 v29.8h, v23.8h, v10.8h\n" + "zip2 v28.8h, v23.8h, v10.8h\n" + "str q21, [x21, #0xb0]\n" + "zip1 v27.8h, v25.8h, v3.8h\n" + "zip2 v26.8h, v25.8h, v3.8h\n" + "str q19, [x21, #0xc0]\n" + "zip1 v25.8h, v2.8h, v24.8h\n" + "zip2 v24.8h, v2.8h, v24.8h\n" + "str q18, [x21, #0xd0]\n" + "zip1 v23.8h, v14.8h, v13.8h\n" + "zip2 v22.8h, v14.8h, v13.8h\n" + "str q17, [x21, #0xe0]\n" + "zip1 v21.8h, v9.8h, v7.8h\n" + "zip2 v19.8h, v9.8h, v7.8h\n" + "str q16, [x21, #0xf0]\n" + "add x21, x21, %x[out_stride]\n" + "zip1 v2.8h, v6.8h, v0.8h\n" + "zip2 v18.8h, v6.8h, v0.8h\n" + "zip1 v17.8h, v5.8h, v20.8h\n" + "zip2 v16.8h, v5.8h, v20.8h\n" + "str q31, [x21, #0x0]\n" + "str q30, [x21, #0x10]\n" + "str q29, [x21, #0x20]\n" + "str q28, [x21, #0x30]\n" + "str q27, [x21, #0x40]\n" + "str q26, [x21, #0x50]\n" + "str q25, [x21, #0x60]\n" + "str q24, [x21, #0x70]\n" + "str q23, [x21, #0x80]\n" + "str q22, [x21, #0x90]\n" + "str q21, [x21, #0xa0]\n" + "str q19, [x21, #0xb0]\n" + "str q2, [x21, #0xc0]\n" + "str q18, [x21, #0xd0]\n" + "str q17, [x21, #0xe0]\n" + "str q16, [x21, #0xf0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0x10\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v3.8h, v21.8h, v17.8h\n" + "zip1 v2.8h, v20.8h, v16.8h\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v1.8h, v21.8h, v17.8h\n" + "zip2 v24.8h, v20.8h, v16.8h\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v0.8h, v19.8h, v17.8h\n" + "zip1 v31.8h, v18.8h, v16.8h\n" + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v30.8h, v19.8h, v17.8h\n" + "zip2 v29.8h, v18.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v23.8h, v21.8h, v17.8h\n" + "zip1 v22.8h, v20.8h, v16.8h\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v21.8h, v21.8h, v17.8h\n" + "zip2 v20.8h, v20.8h, v16.8h\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v28.8h, v19.8h, v17.8h\n" + "zip1 v27.8h, v18.8h, v16.8h\n" + "zip2 v26.8h, v19.8h, v17.8h\n" + "zip2 v25.8h, v18.8h, v16.8h\n" + "zip1 v16.8h, v3.8h, v2.8h\n" + "zip2 v17.8h, v3.8h, v2.8h\n" + "str q16, [x21, #0x0]\n" + "zip1 v16.8h, v1.8h, v24.8h\n" + "zip2 v19.8h, v1.8h, v24.8h\n" + "str q17, [x21, #0x10]\n" + "zip1 v18.8h, v23.8h, v22.8h\n" + "zip2 v17.8h, v23.8h, v22.8h\n" + "str q16, [x21, #0x20]\n" + "zip1 v16.8h, v21.8h, v20.8h\n" + "zip2 v24.8h, v21.8h, v20.8h\n" + "str q19, [x21, #0x30]\n" + "zip1 v23.8h, v0.8h, v31.8h\n" + "zip2 v22.8h, v0.8h, v31.8h\n" + "str q18, [x21, #0x40]\n" + "zip1 v21.8h, v30.8h, v29.8h\n" + "zip2 v20.8h, v30.8h, v29.8h\n" + "str q17, [x21, #0x50]\n" + "zip1 v19.8h, v28.8h, v27.8h\n" + "zip2 v18.8h, v28.8h, v27.8h\n" + "str q16, [x21, #0x60]\n" + "zip1 v17.8h, v26.8h, v25.8h\n" + "zip2 v16.8h, v26.8h, v25.8h\n" + "str q24, [x21, #0x70]\n" + "str q23, [x21, #0x80]\n" + "str q22, [x21, #0x90]\n" + "str q21, [x21, #0xa0]\n" + "str q20, [x21, #0xb0]\n" + "str q19, [x21, #0xc0]\n" + "str q18, [x21, #0xd0]\n" + "str q17, [x21, #0xe0]\n" + "str q16, [x21, #0xf0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr d19, [x9], #0x8\n" + "ldr d18, [x28], #0x8\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr d17, [x27], #0x8\n" + "ldr d16, [x26], #0x8\n" + "zip1 v17.8h, v19.8h, v17.8h\n" + "zip1 v16.8h, v18.8h, v16.8h\n" + "ldr d18, [x25], #0x8\n" + "ldr d21, [x23], #0x8\n" + "zip1 v20.8h, v17.8h, v16.8h\n" + "zip2 v19.8h, v17.8h, v16.8h\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "zip1 v18.8h, v18.8h, v17.8h\n" + "zip1 v16.8h, v21.8h, v16.8h\n" + "str q20, [x21, #0x0]\n" + "zip1 v17.8h, v18.8h, v16.8h\n" + "zip2 v16.8h, v18.8h, v16.8h\n" + "str q19, [x21, #0x10]\n" + "str q17, [x21, #0x80]\n" + "str q16, [x21, #0x90]\n" + "add x21, x21, #0x20\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr h19, [x9], #0x2\n" + "ldr h18, [x28], #0x2\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr h17, [x27], #0x2\n" + "ldr h16, [x26], #0x2\n" + "zip1 v17.8h, v19.8h, v17.8h\n" + "zip1 v16.8h, v18.8h, v16.8h\n" + "ldr h20, [x25], #0x2\n" + "ldr h19, [x23], #0x2\n" + "zip1 v18.8h, v17.8h, v16.8h\n" + "ldr h17, [x22], #0x2\n" + "ldr h16, [x20], #0x2\n" + "zip1 v17.8h, v20.8h, v17.8h\n" + "zip1 v16.8h, v19.8h, v16.8h\n" + "str d18, [x21, #0x0]\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str d16, [x21, #0x80]\n" + "add x21, x21, #0x8\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x100\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "mov x20, %x[width]\n" + "add x26, x27, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x26, %x[in_stride]\n" + "csel x26, x26, %x[pad_row], GT\n" + "csel x27, x27, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x28, x28, %x[pad_row], GT\n" + "cmp x20, #0x20\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: Unroll column loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "sub x20, x20, #0x20\n" + "cmp x20, #0x20\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v4.8h, v21.8h, v17.8h\n" + "zip1 v3.8h, v20.8h, v16.8h\n" + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v2.8h, v21.8h, v17.8h\n" + "zip2 v1.8h, v20.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v0.8h, v19.8h, v17.8h\n" + "zip1 v31.8h, v18.8h, v16.8h\n" + "ldr q24, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v30.8h, v19.8h, v17.8h\n" + "zip2 v23.8h, v18.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v22.8h, v24.8h, v17.8h\n" + "zip1 v21.8h, v20.8h, v16.8h\n" + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v29.8h, v24.8h, v17.8h\n" + "zip2 v28.8h, v20.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v27.8h, v19.8h, v17.8h\n" + "zip1 v26.8h, v18.8h, v16.8h\n" + "zip2 v25.8h, v19.8h, v17.8h\n" + "zip2 v24.8h, v18.8h, v16.8h\n" + "zip1 v16.8h, v4.8h, v3.8h\n" + "zip2 v17.8h, v4.8h, v3.8h\n" + "str q16, [x21, #0x0]\n" + "zip1 v16.8h, v2.8h, v1.8h\n" + "zip2 v20.8h, v2.8h, v1.8h\n" + "str q17, [x21, #0x10]\n" + "zip1 v19.8h, v0.8h, v31.8h\n" + "zip2 v18.8h, v0.8h, v31.8h\n" + "str q16, [x21, #0x20]\n" + "zip1 v17.8h, v30.8h, v23.8h\n" + "zip2 v16.8h, v30.8h, v23.8h\n" + "str q20, [x21, #0x30]\n" + "str q19, [x21, #0x40]\n" + "zip1 v23.8h, v22.8h, v21.8h\n" + "zip2 v22.8h, v22.8h, v21.8h\n" + "str q18, [x21, #0x50]\n" + "zip1 v21.8h, v29.8h, v28.8h\n" + "zip2 v20.8h, v29.8h, v28.8h\n" + "str q17, [x21, #0x60]\n" + "zip1 v19.8h, v27.8h, v26.8h\n" + "zip2 v18.8h, v27.8h, v26.8h\n" + "str q16, [x21, #0x70]\n" + "add x21, x21, %x[out_stride]\n" + "zip1 v17.8h, v25.8h, v24.8h\n" + "zip2 v16.8h, v25.8h, v24.8h\n" + "str q23, [x21, #0x0]\n" + "str q22, [x21, #0x10]\n" + "str q21, [x21, #0x20]\n" + "str q20, [x21, #0x30]\n" + "str q19, [x21, #0x40]\n" + "str q18, [x21, #0x50]\n" + "str q17, [x21, #0x60]\n" + "str q16, [x21, #0x70]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Unroll column loop skip + "cmp x20, #0x10\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v25.8h, v19.8h, v17.8h\n" + "zip1 v24.8h, v18.8h, v16.8h\n" + "ldr q22, [x9], #0x10\n" + "ldr q21, [x28], #0x10\n" + "zip2 v20.8h, v19.8h, v17.8h\n" + "zip2 v19.8h, v18.8h, v16.8h\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v23.8h, v22.8h, v17.8h\n" + "zip1 v18.8h, v21.8h, v16.8h\n" + "zip2 v22.8h, v22.8h, v17.8h\n" + "zip2 v21.8h, v21.8h, v16.8h\n" + "zip1 v16.8h, v25.8h, v24.8h\n" + "zip2 v17.8h, v25.8h, v24.8h\n" + "str q16, [x21, #0x0]\n" + "zip1 v16.8h, v20.8h, v19.8h\n" + "zip2 v20.8h, v20.8h, v19.8h\n" + "str q17, [x21, #0x10]\n" + "zip1 v19.8h, v23.8h, v18.8h\n" + "zip2 v18.8h, v23.8h, v18.8h\n" + "str q16, [x21, #0x20]\n" + "zip1 v17.8h, v22.8h, v21.8h\n" + "zip2 v16.8h, v22.8h, v21.8h\n" + "str q20, [x21, #0x30]\n" + "str q19, [x21, #0x40]\n" + "str q18, [x21, #0x50]\n" + "str q17, [x21, #0x60]\n" + "str q16, [x21, #0x70]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr d18, [x9], #0x8\n" + "ldr d19, [x28], #0x8\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "ldr d17, [x27], #0x8\n" + "ldr d16, [x26], #0x8\n" + "zip1 v18.8h, v18.8h, v17.8h\n" + "zip1 v16.8h, v19.8h, v16.8h\n" + "zip1 v17.8h, v18.8h, v16.8h\n" + "zip2 v16.8h, v18.8h, v16.8h\n" + "str q17, [x21, #0x0]\n" + "str q16, [x21, #0x10]\n" + "add x21, x21, #0x20\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr h19, [x9], #0x2\n" + "ldr h18, [x28], #0x2\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "ldr h17, [x27], #0x2\n" + "ldr h16, [x26], #0x2\n" + "zip1 v17.8h, v19.8h, v17.8h\n" + "zip1 v16.8h, v18.8h, v16.8h\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str d16, [x21, #0x0]\n" + "add x21, x21, #0x8\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x80\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<16, 4, true, VLType::None>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_16_2x4( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_2x4_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_2x4_fp32bf16.hpp new file mode 100644 index 0000000000..b9fe8b126a --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_16_2x4_fp32bf16.hpp @@ -0,0 +1,446 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_16_2x4_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 16 * roundup<size_t>(height, 4) * sizeof(bfloat16); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "blt 8f\n" + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x10\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Column loop + "ldr q13, [x9], #0x10\n" + "ldr q12, [x28], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q1, [x27], #0x10\n" + "ldr q9, [x26], #0x10\n" + "zip1 v19.4s, v13.4s, v1.4s\n" + "zip1 v14.4s, v12.4s, v9.4s\n" + "ldr q15, [x25], #0x10\n" + "ldr q4, [x23], #0x10\n" + "zip2 v8.4s, v13.4s, v1.4s\n" + "zip2 v28.4s, v12.4s, v9.4s\n" + "ldr q0, [x22], #0x10\n" + "ldr q1, [x20], #0x10\n" + "zip1 v16.4s, v15.4s, v0.4s\n" + "zip1 v5.4s, v4.4s, v1.4s\n" + "ldr q25, [x9], #0x10\n" + "ldr q24, [x28], #0x10\n" + "zip2 v3.4s, v15.4s, v0.4s\n" + "zip2 v2.4s, v4.4s, v1.4s\n" + "ldr q21, [x27], #0x10\n" + "ldr q30, [x26], #0x10\n" + "zip1 v18.4s, v25.4s, v21.4s\n" + "zip1 v27.4s, v24.4s, v30.4s\n" + "ldr q22, [x25], #0x10\n" + "ldr q20, [x23], #0x10\n" + "zip2 v9.4s, v25.4s, v21.4s\n" + "zip2 v10.4s, v24.4s, v30.4s\n" + "ldr q1, [x22], #0x10\n" + "ldr q21, [x20], #0x10\n" + "zip1 v25.4s, v22.4s, v1.4s\n" + "zip1 v7.4s, v20.4s, v21.4s\n" + "ldr q31, [x9], #0x10\n" + "ldr q17, [x28], #0x10\n" + "zip2 v30.4s, v22.4s, v1.4s\n" + "zip2 v20.4s, v20.4s, v21.4s\n" + "ldr q15, [x27], #0x10\n" + "ldr q24, [x26], #0x10\n" + "zip1 v6.4s, v31.4s, v15.4s\n" + "zip1 v4.4s, v17.4s, v24.4s\n" + "ldr q12, [x25], #0x10\n" + "ldr q29, [x23], #0x10\n" + "zip2 v22.4s, v31.4s, v15.4s\n" + "zip2 v26.4s, v17.4s, v24.4s\n" + "ldr q0, [x22], #0x10\n" + "ldr q24, [x20], #0x10\n" + "zip1 v17.4s, v12.4s, v0.4s\n" + "zip1 v31.4s, v29.4s, v24.4s\n" + "ldr q21, [x9], #0x10\n" + "ldr q1, [x28], #0x10\n" + "zip2 v23.4s, v12.4s, v0.4s\n" + "zip2 v24.4s, v29.4s, v24.4s\n" + "ldr q11, [x27], #0x10\n" + "ldr q29, [x26], #0x10\n" + "zip1 v0.4s, v21.4s, v11.4s\n" + "zip1 v13.4s, v1.4s, v29.4s\n" + "ldr q15, [x25], #0x10\n" + "ldr q12, [x23], #0x10\n" + "zip2 v21.4s, v21.4s, v11.4s\n" + "zip2 v29.4s, v1.4s, v29.4s\n" + "ldr q1, [x22], #0x10\n" + "zip1 v11.4s, v15.4s, v1.4s\n" + "zip2 v1.4s, v15.4s, v1.4s\n" + "zip1 v15.4s, v19.4s, v14.4s\n" + ".inst 0x0ea169ef // bfcvtn v15.4h, v15.4s\n" + "zip2 v14.4s, v19.4s, v14.4s\n" + "ldr q19, [x20], #0x10\n" + ".inst 0x4ea169cf // bfcvtn2 v15.8h, v14.4s\n" + "str q15, [x21, #0x0]\n" + "zip1 v14.4s, v12.4s, v19.4s\n" + "zip2 v15.4s, v12.4s, v19.4s\n" + "zip1 v12.4s, v8.4s, v28.4s\n" + "zip1 v19.4s, v18.4s, v27.4s\n" + ".inst 0x0ea1698c // bfcvtn v12.4h, v12.4s\n" + "zip2 v28.4s, v8.4s, v28.4s\n" + "zip1 v8.4s, v9.4s, v10.4s\n" + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + "zip2 v18.4s, v18.4s, v27.4s\n" + "zip1 v27.4s, v6.4s, v4.4s\n" + ".inst 0x0ea16908 // bfcvtn v8.4h, v8.4s\n" + "zip2 v10.4s, v9.4s, v10.4s\n" + "zip1 v9.4s, v22.4s, v26.4s\n" + ".inst 0x0ea16b7b // bfcvtn v27.4h, v27.4s\n" + "zip2 v6.4s, v6.4s, v4.4s\n" + "zip1 v4.4s, v0.4s, v13.4s\n" + ".inst 0x0ea16929 // bfcvtn v9.4h, v9.4s\n" + "zip2 v22.4s, v22.4s, v26.4s\n" + "zip1 v26.4s, v21.4s, v29.4s\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "zip2 v13.4s, v0.4s, v13.4s\n" + "zip1 v0.4s, v16.4s, v5.4s\n" + ".inst 0x0ea16b5a // bfcvtn v26.4h, v26.4s\n" + "zip2 v21.4s, v21.4s, v29.4s\n" + "zip1 v29.4s, v3.4s, v2.4s\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "zip2 v5.4s, v16.4s, v5.4s\n" + "zip1 v16.4s, v25.4s, v7.4s\n" + ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n" + "zip2 v2.4s, v3.4s, v2.4s\n" + "zip1 v3.4s, v30.4s, v20.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "zip2 v7.4s, v25.4s, v7.4s\n" + "zip1 v25.4s, v17.4s, v31.4s\n" + ".inst 0x0ea16863 // bfcvtn v3.4h, v3.4s\n" + "zip2 v30.4s, v30.4s, v20.4s\n" + "zip1 v20.4s, v23.4s, v24.4s\n" + ".inst 0x0ea16b39 // bfcvtn v25.4h, v25.4s\n" + "zip2 v17.4s, v17.4s, v31.4s\n" + "zip1 v31.4s, v11.4s, v14.4s\n" + ".inst 0x0ea16a94 // bfcvtn v20.4h, v20.4s\n" + "zip2 v24.4s, v23.4s, v24.4s\n" + "zip1 v23.4s, v1.4s, v15.4s\n" + ".inst 0x0ea16bff // bfcvtn v31.4h, v31.4s\n" + "zip2 v14.4s, v11.4s, v14.4s\n" + ".inst 0x0ea16af7 // bfcvtn v23.4h, v23.4s\n" + "zip2 v1.4s, v1.4s, v15.4s\n" + ".inst 0x4ea16b8c // bfcvtn2 v12.8h, v28.4s\n" + "str q12, [x21, #0x10]\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16948 // bfcvtn2 v8.8h, v10.4s\n" + "str q19, [x21, #0x20]\n" + ".inst 0x4ea168db // bfcvtn2 v27.8h, v6.4s\n" + ".inst 0x4ea16ac9 // bfcvtn2 v9.8h, v22.4s\n" + "str q8, [x21, #0x30]\n" + ".inst 0x4ea169a4 // bfcvtn2 v4.8h, v13.4s\n" + ".inst 0x4ea16aba // bfcvtn2 v26.8h, v21.4s\n" + "str q27, [x21, #0x40]\n" + ".inst 0x4ea168a0 // bfcvtn2 v0.8h, v5.4s\n" + ".inst 0x4ea1685d // bfcvtn2 v29.8h, v2.4s\n" + "str q9, [x21, #0x50]\n" + ".inst 0x4ea168f0 // bfcvtn2 v16.8h, v7.4s\n" + ".inst 0x4ea16bc3 // bfcvtn2 v3.8h, v30.4s\n" + "str q4, [x21, #0x60]\n" + ".inst 0x4ea16a39 // bfcvtn2 v25.8h, v17.4s\n" + ".inst 0x4ea16b14 // bfcvtn2 v20.8h, v24.4s\n" + "str q26, [x21, #0x70]\n" + ".inst 0x4ea169df // bfcvtn2 v31.8h, v14.4s\n" + ".inst 0x4ea16837 // bfcvtn2 v23.8h, v1.4s\n" + "str q0, [x21, #0x80]\n" + "str q29, [x21, #0x90]\n" + "str q16, [x21, #0xa0]\n" + "str q3, [x21, #0xb0]\n" + "str q25, [x21, #0xc0]\n" + "str q20, [x21, #0xd0]\n" + "str q31, [x21, #0xe0]\n" + "str q23, [x21, #0xf0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 5f\n" + "4:" // Main row loop: width 4 loop: loop + "ldr q23, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v22.4s, v23.4s, v17.4s\n" + "zip1 v21.4s, v20.4s, v16.4s\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v28.4s, v23.4s, v17.4s\n" + "zip2 v20.4s, v20.4s, v16.4s\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v27.4s, v19.4s, v17.4s\n" + "zip1 v26.4s, v18.4s, v16.4s\n" + "zip2 v25.4s, v19.4s, v17.4s\n" + "zip2 v24.4s, v18.4s, v16.4s\n" + "zip1 v19.4s, v22.4s, v21.4s\n" + "zip1 v18.4s, v28.4s, v20.4s\n" + "zip1 v17.4s, v27.4s, v26.4s\n" + "zip1 v16.4s, v25.4s, v24.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v22.4s, v21.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v28.4s, v20.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v27.4s, v26.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v25.4s, v24.4s\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + "str q23, [x21, #0x0]\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q21, [x21, #0x10]\n" + "str q19, [x21, #0x80]\n" + "str q17, [x21, #0x90]\n" + "add x21, x21, #0x20\n" + "bge 4b\n" + "5:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 7f\n" + "6:" // Main row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "ldr s20, [x25], #0x4\n" + "ldr s19, [x23], #0x4\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a12 // bfcvtn v18.4h, v16.4s\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v17.4s, v20.4s, v17.4s\n" + "zip1 v16.4s, v19.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d18, [x21, #0x0]\n" + "str d16, [x21, #0x80]\n" + "add x21, x21, #0x8\n" + "bge 6b\n" + "7:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x100\n" + "bge 1b\n" + "cbz %x[height], 16f\n" + "8:" // Main loop skip + "9:" // Tail row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "mov x20, %x[width]\n" + "add x26, x27, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x26, %x[in_stride]\n" + "csel x26, x26, %x[pad_row], GT\n" + "csel x27, x27, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x28, x28, %x[pad_row], GT\n" + "cmp x20, #0x10\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 11f\n" + "10:" // Tail row loop: Column loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v30.4s, v21.4s, v17.4s\n" + "zip1 v29.4s, v20.4s, v16.4s\n" + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v28.4s, v21.4s, v17.4s\n" + "zip2 v27.4s, v20.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v26.4s, v19.4s, v17.4s\n" + "zip1 v25.4s, v18.4s, v16.4s\n" + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v8.4s, v19.4s, v17.4s\n" + "zip2 v24.4s, v18.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v7.4s, v21.4s, v17.4s\n" + "zip1 v6.4s, v20.4s, v16.4s\n" + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v5.4s, v21.4s, v17.4s\n" + "zip2 v4.4s, v20.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v3.4s, v19.4s, v17.4s\n" + "zip1 v2.4s, v18.4s, v16.4s\n" + "zip2 v1.4s, v19.4s, v17.4s\n" + "zip2 v0.4s, v18.4s, v16.4s\n" + "zip1 v23.4s, v30.4s, v29.4s\n" + "zip1 v22.4s, v28.4s, v27.4s\n" + "zip1 v21.4s, v26.4s, v25.4s\n" + "zip1 v20.4s, v8.4s, v24.4s\n" + "zip1 v19.4s, v7.4s, v6.4s\n" + "zip1 v18.4s, v5.4s, v4.4s\n" + "zip1 v17.4s, v3.4s, v2.4s\n" + "zip1 v16.4s, v1.4s, v0.4s\n" + ".inst 0x0ea16aff // bfcvtn v31.4h, v23.4s\n" + "zip2 v30.4s, v30.4s, v29.4s\n" + ".inst 0x0ea16add // bfcvtn v29.4h, v22.4s\n" + "zip2 v28.4s, v28.4s, v27.4s\n" + ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" + "zip2 v26.4s, v26.4s, v25.4s\n" + ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n" + "zip2 v24.4s, v8.4s, v24.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v7.4s, v6.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v5.4s, v4.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v3.4s, v2.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v1.4s, v0.4s\n" + ".inst 0x4ea16bdf // bfcvtn2 v31.8h, v30.4s\n" + ".inst 0x4ea16b9d // bfcvtn2 v29.8h, v28.4s\n" + "str q31, [x21, #0x0]\n" + ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n" + ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" + "str q29, [x21, #0x10]\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + "str q27, [x21, #0x20]\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q25, [x21, #0x30]\n" + "str q23, [x21, #0x40]\n" + "str q21, [x21, #0x50]\n" + "str q19, [x21, #0x60]\n" + "str q17, [x21, #0x70]\n" + "add x21, x21, %x[out_stride]\n" + "bge 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: width 4 loop: loop + "ldr q20, [x9], #0x10\n" + "ldr q19, [x28], #0x10\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v22.4s, v20.4s, v17.4s\n" + "zip1 v18.4s, v19.4s, v16.4s\n" + "zip2 v21.4s, v20.4s, v17.4s\n" + "zip2 v20.4s, v19.4s, v16.4s\n" + "zip1 v17.4s, v22.4s, v18.4s\n" + "zip1 v16.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v22.4s, v18.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v21.4s, v20.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q19, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "add x21, x21, #0x20\n" + "bge 12b\n" + "13:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 15f\n" + "14:" // Tail row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d16, [x21, #0x0]\n" + "add x21, x21, #0x8\n" + "bge 14b\n" + "15:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x80\n" + "bge 9b\n" + "16:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace +template<> +void Transform<16, 4, true, VLType::None>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_16_2x4_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24.hpp new file mode 100644 index 0000000000..46211ad4e4 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24.hpp @@ -0,0 +1,271 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_24(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 12 * height * sizeof(uint16_t); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q1, [x25], #0x10\n" + "ldr q0, [x22], #0x10\n" + "sub x24, x24, #0x18\n" + "cmp x24, #0x18\n" + "ldr q17, [x25], #0x10\n" + "ldr q31, [x23], #0x10\n" + "dup v30.2d, v17.d[0]\n" + "dup v29.2d, v31.d[1]\n" + "ldr q16, [x22], #0x10\n" + "ldr q28, [x20], #0x10\n" + "dup v27.2d, v16.d[0]\n" + "dup v26.2d, v28.d[1]\n" + "ldr q25, [x25], #0x10\n" + "ldr q24, [x22], #0x10\n" + "dup v23.2d, v17.d[1]\n" + "dup v22.2d, v25.d[1]\n" + "ldr q21, [x23], #0x10\n" + "ldr q20, [x20], #0x10\n" + "dup v19.2d, v16.d[1]\n" + "dup v18.2d, v24.d[1]\n" + "ldr q17, [x23], #0x10\n" + "ldr q16, [x20], #0x10\n" + "mov v30.d[1], v31.d[0]\n" + "mov v29.d[1], v21.d[0]\n" + "mov v27.d[1], v28.d[0]\n" + "mov v26.d[1], v20.d[0]\n" + "str q1, [x21, #0x0]\n" + "str q30, [x21, #0x10]\n" + "mov v23.d[1], v25.d[0]\n" + "mov v22.d[1], v21.d[1]\n" + "str q29, [x21, #0x20]\n" + "mov v19.d[1], v24.d[0]\n" + "mov v18.d[1], v20.d[1]\n" + "str q0, [x21, #0x30]\n" + "str q27, [x21, #0x40]\n" + "str q26, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "str q23, [x21, #0x0]\n" + "str q22, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q19, [x21, #0x30]\n" + "str q18, [x21, #0x40]\n" + "str q16, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0xc\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr q17, [x23], #0x10\n" + "ldr q23, [x20], #0x10\n" + "dup v22.2d, v17.d[1]\n" + "dup v21.2d, v23.d[1]\n" + "ldr q20, [x25], #0x10\n" + "ldr q19, [x22], #0x10\n" + "sub x24, x24, #0xc\n" + "cmp x24, #0xc\n" + "ldr d18, [x25], #0x8\n" + "ldr d16, [x23], #0x8\n" + "mov v18.d[1], v17.d[0]\n" + "mov v22.d[1], v16.d[0]\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "mov v17.d[1], v23.d[0]\n" + "mov v21.d[1], v16.d[0]\n" + "str q20, [x21, #0x0]\n" + "str q18, [x21, #0x10]\n" + "str q22, [x21, #0x20]\n" + "str q19, [x21, #0x30]\n" + "str q17, [x21, #0x40]\n" + "str q21, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr d19, [x25], #0x8\n" + "ldr d18, [x23], #0x8\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "str d19, [x21, #0x0]\n" + "str d18, [x21, #0x18]\n" + "str d17, [x21, #0x30]\n" + "str d16, [x21, #0x48]\n" + "add x21, x21, #0x8\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr h19, [x25], #0x2\n" + "ldr h18, [x23], #0x2\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr h17, [x22], #0x2\n" + "ldr h16, [x20], #0x2\n" + "str h19, [x21, #0x0]\n" + "str h18, [x21, #0x18]\n" + "str h17, [x21, #0x30]\n" + "str h16, [x21, #0x48]\n" + "add x21, x21, #0x2\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0x60\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "cmp x20, #0x18\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 13f\n" + "12:" // Tail row loop: Unroll column loop + "ldr q19, [x25], #0x10\n" + "ldr q16, [x25], #0x10\n" + "dup v18.2d, v16.d[1]\n" + "sub x20, x20, #0x18\n" + "ldr q17, [x25], #0x10\n" + "dup v16.2d, v16.d[0]\n" + "str q19, [x21, #0x0]\n" + "cmp x20, #0x18\n" + "str d16, [x21, #0x10]\n" + "add x21, x21, %x[out_stride]\n" + "mov v18.d[1], v17.d[0]\n" + "dup v16.2d, v17.d[1]\n" + "str q18, [x21, #0x0]\n" + "str d16, [x21, #0x10]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Unroll column loop skip + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q17, [x25], #0x10\n" + "ldr d16, [x25], #0x8\n" + "sub x20, x20, #0xc\n" + "cmp x20, #0xc\n" + "str q17, [x21, #0x0]\n" + "str d16, [x21, #0x10]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr d16, [x25], #0x8\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "str d16, [x21, #0x0]\n" + "add x21, x21, #0x8\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr h16, [x25], #0x2\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "str h16, [x21, #0x0]\n" + "add x21, x21, #0x2\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x18\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace + +template<> +void Transform<6, 1, true, VLType::None>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_24( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 2, + stride * sizeof(float), + (kmax-k0) + ); +} + +template<> +void Transform<12, 1, true, VLType::None>( + int16_t *out, const int16_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_24( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int16_t) / 2, + stride * sizeof(int16_t), + (kmax-k0) + ); +} + +template<> +void Transform<12, 1, true, VLType::None>( + uint16_t *out, const uint16_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_24( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint16_t) / 2, + stride * sizeof(uint16_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24_2x4_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24_2x4_fp32bf16.hpp new file mode 100644 index 0000000000..1cb7bc4445 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24_2x4_fp32bf16.hpp @@ -0,0 +1,786 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_24_2x4_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 24 * roundup<size_t>(height, 4) * sizeof(bfloat16); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Column loop + "ldr q15, [x9], #0x10\n" + "ldr q1, [x28], #0x10\n" + "sub x24, x24, #0x18\n" + "cmp x24, #0x18\n" + "ldr q0, [x27], #0x10\n" + "ldr q27, [x26], #0x10\n" + "zip1 v18.4s, v15.4s, v0.4s\n" + "zip1 v20.4s, v1.4s, v27.4s\n" + "ldr q13, [x25], #0x10\n" + "ldr q14, [x23], #0x10\n" + "zip2 v16.4s, v15.4s, v0.4s\n" + "zip2 v3.4s, v1.4s, v27.4s\n" + "ldr q12, [x22], #0x10\n" + "ldr q11, [x20], #0x10\n" + "zip1 v4.4s, v13.4s, v12.4s\n" + "zip1 v28.4s, v14.4s, v11.4s\n" + "ldr q5, [x9], #0x10\n" + "ldr q30, [x28], #0x10\n" + "zip2 v23.4s, v13.4s, v12.4s\n" + "zip2 v19.4s, v14.4s, v11.4s\n" + "ldr q25, [x27], #0x10\n" + "ldr q11, [x26], #0x10\n" + "zip1 v21.4s, v5.4s, v25.4s\n" + "zip1 v14.4s, v30.4s, v11.4s\n" + "ldr q6, [x25], #0x10\n" + "ldr q27, [x23], #0x10\n" + "zip2 v29.4s, v5.4s, v25.4s\n" + "zip2 v17.4s, v30.4s, v11.4s\n" + "ldr q2, [x22], #0x10\n" + "ldr q10, [x20], #0x10\n" + "zip1 v11.4s, v6.4s, v2.4s\n" + "zip1 v1.4s, v27.4s, v10.4s\n" + "ldr q8, [x9], #0x10\n" + "ldr q5, [x28], #0x10\n" + "zip2 v24.4s, v6.4s, v2.4s\n" + "zip2 v0.4s, v27.4s, v10.4s\n" + "ldr q6, [x27], #0x10\n" + "ldr q31, [x26], #0x10\n" + "zip1 v12.4s, v8.4s, v6.4s\n" + "zip1 v10.4s, v5.4s, v31.4s\n" + "ldr q30, [x25], #0x10\n" + "ldr q2, [x23], #0x10\n" + "zip2 v9.4s, v8.4s, v6.4s\n" + "zip2 v13.4s, v5.4s, v31.4s\n" + "ldr q7, [x22], #0x10\n" + "ldr q8, [x20], #0x10\n" + "zip1 v27.4s, v30.4s, v7.4s\n" + "zip1 v31.4s, v2.4s, v8.4s\n" + "ldr q5, [x9], #0x10\n" + "ldr q26, [x28], #0x10\n" + "zip2 v22.4s, v30.4s, v7.4s\n" + "zip2 v8.4s, v2.4s, v8.4s\n" + "ldr q2, [x27], #0x10\n" + "ldr q6, [x26], #0x10\n" + "zip1 v25.4s, v5.4s, v2.4s\n" + "zip1 v15.4s, v26.4s, v6.4s\n" + "ldr q7, [x25], #0x10\n" + "ldr q30, [x23], #0x10\n" + "zip2 v5.4s, v5.4s, v2.4s\n" + "zip2 v26.4s, v26.4s, v6.4s\n" + "ldr q2, [x22], #0x10\n" + "zip1 v6.4s, v7.4s, v2.4s\n" + "zip2 v7.4s, v7.4s, v2.4s\n" + "zip1 v2.4s, v18.4s, v20.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "zip2 v20.4s, v18.4s, v20.4s\n" + "ldr q18, [x20], #0x10\n" + ".inst 0x4ea16a82 // bfcvtn2 v2.8h, v20.4s\n" + "zip1 v20.4s, v30.4s, v18.4s\n" + "zip2 v18.4s, v30.4s, v18.4s\n" + "zip1 v30.4s, v16.4s, v3.4s\n" + ".inst 0x0ea16bde // bfcvtn v30.4h, v30.4s\n" + "zip2 v3.4s, v16.4s, v3.4s\n" + "ldr q16, [x9], #0x10\n" + ".inst 0x4ea1687e // bfcvtn2 v30.8h, v3.4s\n" + "zip1 v3.4s, v21.4s, v14.4s\n" + ".inst 0x0ea16863 // bfcvtn v3.4h, v3.4s\n" + "zip2 v21.4s, v21.4s, v14.4s\n" + "ldr q14, [x28], #0x10\n" + ".inst 0x4ea16aa3 // bfcvtn2 v3.8h, v21.4s\n" + "zip1 v21.4s, v29.4s, v17.4s\n" + ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n" + "zip2 v29.4s, v29.4s, v17.4s\n" + "ldr q17, [x27], #0x10\n" + ".inst 0x4ea16bb5 // bfcvtn2 v21.8h, v29.4s\n" + "zip1 v29.4s, v16.4s, v17.4s\n" + "zip2 v16.4s, v16.4s, v17.4s\n" + "zip1 v17.4s, v12.4s, v10.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + "zip2 v10.4s, v12.4s, v10.4s\n" + "ldr q12, [x26], #0x10\n" + ".inst 0x4ea16951 // bfcvtn2 v17.8h, v10.4s\n" + "zip1 v10.4s, v14.4s, v12.4s\n" + "zip2 v14.4s, v14.4s, v12.4s\n" + "zip1 v12.4s, v9.4s, v13.4s\n" + ".inst 0x0ea1698c // bfcvtn v12.4h, v12.4s\n" + "zip2 v13.4s, v9.4s, v13.4s\n" + "ldr q9, [x25], #0x10\n" + ".inst 0x4ea169ac // bfcvtn2 v12.8h, v13.4s\n" + "zip1 v13.4s, v25.4s, v15.4s\n" + ".inst 0x0ea169ad // bfcvtn v13.4h, v13.4s\n" + "zip2 v25.4s, v25.4s, v15.4s\n" + "ldr q15, [x23], #0x10\n" + ".inst 0x4ea16b2d // bfcvtn2 v13.8h, v25.4s\n" + "zip1 v25.4s, v5.4s, v26.4s\n" + ".inst 0x0ea16b39 // bfcvtn v25.4h, v25.4s\n" + "zip2 v5.4s, v5.4s, v26.4s\n" + "ldr q26, [x22], #0x10\n" + ".inst 0x4ea168b9 // bfcvtn2 v25.8h, v5.4s\n" + "zip1 v5.4s, v9.4s, v26.4s\n" + "zip2 v9.4s, v9.4s, v26.4s\n" + "zip1 v26.4s, v29.4s, v10.4s\n" + ".inst 0x0ea16b5a // bfcvtn v26.4h, v26.4s\n" + "zip2 v10.4s, v29.4s, v10.4s\n" + "ldr q29, [x20], #0x10\n" + ".inst 0x4ea1695a // bfcvtn2 v26.8h, v10.4s\n" + "zip1 v10.4s, v15.4s, v29.4s\n" + "zip2 v15.4s, v15.4s, v29.4s\n" + "zip1 v29.4s, v16.4s, v14.4s\n" + ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n" + "zip2 v14.4s, v16.4s, v14.4s\n" + "ldr q16, [x9], #0x10\n" + ".inst 0x4ea169dd // bfcvtn2 v29.8h, v14.4s\n" + "zip1 v14.4s, v4.4s, v28.4s\n" + ".inst 0x0ea169ce // bfcvtn v14.4h, v14.4s\n" + "zip2 v4.4s, v4.4s, v28.4s\n" + "ldr q28, [x28], #0x10\n" + ".inst 0x4ea1688e // bfcvtn2 v14.8h, v4.4s\n" + "zip1 v4.4s, v23.4s, v19.4s\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "zip2 v19.4s, v23.4s, v19.4s\n" + "ldr q23, [x27], #0x10\n" + ".inst 0x4ea16a64 // bfcvtn2 v4.8h, v19.4s\n" + "zip1 v19.4s, v16.4s, v23.4s\n" + "zip2 v16.4s, v16.4s, v23.4s\n" + "zip1 v23.4s, v11.4s, v1.4s\n" + ".inst 0x0ea16af7 // bfcvtn v23.4h, v23.4s\n" + "zip2 v1.4s, v11.4s, v1.4s\n" + "ldr q11, [x26], #0x10\n" + ".inst 0x4ea16837 // bfcvtn2 v23.8h, v1.4s\n" + "zip1 v1.4s, v28.4s, v11.4s\n" + "zip2 v28.4s, v28.4s, v11.4s\n" + "zip1 v11.4s, v19.4s, v1.4s\n" + ".inst 0x0ea1696b // bfcvtn v11.4h, v11.4s\n" + "zip2 v19.4s, v19.4s, v1.4s\n" + "ldr q1, [x25], #0x10\n" + ".inst 0x4ea16a6b // bfcvtn2 v11.8h, v19.4s\n" + "zip1 v19.4s, v16.4s, v28.4s\n" + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + "zip2 v16.4s, v16.4s, v28.4s\n" + "ldr q28, [x23], #0x10\n" + ".inst 0x4ea16a13 // bfcvtn2 v19.8h, v16.4s\n" + "zip1 v16.4s, v24.4s, v0.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "zip2 v24.4s, v24.4s, v0.4s\n" + "ldr q0, [x22], #0x10\n" + ".inst 0x4ea16b10 // bfcvtn2 v16.8h, v24.4s\n" + "ldr q24, [x20], #0x10\n" + "str q2, [x21, #0x0]\n" + "zip1 v2.4s, v1.4s, v0.4s\n" + "zip2 v0.4s, v1.4s, v0.4s\n" + "zip1 v1.4s, v28.4s, v24.4s\n" + "zip2 v28.4s, v28.4s, v24.4s\n" + "str q30, [x21, #0x10]\n" + "zip1 v24.4s, v27.4s, v31.4s\n" + "zip1 v30.4s, v22.4s, v8.4s\n" + "str q3, [x21, #0x20]\n" + "zip1 v3.4s, v6.4s, v20.4s\n" + ".inst 0x0ea16b18 // bfcvtn v24.4h, v24.4s\n" + "str q21, [x21, #0x30]\n" + "zip1 v21.4s, v7.4s, v18.4s\n" + "zip2 v31.4s, v27.4s, v31.4s\n" + "str q17, [x21, #0x40]\n" + "zip1 v17.4s, v5.4s, v10.4s\n" + "zip1 v27.4s, v9.4s, v15.4s\n" + "str q12, [x21, #0x50]\n" + "zip1 v12.4s, v2.4s, v1.4s\n" + ".inst 0x0ea16bde // bfcvtn v30.4h, v30.4s\n" + "str q13, [x21, #0x60]\n" + "zip1 v13.4s, v0.4s, v28.4s\n" + "zip2 v22.4s, v22.4s, v8.4s\n" + "str q25, [x21, #0x70]\n" + ".inst 0x0ea16879 // bfcvtn v25.4h, v3.4s\n" + "zip2 v8.4s, v6.4s, v20.4s\n" + "str q26, [x21, #0x80]\n" + ".inst 0x0ea16aa3 // bfcvtn v3.4h, v21.4s\n" + "zip2 v18.4s, v7.4s, v18.4s\n" + "str q29, [x21, #0x90]\n" + ".inst 0x0ea16a27 // bfcvtn v7.4h, v17.4s\n" + "zip2 v21.4s, v5.4s, v10.4s\n" + "str q11, [x21, #0xa0]\n" + ".inst 0x0ea16b65 // bfcvtn v5.4h, v27.4s\n" + "zip2 v15.4s, v9.4s, v15.4s\n" + "str q19, [x21, #0xb0]\n" + ".inst 0x0ea16991 // bfcvtn v17.4h, v12.4s\n" + "zip2 v20.4s, v2.4s, v1.4s\n" + "str q14, [x21, #0xc0]\n" + ".inst 0x0ea169bb // bfcvtn v27.4h, v13.4s\n" + "zip2 v29.4s, v0.4s, v28.4s\n" + "str q4, [x21, #0xd0]\n" + ".inst 0x4ea16bf8 // bfcvtn2 v24.8h, v31.4s\n" + ".inst 0x4ea16ade // bfcvtn2 v30.8h, v22.4s\n" + "str q23, [x21, #0xe0]\n" + ".inst 0x4ea16919 // bfcvtn2 v25.8h, v8.4s\n" + ".inst 0x4ea16a43 // bfcvtn2 v3.8h, v18.4s\n" + "str q16, [x21, #0xf0]\n" + ".inst 0x4ea16aa7 // bfcvtn2 v7.8h, v21.4s\n" + ".inst 0x4ea169e5 // bfcvtn2 v5.8h, v15.4s\n" + "str q24, [x21, #0x100]\n" + ".inst 0x4ea16a91 // bfcvtn2 v17.8h, v20.4s\n" + ".inst 0x4ea16bbb // bfcvtn2 v27.8h, v29.4s\n" + "str q30, [x21, #0x110]\n" + "str q25, [x21, #0x120]\n" + "str q3, [x21, #0x130]\n" + "str q7, [x21, #0x140]\n" + "str q5, [x21, #0x150]\n" + "str q17, [x21, #0x160]\n" + "str q27, [x21, #0x170]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Column loop skip + "cmp x24, #0x10\n" + "blt 5f\n" + "4:" // Main row loop: width 16 loop: loop + "ldr q9, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q15, [x27], #0x10\n" + "ldr q17, [x26], #0x10\n" + "zip1 v14.4s, v9.4s, v15.4s\n" + "zip1 v11.4s, v18.4s, v17.4s\n" + "ldr q7, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip2 v12.4s, v9.4s, v15.4s\n" + "zip2 v6.4s, v18.4s, v17.4s\n" + "ldr q15, [x22], #0x10\n" + "ldr q3, [x20], #0x10\n" + "zip1 v30.4s, v7.4s, v15.4s\n" + "zip1 v20.4s, v16.4s, v3.4s\n" + "ldr q17, [x9], #0x10\n" + "ldr q9, [x28], #0x10\n" + "zip2 v1.4s, v7.4s, v15.4s\n" + "zip2 v24.4s, v16.4s, v3.4s\n" + "ldr q10, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v0.4s, v17.4s, v10.4s\n" + "zip1 v8.4s, v9.4s, v16.4s\n" + "ldr q7, [x25], #0x10\n" + "ldr q2, [x23], #0x10\n" + "zip2 v17.4s, v17.4s, v10.4s\n" + "zip2 v3.4s, v9.4s, v16.4s\n" + "ldr q9, [x22], #0x10\n" + "ldr q10, [x20], #0x10\n" + "zip1 v25.4s, v7.4s, v9.4s\n" + "zip1 v23.4s, v2.4s, v10.4s\n" + "ldr q31, [x9], #0x10\n" + "ldr q21, [x28], #0x10\n" + "zip2 v16.4s, v7.4s, v9.4s\n" + "zip2 v27.4s, v2.4s, v10.4s\n" + "ldr q26, [x27], #0x10\n" + "ldr q19, [x26], #0x10\n" + "zip1 v2.4s, v31.4s, v26.4s\n" + "zip1 v7.4s, v21.4s, v19.4s\n" + "ldr q29, [x25], #0x10\n" + "ldr q13, [x23], #0x10\n" + "zip2 v31.4s, v31.4s, v26.4s\n" + "zip2 v19.4s, v21.4s, v19.4s\n" + "ldr q4, [x22], #0x10\n" + "ldr q18, [x20], #0x10\n" + "zip1 v26.4s, v29.4s, v4.4s\n" + "zip1 v15.4s, v13.4s, v18.4s\n" + "ldr q9, [x9], #0x10\n" + "ldr q22, [x28], #0x10\n" + "zip2 v4.4s, v29.4s, v4.4s\n" + "zip2 v18.4s, v13.4s, v18.4s\n" + "ldr q29, [x27], #0x10\n" + "ldr q10, [x26], #0x10\n" + "zip1 v21.4s, v9.4s, v29.4s\n" + "zip1 v5.4s, v22.4s, v10.4s\n" + "ldr q28, [x25], #0x10\n" + "ldr q13, [x23], #0x10\n" + "zip2 v29.4s, v9.4s, v29.4s\n" + "zip2 v9.4s, v22.4s, v10.4s\n" + "ldr q22, [x22], #0x10\n" + "zip1 v10.4s, v28.4s, v22.4s\n" + "zip2 v28.4s, v28.4s, v22.4s\n" + "zip1 v22.4s, v14.4s, v11.4s\n" + ".inst 0x0ea16ad6 // bfcvtn v22.4h, v22.4s\n" + "zip2 v11.4s, v14.4s, v11.4s\n" + "ldr q14, [x20], #0x10\n" + ".inst 0x4ea16976 // bfcvtn2 v22.8h, v11.4s\n" + "str q22, [x21, #0x0]\n" + "zip1 v22.4s, v13.4s, v14.4s\n" + "zip2 v14.4s, v13.4s, v14.4s\n" + "zip1 v13.4s, v12.4s, v6.4s\n" + "zip1 v11.4s, v0.4s, v8.4s\n" + ".inst 0x0ea169ad // bfcvtn v13.4h, v13.4s\n" + "zip2 v12.4s, v12.4s, v6.4s\n" + "zip1 v6.4s, v17.4s, v3.4s\n" + ".inst 0x0ea1696b // bfcvtn v11.4h, v11.4s\n" + "zip2 v0.4s, v0.4s, v8.4s\n" + "zip1 v8.4s, v2.4s, v7.4s\n" + ".inst 0x0ea168c6 // bfcvtn v6.4h, v6.4s\n" + "zip2 v3.4s, v17.4s, v3.4s\n" + "zip1 v17.4s, v31.4s, v19.4s\n" + ".inst 0x0ea16908 // bfcvtn v8.4h, v8.4s\n" + "zip2 v2.4s, v2.4s, v7.4s\n" + "zip1 v7.4s, v21.4s, v5.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + "zip2 v31.4s, v31.4s, v19.4s\n" + "zip1 v19.4s, v29.4s, v9.4s\n" + ".inst 0x0ea168e7 // bfcvtn v7.4h, v7.4s\n" + "zip2 v21.4s, v21.4s, v5.4s\n" + "zip1 v5.4s, v30.4s, v20.4s\n" + ".inst 0x0ea16a73 // bfcvtn v19.4h, v19.4s\n" + "zip2 v29.4s, v29.4s, v9.4s\n" + "zip1 v9.4s, v1.4s, v24.4s\n" + ".inst 0x0ea168a5 // bfcvtn v5.4h, v5.4s\n" + "zip2 v20.4s, v30.4s, v20.4s\n" + "zip1 v30.4s, v25.4s, v23.4s\n" + ".inst 0x0ea16929 // bfcvtn v9.4h, v9.4s\n" + "zip2 v1.4s, v1.4s, v24.4s\n" + "zip1 v24.4s, v16.4s, v27.4s\n" + ".inst 0x0ea16bde // bfcvtn v30.4h, v30.4s\n" + "zip2 v23.4s, v25.4s, v23.4s\n" + "zip1 v25.4s, v26.4s, v15.4s\n" + ".inst 0x0ea16b18 // bfcvtn v24.4h, v24.4s\n" + "zip2 v27.4s, v16.4s, v27.4s\n" + "zip1 v16.4s, v4.4s, v18.4s\n" + ".inst 0x0ea16b39 // bfcvtn v25.4h, v25.4s\n" + "zip2 v15.4s, v26.4s, v15.4s\n" + "zip1 v26.4s, v10.4s, v22.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "zip2 v18.4s, v4.4s, v18.4s\n" + "zip1 v4.4s, v28.4s, v14.4s\n" + ".inst 0x0ea16b5a // bfcvtn v26.4h, v26.4s\n" + "zip2 v10.4s, v10.4s, v22.4s\n" + ".inst 0x0ea16896 // bfcvtn v22.4h, v4.4s\n" + "zip2 v4.4s, v28.4s, v14.4s\n" + ".inst 0x4ea1698d // bfcvtn2 v13.8h, v12.4s\n" + "str q13, [x21, #0x10]\n" + ".inst 0x4ea1680b // bfcvtn2 v11.8h, v0.4s\n" + ".inst 0x4ea16866 // bfcvtn2 v6.8h, v3.4s\n" + "str q11, [x21, #0x20]\n" + ".inst 0x4ea16848 // bfcvtn2 v8.8h, v2.4s\n" + ".inst 0x4ea16bf1 // bfcvtn2 v17.8h, v31.4s\n" + "str q6, [x21, #0x30]\n" + ".inst 0x4ea16aa7 // bfcvtn2 v7.8h, v21.4s\n" + ".inst 0x4ea16bb3 // bfcvtn2 v19.8h, v29.4s\n" + "str q8, [x21, #0x40]\n" + ".inst 0x4ea16a85 // bfcvtn2 v5.8h, v20.4s\n" + ".inst 0x4ea16829 // bfcvtn2 v9.8h, v1.4s\n" + "str q17, [x21, #0x50]\n" + ".inst 0x4ea16afe // bfcvtn2 v30.8h, v23.4s\n" + ".inst 0x4ea16b78 // bfcvtn2 v24.8h, v27.4s\n" + "str q7, [x21, #0x60]\n" + ".inst 0x4ea169f9 // bfcvtn2 v25.8h, v15.4s\n" + ".inst 0x4ea16a50 // bfcvtn2 v16.8h, v18.4s\n" + "str q19, [x21, #0x70]\n" + ".inst 0x4ea1695a // bfcvtn2 v26.8h, v10.4s\n" + ".inst 0x4ea16896 // bfcvtn2 v22.8h, v4.4s\n" + "str q5, [x21, #0xc0]\n" + "str q9, [x21, #0xd0]\n" + "str q30, [x21, #0xe0]\n" + "str q24, [x21, #0xf0]\n" + "str q25, [x21, #0x100]\n" + "str q16, [x21, #0x110]\n" + "str q26, [x21, #0x120]\n" + "str q22, [x21, #0x130]\n" + "add x21, x21, #0x80\n" + "bge 4b\n" + "5:" // Main row loop: width 16 loop: skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr q23, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v22.4s, v23.4s, v17.4s\n" + "zip1 v21.4s, v20.4s, v16.4s\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v28.4s, v23.4s, v17.4s\n" + "zip2 v20.4s, v20.4s, v16.4s\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v27.4s, v19.4s, v17.4s\n" + "zip1 v26.4s, v18.4s, v16.4s\n" + "zip2 v25.4s, v19.4s, v17.4s\n" + "zip2 v24.4s, v18.4s, v16.4s\n" + "zip1 v19.4s, v22.4s, v21.4s\n" + "zip1 v18.4s, v28.4s, v20.4s\n" + "zip1 v17.4s, v27.4s, v26.4s\n" + "zip1 v16.4s, v25.4s, v24.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v22.4s, v21.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v28.4s, v20.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v27.4s, v26.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v25.4s, v24.4s\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + "str q23, [x21, #0x0]\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q21, [x21, #0x10]\n" + "str q19, [x21, #0xc0]\n" + "str q17, [x21, #0xd0]\n" + "add x21, x21, #0x20\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "ldr s20, [x25], #0x4\n" + "ldr s19, [x23], #0x4\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a12 // bfcvtn v18.4h, v16.4s\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v17.4s, v20.4s, v17.4s\n" + "zip1 v16.4s, v19.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d18, [x21, #0x0]\n" + "str d16, [x21, #0xc0]\n" + "add x21, x21, #0x8\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x180\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x9, %x[in]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "mov x20, %x[width]\n" + "add x26, x27, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x26, %x[in_stride]\n" + "csel x26, x26, %x[pad_row], GT\n" + "csel x27, x27, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x28, x28, %x[pad_row], GT\n" + "cmp x20, #0x18\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: Column loop + "ldr q20, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "sub x20, x20, #0x18\n" + "cmp x20, #0x18\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v9.4s, v20.4s, v17.4s\n" + "zip1 v30.4s, v18.4s, v16.4s\n" + "ldr q21, [x9], #0x10\n" + "ldr q19, [x28], #0x10\n" + "zip2 v17.4s, v20.4s, v17.4s\n" + "zip2 v5.4s, v18.4s, v16.4s\n" + "ldr q18, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v0.4s, v21.4s, v18.4s\n" + "zip1 v3.4s, v19.4s, v16.4s\n" + "ldr q23, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v1.4s, v21.4s, v18.4s\n" + "zip2 v16.4s, v19.4s, v16.4s\n" + "ldr q19, [x27], #0x10\n" + "ldr q18, [x26], #0x10\n" + "zip1 v4.4s, v23.4s, v19.4s\n" + "zip1 v2.4s, v20.4s, v18.4s\n" + "ldr q22, [x9], #0x10\n" + "ldr q21, [x28], #0x10\n" + "zip2 v27.4s, v23.4s, v19.4s\n" + "zip2 v28.4s, v20.4s, v18.4s\n" + "ldr q20, [x27], #0x10\n" + "ldr q24, [x26], #0x10\n" + "zip1 v25.4s, v22.4s, v20.4s\n" + "zip1 v26.4s, v21.4s, v24.4s\n" + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v14.4s, v22.4s, v20.4s\n" + "zip2 v12.4s, v21.4s, v24.4s\n" + "ldr q31, [x27], #0x10\n" + "ldr q24, [x26], #0x10\n" + "zip1 v15.4s, v19.4s, v31.4s\n" + "zip1 v13.4s, v18.4s, v24.4s\n" + "ldr q21, [x9], #0x10\n" + "ldr q11, [x28], #0x10\n" + "zip2 v20.4s, v19.4s, v31.4s\n" + "zip2 v10.4s, v18.4s, v24.4s\n" + "ldr q22, [x27], #0x10\n" + "ldr q23, [x26], #0x10\n" + "zip1 v19.4s, v21.4s, v22.4s\n" + "zip1 v18.4s, v11.4s, v23.4s\n" + "zip2 v6.4s, v21.4s, v22.4s\n" + "zip2 v11.4s, v11.4s, v23.4s\n" + "zip1 v8.4s, v9.4s, v30.4s\n" + "zip1 v21.4s, v17.4s, v5.4s\n" + "zip1 v7.4s, v0.4s, v3.4s\n" + "zip1 v31.4s, v1.4s, v16.4s\n" + "zip1 v29.4s, v4.4s, v2.4s\n" + "zip1 v22.4s, v27.4s, v28.4s\n" + "zip1 v24.4s, v25.4s, v26.4s\n" + "zip1 v23.4s, v14.4s, v12.4s\n" + ".inst 0x0ea16908 // bfcvtn v8.4h, v8.4s\n" + "zip2 v9.4s, v9.4s, v30.4s\n" + "zip1 v30.4s, v15.4s, v13.4s\n" + ".inst 0x0ea16ab5 // bfcvtn v21.4h, v21.4s\n" + "zip2 v5.4s, v17.4s, v5.4s\n" + "zip1 v17.4s, v20.4s, v10.4s\n" + ".inst 0x0ea168e7 // bfcvtn v7.4h, v7.4s\n" + "zip2 v0.4s, v0.4s, v3.4s\n" + "zip1 v3.4s, v19.4s, v18.4s\n" + ".inst 0x0ea16bff // bfcvtn v31.4h, v31.4s\n" + "zip2 v16.4s, v1.4s, v16.4s\n" + "zip1 v1.4s, v6.4s, v11.4s\n" + ".inst 0x0ea16bbd // bfcvtn v29.4h, v29.4s\n" + "zip2 v2.4s, v4.4s, v2.4s\n" + ".inst 0x0ea16ac4 // bfcvtn v4.4h, v22.4s\n" + "zip2 v27.4s, v27.4s, v28.4s\n" + ".inst 0x0ea16b1c // bfcvtn v28.4h, v24.4s\n" + "zip2 v25.4s, v25.4s, v26.4s\n" + ".inst 0x0ea16afa // bfcvtn v26.4h, v23.4s\n" + "zip2 v14.4s, v14.4s, v12.4s\n" + ".inst 0x0ea16bd8 // bfcvtn v24.4h, v30.4s\n" + "zip2 v13.4s, v15.4s, v13.4s\n" + ".inst 0x0ea16a2f // bfcvtn v15.4h, v17.4s\n" + "zip2 v12.4s, v20.4s, v10.4s\n" + ".inst 0x0ea16874 // bfcvtn v20.4h, v3.4s\n" + "zip2 v10.4s, v19.4s, v18.4s\n" + ".inst 0x0ea16831 // bfcvtn v17.4h, v1.4s\n" + "zip2 v18.4s, v6.4s, v11.4s\n" + ".inst 0x4ea16928 // bfcvtn2 v8.8h, v9.4s\n" + ".inst 0x4ea168b5 // bfcvtn2 v21.8h, v5.4s\n" + "str q8, [x21, #0x0]\n" + ".inst 0x4ea16807 // bfcvtn2 v7.8h, v0.4s\n" + ".inst 0x4ea16a1f // bfcvtn2 v31.8h, v16.4s\n" + "str q21, [x21, #0x10]\n" + ".inst 0x4ea1685d // bfcvtn2 v29.8h, v2.4s\n" + ".inst 0x4ea16b64 // bfcvtn2 v4.8h, v27.4s\n" + "str q7, [x21, #0x20]\n" + ".inst 0x4ea16b3c // bfcvtn2 v28.8h, v25.4s\n" + ".inst 0x4ea169da // bfcvtn2 v26.8h, v14.4s\n" + "str q31, [x21, #0x30]\n" + ".inst 0x4ea169b8 // bfcvtn2 v24.8h, v13.4s\n" + ".inst 0x4ea1698f // bfcvtn2 v15.8h, v12.4s\n" + "str q29, [x21, #0x40]\n" + ".inst 0x4ea16954 // bfcvtn2 v20.8h, v10.4s\n" + ".inst 0x4ea16a51 // bfcvtn2 v17.8h, v18.4s\n" + "str q4, [x21, #0x50]\n" + "str q28, [x21, #0x60]\n" + "str q26, [x21, #0x70]\n" + "str q24, [x21, #0x80]\n" + "str q15, [x21, #0x90]\n" + "str q20, [x21, #0xa0]\n" + "str q17, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Column loop skip + "cmp x20, #0x10\n" + "blt 15f\n" + "14:" // Tail row loop: width 16 loop: loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v30.4s, v21.4s, v17.4s\n" + "zip1 v29.4s, v20.4s, v16.4s\n" + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v28.4s, v21.4s, v17.4s\n" + "zip2 v27.4s, v20.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v26.4s, v19.4s, v17.4s\n" + "zip1 v25.4s, v18.4s, v16.4s\n" + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v8.4s, v19.4s, v17.4s\n" + "zip2 v24.4s, v18.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v7.4s, v21.4s, v17.4s\n" + "zip1 v6.4s, v20.4s, v16.4s\n" + "ldr q19, [x9], #0x10\n" + "ldr q18, [x28], #0x10\n" + "zip2 v5.4s, v21.4s, v17.4s\n" + "zip2 v4.4s, v20.4s, v16.4s\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v3.4s, v19.4s, v17.4s\n" + "zip1 v2.4s, v18.4s, v16.4s\n" + "zip2 v1.4s, v19.4s, v17.4s\n" + "zip2 v0.4s, v18.4s, v16.4s\n" + "zip1 v23.4s, v30.4s, v29.4s\n" + "zip1 v22.4s, v28.4s, v27.4s\n" + "zip1 v21.4s, v26.4s, v25.4s\n" + "zip1 v20.4s, v8.4s, v24.4s\n" + "zip1 v19.4s, v7.4s, v6.4s\n" + "zip1 v18.4s, v5.4s, v4.4s\n" + "zip1 v17.4s, v3.4s, v2.4s\n" + "zip1 v16.4s, v1.4s, v0.4s\n" + ".inst 0x0ea16aff // bfcvtn v31.4h, v23.4s\n" + "zip2 v30.4s, v30.4s, v29.4s\n" + ".inst 0x0ea16add // bfcvtn v29.4h, v22.4s\n" + "zip2 v28.4s, v28.4s, v27.4s\n" + ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" + "zip2 v26.4s, v26.4s, v25.4s\n" + ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n" + "zip2 v24.4s, v8.4s, v24.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v7.4s, v6.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v5.4s, v4.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v3.4s, v2.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v1.4s, v0.4s\n" + ".inst 0x4ea16bdf // bfcvtn2 v31.8h, v30.4s\n" + ".inst 0x4ea16b9d // bfcvtn2 v29.8h, v28.4s\n" + "str q31, [x21, #0x0]\n" + ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n" + ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" + "str q29, [x21, #0x10]\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + "str q27, [x21, #0x20]\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q25, [x21, #0x30]\n" + "str q23, [x21, #0x40]\n" + "str q21, [x21, #0x50]\n" + "str q19, [x21, #0x60]\n" + "str q17, [x21, #0x70]\n" + "add x21, x21, #0x80\n" + "bge 14b\n" + "15:" // Tail row loop: width 16 loop: skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr q20, [x9], #0x10\n" + "ldr q19, [x28], #0x10\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v22.4s, v20.4s, v17.4s\n" + "zip1 v18.4s, v19.4s, v16.4s\n" + "zip2 v21.4s, v20.4s, v17.4s\n" + "zip2 v20.4s, v19.4s, v16.4s\n" + "zip1 v17.4s, v22.4s, v18.4s\n" + "zip1 v16.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v22.4s, v18.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v21.4s, v20.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q19, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "add x21, x21, #0x20\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d16, [x21, #0x0]\n" + "add x21, x21, #0x8\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0xc0\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace +template<> +void Transform<24, 4, true, VLType::None>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_24_2x4_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24_bf16fp32.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24_bf16fp32.hpp new file mode 100644 index 0000000000..dcaf69d2a8 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24_bf16fp32.hpp @@ -0,0 +1,294 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_24_bf16fp32(float *out, const bfloat16 *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 12 * height * sizeof(float); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q18, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "sub x24, x24, #0x18\n" + "shll v26.4s, v18.4h, #0x10\n" + "ldr q16, [x22], #0x10\n" + "ldr q25, [x20], #0x10\n" + "shll2 v24.4s, v18.8h, #0x10\n" + "shll v5.4s, v17.4h, #0x10\n" + "ldr q23, [x25], #0x10\n" + "ldr q22, [x23], #0x10\n" + "shll v21.4s, v23.4h, #0x10\n" + "shll2 v4.4s, v17.8h, #0x10\n" + "ldr q20, [x22], #0x10\n" + "ldr q3, [x20], #0x10\n" + "shll v2.4s, v22.4h, #0x10\n" + "shll v1.4s, v16.4h, #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "shll2 v0.4s, v16.8h, #0x10\n" + "shll v31.4s, v20.4h, #0x10\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "shll v30.4s, v25.4h, #0x10\n" + "shll2 v29.4s, v25.8h, #0x10\n" + "shll v28.4s, v3.4h, #0x10\n" + "str q26, [x21, #0x0]\n" + "cmp x24, #0x18\n" + "shll2 v27.4s, v23.8h, #0x10\n" + "str q24, [x21, #0x10]\n" + "shll v26.4s, v19.4h, #0x10\n" + "shll2 v25.4s, v19.8h, #0x10\n" + "str q21, [x21, #0x20]\n" + "shll2 v24.4s, v22.8h, #0x10\n" + "shll v23.4s, v18.4h, #0x10\n" + "str q5, [x21, #0x30]\n" + "shll2 v22.4s, v18.8h, #0x10\n" + "shll2 v21.4s, v20.8h, #0x10\n" + "str q4, [x21, #0x40]\n" + "shll v20.4s, v17.4h, #0x10\n" + "shll2 v19.4s, v17.8h, #0x10\n" + "str q2, [x21, #0x50]\n" + "shll2 v18.4s, v3.8h, #0x10\n" + "shll v17.4s, v16.4h, #0x10\n" + "str q1, [x21, #0x60]\n" + "shll2 v16.4s, v16.8h, #0x10\n" + "str q0, [x21, #0x70]\n" + "str q31, [x21, #0x80]\n" + "str q30, [x21, #0x90]\n" + "str q29, [x21, #0xa0]\n" + "str q28, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "str q27, [x21, #0x0]\n" + "str q26, [x21, #0x10]\n" + "str q25, [x21, #0x20]\n" + "str q24, [x21, #0x30]\n" + "str q23, [x21, #0x40]\n" + "str q22, [x21, #0x50]\n" + "str q21, [x21, #0x60]\n" + "str q20, [x21, #0x70]\n" + "str q19, [x21, #0x80]\n" + "str q18, [x21, #0x90]\n" + "str q17, [x21, #0xa0]\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0xc\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr q16, [x25], #0x10\n" + "ldr q21, [x23], #0x10\n" + "sub x24, x24, #0xc\n" + "cmp x24, #0xc\n" + "ldr q20, [x22], #0x10\n" + "ldr q27, [x20], #0x10\n" + "shll v19.4s, v16.4h, #0x10\n" + "shll2 v26.4s, v16.8h, #0x10\n" + "ldr d16, [x25], #0x8\n" + "ldr d18, [x23], #0x8\n" + "shll v25.4s, v16.4h, #0x10\n" + "shll v24.4s, v21.4h, #0x10\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "shll2 v23.4s, v21.8h, #0x10\n" + "shll v22.4s, v18.4h, #0x10\n" + "shll v21.4s, v20.4h, #0x10\n" + "shll2 v20.4s, v20.8h, #0x10\n" + "str q19, [x21, #0x0]\n" + "shll v19.4s, v17.4h, #0x10\n" + "shll v18.4s, v27.4h, #0x10\n" + "str q26, [x21, #0x10]\n" + "shll2 v17.4s, v27.8h, #0x10\n" + "shll v16.4s, v16.4h, #0x10\n" + "str q25, [x21, #0x20]\n" + "str q24, [x21, #0x30]\n" + "str q23, [x21, #0x40]\n" + "str q22, [x21, #0x50]\n" + "str q21, [x21, #0x60]\n" + "str q20, [x21, #0x70]\n" + "str q19, [x21, #0x80]\n" + "str q18, [x21, #0x90]\n" + "str q17, [x21, #0xa0]\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr d19, [x25], #0x8\n" + "ldr d18, [x23], #0x8\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "shll v19.4s, v19.4h, #0x10\n" + "shll v18.4s, v18.4h, #0x10\n" + "shll v17.4s, v17.4h, #0x10\n" + "shll v16.4s, v16.4h, #0x10\n" + "str q19, [x21, #0x0]\n" + "str q18, [x21, #0x30]\n" + "str q17, [x21, #0x60]\n" + "str q16, [x21, #0x90]\n" + "add x21, x21, #0x10\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr h19, [x25], #0x2\n" + "ldr h18, [x23], #0x2\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr h17, [x22], #0x2\n" + "ldr h16, [x20], #0x2\n" + "shll v19.4s, v19.4h, #0x10\n" + "shll v18.4s, v18.4h, #0x10\n" + "shll v17.4s, v17.4h, #0x10\n" + "shll v16.4s, v16.4h, #0x10\n" + "str s19, [x21, #0x0]\n" + "str s18, [x21, #0x30]\n" + "str s17, [x21, #0x60]\n" + "str s16, [x21, #0x90]\n" + "add x21, x21, #0x4\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0xc0\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "cmp x20, #0x18\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 13f\n" + "12:" // Tail row loop: Unroll column loop + "ldr q16, [x25], #0x10\n" + "ldr q20, [x25], #0x10\n" + "sub x20, x20, #0x18\n" + "shll v18.4s, v16.4h, #0x10\n" + "ldr q19, [x25], #0x10\n" + "shll2 v17.4s, v16.8h, #0x10\n" + "shll v16.4s, v20.4h, #0x10\n" + "str q18, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "cmp x20, #0x18\n" + "shll2 v18.4s, v20.8h, #0x10\n" + "shll v17.4s, v19.4h, #0x10\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "shll2 v16.4s, v19.8h, #0x10\n" + "str q18, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Unroll column loop skip + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q17, [x25], #0x10\n" + "ldr d18, [x25], #0x8\n" + "sub x20, x20, #0xc\n" + "cmp x20, #0xc\n" + "shll v16.4s, v17.4h, #0x10\n" + "shll2 v17.4s, v17.8h, #0x10\n" + "str q16, [x21, #0x0]\n" + "shll v16.4s, v18.4h, #0x10\n" + "str q17, [x21, #0x10]\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr d16, [x25], #0x8\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "shll v16.4s, v16.4h, #0x10\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, #0x10\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr h16, [x25], #0x2\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "shll v16.4s, v16.4h, #0x10\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x30\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace +template<> +void Transform<12, 1, true, VLType::None>( + float *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_24_bf16fp32( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24_fp16fp32.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24_fp16fp32.hpp new file mode 100644 index 0000000000..966b75664e --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24_fp16fp32.hpp @@ -0,0 +1,294 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_24_fp16fp32(float *out, const __fp16 *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 12 * height * sizeof(float); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q18, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "sub x24, x24, #0x18\n" + "fcvtl v26.4s, v18.4h\n" + "ldr q16, [x22], #0x10\n" + "ldr q25, [x20], #0x10\n" + "fcvtl2 v24.4s, v18.8h\n" + "fcvtl v5.4s, v17.4h\n" + "ldr q23, [x25], #0x10\n" + "ldr q22, [x23], #0x10\n" + "fcvtl v21.4s, v23.4h\n" + "fcvtl2 v4.4s, v17.8h\n" + "ldr q20, [x22], #0x10\n" + "ldr q3, [x20], #0x10\n" + "fcvtl v2.4s, v22.4h\n" + "fcvtl v1.4s, v16.4h\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "fcvtl2 v0.4s, v16.8h\n" + "fcvtl v31.4s, v20.4h\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "fcvtl v30.4s, v25.4h\n" + "fcvtl2 v29.4s, v25.8h\n" + "fcvtl v28.4s, v3.4h\n" + "str q26, [x21, #0x0]\n" + "cmp x24, #0x18\n" + "fcvtl2 v27.4s, v23.8h\n" + "str q24, [x21, #0x10]\n" + "fcvtl v26.4s, v19.4h\n" + "fcvtl2 v25.4s, v19.8h\n" + "str q21, [x21, #0x20]\n" + "fcvtl2 v24.4s, v22.8h\n" + "fcvtl v23.4s, v18.4h\n" + "str q5, [x21, #0x30]\n" + "fcvtl2 v22.4s, v18.8h\n" + "fcvtl2 v21.4s, v20.8h\n" + "str q4, [x21, #0x40]\n" + "fcvtl v20.4s, v17.4h\n" + "fcvtl2 v19.4s, v17.8h\n" + "str q2, [x21, #0x50]\n" + "fcvtl2 v18.4s, v3.8h\n" + "fcvtl v17.4s, v16.4h\n" + "str q1, [x21, #0x60]\n" + "fcvtl2 v16.4s, v16.8h\n" + "str q0, [x21, #0x70]\n" + "str q31, [x21, #0x80]\n" + "str q30, [x21, #0x90]\n" + "str q29, [x21, #0xa0]\n" + "str q28, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "str q27, [x21, #0x0]\n" + "str q26, [x21, #0x10]\n" + "str q25, [x21, #0x20]\n" + "str q24, [x21, #0x30]\n" + "str q23, [x21, #0x40]\n" + "str q22, [x21, #0x50]\n" + "str q21, [x21, #0x60]\n" + "str q20, [x21, #0x70]\n" + "str q19, [x21, #0x80]\n" + "str q18, [x21, #0x90]\n" + "str q17, [x21, #0xa0]\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0xc\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr q16, [x25], #0x10\n" + "ldr q21, [x23], #0x10\n" + "sub x24, x24, #0xc\n" + "cmp x24, #0xc\n" + "ldr q20, [x22], #0x10\n" + "ldr q27, [x20], #0x10\n" + "fcvtl v19.4s, v16.4h\n" + "fcvtl2 v26.4s, v16.8h\n" + "ldr d16, [x25], #0x8\n" + "ldr d18, [x23], #0x8\n" + "fcvtl v25.4s, v16.4h\n" + "fcvtl v24.4s, v21.4h\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "fcvtl2 v23.4s, v21.8h\n" + "fcvtl v22.4s, v18.4h\n" + "fcvtl v21.4s, v20.4h\n" + "fcvtl2 v20.4s, v20.8h\n" + "str q19, [x21, #0x0]\n" + "fcvtl v19.4s, v17.4h\n" + "fcvtl v18.4s, v27.4h\n" + "str q26, [x21, #0x10]\n" + "fcvtl2 v17.4s, v27.8h\n" + "fcvtl v16.4s, v16.4h\n" + "str q25, [x21, #0x20]\n" + "str q24, [x21, #0x30]\n" + "str q23, [x21, #0x40]\n" + "str q22, [x21, #0x50]\n" + "str q21, [x21, #0x60]\n" + "str q20, [x21, #0x70]\n" + "str q19, [x21, #0x80]\n" + "str q18, [x21, #0x90]\n" + "str q17, [x21, #0xa0]\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr d19, [x25], #0x8\n" + "ldr d18, [x23], #0x8\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "fcvtl v19.4s, v19.4h\n" + "fcvtl v18.4s, v18.4h\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v16.4s, v16.4h\n" + "str q19, [x21, #0x0]\n" + "str q18, [x21, #0x30]\n" + "str q17, [x21, #0x60]\n" + "str q16, [x21, #0x90]\n" + "add x21, x21, #0x10\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr h19, [x25], #0x2\n" + "ldr h18, [x23], #0x2\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr h17, [x22], #0x2\n" + "ldr h16, [x20], #0x2\n" + "fcvtl v19.4s, v19.4h\n" + "fcvtl v18.4s, v18.4h\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v16.4s, v16.4h\n" + "str s19, [x21, #0x0]\n" + "str s18, [x21, #0x30]\n" + "str s17, [x21, #0x60]\n" + "str s16, [x21, #0x90]\n" + "add x21, x21, #0x4\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0xc0\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "cmp x20, #0x18\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 13f\n" + "12:" // Tail row loop: Unroll column loop + "ldr q16, [x25], #0x10\n" + "ldr q20, [x25], #0x10\n" + "sub x20, x20, #0x18\n" + "fcvtl v18.4s, v16.4h\n" + "ldr q19, [x25], #0x10\n" + "fcvtl2 v17.4s, v16.8h\n" + "fcvtl v16.4s, v20.4h\n" + "str q18, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "cmp x20, #0x18\n" + "fcvtl2 v18.4s, v20.8h\n" + "fcvtl v17.4s, v19.4h\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "fcvtl2 v16.4s, v19.8h\n" + "str q18, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Unroll column loop skip + "cmp x20, #0xc\n" + "blt 15f\n" + "14:" // Tail row loop: Column loop + "ldr q17, [x25], #0x10\n" + "ldr d18, [x25], #0x8\n" + "sub x20, x20, #0xc\n" + "cmp x20, #0xc\n" + "fcvtl v16.4s, v17.4h\n" + "fcvtl2 v17.4s, v17.8h\n" + "str q16, [x21, #0x0]\n" + "fcvtl v16.4s, v18.4h\n" + "str q17, [x21, #0x10]\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Column loop skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr d16, [x25], #0x8\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "fcvtl v16.4s, v16.4h\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, #0x10\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr h16, [x25], #0x2\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "fcvtl v16.4s, v16.4h\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x30\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace +template<> +void Transform<12, 1, true, VLType::None>( + float *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_24_fp16fp32( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(__fp16), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp deleted file mode 100644 index 6d627334cd..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright (c) 2017-2019 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ - -#include "transpose_interleave_common.hpp" - -// Generic unblocked transposed 12x32-bit sized specialisation -template <> -template <typename T> -inline void TransformImpl<12, 1, true, 4, 4, false>::Transform( - T* out, const T* const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax -) { - // Redirect to a 24 x uint16_t specialisation - TransformImpl<24, 1, true, 2, 2, false>::Transform( - reinterpret_cast<uint16_t *>(out), - reinterpret_cast<const uint16_t * const>(in), - stride*2, x0*2, xmax*2, k0, kmax - ); -} - -// Generic 24x16-bit sized specialisation -template <> -template <typename T> -inline void TransformImpl<24, 1, true, 2, 2, false>::Transform( - T* out, const T* const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax -) { - // Redirect to a uint16_t specialisation - Transform( - reinterpret_cast<uint16_t *>(out), - reinterpret_cast<const uint16_t * const>(in), - stride, x0, xmax, k0, kmax - ); -} - -// Specialised 24 x uint16_t version -template <> -inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *out) { - __asm __volatile ( - "LDP q0, q1, [%[in0]], #32\n" - "STP q0, q1, [%[out]]\n" - ASM_PREFETCH("[%[in0], #192]") - "LDR q2, [%[in0]], #16\n" - "STR q2, [%[out], #32]\n" - : [in0] "+r" (in0), [out] "+r" (out) - : - : "v0", "v1", "v2", "memory" - ); -} - -template <> -inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1,uint16_t *out) { - __asm __volatile ( - "LDP q0, q1, [%[in0]], #32\n" - "STP q0, q1, [%[out]]\n" - ASM_PREFETCH("[%[in0], #192]") - "LDR q2, [%[in0]], #16\n" - "LDP q3, q4, [%[in1]], #32\n" - "STP q2, q3, [%[out], #32]\n" - ASM_PREFETCH("[%[in1], #192]") - "LDR q5, [%[in1]], #16\n" - "STP q4, q5, [%[out], #64]\n" - : [in0] "+r" (in0), [in1] "+r" (in1), [out] "+r" (out) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "memory" - ); -} - -template <> -inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *out) { - __asm __volatile ( - "LDP q0, q1, [%[in0]], #32\n" - "STP q0, q1, [%[out]]\n" - "LDR q2, [%[in0]], #16\n" - ASM_PREFETCH("[%[in0], #192]") - "LDP q3, q4, [%[in1]], #32\n" - "STP q2, q3, [%[out], #32]\n" - "LDR q5, [%[in1]], #16\n" - ASM_PREFETCH("[%[in1], #192]") - "STP q4, q5, [%[out], #64]\n" - "LDP q6, q7, [%[in2]], #32\n" - "STP q6, q7, [%[out], #96]\n" - "LDR q8, [%[in2]], #16\n" - ASM_PREFETCH("[%[in2], #192]") - "LDP q9, q10, [%[in3]], #32\n" - "STP q8, q9, [%[out], #128]\n" - "LDR q11, [%[in3]], #16\n" - "STP q10, q11, [%[out], #160]\n" - ASM_PREFETCH("[%[in3], #192]") - - : [in0] "+r" (in0), [in1] "+r" (in1), [in2] "+r" (in2), [in3] "+r" (in3), [out] "+r" (out) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "memory" - ); -} - -template <> -template <> -inline void TransformImpl<24, 1, true, 2, 2, false>::Transform( - uint16_t* out, const uint16_t* const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax -) { - TransposeInterleaveCommon<24, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax); -} - -#endif // __arch64__ diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_32_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_32_1x4.hpp new file mode 100644 index 0000000000..4a22675028 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_32_1x4.hpp @@ -0,0 +1,507 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_32_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 32 * roundup<size_t>(height, 4) * sizeof(uint8_t); + + __asm__ __volatile__( + "cmp %x[height], #0x10\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x17, %x[in]\n" + "add x16, x17, %x[in_stride]\n" + "add x15, x16, %x[in_stride]\n" + "add x14, x15, %x[in_stride]\n" + "add x13, x14, %x[in_stride]\n" + "add x12, x13, %x[in_stride]\n" + "add x11, x12, %x[in_stride]\n" + "add x10, x11, %x[in_stride]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x20\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x10\n" + "blt 3f\n" + "2:" // Main row loop: Column loop + "ldr q6, [x17], #0x10\n" + "ldr q31, [x16], #0x10\n" + "sub x24, x24, #0x20\n" + "cmp x24, #0x20\n" + "ldr q7, [x15], #0x10\n" + "ldr q0, [x14], #0x10\n" + "zip1 v9.16b, v6.16b, v7.16b\n" + "zip1 v20.16b, v31.16b, v0.16b\n" + "ldr q24, [x13], #0x10\n" + "ldr q19, [x12], #0x10\n" + "zip2 v30.16b, v6.16b, v7.16b\n" + "zip2 v12.16b, v31.16b, v0.16b\n" + "ldr q23, [x11], #0x10\n" + "ldr q17, [x10], #0x10\n" + "zip1 v13.16b, v24.16b, v23.16b\n" + "zip1 v16.16b, v19.16b, v17.16b\n" + "ldr q0, [x9], #0x10\n" + "ldr q31, [x28], #0x10\n" + "zip2 v15.16b, v24.16b, v23.16b\n" + "zip2 v11.16b, v19.16b, v17.16b\n" + "ldr q17, [x27], #0x10\n" + "ldr q4, [x26], #0x10\n" + "zip1 v1.16b, v0.16b, v17.16b\n" + "zip1 v21.16b, v31.16b, v4.16b\n" + "ldr q28, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v0.16b, v0.16b, v17.16b\n" + "zip2 v26.16b, v31.16b, v4.16b\n" + "ldr q17, [x22], #0x10\n" + "ldr q19, [x20], #0x10\n" + "zip1 v23.16b, v28.16b, v17.16b\n" + "zip1 v25.16b, v18.16b, v19.16b\n" + "ldr q2, [x17], #0x10\n" + "ldr q3, [x16], #0x10\n" + "zip2 v7.16b, v28.16b, v17.16b\n" + "zip2 v8.16b, v18.16b, v19.16b\n" + "ldr q22, [x15], #0x10\n" + "ldr q27, [x14], #0x10\n" + "zip1 v19.16b, v2.16b, v22.16b\n" + "zip1 v17.16b, v3.16b, v27.16b\n" + "ldr q6, [x13], #0x10\n" + "ldr q4, [x12], #0x10\n" + "zip2 v24.16b, v2.16b, v22.16b\n" + "zip2 v22.16b, v3.16b, v27.16b\n" + "ldr q14, [x11], #0x10\n" + "ldr q18, [x10], #0x10\n" + "zip1 v29.16b, v6.16b, v14.16b\n" + "zip1 v31.16b, v4.16b, v18.16b\n" + "ldr q2, [x9], #0x10\n" + "ldr q10, [x28], #0x10\n" + "zip2 v28.16b, v6.16b, v14.16b\n" + "zip2 v27.16b, v4.16b, v18.16b\n" + "ldr q6, [x27], #0x10\n" + "ldr q5, [x26], #0x10\n" + "zip1 v14.16b, v2.16b, v6.16b\n" + "zip1 v4.16b, v10.16b, v5.16b\n" + "ldr q3, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v6.16b, v2.16b, v6.16b\n" + "zip2 v10.16b, v10.16b, v5.16b\n" + "ldr q5, [x22], #0x10\n" + "zip1 v2.16b, v3.16b, v5.16b\n" + "zip2 v3.16b, v3.16b, v5.16b\n" + "zip1 v5.16b, v9.16b, v20.16b\n" + "zip2 v20.16b, v9.16b, v20.16b\n" + "ldr q9, [x20], #0x10\n" + "str q5, [x21, #0x0]\n" + "zip1 v5.16b, v18.16b, v9.16b\n" + "zip2 v9.16b, v18.16b, v9.16b\n" + "str q20, [x21, #0x10]\n" + "zip1 v18.16b, v30.16b, v12.16b\n" + "zip2 v30.16b, v30.16b, v12.16b\n" + "str q18, [x21, #0x20]\n" + "zip1 v20.16b, v19.16b, v17.16b\n" + "zip2 v12.16b, v19.16b, v17.16b\n" + "str q30, [x21, #0x30]\n" + "zip1 v18.16b, v24.16b, v22.16b\n" + "zip2 v17.16b, v24.16b, v22.16b\n" + "str q20, [x21, #0x40]\n" + "zip1 v30.16b, v13.16b, v16.16b\n" + "zip2 v24.16b, v13.16b, v16.16b\n" + "str q12, [x21, #0x50]\n" + "zip1 v22.16b, v15.16b, v11.16b\n" + "zip2 v20.16b, v15.16b, v11.16b\n" + "str q18, [x21, #0x60]\n" + "zip1 v19.16b, v29.16b, v31.16b\n" + "zip2 v18.16b, v29.16b, v31.16b\n" + "str q17, [x21, #0x70]\n" + "zip1 v17.16b, v28.16b, v27.16b\n" + "zip2 v16.16b, v28.16b, v27.16b\n" + "str q30, [x21, #0x80]\n" + "zip1 v31.16b, v1.16b, v21.16b\n" + "zip2 v1.16b, v1.16b, v21.16b\n" + "str q24, [x21, #0x90]\n" + "zip1 v30.16b, v0.16b, v26.16b\n" + "zip2 v29.16b, v0.16b, v26.16b\n" + "str q22, [x21, #0xa0]\n" + "zip1 v28.16b, v14.16b, v4.16b\n" + "zip2 v27.16b, v14.16b, v4.16b\n" + "str q20, [x21, #0xb0]\n" + "zip1 v26.16b, v6.16b, v10.16b\n" + "zip2 v24.16b, v6.16b, v10.16b\n" + "str q19, [x21, #0xc0]\n" + "zip1 v14.16b, v23.16b, v25.16b\n" + "zip2 v22.16b, v23.16b, v25.16b\n" + "str q18, [x21, #0xd0]\n" + "zip1 v21.16b, v7.16b, v8.16b\n" + "zip2 v20.16b, v7.16b, v8.16b\n" + "str q17, [x21, #0xe0]\n" + "zip1 v19.16b, v2.16b, v5.16b\n" + "zip2 v18.16b, v2.16b, v5.16b\n" + "str q16, [x21, #0xf0]\n" + "zip1 v17.16b, v3.16b, v9.16b\n" + "zip2 v16.16b, v3.16b, v9.16b\n" + "str q31, [x21, #0x100]\n" + "str q1, [x21, #0x110]\n" + "str q30, [x21, #0x120]\n" + "str q29, [x21, #0x130]\n" + "str q28, [x21, #0x140]\n" + "str q27, [x21, #0x150]\n" + "str q26, [x21, #0x160]\n" + "str q24, [x21, #0x170]\n" + "str q14, [x21, #0x180]\n" + "str q22, [x21, #0x190]\n" + "str q21, [x21, #0x1a0]\n" + "str q20, [x21, #0x1b0]\n" + "str q19, [x21, #0x1c0]\n" + "str q18, [x21, #0x1d0]\n" + "str q17, [x21, #0x1e0]\n" + "str q16, [x21, #0x1f0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Column loop skip + "cmp x24, #0x10\n" + "blt 5f\n" + "4:" // Main row loop: width 16 loop: loop + "ldr q21, [x17], #0x10\n" + "ldr q20, [x16], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q17, [x15], #0x10\n" + "ldr q16, [x14], #0x10\n" + "zip1 v3.16b, v21.16b, v17.16b\n" + "zip1 v2.16b, v20.16b, v16.16b\n" + "ldr q19, [x13], #0x10\n" + "ldr q18, [x12], #0x10\n" + "zip2 v1.16b, v21.16b, v17.16b\n" + "zip2 v0.16b, v20.16b, v16.16b\n" + "ldr q17, [x11], #0x10\n" + "ldr q16, [x10], #0x10\n" + "zip1 v31.16b, v19.16b, v17.16b\n" + "zip1 v30.16b, v18.16b, v16.16b\n" + "ldr q24, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v29.16b, v19.16b, v17.16b\n" + "zip2 v23.16b, v18.16b, v16.16b\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v22.16b, v24.16b, v17.16b\n" + "zip1 v21.16b, v20.16b, v16.16b\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v28.16b, v24.16b, v17.16b\n" + "zip2 v20.16b, v20.16b, v16.16b\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v27.16b, v19.16b, v17.16b\n" + "zip1 v26.16b, v18.16b, v16.16b\n" + "zip2 v25.16b, v19.16b, v17.16b\n" + "zip2 v24.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v3.16b, v2.16b\n" + "zip2 v18.16b, v3.16b, v2.16b\n" + "str q16, [x21, #0x0]\n" + "zip1 v17.16b, v1.16b, v0.16b\n" + "zip2 v16.16b, v1.16b, v0.16b\n" + "str q18, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "zip1 v19.16b, v31.16b, v30.16b\n" + "zip2 v18.16b, v31.16b, v30.16b\n" + "str q16, [x21, #0x30]\n" + "zip1 v17.16b, v29.16b, v23.16b\n" + "zip2 v16.16b, v29.16b, v23.16b\n" + "str q19, [x21, #0x80]\n" + "zip1 v23.16b, v22.16b, v21.16b\n" + "zip2 v22.16b, v22.16b, v21.16b\n" + "str q18, [x21, #0x90]\n" + "zip1 v21.16b, v28.16b, v20.16b\n" + "zip2 v20.16b, v28.16b, v20.16b\n" + "str q17, [x21, #0xa0]\n" + "zip1 v19.16b, v27.16b, v26.16b\n" + "zip2 v18.16b, v27.16b, v26.16b\n" + "str q16, [x21, #0xb0]\n" + "zip1 v17.16b, v25.16b, v24.16b\n" + "zip2 v16.16b, v25.16b, v24.16b\n" + "str q23, [x21, #0x100]\n" + "str q22, [x21, #0x110]\n" + "str q21, [x21, #0x120]\n" + "str q20, [x21, #0x130]\n" + "str q19, [x21, #0x180]\n" + "str q18, [x21, #0x190]\n" + "str q17, [x21, #0x1a0]\n" + "str q16, [x21, #0x1b0]\n" + "add x21, x21, #0x40\n" + "bge 4b\n" + "5:" // Main row loop: width 16 loop: skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr s19, [x17], #0x4\n" + "ldr s18, [x16], #0x4\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr s17, [x15], #0x4\n" + "ldr s16, [x14], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr s19, [x13], #0x4\n" + "ldr s18, [x12], #0x4\n" + "zip1 v22.16b, v17.16b, v16.16b\n" + "ldr s17, [x11], #0x4\n" + "ldr s16, [x10], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "zip1 v21.16b, v17.16b, v16.16b\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr s20, [x25], #0x4\n" + "ldr s19, [x23], #0x4\n" + "zip1 v18.16b, v17.16b, v16.16b\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v17.16b, v20.16b, v17.16b\n" + "zip1 v16.16b, v19.16b, v16.16b\n" + "str q22, [x21, #0x0]\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str q21, [x21, #0x80]\n" + "str q18, [x21, #0x100]\n" + "str q16, [x21, #0x180]\n" + "add x21, x21, #0x10\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr b19, [x17], #0x1\n" + "ldr b18, [x16], #0x1\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr b17, [x15], #0x1\n" + "ldr b16, [x14], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr b19, [x13], #0x1\n" + "ldr b18, [x12], #0x1\n" + "zip1 v22.16b, v17.16b, v16.16b\n" + "ldr b17, [x11], #0x1\n" + "ldr b16, [x10], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr b19, [x9], #0x1\n" + "ldr b18, [x28], #0x1\n" + "zip1 v21.16b, v17.16b, v16.16b\n" + "ldr b17, [x27], #0x1\n" + "ldr b16, [x26], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr b20, [x25], #0x1\n" + "ldr b19, [x23], #0x1\n" + "zip1 v18.16b, v17.16b, v16.16b\n" + "ldr b17, [x22], #0x1\n" + "ldr b16, [x20], #0x1\n" + "zip1 v17.16b, v20.16b, v17.16b\n" + "zip1 v16.16b, v19.16b, v16.16b\n" + "str s22, [x21, #0x0]\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str s21, [x21, #0x80]\n" + "str s18, [x21, #0x100]\n" + "str s16, [x21, #0x180]\n" + "add x21, x21, #0x4\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x10\n" + "add %x[out], %x[out], #0x200\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x17, %x[in]\n" + "add x16, x17, %x[in_stride]\n" + "add x15, x16, %x[in_stride]\n" + "mov x20, %x[width]\n" + "add x14, x15, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x14, %x[in_stride]\n" + "csel x14, x14, %x[pad_row], GT\n" + "csel x15, x15, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x16, x16, %x[pad_row], GT\n" + "cmp x20, #0x20\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: Column loop + "ldr q19, [x17], #0x10\n" + "ldr q18, [x16], #0x10\n" + "sub x20, x20, #0x20\n" + "cmp x20, #0x20\n" + "ldr q17, [x15], #0x10\n" + "ldr q16, [x14], #0x10\n" + "zip1 v25.16b, v19.16b, v17.16b\n" + "zip1 v24.16b, v18.16b, v16.16b\n" + "ldr q22, [x17], #0x10\n" + "ldr q21, [x16], #0x10\n" + "zip2 v20.16b, v19.16b, v17.16b\n" + "zip2 v19.16b, v18.16b, v16.16b\n" + "ldr q17, [x15], #0x10\n" + "ldr q16, [x14], #0x10\n" + "zip1 v23.16b, v22.16b, v17.16b\n" + "zip1 v18.16b, v21.16b, v16.16b\n" + "zip2 v22.16b, v22.16b, v17.16b\n" + "zip2 v21.16b, v21.16b, v16.16b\n" + "zip1 v16.16b, v25.16b, v24.16b\n" + "zip2 v17.16b, v25.16b, v24.16b\n" + "str q16, [x21, #0x0]\n" + "zip1 v16.16b, v20.16b, v19.16b\n" + "zip2 v20.16b, v20.16b, v19.16b\n" + "str q17, [x21, #0x10]\n" + "zip1 v19.16b, v23.16b, v18.16b\n" + "zip2 v18.16b, v23.16b, v18.16b\n" + "str q16, [x21, #0x20]\n" + "zip1 v17.16b, v22.16b, v21.16b\n" + "zip2 v16.16b, v22.16b, v21.16b\n" + "str q20, [x21, #0x30]\n" + "str q19, [x21, #0x40]\n" + "str q18, [x21, #0x50]\n" + "str q17, [x21, #0x60]\n" + "str q16, [x21, #0x70]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Column loop skip + "cmp x20, #0x10\n" + "blt 15f\n" + "14:" // Tail row loop: width 16 loop: loop + "ldr q20, [x17], #0x10\n" + "ldr q21, [x16], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "ldr q19, [x15], #0x10\n" + "ldr q16, [x14], #0x10\n" + "zip1 v18.16b, v20.16b, v19.16b\n" + "zip1 v17.16b, v21.16b, v16.16b\n" + "zip2 v20.16b, v20.16b, v19.16b\n" + "zip2 v19.16b, v21.16b, v16.16b\n" + "zip1 v16.16b, v18.16b, v17.16b\n" + "zip2 v18.16b, v18.16b, v17.16b\n" + "str q16, [x21, #0x0]\n" + "zip1 v17.16b, v20.16b, v19.16b\n" + "zip2 v16.16b, v20.16b, v19.16b\n" + "str q18, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, #0x40\n" + "bge 14b\n" + "15:" // Tail row loop: width 16 loop: skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr s19, [x17], #0x4\n" + "ldr s18, [x16], #0x4\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "ldr s17, [x15], #0x4\n" + "ldr s16, [x14], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, #0x10\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr b19, [x17], #0x1\n" + "ldr b18, [x16], #0x1\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "ldr b17, [x15], #0x1\n" + "ldr b16, [x14], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x80\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<32, 4, true, VLType::None>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_32_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<32, 4, true, VLType::None>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_32_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_32_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_32_2x2.hpp new file mode 100644 index 0000000000..237536697c --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_32_2x2.hpp @@ -0,0 +1,451 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_32_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 32 * roundup<size_t>(height, 2) * sizeof(uint16_t); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 12f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x40\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q14, [x25], #0x10\n" + "ldr q10, [x23], #0x10\n" + "sub x24, x24, #0x40\n" + "zip1 v12.8h, v14.8h, v10.8h\n" + "ldr q5, [x22], #0x10\n" + "ldr q3, [x20], #0x10\n" + "zip2 v31.8h, v14.8h, v10.8h\n" + "zip1 v19.8h, v5.8h, v3.8h\n" + "ldr q27, [x25], #0x10\n" + "ldr q25, [x23], #0x10\n" + "zip1 v11.8h, v27.8h, v25.8h\n" + "zip2 v24.8h, v27.8h, v25.8h\n" + "ldr q6, [x22], #0x10\n" + "ldr q29, [x20], #0x10\n" + "zip2 v15.8h, v5.8h, v3.8h\n" + "zip1 v18.8h, v6.8h, v29.8h\n" + "ldr q17, [x25], #0x10\n" + "ldr q9, [x23], #0x10\n" + "zip1 v0.8h, v17.8h, v9.8h\n" + "zip2 v9.8h, v17.8h, v9.8h\n" + "ldr q21, [x22], #0x10\n" + "ldr q20, [x20], #0x10\n" + "zip2 v8.8h, v6.8h, v29.8h\n" + "zip1 v30.8h, v21.8h, v20.8h\n" + "ldr q17, [x25], #0x10\n" + "ldr q5, [x23], #0x10\n" + "zip1 v13.8h, v17.8h, v5.8h\n" + "zip2 v25.8h, v17.8h, v5.8h\n" + "ldr q7, [x22], #0x10\n" + "ldr q29, [x20], #0x10\n" + "zip2 v27.8h, v21.8h, v20.8h\n" + "zip1 v14.8h, v7.8h, v29.8h\n" + "ldr q28, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "zip2 v1.8h, v7.8h, v29.8h\n" + "cmp x24, #0x40\n" + "ldr q10, [x22], #0x10\n" + "ldr q21, [x20], #0x10\n" + "zip1 v16.8h, v28.8h, v17.8h\n" + "zip2 v17.8h, v28.8h, v17.8h\n" + "ldr q5, [x25], #0x10\n" + "ldr q20, [x23], #0x10\n" + "zip1 v3.8h, v5.8h, v20.8h\n" + "zip2 v7.8h, v5.8h, v20.8h\n" + "ldr q22, [x22], #0x10\n" + "ldr q29, [x20], #0x10\n" + "zip1 v2.8h, v10.8h, v21.8h\n" + "zip2 v5.8h, v10.8h, v21.8h\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x23], #0x10\n" + "zip1 v4.8h, v21.8h, v20.8h\n" + "zip2 v28.8h, v21.8h, v20.8h\n" + "ldr q6, [x22], #0x10\n" + "ldr q10, [x20], #0x10\n" + "zip1 v26.8h, v22.8h, v29.8h\n" + "zip2 v20.8h, v22.8h, v29.8h\n" + "ldr q29, [x25], #0x10\n" + "ldr q23, [x23], #0x10\n" + "zip1 v21.8h, v29.8h, v23.8h\n" + "zip2 v23.8h, v29.8h, v23.8h\n" + "ldr q22, [x22], #0x10\n" + "ldr q29, [x20], #0x10\n" + "str q12, [x21, #0x0]\n" + "zip1 v12.8h, v6.8h, v10.8h\n" + "str q31, [x21, #0x10]\n" + "zip2 v6.8h, v6.8h, v10.8h\n" + "zip1 v31.8h, v22.8h, v29.8h\n" + "str q11, [x21, #0x20]\n" + "zip2 v11.8h, v22.8h, v29.8h\n" + "str q24, [x21, #0x30]\n" + "str q0, [x21, #0x40]\n" + "str q9, [x21, #0x50]\n" + "str q13, [x21, #0x60]\n" + "str q25, [x21, #0x70]\n" + "str q19, [x21, #0x80]\n" + "str q15, [x21, #0x90]\n" + "str q18, [x21, #0xa0]\n" + "str q8, [x21, #0xb0]\n" + "str q30, [x21, #0xc0]\n" + "str q27, [x21, #0xd0]\n" + "str q14, [x21, #0xe0]\n" + "str q1, [x21, #0xf0]\n" + "add x21, x21, %x[out_stride]\n" + "str q16, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "str q3, [x21, #0x20]\n" + "str q7, [x21, #0x30]\n" + "str q4, [x21, #0x40]\n" + "str q28, [x21, #0x50]\n" + "str q21, [x21, #0x60]\n" + "str q23, [x21, #0x70]\n" + "str q2, [x21, #0x80]\n" + "str q5, [x21, #0x90]\n" + "str q26, [x21, #0xa0]\n" + "str q20, [x21, #0xb0]\n" + "str q12, [x21, #0xc0]\n" + "str q6, [x21, #0xd0]\n" + "str q31, [x21, #0xe0]\n" + "str q11, [x21, #0xf0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0x20\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr q17, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "sub x24, x24, #0x20\n" + "cmp x24, #0x20\n" + "ldr q21, [x22], #0x10\n" + "ldr q18, [x20], #0x10\n" + "zip1 v1.8h, v17.8h, v16.8h\n" + "zip2 v0.8h, v17.8h, v16.8h\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v31.8h, v17.8h, v16.8h\n" + "zip2 v30.8h, v17.8h, v16.8h\n" + "ldr q20, [x22], #0x10\n" + "ldr q19, [x20], #0x10\n" + "zip1 v29.8h, v21.8h, v18.8h\n" + "zip2 v28.8h, v21.8h, v18.8h\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v27.8h, v17.8h, v16.8h\n" + "zip2 v26.8h, v17.8h, v16.8h\n" + "ldr q25, [x22], #0x10\n" + "ldr q18, [x20], #0x10\n" + "zip1 v24.8h, v20.8h, v19.8h\n" + "zip2 v23.8h, v20.8h, v19.8h\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v22.8h, v17.8h, v16.8h\n" + "zip2 v21.8h, v17.8h, v16.8h\n" + "ldr q20, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v19.8h, v25.8h, v18.8h\n" + "zip2 v18.8h, v25.8h, v18.8h\n" + "zip1 v17.8h, v20.8h, v16.8h\n" + "zip2 v16.8h, v20.8h, v16.8h\n" + "str q1, [x21, #0x0]\n" + "str q0, [x21, #0x10]\n" + "str q31, [x21, #0x20]\n" + "str q30, [x21, #0x30]\n" + "str q27, [x21, #0x40]\n" + "str q26, [x21, #0x50]\n" + "str q22, [x21, #0x60]\n" + "str q21, [x21, #0x70]\n" + "str q29, [x21, #0x80]\n" + "str q28, [x21, #0x90]\n" + "str q24, [x21, #0xa0]\n" + "str q23, [x21, #0xb0]\n" + "str q19, [x21, #0xc0]\n" + "str q18, [x21, #0xd0]\n" + "str q17, [x21, #0xe0]\n" + "str q16, [x21, #0xf0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x10\n" + "blt 7f\n" + "6:" // Main row loop: width 16 loop: loop + "ldr q17, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q24, [x22], #0x10\n" + "ldr q23, [x20], #0x10\n" + "zip1 v19.8h, v17.8h, v16.8h\n" + "zip2 v18.8h, v17.8h, v16.8h\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v22.8h, v17.8h, v16.8h\n" + "zip2 v21.8h, v17.8h, v16.8h\n" + "ldr q20, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q19, [x21, #0x0]\n" + "zip1 v19.8h, v24.8h, v23.8h\n" + "str q18, [x21, #0x10]\n" + "zip2 v18.8h, v24.8h, v23.8h\n" + "zip1 v17.8h, v20.8h, v16.8h\n" + "str q22, [x21, #0x20]\n" + "zip2 v16.8h, v20.8h, v16.8h\n" + "str q21, [x21, #0x30]\n" + "str q19, [x21, #0x80]\n" + "str q18, [x21, #0x90]\n" + "str q17, [x21, #0xa0]\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, #0x40\n" + "bge 6b\n" + "7:" // Main row loop: width 16 loop: skip + "cmp x24, #0x4\n" + "blt 9f\n" + "8:" // Main row loop: width 4 loop: loop + "ldr d19, [x25], #0x8\n" + "ldr d16, [x23], #0x8\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr d18, [x22], #0x8\n" + "ldr d17, [x20], #0x8\n" + "zip1 v16.8h, v19.8h, v16.8h\n" + "str q16, [x21, #0x0]\n" + "zip1 v16.8h, v18.8h, v17.8h\n" + "str q16, [x21, #0x80]\n" + "add x21, x21, #0x10\n" + "bge 8b\n" + "9:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 11f\n" + "10:" // Main row loop: width 1 loop: loop + "ldr h19, [x25], #0x2\n" + "ldr h16, [x23], #0x2\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr h18, [x22], #0x2\n" + "ldr h17, [x20], #0x2\n" + "zip1 v16.8h, v19.8h, v16.8h\n" + "str s16, [x21, #0x0]\n" + "zip1 v16.8h, v18.8h, v17.8h\n" + "str s16, [x21, #0x80]\n" + "add x21, x21, #0x4\n" + "bge 10b\n" + "11:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0x100\n" + "bge 1b\n" + "cbz %x[height], 24f\n" + "12:" // Main loop skip + "13:" // Tail row loop: Head + "mov x25, %x[in]\n" + "mov x20, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x23, %x[in_stride]\n" + "csel x23, x23, %x[pad_row], GT\n" + "cmp x20, #0x40\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 15f\n" + "14:" // Tail row loop: Unroll column loop + "ldr q18, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "sub x20, x20, #0x40\n" + "zip1 v0.8h, v18.8h, v17.8h\n" + "ldr q19, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip2 v31.8h, v18.8h, v17.8h\n" + "zip1 v30.8h, v19.8h, v16.8h\n" + "ldr q18, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "zip2 v29.8h, v19.8h, v16.8h\n" + "zip1 v28.8h, v18.8h, v17.8h\n" + "ldr q19, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip2 v27.8h, v18.8h, v17.8h\n" + "zip1 v26.8h, v19.8h, v16.8h\n" + "ldr q18, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "zip2 v25.8h, v19.8h, v16.8h\n" + "cmp x20, #0x40\n" + "ldr q19, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v24.8h, v18.8h, v17.8h\n" + "zip2 v23.8h, v18.8h, v17.8h\n" + "ldr q18, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "zip1 v22.8h, v19.8h, v16.8h\n" + "zip2 v21.8h, v19.8h, v16.8h\n" + "ldr q20, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "str q0, [x21, #0x0]\n" + "zip1 v19.8h, v18.8h, v17.8h\n" + "str q31, [x21, #0x10]\n" + "zip2 v18.8h, v18.8h, v17.8h\n" + "zip1 v17.8h, v20.8h, v16.8h\n" + "str q30, [x21, #0x20]\n" + "zip2 v16.8h, v20.8h, v16.8h\n" + "str q29, [x21, #0x30]\n" + "str q28, [x21, #0x40]\n" + "str q27, [x21, #0x50]\n" + "str q26, [x21, #0x60]\n" + "str q25, [x21, #0x70]\n" + "add x21, x21, %x[out_stride]\n" + "str q24, [x21, #0x0]\n" + "str q23, [x21, #0x10]\n" + "str q22, [x21, #0x20]\n" + "str q21, [x21, #0x30]\n" + "str q19, [x21, #0x40]\n" + "str q18, [x21, #0x50]\n" + "str q17, [x21, #0x60]\n" + "str q16, [x21, #0x70]\n" + "add x21, x21, %x[out_stride]\n" + "bge 14b\n" + "15:" // Tail row loop: Unroll column loop skip + "cmp x20, #0x20\n" + "blt 17f\n" + "16:" // Tail row loop: Column loop + "ldr q18, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "sub x20, x20, #0x20\n" + "cmp x20, #0x20\n" + "ldr q19, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v24.8h, v18.8h, v17.8h\n" + "zip2 v23.8h, v18.8h, v17.8h\n" + "ldr q18, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "zip1 v22.8h, v19.8h, v16.8h\n" + "zip2 v21.8h, v19.8h, v16.8h\n" + "ldr q20, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v19.8h, v18.8h, v17.8h\n" + "zip2 v18.8h, v18.8h, v17.8h\n" + "zip1 v17.8h, v20.8h, v16.8h\n" + "zip2 v16.8h, v20.8h, v16.8h\n" + "str q24, [x21, #0x0]\n" + "str q23, [x21, #0x10]\n" + "str q22, [x21, #0x20]\n" + "str q21, [x21, #0x30]\n" + "str q19, [x21, #0x40]\n" + "str q18, [x21, #0x50]\n" + "str q17, [x21, #0x60]\n" + "str q16, [x21, #0x70]\n" + "add x21, x21, %x[out_stride]\n" + "bge 16b\n" + "17:" // Tail row loop: Column loop skip + "cmp x20, #0x10\n" + "blt 19f\n" + "18:" // Tail row loop: width 16 loop: loop + "ldr q18, [x25], #0x10\n" + "ldr q17, [x23], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "ldr q20, [x25], #0x10\n" + "ldr q16, [x23], #0x10\n" + "zip1 v19.8h, v18.8h, v17.8h\n" + "zip2 v18.8h, v18.8h, v17.8h\n" + "zip1 v17.8h, v20.8h, v16.8h\n" + "zip2 v16.8h, v20.8h, v16.8h\n" + "str q19, [x21, #0x0]\n" + "str q18, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, #0x40\n" + "bge 18b\n" + "19:" // Tail row loop: width 16 loop: skip + "cmp x20, #0x4\n" + "blt 21f\n" + "20:" // Tail row loop: width 4 loop: loop + "ldr d17, [x25], #0x8\n" + "ldr d16, [x23], #0x8\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, #0x10\n" + "bge 20b\n" + "21:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 23f\n" + "22:" // Tail row loop: width 1 loop: loop + "ldr h17, [x25], #0x2\n" + "ldr h16, [x23], #0x2\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "zip1 v16.8h, v17.8h, v16.8h\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 22b\n" + "23:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x80\n" + "bge 13b\n" + "24:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace + +template<> +void Transform<32, 2, true, VLType::None>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_32_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_48.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_48.hpp new file mode 100644 index 0000000000..f35752d5a8 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_48.hpp @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_48(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 24 * height * sizeof(uint16_t); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Column loop + "ldr q27, [x25], #0x10\n" + "ldr q26, [x23], #0x10\n" + "sub x24, x24, #0x18\n" + "cmp x24, #0x18\n" + "ldr q25, [x22], #0x10\n" + "ldr q24, [x20], #0x10\n" + "ldr q23, [x25], #0x10\n" + "ldr q22, [x23], #0x10\n" + "ldr q21, [x22], #0x10\n" + "ldr q20, [x20], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q27, [x21, #0x0]\n" + "str q23, [x21, #0x10]\n" + "str q19, [x21, #0x20]\n" + "str q26, [x21, #0x30]\n" + "str q22, [x21, #0x40]\n" + "str q18, [x21, #0x50]\n" + "str q25, [x21, #0x60]\n" + "str q21, [x21, #0x70]\n" + "str q17, [x21, #0x80]\n" + "str q24, [x21, #0x90]\n" + "str q20, [x21, #0xa0]\n" + "str q16, [x21, #0xb0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Column loop skip + "cmp x24, #0x10\n" + "blt 5f\n" + "4:" // Main row loop: width 16 loop: loop + "ldr q23, [x25], #0x10\n" + "ldr q22, [x23], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q21, [x22], #0x10\n" + "ldr q20, [x20], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q23, [x21, #0x0]\n" + "str q19, [x21, #0x10]\n" + "str q22, [x21, #0x30]\n" + "str q18, [x21, #0x40]\n" + "str q21, [x21, #0x60]\n" + "str q17, [x21, #0x70]\n" + "str q20, [x21, #0x90]\n" + "str q16, [x21, #0xa0]\n" + "add x21, x21, #0x20\n" + "bge 4b\n" + "5:" // Main row loop: width 16 loop: skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr d19, [x25], #0x8\n" + "ldr d18, [x23], #0x8\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "str d19, [x21, #0x0]\n" + "str d18, [x21, #0x30]\n" + "str d17, [x21, #0x60]\n" + "str d16, [x21, #0x90]\n" + "add x21, x21, #0x8\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr h19, [x25], #0x2\n" + "ldr h18, [x23], #0x2\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr h17, [x22], #0x2\n" + "ldr h16, [x20], #0x2\n" + "str h19, [x21, #0x0]\n" + "str h18, [x21, #0x30]\n" + "str h17, [x21, #0x60]\n" + "str h16, [x21, #0x90]\n" + "add x21, x21, #0x2\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0xc0\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "cmp x20, #0x18\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 13f\n" + "12:" // Tail row loop: Column loop + "ldr q18, [x25], #0x10\n" + "ldr q17, [x25], #0x10\n" + "sub x20, x20, #0x18\n" + "cmp x20, #0x18\n" + "ldr q16, [x25], #0x10\n" + "str q18, [x21, #0x0]\n" + "str q17, [x21, #0x10]\n" + "str q16, [x21, #0x20]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Column loop skip + "cmp x20, #0x10\n" + "blt 15f\n" + "14:" // Tail row loop: width 16 loop: loop + "ldr q17, [x25], #0x10\n" + "ldr q16, [x25], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "str q17, [x21, #0x0]\n" + "str q16, [x21, #0x10]\n" + "add x21, x21, #0x20\n" + "bge 14b\n" + "15:" // Tail row loop: width 16 loop: skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr d16, [x25], #0x8\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "str d16, [x21, #0x0]\n" + "add x21, x21, #0x8\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr h16, [x25], #0x2\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "str h16, [x21, #0x0]\n" + "add x21, x21, #0x2\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x30\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace + +template<> +void Transform<12, 1, true, VLType::None>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_48( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 2, + stride * sizeof(float), + (kmax-k0) + ); +} + +template<> +void Transform<24, 1, true, VLType::None>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_48( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + +template<> +void Transform<6, 1, true, VLType::None>( + double *out, const double *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_48( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(double) / 2, + stride * sizeof(double), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_4_1x16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_4_1x16.hpp new file mode 100644 index 0000000000..6ef02ac044 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_4_1x16.hpp @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_4_1x16(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 16) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 4 * roundup<size_t>(height, 16) * sizeof(uint8_t); + + __asm__ __volatile__( + "1:" // Main row loop: Head + "mov x17, %x[in]\n" + "add x16, x17, %x[in_stride]\n" + "add x15, x16, %x[in_stride]\n" + "add x14, x15, %x[in_stride]\n" + "add x13, x14, %x[in_stride]\n" + "add x12, x13, %x[in_stride]\n" + "add x11, x12, %x[in_stride]\n" + "add x10, x11, %x[in_stride]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "cmp %x[height], #0xf\n" + "add %x[in], x22, %x[in_stride]\n" + "csel x22, x22, %x[pad_row], GT\n" + "csel x23, x23, %x[pad_row], GE\n" + "cmp %x[height], #0xd\n" + "csel x24, x24, %x[pad_row], GT\n" + "csel x25, x25, %x[pad_row], GE\n" + "cmp %x[height], #0xb\n" + "csel x26, x26, %x[pad_row], GT\n" + "csel x27, x27, %x[pad_row], GE\n" + "cmp %x[height], #0x9\n" + "csel x28, x28, %x[pad_row], GT\n" + "csel x9, x9, %x[pad_row], GE\n" + "cmp %x[height], #0x7\n" + "csel x10, x10, %x[pad_row], GT\n" + "csel x11, x11, %x[pad_row], GE\n" + "cmp %x[height], #0x5\n" + "mov x21, %x[width]\n" + "csel x12, x12, %x[pad_row], GT\n" + "csel x13, x13, %x[pad_row], GE\n" + "cmp %x[height], #0x3\n" + "csel x14, x14, %x[pad_row], GT\n" + "csel x15, x15, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x16, x16, %x[pad_row], GT\n" + "cmp x21, #0x10\n" + "mov x20, %x[out]\n" + "sub %x[height], %x[height], #0x10\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q3, [x17], #0x10\n" + "ldr q9, [x16], #0x10\n" + "sub x21, x21, #0x10\n" + "cmp x21, #0x10\n" + "ldr q2, [x15], #0x10\n" + "ldr q8, [x14], #0x10\n" + "ldr q0, [x13], #0x10\n" + "ldr q31, [x12], #0x10\n" + "ldr q30, [x11], #0x10\n" + "ldr q7, [x10], #0x10\n" + "ldr q29, [x9], #0x10\n" + "ldr q28, [x28], #0x10\n" + "zip1 v27.16b, v3.16b, v29.16b\n" + "zip1 v6.16b, v9.16b, v28.16b\n" + "ldr q25, [x27], #0x10\n" + "ldr q24, [x26], #0x10\n" + "zip1 v26.16b, v2.16b, v25.16b\n" + "zip1 v1.16b, v8.16b, v24.16b\n" + "ldr q23, [x25], #0x10\n" + "ldr q22, [x24], #0x10\n" + "zip1 v21.16b, v0.16b, v23.16b\n" + "zip1 v20.16b, v31.16b, v22.16b\n" + "ldr q19, [x23], #0x10\n" + "ldr q18, [x22], #0x10\n" + "zip1 v17.16b, v30.16b, v19.16b\n" + "zip1 v16.16b, v7.16b, v18.16b\n" + "zip2 v5.16b, v3.16b, v29.16b\n" + "zip2 v0.16b, v0.16b, v23.16b\n" + "zip2 v4.16b, v2.16b, v25.16b\n" + "zip2 v3.16b, v30.16b, v19.16b\n" + "zip2 v2.16b, v9.16b, v28.16b\n" + "zip2 v31.16b, v31.16b, v22.16b\n" + "zip2 v30.16b, v8.16b, v24.16b\n" + "zip2 v29.16b, v7.16b, v18.16b\n" + "zip1 v25.16b, v27.16b, v21.16b\n" + "zip1 v24.16b, v26.16b, v17.16b\n" + "zip1 v23.16b, v6.16b, v20.16b\n" + "zip1 v22.16b, v1.16b, v16.16b\n" + "zip2 v28.16b, v27.16b, v21.16b\n" + "zip2 v27.16b, v26.16b, v17.16b\n" + "zip2 v26.16b, v6.16b, v20.16b\n" + "zip2 v21.16b, v1.16b, v16.16b\n" + "zip1 v1.16b, v5.16b, v0.16b\n" + "zip1 v20.16b, v4.16b, v3.16b\n" + "zip1 v19.16b, v2.16b, v31.16b\n" + "zip1 v16.16b, v30.16b, v29.16b\n" + "zip1 v18.16b, v25.16b, v24.16b\n" + "zip1 v17.16b, v23.16b, v22.16b\n" + "zip2 v25.16b, v25.16b, v24.16b\n" + "zip2 v24.16b, v23.16b, v22.16b\n" + "zip2 v0.16b, v5.16b, v0.16b\n" + "zip2 v23.16b, v4.16b, v3.16b\n" + "zip2 v31.16b, v2.16b, v31.16b\n" + "zip2 v22.16b, v30.16b, v29.16b\n" + "zip1 v30.16b, v28.16b, v27.16b\n" + "zip1 v29.16b, v26.16b, v21.16b\n" + "zip2 v28.16b, v28.16b, v27.16b\n" + "zip2 v27.16b, v26.16b, v21.16b\n" + "zip1 v26.16b, v1.16b, v20.16b\n" + "zip1 v21.16b, v19.16b, v16.16b\n" + "zip2 v20.16b, v1.16b, v20.16b\n" + "zip2 v19.16b, v19.16b, v16.16b\n" + "zip1 v16.16b, v18.16b, v17.16b\n" + "zip2 v18.16b, v18.16b, v17.16b\n" + "str q16, [x20, #0x0]\n" + "zip1 v17.16b, v25.16b, v24.16b\n" + "zip2 v16.16b, v25.16b, v24.16b\n" + "str q18, [x20, #0x10]\n" + "str q17, [x20, #0x20]\n" + "zip1 v25.16b, v0.16b, v23.16b\n" + "zip1 v24.16b, v31.16b, v22.16b\n" + "str q16, [x20, #0x30]\n" + "add x20, x20, %x[out_stride]\n" + "zip2 v23.16b, v0.16b, v23.16b\n" + "zip2 v22.16b, v31.16b, v22.16b\n" + "zip1 v16.16b, v30.16b, v29.16b\n" + "zip2 v17.16b, v30.16b, v29.16b\n" + "str q16, [x20, #0x0]\n" + "zip1 v16.16b, v28.16b, v27.16b\n" + "zip2 v18.16b, v28.16b, v27.16b\n" + "str q17, [x20, #0x10]\n" + "str q16, [x20, #0x20]\n" + "zip1 v17.16b, v26.16b, v21.16b\n" + "zip2 v16.16b, v26.16b, v21.16b\n" + "str q18, [x20, #0x30]\n" + "add x20, x20, %x[out_stride]\n" + "zip1 v21.16b, v20.16b, v19.16b\n" + "zip2 v20.16b, v20.16b, v19.16b\n" + "str q17, [x20, #0x0]\n" + "zip1 v19.16b, v25.16b, v24.16b\n" + "zip2 v18.16b, v25.16b, v24.16b\n" + "str q16, [x20, #0x10]\n" + "zip1 v17.16b, v23.16b, v22.16b\n" + "zip2 v16.16b, v23.16b, v22.16b\n" + "str q21, [x20, #0x20]\n" + "str q20, [x20, #0x30]\n" + "add x20, x20, %x[out_stride]\n" + "str q19, [x20, #0x0]\n" + "str q18, [x20, #0x10]\n" + "str q17, [x20, #0x20]\n" + "str q16, [x20, #0x30]\n" + "add x20, x20, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x21, #0x4\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr s21, [x17], #0x4\n" + "ldr s23, [x16], #0x4\n" + "sub x21, x21, #0x4\n" + "cmp x21, #0x4\n" + "ldr s20, [x15], #0x4\n" + "ldr s22, [x14], #0x4\n" + "ldr s19, [x13], #0x4\n" + "ldr s18, [x12], #0x4\n" + "ldr s25, [x11], #0x4\n" + "ldr s24, [x10], #0x4\n" + "ldr s17, [x9], #0x4\n" + "ldr s16, [x28], #0x4\n" + "zip1 v21.16b, v21.16b, v17.16b\n" + "zip1 v23.16b, v23.16b, v16.16b\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v20.16b, v20.16b, v17.16b\n" + "zip1 v22.16b, v22.16b, v16.16b\n" + "ldr s17, [x25], #0x4\n" + "ldr s16, [x24], #0x4\n" + "zip1 v19.16b, v19.16b, v17.16b\n" + "zip1 v18.16b, v18.16b, v16.16b\n" + "ldr s17, [x23], #0x4\n" + "ldr s16, [x22], #0x4\n" + "zip1 v17.16b, v25.16b, v17.16b\n" + "zip1 v16.16b, v24.16b, v16.16b\n" + "zip1 v21.16b, v21.16b, v19.16b\n" + "zip1 v20.16b, v20.16b, v17.16b\n" + "zip1 v19.16b, v23.16b, v18.16b\n" + "zip1 v16.16b, v22.16b, v16.16b\n" + "zip1 v18.16b, v21.16b, v20.16b\n" + "zip1 v17.16b, v19.16b, v16.16b\n" + "zip2 v20.16b, v21.16b, v20.16b\n" + "zip2 v19.16b, v19.16b, v16.16b\n" + "zip1 v16.16b, v18.16b, v17.16b\n" + "zip2 v18.16b, v18.16b, v17.16b\n" + "str q16, [x20, #0x0]\n" + "zip1 v17.16b, v20.16b, v19.16b\n" + "zip2 v16.16b, v20.16b, v19.16b\n" + "str q18, [x20, #0x10]\n" + "str q17, [x20, #0x20]\n" + "str q16, [x20, #0x30]\n" + "add x20, x20, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x21, #0x1\n" + "blt 7f\n" + "6:" // Main row loop: width 1 loop: loop + "ldr b23, [x17], #0x1\n" + "ldr b22, [x16], #0x1\n" + "sub x21, x21, #0x1\n" + "cmp x21, #0x1\n" + "ldr b21, [x15], #0x1\n" + "ldr b20, [x14], #0x1\n" + "ldr b19, [x13], #0x1\n" + "ldr b18, [x12], #0x1\n" + "ldr b25, [x11], #0x1\n" + "ldr b24, [x10], #0x1\n" + "ldr b17, [x9], #0x1\n" + "ldr b16, [x28], #0x1\n" + "zip1 v23.16b, v23.16b, v17.16b\n" + "zip1 v22.16b, v22.16b, v16.16b\n" + "ldr b17, [x27], #0x1\n" + "ldr b16, [x26], #0x1\n" + "zip1 v21.16b, v21.16b, v17.16b\n" + "zip1 v20.16b, v20.16b, v16.16b\n" + "ldr b17, [x25], #0x1\n" + "ldr b16, [x24], #0x1\n" + "zip1 v19.16b, v19.16b, v17.16b\n" + "zip1 v18.16b, v18.16b, v16.16b\n" + "ldr b17, [x23], #0x1\n" + "ldr b16, [x22], #0x1\n" + "zip1 v17.16b, v25.16b, v17.16b\n" + "zip1 v16.16b, v24.16b, v16.16b\n" + "zip1 v19.16b, v23.16b, v19.16b\n" + "zip1 v17.16b, v21.16b, v17.16b\n" + "zip1 v18.16b, v22.16b, v18.16b\n" + "zip1 v16.16b, v20.16b, v16.16b\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str q16, [x20, #0x0]\n" + "add x20, x20, #0x10\n" + "bge 6b\n" + "7:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x40\n" + "bge 1b\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<4, 16, true, VLType::None>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_4_1x16( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<4, 16, true, VLType::None>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_4_1x16( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_4_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_4_1x4.hpp new file mode 100644 index 0000000000..5667820865 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_4_1x4.hpp @@ -0,0 +1,337 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_4_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 4 * roundup<size_t>(height, 4) * sizeof(uint8_t); + + __asm__ __volatile__( + "cmp %x[height], #0x10\n" + "blt 8f\n" + "1:" // Main row loop: Head + "mov x17, %x[in]\n" + "add x16, x17, %x[in_stride]\n" + "add x15, x16, %x[in_stride]\n" + "add x14, x15, %x[in_stride]\n" + "add x13, x14, %x[in_stride]\n" + "add x12, x13, %x[in_stride]\n" + "add x11, x12, %x[in_stride]\n" + "add x10, x11, %x[in_stride]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x10\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x10\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q21, [x17], #0x10\n" + "ldr q20, [x16], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q17, [x15], #0x10\n" + "ldr q16, [x14], #0x10\n" + "zip1 v3.16b, v21.16b, v17.16b\n" + "zip1 v2.16b, v20.16b, v16.16b\n" + "ldr q19, [x13], #0x10\n" + "ldr q18, [x12], #0x10\n" + "zip2 v1.16b, v21.16b, v17.16b\n" + "zip2 v0.16b, v20.16b, v16.16b\n" + "ldr q17, [x11], #0x10\n" + "ldr q16, [x10], #0x10\n" + "zip1 v31.16b, v19.16b, v17.16b\n" + "zip1 v30.16b, v18.16b, v16.16b\n" + "ldr q21, [x9], #0x10\n" + "ldr q20, [x28], #0x10\n" + "zip2 v29.16b, v19.16b, v17.16b\n" + "zip2 v28.16b, v18.16b, v16.16b\n" + "ldr q17, [x27], #0x10\n" + "ldr q16, [x26], #0x10\n" + "zip1 v23.16b, v21.16b, v17.16b\n" + "zip1 v22.16b, v20.16b, v16.16b\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "zip2 v27.16b, v21.16b, v17.16b\n" + "zip2 v26.16b, v20.16b, v16.16b\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip1 v21.16b, v19.16b, v17.16b\n" + "zip1 v20.16b, v18.16b, v16.16b\n" + "zip2 v25.16b, v19.16b, v17.16b\n" + "zip2 v24.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v3.16b, v2.16b\n" + "zip1 v18.16b, v31.16b, v30.16b\n" + "str q16, [x21, #0x0]\n" + "zip1 v17.16b, v23.16b, v22.16b\n" + "zip1 v16.16b, v21.16b, v20.16b\n" + "str q18, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "zip2 v19.16b, v3.16b, v2.16b\n" + "zip2 v18.16b, v31.16b, v30.16b\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, %x[out_stride]\n" + "zip2 v17.16b, v23.16b, v22.16b\n" + "zip2 v16.16b, v21.16b, v20.16b\n" + "str q19, [x21, #0x0]\n" + "zip1 v23.16b, v1.16b, v0.16b\n" + "zip1 v22.16b, v29.16b, v28.16b\n" + "str q18, [x21, #0x10]\n" + "zip1 v21.16b, v27.16b, v26.16b\n" + "zip1 v20.16b, v25.16b, v24.16b\n" + "str q17, [x21, #0x20]\n" + "zip2 v19.16b, v1.16b, v0.16b\n" + "zip2 v18.16b, v29.16b, v28.16b\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, %x[out_stride]\n" + "zip2 v17.16b, v27.16b, v26.16b\n" + "zip2 v16.16b, v25.16b, v24.16b\n" + "str q23, [x21, #0x0]\n" + "str q22, [x21, #0x10]\n" + "str q21, [x21, #0x20]\n" + "str q20, [x21, #0x30]\n" + "add x21, x21, %x[out_stride]\n" + "str q19, [x21, #0x0]\n" + "str q18, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x24, #0x4\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr s19, [x17], #0x4\n" + "ldr s18, [x16], #0x4\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr s17, [x15], #0x4\n" + "ldr s16, [x14], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr s19, [x13], #0x4\n" + "ldr s18, [x12], #0x4\n" + "zip1 v22.16b, v17.16b, v16.16b\n" + "ldr s17, [x11], #0x4\n" + "ldr s16, [x10], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr s19, [x9], #0x4\n" + "ldr s18, [x28], #0x4\n" + "zip1 v21.16b, v17.16b, v16.16b\n" + "ldr s17, [x27], #0x4\n" + "ldr s16, [x26], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr s20, [x25], #0x4\n" + "ldr s19, [x23], #0x4\n" + "zip1 v18.16b, v17.16b, v16.16b\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v17.16b, v20.16b, v17.16b\n" + "zip1 v16.16b, v19.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str q22, [x21, #0x0]\n" + "str q21, [x21, #0x10]\n" + "str q18, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cmp x24, #0x1\n" + "blt 7f\n" + "6:" // Main row loop: width 1 loop: loop + "ldr b19, [x17], #0x1\n" + "ldr b18, [x16], #0x1\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr b17, [x15], #0x1\n" + "ldr b16, [x14], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr b19, [x13], #0x1\n" + "ldr b18, [x12], #0x1\n" + "zip1 v22.16b, v17.16b, v16.16b\n" + "ldr b17, [x11], #0x1\n" + "ldr b16, [x10], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr b19, [x9], #0x1\n" + "ldr b18, [x28], #0x1\n" + "zip1 v21.16b, v17.16b, v16.16b\n" + "ldr b17, [x27], #0x1\n" + "ldr b16, [x26], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "ldr b20, [x25], #0x1\n" + "ldr b19, [x23], #0x1\n" + "zip1 v18.16b, v17.16b, v16.16b\n" + "ldr b17, [x22], #0x1\n" + "ldr b16, [x20], #0x1\n" + "zip1 v17.16b, v20.16b, v17.16b\n" + "zip1 v16.16b, v19.16b, v16.16b\n" + "str s22, [x21, #0x0]\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str s21, [x21, #0x10]\n" + "str s18, [x21, #0x20]\n" + "str s16, [x21, #0x30]\n" + "add x21, x21, #0x4\n" + "bge 6b\n" + "7:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x10\n" + "add %x[out], %x[out], #0x40\n" + "bge 1b\n" + "cbz %x[height], 16f\n" + "8:" // Main loop skip + "9:" // Tail row loop: Head + "mov x17, %x[in]\n" + "add x16, x17, %x[in_stride]\n" + "add x15, x16, %x[in_stride]\n" + "mov x20, %x[width]\n" + "add x14, x15, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x14, %x[in_stride]\n" + "csel x14, x14, %x[pad_row], GT\n" + "csel x15, x15, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x16, x16, %x[pad_row], GT\n" + "cmp x20, #0x10\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 11f\n" + "10:" // Tail row loop: Unroll column loop + "ldr q19, [x17], #0x10\n" + "ldr q21, [x16], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "ldr q18, [x15], #0x10\n" + "ldr q16, [x14], #0x10\n" + "zip1 v20.16b, v19.16b, v18.16b\n" + "zip1 v17.16b, v21.16b, v16.16b\n" + "zip2 v19.16b, v19.16b, v18.16b\n" + "zip2 v18.16b, v21.16b, v16.16b\n" + "zip1 v16.16b, v20.16b, v17.16b\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, %x[out_stride]\n" + "zip2 v16.16b, v20.16b, v17.16b\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, %x[out_stride]\n" + "zip1 v17.16b, v19.16b, v18.16b\n" + "zip2 v16.16b, v19.16b, v18.16b\n" + "str q17, [x21, #0x0]\n" + "add x21, x21, %x[out_stride]\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 10b\n" + "11:" // Tail row loop: Unroll column loop skip + "cmp x20, #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: Column loop + "ldr s19, [x17], #0x4\n" + "ldr s18, [x16], #0x4\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "ldr s17, [x15], #0x4\n" + "ldr s16, [x14], #0x4\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Column loop skip + "cmp x20, #0x1\n" + "blt 15f\n" + "14:" // Tail row loop: width 1 loop: loop + "ldr b19, [x17], #0x1\n" + "ldr b18, [x16], #0x1\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "ldr b17, [x15], #0x1\n" + "ldr b16, [x14], #0x1\n" + "zip1 v17.16b, v19.16b, v17.16b\n" + "zip1 v16.16b, v18.16b, v16.16b\n" + "zip1 v16.16b, v17.16b, v16.16b\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 14b\n" + "15:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x10\n" + "bge 9b\n" + "16:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace + +template<> +void Transform<4, 4, true, VLType::None>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_4_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<4, 4, true, VLType::None>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_4_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_4_2x4_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_4_2x4_fp32bf16.hpp new file mode 100644 index 0000000000..98200c50c5 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_4_2x4_fp32bf16.hpp @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_4_2x4_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 4 * roundup<size_t>(height, 4) * sizeof(bfloat16); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "blt 8f\n" + "1:" // Main row loop: Head + "mov x9, %x[in]\n" + "mov x28, %x[width]\n" + "mov x27, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "cmp x28, #0x8\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ldr q19, [x9], #0x10\n" + "ldr q18, [x26], #0x10\n" + "sub x28, x28, #0x8\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x24], #0x10\n" + "cmp x28, #0x8\n" + "ldr q1, [x23], #0x10\n" + "ldr q0, [x22], #0x10\n" + "ldr q31, [x21], #0x10\n" + "ldr q24, [x20], #0x10\n" + "ldr q23, [x9], #0x10\n" + "ldr q22, [x26], #0x10\n" + "zip1 v30.4s, v19.4s, v17.4s\n" + "zip1 v29.4s, v18.4s, v16.4s\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "zip2 v28.4s, v19.4s, v17.4s\n" + "zip2 v27.4s, v18.4s, v16.4s\n" + "ldr q19, [x23], #0x10\n" + "ldr q18, [x22], #0x10\n" + "zip1 v26.4s, v1.4s, v31.4s\n" + "zip1 v25.4s, v0.4s, v24.4s\n" + "ldr q17, [x21], #0x10\n" + "ldr q16, [x20], #0x10\n" + "zip2 v8.4s, v1.4s, v31.4s\n" + "zip2 v24.4s, v0.4s, v24.4s\n" + "zip1 v7.4s, v23.4s, v21.4s\n" + "zip1 v6.4s, v22.4s, v20.4s\n" + "zip2 v5.4s, v23.4s, v21.4s\n" + "zip2 v4.4s, v22.4s, v20.4s\n" + "zip1 v3.4s, v19.4s, v17.4s\n" + "zip1 v2.4s, v18.4s, v16.4s\n" + "zip2 v1.4s, v19.4s, v17.4s\n" + "zip2 v0.4s, v18.4s, v16.4s\n" + "zip1 v23.4s, v30.4s, v29.4s\n" + "zip1 v22.4s, v28.4s, v27.4s\n" + "zip1 v21.4s, v26.4s, v25.4s\n" + "zip1 v20.4s, v8.4s, v24.4s\n" + "zip1 v19.4s, v7.4s, v6.4s\n" + "zip1 v18.4s, v5.4s, v4.4s\n" + "zip1 v17.4s, v3.4s, v2.4s\n" + "zip1 v16.4s, v1.4s, v0.4s\n" + ".inst 0x0ea16aff // bfcvtn v31.4h, v23.4s\n" + "zip2 v30.4s, v30.4s, v29.4s\n" + ".inst 0x0ea16add // bfcvtn v29.4h, v22.4s\n" + "zip2 v28.4s, v28.4s, v27.4s\n" + ".inst 0x0ea16abb // bfcvtn v27.4h, v21.4s\n" + "zip2 v26.4s, v26.4s, v25.4s\n" + ".inst 0x0ea16a99 // bfcvtn v25.4h, v20.4s\n" + "zip2 v24.4s, v8.4s, v24.4s\n" + ".inst 0x0ea16a77 // bfcvtn v23.4h, v19.4s\n" + "zip2 v22.4s, v7.4s, v6.4s\n" + ".inst 0x0ea16a55 // bfcvtn v21.4h, v18.4s\n" + "zip2 v20.4s, v5.4s, v4.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v3.4s, v2.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v1.4s, v0.4s\n" + ".inst 0x4ea16bdf // bfcvtn2 v31.8h, v30.4s\n" + ".inst 0x4ea16b9d // bfcvtn2 v29.8h, v28.4s\n" + ".inst 0x4ea16b5b // bfcvtn2 v27.8h, v26.4s\n" + ".inst 0x4ea16b19 // bfcvtn2 v25.8h, v24.4s\n" + ".inst 0x4ea16ad7 // bfcvtn2 v23.8h, v22.4s\n" + ".inst 0x4ea16a95 // bfcvtn2 v21.8h, v20.4s\n" + "str q31, [x27, #0x0]\n" + "str q29, [x27, #0x10]\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q27, [x27, #0x20]\n" + "str q25, [x27, #0x30]\n" + "add x27, x27, %x[out_stride]\n" + "str q23, [x27, #0x0]\n" + "str q21, [x27, #0x10]\n" + "str q19, [x27, #0x20]\n" + "str q17, [x27, #0x30]\n" + "add x27, x27, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cmp x28, #0x4\n" + "blt 5f\n" + "4:" // Main row loop: Column loop + "ldr q25, [x9], #0x10\n" + "ldr q24, [x26], #0x10\n" + "sub x28, x28, #0x4\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "cmp x28, #0x4\n" + "ldr q23, [x23], #0x10\n" + "ldr q19, [x22], #0x10\n" + "ldr q18, [x21], #0x10\n" + "ldr q17, [x20], #0x10\n" + "zip1 v22.4s, v25.4s, v21.4s\n" + "zip1 v16.4s, v24.4s, v20.4s\n" + "zip2 v21.4s, v25.4s, v21.4s\n" + "zip2 v20.4s, v24.4s, v20.4s\n" + "zip1 v27.4s, v23.4s, v18.4s\n" + "zip1 v26.4s, v19.4s, v17.4s\n" + "zip2 v25.4s, v23.4s, v18.4s\n" + "zip2 v24.4s, v19.4s, v17.4s\n" + "zip1 v19.4s, v22.4s, v16.4s\n" + "zip1 v18.4s, v21.4s, v20.4s\n" + "zip1 v17.4s, v27.4s, v26.4s\n" + "zip2 v23.4s, v22.4s, v16.4s\n" + "zip1 v16.4s, v25.4s, v24.4s\n" + "zip2 v22.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a75 // bfcvtn v21.4h, v19.4s\n" + ".inst 0x0ea16a54 // bfcvtn v20.4h, v18.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v27.4s, v26.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v25.4s, v24.4s\n" + ".inst 0x4ea16af5 // bfcvtn2 v21.8h, v23.4s\n" + ".inst 0x4ea16ad4 // bfcvtn2 v20.8h, v22.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q21, [x27, #0x0]\n" + "str q20, [x27, #0x10]\n" + "str q19, [x27, #0x20]\n" + "str q17, [x27, #0x30]\n" + "add x27, x27, %x[out_stride]\n" + "bge 4b\n" + "5:" // Main row loop: Column loop skip + "cbz x28, 7f\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "str q16, [x27, #0x20]\n" + "str q16, [x27, #0x30]\n" + "6:" // Main row loop: width 1 loop: loop + "ldr s23, [x9], #0x4\n" + "ldr s22, [x26], #0x4\n" + "sub x28, x28, #0x1\n" + "ldr s19, [x25], #0x4\n" + "ldr s17, [x24], #0x4\n" + "cmp x28, #0x1\n" + "ldr s21, [x23], #0x4\n" + "ldr s20, [x22], #0x4\n" + "ldr s18, [x21], #0x4\n" + "ldr s16, [x20], #0x4\n" + "zip1 v19.4s, v23.4s, v19.4s\n" + "zip1 v17.4s, v22.4s, v17.4s\n" + "zip1 v18.4s, v21.4s, v18.4s\n" + "zip1 v16.4s, v20.4s, v16.4s\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + ".inst 0x0ea16a31 // bfcvtn v17.4h, v17.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d17, [x27, #0x0]\n" + "str d16, [x27, #0x20]\n" + "add x27, x27, #0x8\n" + "bge 6b\n" + "7:" // Main row loop: odd col skip + "cmp %x[height], #0x8\n" + "add %x[out], %x[out], #0x40\n" + "bge 1b\n" + "cbz %x[height], 16f\n" + "8:" // Main loop skip + "9:" // Tail row loop: Head + "mov x9, %x[in]\n" + "mov x20, %x[width]\n" + "cmp %x[height], #0x3\n" + "mov x27, %x[out]\n" + "add x26, x9, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GE\n" + "add %x[in], x24, %x[in_stride]\n" + "csel x24, x24, %x[pad_row], GT\n" + "cmp %x[height], #0x1\n" + "sub %x[height], %x[height], #0x4\n" + "csel x26, x26, %x[pad_row], GT\n" + "cmp x20, #0x8\n" + "blt 11f\n" + "10:" // Tail row loop: Unroll column loop + "ldr q25, [x9], #0x10\n" + "ldr q24, [x26], #0x10\n" + "sub x20, x20, #0x8\n" + "ldr q21, [x25], #0x10\n" + "ldr q20, [x24], #0x10\n" + "cmp x20, #0x8\n" + "ldr q23, [x9], #0x10\n" + "ldr q19, [x26], #0x10\n" + "ldr q18, [x25], #0x10\n" + "ldr q17, [x24], #0x10\n" + "zip1 v22.4s, v25.4s, v21.4s\n" + "zip1 v16.4s, v24.4s, v20.4s\n" + "zip2 v21.4s, v25.4s, v21.4s\n" + "zip2 v20.4s, v24.4s, v20.4s\n" + "zip1 v27.4s, v23.4s, v18.4s\n" + "zip1 v26.4s, v19.4s, v17.4s\n" + "zip2 v25.4s, v23.4s, v18.4s\n" + "zip2 v24.4s, v19.4s, v17.4s\n" + "zip1 v19.4s, v22.4s, v16.4s\n" + "zip1 v18.4s, v21.4s, v20.4s\n" + "zip1 v17.4s, v27.4s, v26.4s\n" + "zip2 v23.4s, v22.4s, v16.4s\n" + "zip1 v16.4s, v25.4s, v24.4s\n" + "zip2 v22.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a75 // bfcvtn v21.4h, v19.4s\n" + ".inst 0x0ea16a54 // bfcvtn v20.4h, v18.4s\n" + ".inst 0x0ea16a33 // bfcvtn v19.4h, v17.4s\n" + "zip2 v18.4s, v27.4s, v26.4s\n" + ".inst 0x0ea16a11 // bfcvtn v17.4h, v16.4s\n" + "zip2 v16.4s, v25.4s, v24.4s\n" + ".inst 0x4ea16af5 // bfcvtn2 v21.8h, v23.4s\n" + ".inst 0x4ea16ad4 // bfcvtn2 v20.8h, v22.4s\n" + ".inst 0x4ea16a53 // bfcvtn2 v19.8h, v18.4s\n" + ".inst 0x4ea16a11 // bfcvtn2 v17.8h, v16.4s\n" + "str q21, [x27, #0x0]\n" + "str q20, [x27, #0x10]\n" + "add x27, x27, %x[out_stride]\n" + "str q19, [x27, #0x0]\n" + "str q17, [x27, #0x10]\n" + "add x27, x27, %x[out_stride]\n" + "bge 10b\n" + "11:" // Tail row loop: Unroll column loop skip + "cmp x20, #0x4\n" + "blt 13f\n" + "12:" // Tail row loop: Column loop + "ldr q21, [x9], #0x10\n" + "ldr q20, [x26], #0x10\n" + "sub x20, x20, #0x4\n" + "ldr q19, [x25], #0x10\n" + "ldr q17, [x24], #0x10\n" + "cmp x20, #0x4\n" + "zip1 v18.4s, v21.4s, v19.4s\n" + "zip1 v16.4s, v20.4s, v17.4s\n" + "zip2 v21.4s, v21.4s, v19.4s\n" + "zip2 v20.4s, v20.4s, v17.4s\n" + "zip1 v17.4s, v18.4s, v16.4s\n" + "zip2 v19.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a32 // bfcvtn v18.4h, v17.4s\n" + "zip2 v17.4s, v21.4s, v20.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + ".inst 0x4ea16a72 // bfcvtn2 v18.8h, v19.4s\n" + ".inst 0x4ea16a30 // bfcvtn2 v16.8h, v17.4s\n" + "str q18, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "add x27, x27, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Column loop skip + "cbz x20, 15f\n" + "movi v16.16b, #0x0\n" + "str q16, [x27, #0x0]\n" + "str q16, [x27, #0x10]\n" + "14:" // Tail row loop: width 1 loop: loop + "ldr s19, [x9], #0x4\n" + "ldr s18, [x26], #0x4\n" + "sub x20, x20, #0x1\n" + "ldr s17, [x25], #0x4\n" + "ldr s16, [x24], #0x4\n" + "cmp x20, #0x1\n" + "zip1 v17.4s, v19.4s, v17.4s\n" + "zip1 v16.4s, v18.4s, v16.4s\n" + "zip1 v16.4s, v17.4s, v16.4s\n" + ".inst 0x0ea16a10 // bfcvtn v16.4h, v16.4s\n" + "str d16, [x27, #0x0]\n" + "add x27, x27, #0x8\n" + "bge 14b\n" + "15:" // Tail row loop: odd col skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x20\n" + "bge 9b\n" + "16:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // anonymous namespace +template<> +void Transform<4, 4, true, VLType::None>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_4_2x4_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_64.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_64.hpp new file mode 100644 index 0000000000..328274a488 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_64.hpp @@ -0,0 +1,254 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_64(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 32 * height * sizeof(uint16_t); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x20\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Column loop + "ldr q31, [x25], #0x10\n" + "ldr q30, [x23], #0x10\n" + "sub x24, x24, #0x20\n" + "cmp x24, #0x20\n" + "ldr q29, [x22], #0x10\n" + "ldr q28, [x20], #0x10\n" + "ldr q27, [x25], #0x10\n" + "ldr q26, [x23], #0x10\n" + "ldr q25, [x22], #0x10\n" + "ldr q24, [x20], #0x10\n" + "ldr q23, [x25], #0x10\n" + "ldr q22, [x23], #0x10\n" + "ldr q21, [x22], #0x10\n" + "ldr q20, [x20], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q31, [x21, #0x0]\n" + "str q27, [x21, #0x10]\n" + "str q23, [x21, #0x20]\n" + "str q19, [x21, #0x30]\n" + "str q30, [x21, #0x40]\n" + "str q26, [x21, #0x50]\n" + "str q22, [x21, #0x60]\n" + "str q18, [x21, #0x70]\n" + "str q29, [x21, #0x80]\n" + "str q25, [x21, #0x90]\n" + "str q21, [x21, #0xa0]\n" + "str q17, [x21, #0xb0]\n" + "str q28, [x21, #0xc0]\n" + "str q24, [x21, #0xd0]\n" + "str q20, [x21, #0xe0]\n" + "str q16, [x21, #0xf0]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Column loop skip + "cmp x24, #0x10\n" + "blt 5f\n" + "4:" // Main row loop: width 16 loop: loop + "ldr q23, [x25], #0x10\n" + "ldr q22, [x23], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q21, [x22], #0x10\n" + "ldr q20, [x20], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q23, [x21, #0x0]\n" + "str q19, [x21, #0x10]\n" + "str q22, [x21, #0x40]\n" + "str q18, [x21, #0x50]\n" + "str q21, [x21, #0x80]\n" + "str q17, [x21, #0x90]\n" + "str q20, [x21, #0xc0]\n" + "str q16, [x21, #0xd0]\n" + "add x21, x21, #0x20\n" + "bge 4b\n" + "5:" // Main row loop: width 16 loop: skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr d19, [x25], #0x8\n" + "ldr d18, [x23], #0x8\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr d17, [x22], #0x8\n" + "ldr d16, [x20], #0x8\n" + "str d19, [x21, #0x0]\n" + "str d18, [x21, #0x40]\n" + "str d17, [x21, #0x80]\n" + "str d16, [x21, #0xc0]\n" + "add x21, x21, #0x8\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr h19, [x25], #0x2\n" + "ldr h18, [x23], #0x2\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr h17, [x22], #0x2\n" + "ldr h16, [x20], #0x2\n" + "str h19, [x21, #0x0]\n" + "str h18, [x21, #0x40]\n" + "str h17, [x21, #0x80]\n" + "str h16, [x21, #0xc0]\n" + "add x21, x21, #0x2\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0x100\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "cmp x20, #0x20\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 13f\n" + "12:" // Tail row loop: Column loop + "ldr q19, [x25], #0x10\n" + "ldr q18, [x25], #0x10\n" + "sub x20, x20, #0x20\n" + "cmp x20, #0x20\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x25], #0x10\n" + "str q19, [x21, #0x0]\n" + "str q18, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Column loop skip + "cmp x20, #0x10\n" + "blt 15f\n" + "14:" // Tail row loop: width 16 loop: loop + "ldr q17, [x25], #0x10\n" + "ldr q16, [x25], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "str q17, [x21, #0x0]\n" + "str q16, [x21, #0x10]\n" + "add x21, x21, #0x20\n" + "bge 14b\n" + "15:" // Tail row loop: width 16 loop: skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr d16, [x25], #0x8\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "str d16, [x21, #0x0]\n" + "add x21, x21, #0x8\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr h16, [x25], #0x2\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "str h16, [x21, #0x0]\n" + "add x21, x21, #0x2\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x40\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace + +template<> +void Transform<16, 1, true, VLType::None>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_64( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 2, + stride * sizeof(float), + (kmax-k0) + ); +} + +template<> +void Transform<32, 1, true, VLType::None>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_64( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + +template<> +void Transform<32, 1, true, VLType::None>( + uint16_t *out, const uint16_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_64( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint16_t) / 2, + stride * sizeof(uint16_t), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_8way_32bit.hpp deleted file mode 100644 index 0080c91b18..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_8way_32bit.hpp +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright (c) 2017-2019 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __aarch64__ - -#include "transpose_interleave_common.hpp" - -// Generic unblocked transposed 8x32-bit sized specialisation -template <> -template <typename T> -inline void TransformImpl<8, 1, true, 4, 4, false>::Transform( - T* out, const T* const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax -) { - // Redirect to a 16 x uint16_t specialisation - TransformImpl<16, 1, true, 2, 2, false>::Transform( - reinterpret_cast<uint16_t *>(out), - reinterpret_cast<const uint16_t *>(in), - stride*2, x0*2, xmax*2, k0, kmax - ); -} - -// Generic 16x16-bit sized specialisation -template <> -template <typename T> -inline void TransformImpl<16, 1, true, 2, 2, false>::Transform( - T* out, const T* const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax -) { - // Redirect to a uint16_t specialisation - Transform( - reinterpret_cast<uint16_t *>(out), - reinterpret_cast<const uint16_t *>(in), - stride, x0, xmax, k0, kmax - ); -} - -// Specialised 16 x uint16_t version -template <> -inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x1(const uint16_t *&in0, uint16_t *const out) { - __asm volatile ( - "LDR q0, [%[in0]]\n" - "STR q0, [%[out]]\n" - "LDR q1, [%[in0], #0x10]\n" - "STR q1, [%[out], #0x10]\n" - "ADD %x[in0], %x[in0], #0x20\n" - ASM_PREFETCH("[%[in0], #192]") - : [in0] "+r" (in0) - : [out] "r" (out) - : "v0", "v1", "memory" - ); -} - -template <> -inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x2(const uint16_t *&in0, const uint16_t *&in1, uint16_t *const out) { - __asm volatile ( - "LDR q0, [%[in0]]\n" - "STR q0, [%[out]]\n" - "LDR q1, [%[in0], #0x10]\n" - "STR q1, [%[out], #0x10]\n" - "ADD %x[in0], %x[in0], #0x20\n" - ASM_PREFETCH("[%[in0], #192]") - - "LDR q2, [%[in1]]\n" - "STR q2, [%[out], #0x20]\n" - "LDR q3, [%[in1], #0x10]\n" - "STR q3, [%[out], #0x30]\n" - "ADD %x[in1], %x[in1], #0x20\n" - ASM_PREFETCH("[%[in1], #192]") - : [in0] "+r" (in0), - [in1] "+r" (in1) - : [out] "r" (out) - : "v0", "v1", "v2", "v3", "memory" - ); -} - -template <> -inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x4(const uint16_t *&in0, const uint16_t *&in1, const uint16_t *&in2, const uint16_t *&in3, uint16_t *const out) { - __asm __volatile ( - "LDR q0, [%[in0]]\n" - "STR q0, [%[out]]\n" - "LDR q1, [%[in0], #0x10]\n" - "STR q1, [%[out], #0x10]\n" - "ADD %x[in0], %x[in0], #0x20\n" - ASM_PREFETCH("[%[in0], #192]") - - "LDR q2, [%[in1]]\n" - "STR q2, [%[out], #0x20]\n" - "LDR q3, [%[in1], #0x10]\n" - "STR q3, [%[out], #0x30]\n" - "ADD %x[in1], %x[in1], #0x20\n" - ASM_PREFETCH("[%[in1], #192]") - - "LDR q0, [%[in2]]\n" - "STR q0, [%[out], #0x40]\n" - "LDR q1, [%[in2], #0x10]\n" - "STR q1, [%[out], #0x50]\n" - "ADD %x[in2], %x[in2], #0x20\n" - ASM_PREFETCH("[%[in2], #192]") - - "LDR q2, [%[in3]]\n" - "STR q2, [%[out], #0x60]\n" - "LDR q3, [%[in3], #0x10]\n" - "STR q3, [%[out], #0x70]\n" - "ADD %x[in3], %x[in3], #0x20\n" - ASM_PREFETCH("[%[in3], #192]") - : [in0] "+r" (in0), - [in1] "+r" (in1), - [in2] "+r" (in2), - [in3] "+r" (in3) - : [out] "r" (out) - : "v0", "v1", "v2", "v3", "memory" - ); -} - -template <> -template <> -inline void TransformImpl<16, 1, true, 2, 2, false>::Transform( - uint16_t* out, const uint16_t* const in, const int stride, - const int x0, const int xmax, const int k0, const int kmax -) { - TransposeInterleaveCommon<16, uint16_t, uint16_t>::Transform(out, in, stride, x0, xmax, k0, kmax); -} - -#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_96.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_96.hpp new file mode 100644 index 0000000000..feb469ab0e --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_96.hpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(__aarch64__) + +namespace { + +void a64_transpose_interleave_96(uint32_t *out, const uint32_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 24 * height * sizeof(uint32_t); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "blt 10f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "mov x24, %x[width]\n" + "add x23, x25, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "add x20, x22, %x[in_stride]\n" + "cmp x24, #0x18\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Column loop + "ldr q7, [x25], #0x10\n" + "ldr q6, [x23], #0x10\n" + "sub x24, x24, #0x18\n" + "cmp x24, #0x18\n" + "ldr q5, [x22], #0x10\n" + "ldr q4, [x20], #0x10\n" + "ldr q3, [x25], #0x10\n" + "ldr q2, [x23], #0x10\n" + "ldr q1, [x22], #0x10\n" + "ldr q0, [x20], #0x10\n" + "ldr q31, [x25], #0x10\n" + "ldr q30, [x23], #0x10\n" + "ldr q29, [x22], #0x10\n" + "ldr q28, [x20], #0x10\n" + "ldr q27, [x25], #0x10\n" + "ldr q26, [x23], #0x10\n" + "ldr q25, [x22], #0x10\n" + "ldr q24, [x20], #0x10\n" + "ldr q23, [x25], #0x10\n" + "ldr q22, [x23], #0x10\n" + "ldr q21, [x22], #0x10\n" + "ldr q20, [x20], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q7, [x21, #0x0]\n" + "str q3, [x21, #0x10]\n" + "str q31, [x21, #0x20]\n" + "str q27, [x21, #0x30]\n" + "str q23, [x21, #0x40]\n" + "str q19, [x21, #0x50]\n" + "str q6, [x21, #0x60]\n" + "str q2, [x21, #0x70]\n" + "str q30, [x21, #0x80]\n" + "str q26, [x21, #0x90]\n" + "str q22, [x21, #0xa0]\n" + "str q18, [x21, #0xb0]\n" + "str q5, [x21, #0xc0]\n" + "str q1, [x21, #0xd0]\n" + "str q29, [x21, #0xe0]\n" + "str q25, [x21, #0xf0]\n" + "str q21, [x21, #0x100]\n" + "str q17, [x21, #0x110]\n" + "str q4, [x21, #0x120]\n" + "str q0, [x21, #0x130]\n" + "str q28, [x21, #0x140]\n" + "str q24, [x21, #0x150]\n" + "str q20, [x21, #0x160]\n" + "str q16, [x21, #0x170]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Column loop skip + "cmp x24, #0x10\n" + "blt 5f\n" + "4:" // Main row loop: width 16 loop: loop + "ldr q31, [x25], #0x10\n" + "ldr q30, [x23], #0x10\n" + "sub x24, x24, #0x10\n" + "cmp x24, #0x10\n" + "ldr q29, [x22], #0x10\n" + "ldr q28, [x20], #0x10\n" + "ldr q27, [x25], #0x10\n" + "ldr q26, [x23], #0x10\n" + "ldr q25, [x22], #0x10\n" + "ldr q24, [x20], #0x10\n" + "ldr q23, [x25], #0x10\n" + "ldr q22, [x23], #0x10\n" + "ldr q21, [x22], #0x10\n" + "ldr q20, [x20], #0x10\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q31, [x21, #0x0]\n" + "str q27, [x21, #0x10]\n" + "str q23, [x21, #0x20]\n" + "str q19, [x21, #0x30]\n" + "str q30, [x21, #0x60]\n" + "str q26, [x21, #0x70]\n" + "str q22, [x21, #0x80]\n" + "str q18, [x21, #0x90]\n" + "str q29, [x21, #0xc0]\n" + "str q25, [x21, #0xd0]\n" + "str q21, [x21, #0xe0]\n" + "str q17, [x21, #0xf0]\n" + "str q28, [x21, #0x120]\n" + "str q24, [x21, #0x130]\n" + "str q20, [x21, #0x140]\n" + "str q16, [x21, #0x150]\n" + "add x21, x21, #0x40\n" + "bge 4b\n" + "5:" // Main row loop: width 16 loop: skip + "cmp x24, #0x4\n" + "blt 7f\n" + "6:" // Main row loop: width 4 loop: loop + "ldr q19, [x25], #0x10\n" + "ldr q18, [x23], #0x10\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ldr q17, [x22], #0x10\n" + "ldr q16, [x20], #0x10\n" + "str q19, [x21, #0x0]\n" + "str q18, [x21, #0x60]\n" + "str q17, [x21, #0xc0]\n" + "str q16, [x21, #0x120]\n" + "add x21, x21, #0x10\n" + "bge 6b\n" + "7:" // Main row loop: width 4 loop: skip + "cmp x24, #0x1\n" + "blt 9f\n" + "8:" // Main row loop: width 1 loop: loop + "ldr s19, [x25], #0x4\n" + "ldr s18, [x23], #0x4\n" + "sub x24, x24, #0x1\n" + "cmp x24, #0x1\n" + "ldr s17, [x22], #0x4\n" + "ldr s16, [x20], #0x4\n" + "str s19, [x21, #0x0]\n" + "str s18, [x21, #0x60]\n" + "str s17, [x21, #0xc0]\n" + "str s16, [x21, #0x120]\n" + "add x21, x21, #0x4\n" + "bge 8b\n" + "9:" // Main row loop: width 1 loop: skip + "cmp %x[height], #0x4\n" + "add %x[out], %x[out], #0x180\n" + "bge 1b\n" + "cbz %x[height], 20f\n" + "10:" // Main loop skip + "11:" // Tail row loop: Head + "mov x20, %x[width]\n" + "mov x25, %x[in]\n" + "cmp x20, #0x18\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 13f\n" + "12:" // Tail row loop: Column loop + "ldr q21, [x25], #0x10\n" + "ldr q20, [x25], #0x10\n" + "sub x20, x20, #0x18\n" + "cmp x20, #0x18\n" + "ldr q19, [x25], #0x10\n" + "ldr q18, [x25], #0x10\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x25], #0x10\n" + "str q21, [x21, #0x0]\n" + "str q20, [x21, #0x10]\n" + "str q19, [x21, #0x20]\n" + "str q18, [x21, #0x30]\n" + "str q17, [x21, #0x40]\n" + "str q16, [x21, #0x50]\n" + "add x21, x21, %x[out_stride]\n" + "bge 12b\n" + "13:" // Tail row loop: Column loop skip + "cmp x20, #0x10\n" + "blt 15f\n" + "14:" // Tail row loop: width 16 loop: loop + "ldr q19, [x25], #0x10\n" + "ldr q18, [x25], #0x10\n" + "sub x20, x20, #0x10\n" + "cmp x20, #0x10\n" + "ldr q17, [x25], #0x10\n" + "ldr q16, [x25], #0x10\n" + "str q19, [x21, #0x0]\n" + "str q18, [x21, #0x10]\n" + "str q17, [x21, #0x20]\n" + "str q16, [x21, #0x30]\n" + "add x21, x21, #0x40\n" + "bge 14b\n" + "15:" // Tail row loop: width 16 loop: skip + "cmp x20, #0x4\n" + "blt 17f\n" + "16:" // Tail row loop: width 4 loop: loop + "ldr q16, [x25], #0x10\n" + "sub x20, x20, #0x4\n" + "cmp x20, #0x4\n" + "str q16, [x21, #0x0]\n" + "add x21, x21, #0x10\n" + "bge 16b\n" + "17:" // Tail row loop: width 4 loop: skip + "cmp x20, #0x1\n" + "blt 19f\n" + "18:" // Tail row loop: width 1 loop: loop + "ldr s16, [x25], #0x4\n" + "sub x20, x20, #0x1\n" + "cmp x20, #0x1\n" + "str s16, [x21, #0x0]\n" + "add x21, x21, #0x4\n" + "bge 18b\n" + "19:" // Tail row loop: width 1 loop: skip + "cmp %x[height], #0x1\n" + "add %x[out], %x[out], #0x60\n" + "bge 11b\n" + "20:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23", "x24", "x25" + ); +} + +} // anonymous namespace + +template<> +void Transform<24, 1, true, VLType::None>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + a64_transpose_interleave_96( + reinterpret_cast<uint32_t *>(out), + reinterpret_cast<const uint32_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 4, + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(__aarch64__) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/list-sve.hpp b/src/core/NEON/kernels/arm_gemm/transforms/list-sve.hpp new file mode 100644 index 0000000000..1e6c3d35f4 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/list-sve.hpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021-2023,2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SME +#include "sme_transpose_interleave_16VL_1x4.hpp" +#include "sme_transpose_interleave_16VL_2x2_fp32bf16.hpp" +#include "sme_transpose_interleave_16VL_2x2.hpp" +#include "sme_transpose_interleave_16VL.hpp" +#include "sme_transpose_interleave_1VL_1x4.hpp" +#include "sme_transpose_interleave_1VL_2x2_fp32bf16.hpp" +#include "sme_transpose_interleave_1VL_2x2.hpp" +#include "sme_transpose_interleave_1VL.hpp" +#include "sme_transpose_interleave_2VL_1x4.hpp" +#include "sme_transpose_interleave_2VL_2x2.hpp" +#include "sme_transpose_interleave_2VL_2x2_fp32bf16.hpp" +#include "sme_transpose_interleave_2VL.hpp" +#include "sme_transpose_interleave_4VL_1x4.hpp" +#include "sme_transpose_interleave_4VL_2x2.hpp" +#include "sme_transpose_interleave_4VL_2x2_fp32bf16.hpp" +#include "sme_transpose_interleave_4VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SME +#include "sve_transpose_interleave_12VL_2x4_fp32bf16.hpp" +#include "sve_transpose_interleave_1VL_1x4.hpp" +#include "sve_transpose_interleave_1VL.hpp" +#include "sve_transpose_interleave_2VL_2x4_fp32bf16.hpp" +#include "sve_transpose_interleave_3VL_1x4.hpp" +#include "sve_transpose_interleave_3VL_2x2.hpp" +#include "sve_transpose_interleave_3VL.hpp" +#include "sve_transpose_interleave_4VL_1x4.hpp" +#include "sve_transpose_interleave_4VL_2x2.hpp" +#include "sve_transpose_interleave_4VL.hpp" +#include "sve_transpose_interleave_6VL_1x8.hpp" +#include "sve_transpose_interleave_6VL_2x4_fp32bf16.hpp" +#include "sve_transpose_interleave_6VL_2x4.hpp" +#include "sve_transpose_interleave_6VL_4x2.hpp" +#include "sve_transpose_interleave_8VL_1x4.hpp" +#include "sve_transpose_interleave_8VL_1x8.hpp" +#include "sve_transpose_interleave_8VL_2x2.hpp" +#include "sve_transpose_interleave_8VL_2x4.hpp" +#include "sve_transpose_interleave_8VL_2x4_fp32bf16.hpp" +#include "sve_transpose_interleave_8VL.hpp" diff --git a/src/core/NEON/kernels/arm_gemm/transforms/list.hpp b/src/core/NEON/kernels/arm_gemm/transforms/list.hpp index be66cd42ff..1ce319efee 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/list.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/list.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2020,2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,19 +21,30 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#include "a32_interleave_6way_32bit.hpp" #include "a32_transpose_interleave_8way_32bit.hpp" -#include "a64_block16_interleave4_8bit.hpp" -#include "a64_interleave_8way_16bit.hpp" -#include "a64_interleave_8way_32bit.hpp" -#include "a64_interleave_8way_block4_8bit.hpp" -#include "a64_interleave_8way_half_to_float.hpp" -#include "a64_transpose_interleave_12way_16bit.hpp" -#include "a64_transpose_interleave_12way_half_to_float.hpp" -#include "a64_transpose_interleave_24way_16bit.hpp" -#include "a64_transpose_interleave_8way_32bit.hpp" -#include "sve_interleave_8way_32bit.hpp" -#include "sve_interleave_8way_block2_16bit.hpp" -#include "sve_interleave_8way_block4_16bit.hpp" -#include "sve_interleave_8way_block4_8bit.hpp" -#include "sve_interleave_8way_block8_8bit.hpp" +#include "a64_transpose_interleave_12_1x4.hpp" +#include "a64_transpose_interleave_12_1x8.hpp" +#include "a64_transpose_interleave_12_2x2.hpp" +#include "a64_transpose_interleave_12_2x4_fp32bf16.hpp" +#include "a64_transpose_interleave_12_2x4.hpp" +#include "a64_transpose_interleave_128.hpp" +#include "a64_transpose_interleave_12_s8s16.hpp" +#include "a64_transpose_interleave_12_u8u16.hpp" +#include "a64_transpose_interleave_16_1x4.hpp" +#include "a64_transpose_interleave_16_1x8.hpp" +#include "a64_transpose_interleave_16_2x2.hpp" +#include "a64_transpose_interleave_16_2x4.hpp" +#include "a64_transpose_interleave_16_2x4_fp32bf16.hpp" +#include "a64_transpose_interleave_16.hpp" +#include "a64_transpose_interleave_24_bf16fp32.hpp" +#include "a64_transpose_interleave_24_fp16fp32.hpp" +#include "a64_transpose_interleave_24_2x4_fp32bf16.hpp" +#include "a64_transpose_interleave_24.hpp" +#include "a64_transpose_interleave_32_1x4.hpp" +#include "a64_transpose_interleave_32_2x2.hpp" +#include "a64_transpose_interleave_4_1x16.hpp" +#include "a64_transpose_interleave_4_1x4.hpp" +#include "a64_transpose_interleave_4_2x4_fp32bf16.hpp" +#include "a64_transpose_interleave_48.hpp" +#include "a64_transpose_interleave_64.hpp" +#include "a64_transpose_interleave_96.hpp" diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL.hpp new file mode 100644 index 0000000000..a4d480c405 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL.hpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_16VL(uint32_t *out, const uint32_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 16 * height * sme::get_vector_length<uint8_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p7.b\n" + "1:" // Main row loop: Head + "mov x23, %x[in]\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "mov x21, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x20, x21\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z31.s }, p0/Z, [x23]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z30.s }, p0/Z, [x23, #1, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z29.s }, p0/Z, [x23, #2, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z28.s }, p0/Z, [x23, #3, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z27.s }, p0/Z, [x23, #4, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z26.s }, p0/Z, [x23, #5, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z25.s }, p0/Z, [x23, #6, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z24.s }, p0/Z, [x23, #7, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "decw x20\n" + "whilelt p6.s, XZR, x20\n" + "decw x20\n" + "whilelt p5.s, XZR, x20\n" + "decw x20\n" + "whilelt p4.s, XZR, x20\n" + "decw x20\n" + "whilelt p3.s, XZR, x20\n" + "decw x20\n" + "whilelt p2.s, XZR, x20\n" + "decw x20\n" + "whilelt p1.s, XZR, x20\n" + "decw x20\n" + "addvl x23, x23, #16\n" + "ld1w { z23.s }, p0/Z, [x23, #-8, MUL VL]\n" + "whilelt p0.s, XZR, x20\n" + "mov x20, x22\n" + "ld1w { z22.s }, p6/Z, [x23, #-7, MUL VL]\n" + "decw x21, ALL, MUL #16\n" + "ld1w { z21.s }, p5/Z, [x23, #-6, MUL VL]\n" + "cmp x21, #0x0\n" + "ld1w { z20.s }, p4/Z, [x23, #-5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "ld1w { z19.s }, p3/Z, [x23, #-4, MUL VL]\n" + "ld1w { z18.s }, p2/Z, [x23, #-3, MUL VL]\n" + "ld1w { z17.s }, p1/Z, [x23, #-2, MUL VL]\n" + "ld1w { z16.s }, p0/Z, [x23, #-1, MUL VL]\n" + "st1w { z31.s }, p7, [x20]\n" + "st1w { z30.s }, p7, [x20, #1, MUL VL]\n" + "st1w { z29.s }, p7, [x20, #2, MUL VL]\n" + "st1w { z28.s }, p7, [x20, #3, MUL VL]\n" + "st1w { z27.s }, p7, [x20, #4, MUL VL]\n" + "st1w { z26.s }, p7, [x20, #5, MUL VL]\n" + "st1w { z25.s }, p7, [x20, #6, MUL VL]\n" + "st1w { z24.s }, p7, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1w { z23.s }, p7, [x20, #-8, MUL VL]\n" + "st1w { z22.s }, p7, [x20, #-7, MUL VL]\n" + "st1w { z21.s }, p7, [x20, #-6, MUL VL]\n" + "st1w { z20.s }, p7, [x20, #-5, MUL VL]\n" + "st1w { z19.s }, p7, [x20, #-4, MUL VL]\n" + "st1w { z18.s }, p7, [x20, #-3, MUL VL]\n" + "st1w { z17.s }, p7, [x20, #-2, MUL VL]\n" + "st1w { z16.s }, p7, [x20, #-1, MUL VL]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<16, 1, true, VLType::SME>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_16VL( + reinterpret_cast<uint32_t *>(out), + reinterpret_cast<const uint32_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 4, + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL_1x4.hpp new file mode 100644 index 0000000000..552abfc1c6 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL_1x4.hpp @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_16VL_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 16 * roundup<size_t>(height, 4) * sme::get_vector_length<uint32_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p4.b\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x23, %x[in_stride]\n" + "csel x23, x23, %x[pad_row], GT\n" + "csel x24, x24, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "mov x22, %x[out]\n" + "csel x25, x25, %x[pad_row], GT\n" + "sub %x[height], %x[height], #0x4\n" + "mov x21, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x20, x21\n" + "whilelt p3.b, XZR, x20\n" + "ld1b { z20.b }, p3/Z, [x26]\n" + "decb x20\n" + "whilelt p2.b, XZR, x20\n" + "ld1b { z18.b }, p2/Z, [x26, #1, MUL VL]\n" + "decb x20\n" + "whilelt p1.b, XZR, x20\n" + "ld1b { z17.b }, p3/Z, [x25]\n" + "decb x20\n" + "whilelt p0.b, XZR, x20\n" + "ld1b { z19.b }, p2/Z, [x25, #1, MUL VL]\n" + "ld1b { z16.b }, p3/Z, [x24]\n" + "zip1 z25.b, z20.b, z16.b\n" + "zip2 z24.b, z20.b, z16.b\n" + "mov x20, x22\n" + "ld1b { z16.b }, p2/Z, [x24, #1, MUL VL]\n" + "zip1 z22.b, z18.b, z16.b\n" + "zip2 z21.b, z18.b, z16.b\n" + "decw x21, ALL, MUL #16\n" + "ld1b { z16.b }, p3/Z, [x23]\n" + "zip1 z18.b, z17.b, z16.b\n" + "zip2 z17.b, z17.b, z16.b\n" + "cmp x21, #0x0\n" + "ld1b { z16.b }, p2/Z, [x23, #1, MUL VL]\n" + "zip1 z20.b, z19.b, z16.b\n" + "zip2 z16.b, z19.b, z16.b\n" + "add x22, x22, %x[out_stride]\n" + "ld1b { z19.b }, p1/Z, [x26, #2, MUL VL]\n" + "zip1 z23.b, z25.b, z18.b\n" + "zip2 z0.b, z25.b, z18.b\n" + "ld1b { z18.b }, p0/Z, [x26, #3, MUL VL]\n" + "zip1 z31.b, z24.b, z17.b\n" + "zip2 z30.b, z24.b, z17.b\n" + "addvl x26, x26, #4\n" + "ld1b { z17.b }, p1/Z, [x25, #2, MUL VL]\n" + "zip1 z29.b, z22.b, z20.b\n" + "zip2 z28.b, z22.b, z20.b\n" + "ld1b { z22.b }, p0/Z, [x25, #3, MUL VL]\n" + "zip1 z27.b, z21.b, z16.b\n" + "zip2 z26.b, z21.b, z16.b\n" + "addvl x25, x25, #4\n" + "ld1b { z16.b }, p1/Z, [x24, #2, MUL VL]\n" + "zip1 z21.b, z19.b, z16.b\n" + "zip2 z20.b, z19.b, z16.b\n" + "ld1b { z16.b }, p0/Z, [x24, #3, MUL VL]\n" + "zip1 z25.b, z18.b, z16.b\n" + "zip2 z24.b, z18.b, z16.b\n" + "addvl x24, x24, #4\n" + "ld1b { z16.b }, p1/Z, [x23, #2, MUL VL]\n" + "zip1 z19.b, z17.b, z16.b\n" + "zip2 z18.b, z17.b, z16.b\n" + "ld1b { z16.b }, p0/Z, [x23, #3, MUL VL]\n" + "zip1 z17.b, z22.b, z16.b\n" + "zip2 z16.b, z22.b, z16.b\n" + "addvl x23, x23, #4\n" + "st1b { z23.b }, p4, [x20]\n" + "zip1 z23.b, z21.b, z19.b\n" + "zip2 z22.b, z21.b, z19.b\n" + "st1b { z0.b }, p4, [x20, #1, MUL VL]\n" + "zip1 z21.b, z20.b, z18.b\n" + "zip2 z20.b, z20.b, z18.b\n" + "st1b { z31.b }, p4, [x20, #2, MUL VL]\n" + "zip1 z19.b, z25.b, z17.b\n" + "zip2 z18.b, z25.b, z17.b\n" + "st1b { z30.b }, p4, [x20, #3, MUL VL]\n" + "zip1 z17.b, z24.b, z16.b\n" + "zip2 z16.b, z24.b, z16.b\n" + "st1b { z29.b }, p4, [x20, #4, MUL VL]\n" + "st1b { z28.b }, p4, [x20, #5, MUL VL]\n" + "st1b { z27.b }, p4, [x20, #6, MUL VL]\n" + "st1b { z26.b }, p4, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1b { z23.b }, p4, [x20, #-8, MUL VL]\n" + "st1b { z22.b }, p4, [x20, #-7, MUL VL]\n" + "st1b { z21.b }, p4, [x20, #-6, MUL VL]\n" + "st1b { z20.b }, p4, [x20, #-5, MUL VL]\n" + "st1b { z19.b }, p4, [x20, #-4, MUL VL]\n" + "st1b { z18.b }, p4, [x20, #-3, MUL VL]\n" + "st1b { z17.b }, p4, [x20, #-2, MUL VL]\n" + "st1b { z16.b }, p4, [x20, #-1, MUL VL]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<16, 4, true, VLType::SME>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_16VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<16, 4, true, VLType::SME>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_16VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL_2x2.hpp new file mode 100644 index 0000000000..dac6b06f1e --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL_2x2.hpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_16VL_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 16 * roundup<size_t>(height, 2) * sme::get_vector_length<uint16_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p6.b\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "cmp %x[height], #0x1\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[out]\n" + "add %x[in], x24, %x[in_stride]\n" + "csel x24, x24, %x[pad_row], GT\n" + "sub %x[height], %x[height], #0x2\n" + "mov x22, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x21, x22\n" + "mov x20, x23\n" + "whilelt p1.h, XZR, x21\n" + "dech x21\n" + "whilelt p0.h, XZR, x21\n" + "dech x21\n" + "ld1h { z21.h }, p1/Z, [x25]\n" + "whilelt p5.h, XZR, x21\n" + "dech x21\n" + "ld1h { z20.h }, p0/Z, [x25, #1, MUL VL]\n" + "whilelt p4.h, XZR, x21\n" + "dech x21\n" + "ld1h { z25.h }, p5/Z, [x25, #2, MUL VL]\n" + "whilelt p3.h, XZR, x21\n" + "dech x21\n" + "ld1h { z24.h }, p4/Z, [x25, #3, MUL VL]\n" + "whilelt p2.h, XZR, x21\n" + "dech x21\n" + "ld1h { z19.h }, p1/Z, [x24]\n" + "whilelt p1.h, XZR, x21\n" + "dech x21\n" + "ld1h { z18.h }, p0/Z, [x24, #1, MUL VL]\n" + "whilelt p0.h, XZR, x21\n" + "ld1h { z17.h }, p5/Z, [x24, #2, MUL VL]\n" + "decw x22, ALL, MUL #16\n" + "ld1h { z16.h }, p4/Z, [x24, #3, MUL VL]\n" + "zip1 z23.h, z21.h, z19.h\n" + "zip2 z22.h, z21.h, z19.h\n" + "cmp x22, #0x0\n" + "ld1h { z21.h }, p3/Z, [x25, #4, MUL VL]\n" + "zip1 z31.h, z20.h, z18.h\n" + "zip2 z30.h, z20.h, z18.h\n" + "add x23, x23, %x[out_stride]\n" + "ld1h { z20.h }, p2/Z, [x25, #5, MUL VL]\n" + "zip1 z29.h, z25.h, z17.h\n" + "zip2 z28.h, z25.h, z17.h\n" + "ld1h { z27.h }, p1/Z, [x25, #6, MUL VL]\n" + "zip1 z26.h, z24.h, z16.h\n" + "zip2 z25.h, z24.h, z16.h\n" + "ld1h { z24.h }, p0/Z, [x25, #7, MUL VL]\n" + "addvl x25, x25, #8\n" + "ld1h { z19.h }, p3/Z, [x24, #4, MUL VL]\n" + "ld1h { z18.h }, p2/Z, [x24, #5, MUL VL]\n" + "ld1h { z17.h }, p1/Z, [x24, #6, MUL VL]\n" + "ld1h { z16.h }, p0/Z, [x24, #7, MUL VL]\n" + "st1h { z23.h }, p6, [x20]\n" + "addvl x24, x24, #8\n" + "zip1 z23.h, z21.h, z19.h\n" + "st1h { z22.h }, p6, [x20, #1, MUL VL]\n" + "zip2 z22.h, z21.h, z19.h\n" + "zip1 z21.h, z20.h, z18.h\n" + "st1h { z31.h }, p6, [x20, #2, MUL VL]\n" + "zip2 z20.h, z20.h, z18.h\n" + "zip1 z19.h, z27.h, z17.h\n" + "st1h { z30.h }, p6, [x20, #3, MUL VL]\n" + "zip2 z18.h, z27.h, z17.h\n" + "zip1 z17.h, z24.h, z16.h\n" + "st1h { z29.h }, p6, [x20, #4, MUL VL]\n" + "zip2 z16.h, z24.h, z16.h\n" + "st1h { z28.h }, p6, [x20, #5, MUL VL]\n" + "st1h { z26.h }, p6, [x20, #6, MUL VL]\n" + "st1h { z25.h }, p6, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1h { z23.h }, p6, [x20, #-8, MUL VL]\n" + "st1h { z22.h }, p6, [x20, #-7, MUL VL]\n" + "st1h { z21.h }, p6, [x20, #-6, MUL VL]\n" + "st1h { z20.h }, p6, [x20, #-5, MUL VL]\n" + "st1h { z19.h }, p6, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p6, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p6, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p6, [x20, #-1, MUL VL]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<16, 2, true, VLType::SME>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_16VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + +template<> +void Transform<16, 2, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_16VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL_2x2_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL_2x2_fp32bf16.hpp new file mode 100644 index 0000000000..2756327815 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_16VL_2x2_fp32bf16.hpp @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_16VL_2x2_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 16 * roundup<size_t>(height, 2) * sme::get_vector_length<uint16_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p7.b\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "add x24, x25, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x24, %x[in_stride]\n" + "mov x23, %x[out]\n" + "csel x24, x24, %x[pad_row], GT\n" + "sub %x[height], %x[height], #0x2\n" + "mov x22, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x21, x22\n" + "whilelt p1.s, XZR, x21\n" + "ld1w { z16.s }, p1/Z, [x25]\n" + ".inst 0x658abe00 // bfcvt z0.h, p7/M, z16.s\n" + "decw x21\n" + "whilelt p0.s, XZR, x21\n" + "ld1w { z16.s }, p0/Z, [x25, #1, MUL VL]\n" + ".inst 0x658abe1f // bfcvt z31.h, p7/M, z16.s\n" + "decw x21\n" + "whilelt p6.s, XZR, x21\n" + "ld1w { z16.s }, p6/Z, [x25, #2, MUL VL]\n" + ".inst 0x658abe1e // bfcvt z30.h, p7/M, z16.s\n" + "decw x21\n" + "whilelt p5.s, XZR, x21\n" + "ld1w { z16.s }, p5/Z, [x25, #3, MUL VL]\n" + ".inst 0x658abe1d // bfcvt z29.h, p7/M, z16.s\n" + "decw x21\n" + "whilelt p4.s, XZR, x21\n" + "ld1w { z16.s }, p4/Z, [x25, #4, MUL VL]\n" + ".inst 0x658abe1c // bfcvt z28.h, p7/M, z16.s\n" + "decw x21\n" + "whilelt p3.s, XZR, x21\n" + "ld1w { z16.s }, p3/Z, [x25, #5, MUL VL]\n" + ".inst 0x658abe1b // bfcvt z27.h, p7/M, z16.s\n" + "decw x21\n" + "whilelt p2.s, XZR, x21\n" + "ld1w { z16.s }, p2/Z, [x25, #6, MUL VL]\n" + ".inst 0x658abe1a // bfcvt z26.h, p7/M, z16.s\n" + "decw x21\n" + "ld1w { z16.s }, p1/Z, [x24]\n" + "whilelt p1.s, XZR, x21\n" + ".inst 0x648abe00 // bfcvtnt z0.h, p7/M, z16.s\n" + "decw x21\n" + "ld1w { z16.s }, p1/Z, [x25, #7, MUL VL]\n" + "addvl x25, x25, #16\n" + ".inst 0x658abe19 // bfcvt z25.h, p7/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x24, #1, MUL VL]\n" + "whilelt p0.s, XZR, x21\n" + "decw x21\n" + ".inst 0x648abe1f // bfcvtnt z31.h, p7/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x25, #-8, MUL VL]\n" + ".inst 0x658abe18 // bfcvt z24.h, p7/M, z16.s\n" + "mov x20, x23\n" + "decw x22, ALL, MUL #16\n" + "ld1w { z16.s }, p6/Z, [x24, #2, MUL VL]\n" + "whilelt p6.s, XZR, x21\n" + "decw x21\n" + ".inst 0x648abe1e // bfcvtnt z30.h, p7/M, z16.s\n" + "ld1w { z16.s }, p6/Z, [x25, #-7, MUL VL]\n" + ".inst 0x658abe17 // bfcvt z23.h, p7/M, z16.s\n" + "add x23, x23, %x[out_stride]\n" + "ld1w { z16.s }, p5/Z, [x24, #3, MUL VL]\n" + "whilelt p5.s, XZR, x21\n" + "decw x21\n" + ".inst 0x648abe1d // bfcvtnt z29.h, p7/M, z16.s\n" + "ld1w { z16.s }, p5/Z, [x25, #-6, MUL VL]\n" + ".inst 0x658abe16 // bfcvt z22.h, p7/M, z16.s\n" + "ld1w { z16.s }, p4/Z, [x24, #4, MUL VL]\n" + "whilelt p4.s, XZR, x21\n" + "decw x21\n" + ".inst 0x648abe1c // bfcvtnt z28.h, p7/M, z16.s\n" + "ld1w { z16.s }, p4/Z, [x25, #-5, MUL VL]\n" + ".inst 0x658abe15 // bfcvt z21.h, p7/M, z16.s\n" + "ld1w { z16.s }, p3/Z, [x24, #5, MUL VL]\n" + "whilelt p3.s, XZR, x21\n" + "decw x21\n" + ".inst 0x648abe1b // bfcvtnt z27.h, p7/M, z16.s\n" + "ld1w { z16.s }, p3/Z, [x25, #-4, MUL VL]\n" + ".inst 0x658abe14 // bfcvt z20.h, p7/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x24, #6, MUL VL]\n" + "whilelt p2.s, XZR, x21\n" + "decw x21\n" + ".inst 0x648abe1a // bfcvtnt z26.h, p7/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x25, #-3, MUL VL]\n" + ".inst 0x658abe13 // bfcvt z19.h, p7/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x24, #7, MUL VL]\n" + "whilelt p1.s, XZR, x21\n" + "decw x21\n" + ".inst 0x648abe19 // bfcvtnt z25.h, p7/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x25, #-2, MUL VL]\n" + "addvl x24, x24, #16\n" + ".inst 0x658abe12 // bfcvt z18.h, p7/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x24, #-8, MUL VL]\n" + "whilelt p0.s, XZR, x21\n" + "cmp x22, #0x0\n" + ".inst 0x648abe18 // bfcvtnt z24.h, p7/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x25, #-1, MUL VL]\n" + ".inst 0x658abe11 // bfcvt z17.h, p7/M, z16.s\n" + "ld1w { z16.s }, p6/Z, [x24, #-7, MUL VL]\n" + ".inst 0x648abe17 // bfcvtnt z23.h, p7/M, z16.s\n" + "ld1w { z16.s }, p5/Z, [x24, #-6, MUL VL]\n" + ".inst 0x648abe16 // bfcvtnt z22.h, p7/M, z16.s\n" + "ld1w { z16.s }, p4/Z, [x24, #-5, MUL VL]\n" + ".inst 0x648abe15 // bfcvtnt z21.h, p7/M, z16.s\n" + "ld1w { z16.s }, p3/Z, [x24, #-4, MUL VL]\n" + ".inst 0x648abe14 // bfcvtnt z20.h, p7/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x24, #-3, MUL VL]\n" + ".inst 0x648abe13 // bfcvtnt z19.h, p7/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x24, #-2, MUL VL]\n" + ".inst 0x648abe12 // bfcvtnt z18.h, p7/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x24, #-1, MUL VL]\n" + "st1h { z0.h }, p7, [x20]\n" + ".inst 0x648abe11 // bfcvtnt z17.h, p7/M, z16.s\n" + "st1h { z31.h }, p7, [x20, #1, MUL VL]\n" + "st1h { z30.h }, p7, [x20, #2, MUL VL]\n" + "st1h { z29.h }, p7, [x20, #3, MUL VL]\n" + "st1h { z28.h }, p7, [x20, #4, MUL VL]\n" + "st1h { z27.h }, p7, [x20, #5, MUL VL]\n" + "st1h { z26.h }, p7, [x20, #6, MUL VL]\n" + "st1h { z25.h }, p7, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1h { z24.h }, p7, [x20, #-8, MUL VL]\n" + "st1h { z23.h }, p7, [x20, #-7, MUL VL]\n" + "st1h { z22.h }, p7, [x20, #-6, MUL VL]\n" + "st1h { z21.h }, p7, [x20, #-5, MUL VL]\n" + "st1h { z20.h }, p7, [x20, #-4, MUL VL]\n" + "st1h { z19.h }, p7, [x20, #-3, MUL VL]\n" + "st1h { z18.h }, p7, [x20, #-2, MUL VL]\n" + "st1h { z17.h }, p7, [x20, #-1, MUL VL]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace +template<> +void Transform<16, 2, true, VLType::SME>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_16VL_2x2_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL.hpp new file mode 100644 index 0000000000..a6ddb8fec0 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL.hpp @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_1VL(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 1 * height * sme::get_vector_length<uint8_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x4\n" + "ptrue p1.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[width]\n" + "cnth x21, ALL, MUL #4\n" + "add x20, x24, %x[in_stride]\n" + "cmp x23, x21\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1h { z31.h }, p1/Z, [x26]\n" + "sub x23, x23, x21\n" + "cmp x23, x21\n" + "ld1h { z30.h }, p1/Z, [x26, #1, MUL VL]\n" + "ld1h { z29.h }, p1/Z, [x26, #2, MUL VL]\n" + "ld1h { z28.h }, p1/Z, [x26, #3, MUL VL]\n" + "addvl x26, x26, #4\n" + "ld1h { z27.h }, p1/Z, [x25]\n" + "ld1h { z26.h }, p1/Z, [x25, #1, MUL VL]\n" + "ld1h { z25.h }, p1/Z, [x25, #2, MUL VL]\n" + "ld1h { z24.h }, p1/Z, [x25, #3, MUL VL]\n" + "addvl x25, x25, #4\n" + "ld1h { z23.h }, p1/Z, [x24]\n" + "ld1h { z22.h }, p1/Z, [x24, #1, MUL VL]\n" + "ld1h { z21.h }, p1/Z, [x24, #2, MUL VL]\n" + "ld1h { z20.h }, p1/Z, [x24, #3, MUL VL]\n" + "addvl x24, x24, #4\n" + "ld1h { z19.h }, p1/Z, [x20]\n" + "ld1h { z18.h }, p1/Z, [x20, #1, MUL VL]\n" + "ld1h { z17.h }, p1/Z, [x20, #2, MUL VL]\n" + "ld1h { z16.h }, p1/Z, [x20, #3, MUL VL]\n" + "st1h { z31.h }, p1, [x22]\n" + "addvl x20, x20, #4\n" + "st1h { z27.h }, p1, [x22, #1, MUL VL]\n" + "st1h { z23.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z19.h }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z30.h }, p1, [x22]\n" + "st1h { z26.h }, p1, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z18.h }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z29.h }, p1, [x22]\n" + "st1h { z25.h }, p1, [x22, #1, MUL VL]\n" + "st1h { z21.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z17.h }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z28.h }, p1, [x22]\n" + "st1h { z24.h }, p1, [x22, #1, MUL VL]\n" + "st1h { z20.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z16.h }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x23, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.h, XZR, x23\n" + "dech x23\n" + "ld1h { z19.h }, p0/Z, [x26]\n" + "cmp x23, #0x0\n" + "addvl x26, x26, #1\n" + "ld1h { z18.h }, p0/Z, [x25]\n" + "addvl x25, x25, #1\n" + "ld1h { z17.h }, p0/Z, [x24]\n" + "addvl x24, x24, #1\n" + "ld1h { z16.h }, p0/Z, [x20]\n" + "addvl x20, x20, #1\n" + "st1h { z19.h }, p1, [x22]\n" + "st1h { z18.h }, p1, [x22, #1, MUL VL]\n" + "st1h { z17.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z16.h }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #4\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #4\n" + "mov x26, %x[in]\n" + "cmp x21, x20\n" + "add %x[in], x26, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1h { z19.h }, p1/Z, [x26]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1h { z18.h }, p1/Z, [x26, #1, MUL VL]\n" + "ld1h { z17.h }, p1/Z, [x26, #2, MUL VL]\n" + "ld1h { z16.h }, p1/Z, [x26, #3, MUL VL]\n" + "st1h { z19.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "addvl x26, x26, #4\n" + "st1h { z18.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z17.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z16.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "whilelt p0.h, XZR, x21\n" + "dech x21\n" + "ld1h { z16.h }, p0/Z, [x26]\n" + "st1h { z16.h }, p1, [x22]\n" + "cmp x21, #0x0\n" + "addvl x26, x26, #1\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #1\n" + "bge 7b\n" + "12:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<1, 1, true, VLType::SME>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_1VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 2, + stride * sizeof(float), + (kmax-k0) + ); +} + +template<> +void Transform<1, 1, true, VLType::SME>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_1VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + +template<> +void Transform<1, 1, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_1VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_1x4.hpp new file mode 100644 index 0000000000..399a52e233 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_1x4.hpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_1VL_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 1 * roundup<size_t>(height, 4) * sme::get_vector_length<uint32_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p1.b\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x23, %x[in_stride]\n" + "csel x23, x23, %x[pad_row], GT\n" + "csel x24, x24, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "mov x22, %x[width]\n" + "cntb x21\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x22, x21\n" + "mov x20, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1b { z17.b }, p1/Z, [x26]\n" + "sub x22, x22, x21\n" + "cmp x22, x21\n" + "ld1b { z18.b }, p1/Z, [x25]\n" + "addvl x26, x26, #1\n" + "addvl x25, x25, #1\n" + "ld1b { z16.b }, p1/Z, [x24]\n" + "zip1 z20.b, z17.b, z16.b\n" + "zip2 z19.b, z17.b, z16.b\n" + "addvl x24, x24, #1\n" + "ld1b { z16.b }, p1/Z, [x23]\n" + "zip1 z17.b, z18.b, z16.b\n" + "zip2 z18.b, z18.b, z16.b\n" + "addvl x23, x23, #1\n" + "zip1 z16.b, z20.b, z17.b\n" + "st1b { z16.b }, p1, [x20]\n" + "add x20, x20, %x[out_stride]\n" + "zip2 z16.b, z20.b, z17.b\n" + "st1b { z16.b }, p1, [x20]\n" + "add x20, x20, %x[out_stride]\n" + "zip1 z17.b, z19.b, z18.b\n" + "zip2 z16.b, z19.b, z18.b\n" + "st1b { z17.b }, p1, [x20]\n" + "add x20, x20, %x[out_stride]\n" + "st1b { z16.b }, p1, [x20]\n" + "add x20, x20, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x22, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.b, XZR, x22\n" + "ld1b { z17.b }, p0/Z, [x26]\n" + "decw x22\n" + "ld1b { z18.b }, p0/Z, [x25]\n" + "cmp x22, #0x0\n" + "incd x26, ALL, MUL #2\n" + "ld1b { z16.b }, p0/Z, [x24]\n" + "zip1 z17.b, z17.b, z16.b\n" + "incd x25, ALL, MUL #2\n" + "incd x24, ALL, MUL #2\n" + "ld1b { z16.b }, p0/Z, [x23]\n" + "zip1 z16.b, z18.b, z16.b\n" + "incd x23, ALL, MUL #2\n" + "zip1 z16.b, z17.b, z16.b\n" + "st1b { z16.b }, p1, [x20]\n" + "add x20, x20, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #1\n" + "bge 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<1, 4, true, VLType::SME>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_1VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<1, 4, true, VLType::SME>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_1VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_2x2.hpp new file mode 100644 index 0000000000..6318e29a79 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_2x2.hpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_1VL_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 1 * roundup<size_t>(height, 2) * sme::get_vector_length<uint16_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x4\n" + "ptrue p1.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[width]\n" + "cnth x21, ALL, MUL #2\n" + "add x20, x24, %x[in_stride]\n" + "cmp x23, x21\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1h { z17.h }, p1/Z, [x26]\n" + "sub x23, x23, x21\n" + "cmp x23, x21\n" + "ld1h { z16.h }, p1/Z, [x25]\n" + "zip1 z24.h, z17.h, z16.h\n" + "zip2 z23.h, z17.h, z16.h\n" + "ld1h { z17.h }, p1/Z, [x24]\n" + "ld1h { z16.h }, p1/Z, [x20]\n" + "zip1 z22.h, z17.h, z16.h\n" + "zip2 z21.h, z17.h, z16.h\n" + "ld1h { z17.h }, p1/Z, [x26, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "ld1h { z16.h }, p1/Z, [x25, #1, MUL VL]\n" + "zip1 z20.h, z17.h, z16.h\n" + "addvl x25, x25, #2\n" + "zip2 z19.h, z17.h, z16.h\n" + "ld1h { z18.h }, p1/Z, [x24, #1, MUL VL]\n" + "addvl x24, x24, #2\n" + "ld1h { z16.h }, p1/Z, [x20, #1, MUL VL]\n" + "st1h { z24.h }, p1, [x22]\n" + "zip1 z17.h, z18.h, z16.h\n" + "addvl x20, x20, #2\n" + "st1h { z22.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z16.h, z18.h, z16.h\n" + "st1h { z23.h }, p1, [x22]\n" + "st1h { z21.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z20.h }, p1, [x22]\n" + "st1h { z17.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z19.h }, p1, [x22]\n" + "st1h { z16.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x23, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.h, XZR, x23\n" + "ld1h { z17.h }, p0/Z, [x26]\n" + "decw x23\n" + "ld1h { z16.h }, p0/Z, [x25]\n" + "cmp x23, #0x0\n" + "incd x26, ALL, MUL #4\n" + "zip1 z18.h, z17.h, z16.h\n" + "ld1h { z17.h }, p0/Z, [x24]\n" + "incd x25, ALL, MUL #4\n" + "incd x24, ALL, MUL #4\n" + "ld1h { z16.h }, p0/Z, [x20]\n" + "incd x20, ALL, MUL #4\n" + "zip1 z16.h, z17.h, z16.h\n" + "st1h { z18.h }, p1, [x22]\n" + "st1h { z16.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #2\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #2\n" + "add %x[in], x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1h { z18.h }, p1/Z, [x26]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1h { z16.h }, p1/Z, [x25]\n" + "zip1 z17.h, z18.h, z16.h\n" + "zip2 z19.h, z18.h, z16.h\n" + "ld1h { z18.h }, p1/Z, [x26, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "ld1h { z16.h }, p1/Z, [x25, #1, MUL VL]\n" + "st1h { z17.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z17.h, z18.h, z16.h\n" + "st1h { z19.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "addvl x25, x25, #2\n" + "zip2 z16.h, z18.h, z16.h\n" + "st1h { z17.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z16.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "whilelt p0.h, XZR, x21\n" + "ld1h { z17.h }, p0/Z, [x26]\n" + "decw x21\n" + "ld1h { z16.h }, p0/Z, [x25]\n" + "cmp x21, #0x0\n" + "incd x26, ALL, MUL #4\n" + "zip1 z16.h, z17.h, z16.h\n" + "incd x25, ALL, MUL #4\n" + "st1h { z16.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #1\n" + "bge 7b\n" + "12:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<1, 2, true, VLType::SME>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_1VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + +template<> +void Transform<1, 2, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_1VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_2x2_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_2x2_fp32bf16.hpp new file mode 100644 index 0000000000..b90063028d --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_1VL_2x2_fp32bf16.hpp @@ -0,0 +1,222 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_1VL_2x2_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 1 * roundup<size_t>(height, 2) * sme::get_vector_length<uint16_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x4\n" + "ptrue p1.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[width]\n" + "cnth x21, ALL, MUL #2\n" + "add x20, x24, %x[in_stride]\n" + "cmp x23, x21\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1w { z16.s }, p1/Z, [x26]\n" + ".inst 0x658aa618 // bfcvt z24.h, p1/M, z16.s\n" + "sub x23, x23, x21\n" + "cmp x23, x21\n" + "ld1w { z16.s }, p1/Z, [x24]\n" + ".inst 0x658aa617 // bfcvt z23.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x26, #1, MUL VL]\n" + ".inst 0x658aa616 // bfcvt z22.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x24, #1, MUL VL]\n" + ".inst 0x658aa615 // bfcvt z21.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x26, #2, MUL VL]\n" + ".inst 0x658aa614 // bfcvt z20.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x24, #2, MUL VL]\n" + ".inst 0x658aa613 // bfcvt z19.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x26, #3, MUL VL]\n" + ".inst 0x658aa612 // bfcvt z18.h, p1/M, z16.s\n" + "addvl x26, x26, #4\n" + "ld1w { z16.s }, p1/Z, [x24, #3, MUL VL]\n" + ".inst 0x658aa611 // bfcvt z17.h, p1/M, z16.s\n" + "addvl x24, x24, #4\n" + "ld1w { z16.s }, p1/Z, [x25]\n" + ".inst 0x648aa618 // bfcvtnt z24.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x20]\n" + ".inst 0x648aa617 // bfcvtnt z23.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x25, #1, MUL VL]\n" + ".inst 0x648aa616 // bfcvtnt z22.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x20, #1, MUL VL]\n" + ".inst 0x648aa615 // bfcvtnt z21.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x25, #2, MUL VL]\n" + ".inst 0x648aa614 // bfcvtnt z20.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x20, #2, MUL VL]\n" + ".inst 0x648aa613 // bfcvtnt z19.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x25, #3, MUL VL]\n" + "addvl x25, x25, #4\n" + ".inst 0x648aa612 // bfcvtnt z18.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x20, #3, MUL VL]\n" + "st1h { z24.h }, p1, [x22]\n" + "addvl x20, x20, #4\n" + ".inst 0x648aa611 // bfcvtnt z17.h, p1/M, z16.s\n" + "st1h { z23.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z22.h }, p1, [x22]\n" + "st1h { z21.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z20.h }, p1, [x22]\n" + "st1h { z19.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z18.h }, p1, [x22]\n" + "st1h { z17.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x23, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.s, XZR, x23\n" + "ld1w { z16.s }, p0/Z, [x26]\n" + ".inst 0x658aa612 // bfcvt z18.h, p1/M, z16.s\n" + "decw x23\n" + "ld1w { z16.s }, p0/Z, [x24]\n" + ".inst 0x658aa611 // bfcvt z17.h, p1/M, z16.s\n" + "cmp x23, #0x0\n" + "addvl x26, x26, #1\n" + "ld1w { z16.s }, p0/Z, [x25]\n" + "addvl x25, x25, #1\n" + "addvl x24, x24, #1\n" + ".inst 0x648aa612 // bfcvtnt z18.h, p1/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x20]\n" + "addvl x20, x20, #1\n" + ".inst 0x648aa611 // bfcvtnt z17.h, p1/M, z16.s\n" + "st1h { z18.h }, p1, [x22]\n" + "st1h { z17.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #2\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #2\n" + "add %x[in], x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1w { z16.s }, p1/Z, [x26]\n" + ".inst 0x658aa614 // bfcvt z20.h, p1/M, z16.s\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1w { z16.s }, p1/Z, [x26, #1, MUL VL]\n" + ".inst 0x658aa613 // bfcvt z19.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x26, #2, MUL VL]\n" + ".inst 0x658aa612 // bfcvt z18.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x26, #3, MUL VL]\n" + ".inst 0x658aa611 // bfcvt z17.h, p1/M, z16.s\n" + "addvl x26, x26, #4\n" + "ld1w { z16.s }, p1/Z, [x25]\n" + ".inst 0x648aa614 // bfcvtnt z20.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x25, #1, MUL VL]\n" + ".inst 0x648aa613 // bfcvtnt z19.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x25, #2, MUL VL]\n" + ".inst 0x648aa612 // bfcvtnt z18.h, p1/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x25, #3, MUL VL]\n" + "st1h { z20.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "addvl x25, x25, #4\n" + "st1h { z19.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + ".inst 0x648aa611 // bfcvtnt z17.h, p1/M, z16.s\n" + "st1h { z18.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z17.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "whilelt p0.s, XZR, x21\n" + "ld1w { z16.s }, p0/Z, [x26]\n" + ".inst 0x658aa611 // bfcvt z17.h, p1/M, z16.s\n" + "decw x21\n" + "ld1w { z16.s }, p0/Z, [x25]\n" + "cmp x21, #0x0\n" + "addvl x26, x26, #1\n" + ".inst 0x648aa611 // bfcvtnt z17.h, p1/M, z16.s\n" + "addvl x25, x25, #1\n" + "st1h { z17.h }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #1\n" + "bge 7b\n" + "12:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace +template<> +void Transform<1, 2, true, VLType::SME>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_1VL_2x2_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL.hpp new file mode 100644 index 0000000000..f827197ab7 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL.hpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_2VL(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 2 * height * sme::get_vector_length<uint8_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x4\n" + "ptrue p2.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[width]\n" + "cnth x20, ALL, MUL #4\n" + "add x21, x24, %x[in_stride]\n" + "cmp x23, x20\n" + "add %x[in], x21, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "sub x23, x23, x20\n" + "ld1h { z31.h }, p2/Z, [x26]\n" + "cmp x23, x20\n" + "ld1h { z30.h }, p2/Z, [x26, #1, MUL VL]\n" + "ld1h { z29.h }, p2/Z, [x26, #2, MUL VL]\n" + "ld1h { z28.h }, p2/Z, [x26, #3, MUL VL]\n" + "addvl x26, x26, #4\n" + "ld1h { z27.h }, p2/Z, [x25]\n" + "ld1h { z26.h }, p2/Z, [x25, #1, MUL VL]\n" + "ld1h { z25.h }, p2/Z, [x25, #2, MUL VL]\n" + "ld1h { z24.h }, p2/Z, [x25, #3, MUL VL]\n" + "addvl x25, x25, #4\n" + "ld1h { z23.h }, p2/Z, [x24]\n" + "ld1h { z22.h }, p2/Z, [x24, #1, MUL VL]\n" + "ld1h { z21.h }, p2/Z, [x24, #2, MUL VL]\n" + "ld1h { z20.h }, p2/Z, [x24, #3, MUL VL]\n" + "addvl x24, x24, #4\n" + "ld1h { z19.h }, p2/Z, [x21]\n" + "ld1h { z18.h }, p2/Z, [x21, #1, MUL VL]\n" + "ld1h { z17.h }, p2/Z, [x21, #2, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x21, #3, MUL VL]\n" + "st1h { z31.h }, p2, [x22]\n" + "addvl x21, x21, #4\n" + "st1h { z30.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z27.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z26.h }, p2, [x22, #3, MUL VL]\n" + "st1h { z23.h }, p2, [x22, #4, MUL VL]\n" + "st1h { z22.h }, p2, [x22, #5, MUL VL]\n" + "st1h { z19.h }, p2, [x22, #6, MUL VL]\n" + "st1h { z18.h }, p2, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z29.h }, p2, [x22]\n" + "st1h { z28.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z25.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z24.h }, p2, [x22, #3, MUL VL]\n" + "st1h { z21.h }, p2, [x22, #4, MUL VL]\n" + "st1h { z20.h }, p2, [x22, #5, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #6, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x23, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x23\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z23.h }, p1/Z, [x26]\n" + "dech x20\n" + "dech x23, ALL, MUL #2\n" + "ld1h { z22.h }, p1/Z, [x25]\n" + "whilelt p0.h, XZR, x20\n" + "cmp x23, #0x0\n" + "ld1h { z21.h }, p0/Z, [x26, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "ld1h { z20.h }, p0/Z, [x25, #1, MUL VL]\n" + "addvl x25, x25, #2\n" + "ld1h { z19.h }, p1/Z, [x24]\n" + "ld1h { z18.h }, p0/Z, [x24, #1, MUL VL]\n" + "addvl x24, x24, #2\n" + "ld1h { z17.h }, p1/Z, [x21]\n" + "ld1h { z16.h }, p0/Z, [x21, #1, MUL VL]\n" + "addvl x21, x21, #2\n" + "st1h { z23.h }, p2, [x22]\n" + "st1h { z21.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z20.h }, p2, [x22, #3, MUL VL]\n" + "st1h { z19.h }, p2, [x22, #4, MUL VL]\n" + "st1h { z18.h }, p2, [x22, #5, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #6, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #8\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #4\n" + "mov x26, %x[in]\n" + "cmp x21, x20\n" + "add %x[in], x26, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "sub x21, x21, x20\n" + "ld1h { z19.h }, p2/Z, [x26]\n" + "cmp x21, x20\n" + "ld1h { z18.h }, p2/Z, [x26, #1, MUL VL]\n" + "ld1h { z17.h }, p2/Z, [x26, #2, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x26, #3, MUL VL]\n" + "st1h { z19.h }, p2, [x22]\n" + "addvl x26, x26, #4\n" + "st1h { z18.h }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z17.h }, p2, [x22]\n" + "st1h { z16.h }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z17.h }, p0/Z, [x26]\n" + "dech x20\n" + "dech x21, ALL, MUL #2\n" + "whilelt p0.h, XZR, x20\n" + "cmp x21, #0x0\n" + "ld1h { z16.h }, p0/Z, [x26, #1, MUL VL]\n" + "st1h { z17.h }, p2, [x22]\n" + "addvl x26, x26, #2\n" + "st1h { z16.h }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 7b\n" + "12:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<2, 1, true, VLType::SME>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_2VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 2, + stride * sizeof(float), + (kmax-k0) + ); +} + +template<> +void Transform<2, 1, true, VLType::SME>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_2VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + +template<> +void Transform<2, 1, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_2VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_1x4.hpp new file mode 100644 index 0000000000..c471d66e17 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_1x4.hpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_2VL_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 2 * roundup<size_t>(height, 4) * sme::get_vector_length<uint32_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p1.b\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x23, %x[in_stride]\n" + "csel x23, x23, %x[pad_row], GT\n" + "csel x24, x24, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "mov x22, %x[width]\n" + "cntb x21\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x22, x21\n" + "mov x20, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1b { z17.b }, p1/Z, [x26]\n" + "sub x22, x22, x21\n" + "cmp x22, x21\n" + "ld1b { z18.b }, p1/Z, [x25]\n" + "addvl x26, x26, #1\n" + "addvl x25, x25, #1\n" + "ld1b { z16.b }, p1/Z, [x24]\n" + "zip1 z20.b, z17.b, z16.b\n" + "zip2 z19.b, z17.b, z16.b\n" + "addvl x24, x24, #1\n" + "ld1b { z17.b }, p1/Z, [x23]\n" + "zip1 z16.b, z18.b, z17.b\n" + "zip2 z18.b, z18.b, z17.b\n" + "addvl x23, x23, #1\n" + "zip1 z17.b, z20.b, z16.b\n" + "zip2 z16.b, z20.b, z16.b\n" + "st1b { z17.b }, p1, [x20]\n" + "st1b { z16.b }, p1, [x20, #1, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "zip1 z17.b, z19.b, z18.b\n" + "zip2 z16.b, z19.b, z18.b\n" + "st1b { z17.b }, p1, [x20]\n" + "st1b { z16.b }, p1, [x20, #1, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x22, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.b, XZR, x22\n" + "ld1b { z18.b }, p0/Z, [x26]\n" + "decw x22, ALL, MUL #2\n" + "ld1b { z17.b }, p0/Z, [x25]\n" + "cmp x22, #0x0\n" + "incd x26, ALL, MUL #4\n" + "ld1b { z16.b }, p0/Z, [x24]\n" + "zip1 z18.b, z18.b, z16.b\n" + "incd x25, ALL, MUL #4\n" + "incd x24, ALL, MUL #4\n" + "ld1b { z16.b }, p0/Z, [x23]\n" + "zip1 z16.b, z17.b, z16.b\n" + "incd x23, ALL, MUL #4\n" + "zip1 z17.b, z18.b, z16.b\n" + "zip2 z16.b, z18.b, z16.b\n" + "st1b { z17.b }, p1, [x20]\n" + "st1b { z16.b }, p1, [x20, #1, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<2, 4, true, VLType::SME>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_2VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<2, 4, true, VLType::SME>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_2VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_2x2.hpp new file mode 100644 index 0000000000..5f967fa615 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_2x2.hpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_2VL_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 2 * roundup<size_t>(height, 2) * sme::get_vector_length<uint16_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x4\n" + "ptrue p1.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[width]\n" + "cnth x21, ALL, MUL #2\n" + "add x20, x24, %x[in_stride]\n" + "cmp x23, x21\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1h { z17.h }, p1/Z, [x26]\n" + "sub x23, x23, x21\n" + "cmp x23, x21\n" + "ld1h { z16.h }, p1/Z, [x25]\n" + "zip1 z24.h, z17.h, z16.h\n" + "zip2 z23.h, z17.h, z16.h\n" + "ld1h { z17.h }, p1/Z, [x24]\n" + "ld1h { z16.h }, p1/Z, [x20]\n" + "zip1 z22.h, z17.h, z16.h\n" + "zip2 z21.h, z17.h, z16.h\n" + "ld1h { z17.h }, p1/Z, [x26, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "ld1h { z16.h }, p1/Z, [x25, #1, MUL VL]\n" + "addvl x25, x25, #2\n" + "zip1 z20.h, z17.h, z16.h\n" + "zip2 z19.h, z17.h, z16.h\n" + "ld1h { z18.h }, p1/Z, [x24, #1, MUL VL]\n" + "addvl x24, x24, #2\n" + "ld1h { z16.h }, p1/Z, [x20, #1, MUL VL]\n" + "st1h { z24.h }, p1, [x22]\n" + "addvl x20, x20, #2\n" + "zip1 z17.h, z18.h, z16.h\n" + "st1h { z23.h }, p1, [x22, #1, MUL VL]\n" + "zip2 z16.h, z18.h, z16.h\n" + "st1h { z22.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z21.h }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z20.h }, p1, [x22]\n" + "st1h { z19.h }, p1, [x22, #1, MUL VL]\n" + "st1h { z17.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z16.h }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x23, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.h, XZR, x23\n" + "ld1h { z17.h }, p0/Z, [x26]\n" + "decw x23, ALL, MUL #2\n" + "ld1h { z16.h }, p0/Z, [x25]\n" + "cmp x23, #0x0\n" + "addvl x26, x26, #1\n" + "zip1 z20.h, z17.h, z16.h\n" + "ld1h { z19.h }, p0/Z, [x24]\n" + "addvl x25, x25, #1\n" + "addvl x24, x24, #1\n" + "zip2 z18.h, z17.h, z16.h\n" + "ld1h { z16.h }, p0/Z, [x20]\n" + "addvl x20, x20, #1\n" + "zip1 z17.h, z19.h, z16.h\n" + "zip2 z16.h, z19.h, z16.h\n" + "st1h { z20.h }, p1, [x22]\n" + "st1h { z18.h }, p1, [x22, #1, MUL VL]\n" + "st1h { z17.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z16.h }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #4\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #2\n" + "add %x[in], x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1h { z18.h }, p1/Z, [x26]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1h { z16.h }, p1/Z, [x25]\n" + "zip1 z17.h, z18.h, z16.h\n" + "zip2 z19.h, z18.h, z16.h\n" + "ld1h { z18.h }, p1/Z, [x26, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "ld1h { z16.h }, p1/Z, [x25, #1, MUL VL]\n" + "st1h { z17.h }, p1, [x22]\n" + "addvl x25, x25, #2\n" + "zip1 z17.h, z18.h, z16.h\n" + "st1h { z19.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z16.h, z18.h, z16.h\n" + "st1h { z17.h }, p1, [x22]\n" + "st1h { z16.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "whilelt p0.h, XZR, x21\n" + "ld1h { z18.h }, p0/Z, [x26]\n" + "decw x21, ALL, MUL #2\n" + "ld1h { z16.h }, p0/Z, [x25]\n" + "cmp x21, #0x0\n" + "addvl x26, x26, #1\n" + "zip1 z17.h, z18.h, z16.h\n" + "addvl x25, x25, #1\n" + "zip2 z16.h, z18.h, z16.h\n" + "st1h { z17.h }, p1, [x22]\n" + "st1h { z16.h }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 7b\n" + "12:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<2, 2, true, VLType::SME>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_2VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + +template<> +void Transform<2, 2, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_2VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_2x2_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_2x2_fp32bf16.hpp new file mode 100644 index 0000000000..f22b833821 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_2VL_2x2_fp32bf16.hpp @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_2VL_2x2_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 2 * roundup<size_t>(height, 2) * sme::get_vector_length<uint16_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x4\n" + "ptrue p2.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[width]\n" + "cnth x20, ALL, MUL #2\n" + "add x21, x24, %x[in_stride]\n" + "cmp x23, x20\n" + "add %x[in], x21, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1w { z16.s }, p2/Z, [x26]\n" + ".inst 0x658aaa18 // bfcvt z24.h, p2/M, z16.s\n" + "sub x23, x23, x20\n" + "cmp x23, x20\n" + "ld1w { z16.s }, p2/Z, [x26, #1, MUL VL]\n" + ".inst 0x658aaa17 // bfcvt z23.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x24]\n" + ".inst 0x658aaa16 // bfcvt z22.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x24, #1, MUL VL]\n" + ".inst 0x658aaa15 // bfcvt z21.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x26, #2, MUL VL]\n" + ".inst 0x658aaa14 // bfcvt z20.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x26, #3, MUL VL]\n" + ".inst 0x658aaa13 // bfcvt z19.h, p2/M, z16.s\n" + "addvl x26, x26, #4\n" + "ld1w { z16.s }, p2/Z, [x24, #2, MUL VL]\n" + ".inst 0x658aaa12 // bfcvt z18.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x24, #3, MUL VL]\n" + ".inst 0x658aaa11 // bfcvt z17.h, p2/M, z16.s\n" + "addvl x24, x24, #4\n" + "ld1w { z16.s }, p2/Z, [x25]\n" + ".inst 0x648aaa18 // bfcvtnt z24.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x25, #1, MUL VL]\n" + ".inst 0x648aaa17 // bfcvtnt z23.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x21]\n" + ".inst 0x648aaa16 // bfcvtnt z22.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x21, #1, MUL VL]\n" + ".inst 0x648aaa15 // bfcvtnt z21.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x25, #2, MUL VL]\n" + ".inst 0x648aaa14 // bfcvtnt z20.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x25, #3, MUL VL]\n" + "addvl x25, x25, #4\n" + ".inst 0x648aaa13 // bfcvtnt z19.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x21, #2, MUL VL]\n" + ".inst 0x648aaa12 // bfcvtnt z18.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x21, #3, MUL VL]\n" + "st1h { z24.h }, p2, [x22]\n" + "addvl x21, x21, #4\n" + ".inst 0x648aaa11 // bfcvtnt z17.h, p2/M, z16.s\n" + "st1h { z23.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z21.h }, p2, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z20.h }, p2, [x22]\n" + "st1h { z19.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z18.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x23, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x23\n" + "whilelt p1.s, XZR, x20\n" + "ld1w { z16.s }, p1/Z, [x26]\n" + ".inst 0x658aaa14 // bfcvt z20.h, p2/M, z16.s\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z16.s }, p0/Z, [x26, #1, MUL VL]\n" + ".inst 0x658aaa13 // bfcvt z19.h, p2/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x24]\n" + ".inst 0x658aaa12 // bfcvt z18.h, p2/M, z16.s\n" + "decw x23, ALL, MUL #2\n" + "cmp x23, #0x0\n" + "ld1w { z16.s }, p0/Z, [x24, #1, MUL VL]\n" + ".inst 0x658aaa11 // bfcvt z17.h, p2/M, z16.s\n" + "addvl x26, x26, #2\n" + "addvl x24, x24, #2\n" + "ld1w { z16.s }, p1/Z, [x25]\n" + ".inst 0x648aaa14 // bfcvtnt z20.h, p2/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x25, #1, MUL VL]\n" + "addvl x25, x25, #2\n" + ".inst 0x648aaa13 // bfcvtnt z19.h, p2/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x21]\n" + ".inst 0x648aaa12 // bfcvtnt z18.h, p2/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x21, #1, MUL VL]\n" + "addvl x21, x21, #2\n" + ".inst 0x648aaa11 // bfcvtnt z17.h, p2/M, z16.s\n" + "st1h { z20.h }, p2, [x22]\n" + "st1h { z19.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z18.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #4\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #2\n" + "add %x[in], x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1w { z16.s }, p2/Z, [x26]\n" + ".inst 0x658aaa14 // bfcvt z20.h, p2/M, z16.s\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1w { z16.s }, p2/Z, [x26, #1, MUL VL]\n" + ".inst 0x658aaa13 // bfcvt z19.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x26, #2, MUL VL]\n" + ".inst 0x658aaa12 // bfcvt z18.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x26, #3, MUL VL]\n" + ".inst 0x658aaa11 // bfcvt z17.h, p2/M, z16.s\n" + "addvl x26, x26, #4\n" + "ld1w { z16.s }, p2/Z, [x25]\n" + ".inst 0x648aaa14 // bfcvtnt z20.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x25, #1, MUL VL]\n" + ".inst 0x648aaa13 // bfcvtnt z19.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x25, #2, MUL VL]\n" + ".inst 0x648aaa12 // bfcvtnt z18.h, p2/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x25, #3, MUL VL]\n" + "st1h { z20.h }, p2, [x22]\n" + "addvl x25, x25, #4\n" + ".inst 0x648aaa11 // bfcvtnt z17.h, p2/M, z16.s\n" + "st1h { z19.h }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z18.h }, p2, [x22]\n" + "st1h { z17.h }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p1.s, XZR, x20\n" + "ld1w { z16.s }, p1/Z, [x26]\n" + ".inst 0x658aaa12 // bfcvt z18.h, p2/M, z16.s\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z16.s }, p0/Z, [x26, #1, MUL VL]\n" + ".inst 0x658aaa11 // bfcvt z17.h, p2/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x25]\n" + "decw x21, ALL, MUL #2\n" + "cmp x21, #0x0\n" + ".inst 0x648aaa12 // bfcvtnt z18.h, p2/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x25, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "addvl x25, x25, #2\n" + ".inst 0x648aaa11 // bfcvtnt z17.h, p2/M, z16.s\n" + "st1h { z18.h }, p2, [x22]\n" + "st1h { z17.h }, p2, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 7b\n" + "12:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace +template<> +void Transform<2, 2, true, VLType::SME>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_2VL_2x2_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL.hpp new file mode 100644 index 0000000000..14636e3218 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL.hpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_4VL(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 4 * height * sme::get_vector_length<uint8_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x4\n" + "ptrue p4.b\n" + "blt 4f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "mov x21, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x20, x21\n" + "whilelt p3.h, XZR, x20\n" + "ld1h { z31.h }, p3/Z, [x26]\n" + "dech x20\n" + "whilelt p2.h, XZR, x20\n" + "ld1h { z30.h }, p2/Z, [x26, #1, MUL VL]\n" + "dech x20\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z29.h }, p1/Z, [x26, #2, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z28.h }, p0/Z, [x26, #3, MUL VL]\n" + "mov x20, x22\n" + "dech x21, ALL, MUL #4\n" + "ld1h { z27.h }, p3/Z, [x25]\n" + "ld1h { z26.h }, p2/Z, [x25, #1, MUL VL]\n" + "cmp x21, #0x0\n" + "addvl x26, x26, #4\n" + "ld1h { z25.h }, p1/Z, [x25, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "ld1h { z24.h }, p0/Z, [x25, #3, MUL VL]\n" + "addvl x25, x25, #4\n" + "ld1h { z23.h }, p3/Z, [x24]\n" + "ld1h { z22.h }, p2/Z, [x24, #1, MUL VL]\n" + "ld1h { z21.h }, p1/Z, [x24, #2, MUL VL]\n" + "ld1h { z20.h }, p0/Z, [x24, #3, MUL VL]\n" + "addvl x24, x24, #4\n" + "ld1h { z19.h }, p3/Z, [x23]\n" + "ld1h { z18.h }, p2/Z, [x23, #1, MUL VL]\n" + "ld1h { z17.h }, p1/Z, [x23, #2, MUL VL]\n" + "ld1h { z16.h }, p0/Z, [x23, #3, MUL VL]\n" + "st1h { z31.h }, p4, [x20]\n" + "addvl x23, x23, #4\n" + "st1h { z30.h }, p4, [x20, #1, MUL VL]\n" + "st1h { z29.h }, p4, [x20, #2, MUL VL]\n" + "st1h { z28.h }, p4, [x20, #3, MUL VL]\n" + "st1h { z27.h }, p4, [x20, #4, MUL VL]\n" + "st1h { z26.h }, p4, [x20, #5, MUL VL]\n" + "st1h { z25.h }, p4, [x20, #6, MUL VL]\n" + "st1h { z24.h }, p4, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1h { z23.h }, p4, [x20, #-8, MUL VL]\n" + "st1h { z22.h }, p4, [x20, #-7, MUL VL]\n" + "st1h { z21.h }, p4, [x20, #-6, MUL VL]\n" + "st1h { z20.h }, p4, [x20, #-5, MUL VL]\n" + "st1h { z19.h }, p4, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p4, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p4, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p4, [x20, #-1, MUL VL]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + "cbz %x[height], 8f\n" + "4:" // Main loop skip + "5:" // Tail row loop: Head + "mov x26, %x[in]\n" + "add %x[in], x26, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "mov x21, %x[width]\n" + "6:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z19.h }, p0/Z, [x26]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z18.h }, p0/Z, [x26, #1, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z17.h }, p0/Z, [x26, #2, MUL VL]\n" + "dech x20\n" + "dech x21, ALL, MUL #4\n" + "whilelt p0.h, XZR, x20\n" + "cmp x21, #0x0\n" + "ld1h { z16.h }, p0/Z, [x26, #3, MUL VL]\n" + "st1h { z19.h }, p4, [x22]\n" + "addvl x26, x26, #4\n" + "st1h { z18.h }, p4, [x22, #1, MUL VL]\n" + "st1h { z17.h }, p4, [x22, #2, MUL VL]\n" + "st1h { z16.h }, p4, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 6b\n" + "7:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #4\n" + "bge 5b\n" + "8:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<4, 1, true, VLType::SME>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_4VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 2, + stride * sizeof(float), + (kmax-k0) + ); +} + +template<> +void Transform<4, 1, true, VLType::SME>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_4VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + +template<> +void Transform<4, 1, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_4VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_1x4.hpp new file mode 100644 index 0000000000..2d46a481f3 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_1x4.hpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_4VL_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 4 * roundup<size_t>(height, 4) * sme::get_vector_length<uint32_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p1.b\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add x22, x23, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x22, %x[in_stride]\n" + "csel x22, x22, %x[pad_row], GT\n" + "csel x23, x23, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "mov x21, %x[out]\n" + "csel x24, x24, %x[pad_row], GT\n" + "sub %x[height], %x[height], #0x4\n" + "mov x20, %x[width]\n" + "2:" // Main row loop: Column loop + "whilelt p0.b, XZR, x20\n" + "ld1b { z17.b }, p0/Z, [x25]\n" + "decw x20, ALL, MUL #4\n" + "ld1b { z19.b }, p0/Z, [x24]\n" + "cmp x20, #0x0\n" + "addvl x25, x25, #1\n" + "ld1b { z16.b }, p0/Z, [x23]\n" + "zip1 z18.b, z17.b, z16.b\n" + "zip2 z20.b, z17.b, z16.b\n" + "addvl x24, x24, #1\n" + "ld1b { z16.b }, p0/Z, [x22]\n" + "zip1 z17.b, z19.b, z16.b\n" + "zip2 z19.b, z19.b, z16.b\n" + "addvl x23, x23, #1\n" + "addvl x22, x22, #1\n" + "zip1 z16.b, z18.b, z17.b\n" + "zip2 z18.b, z18.b, z17.b\n" + "st1b { z16.b }, p1, [x21]\n" + "zip1 z17.b, z20.b, z19.b\n" + "zip2 z16.b, z20.b, z19.b\n" + "st1b { z18.b }, p1, [x21, #1, MUL VL]\n" + "st1b { z17.b }, p1, [x21, #2, MUL VL]\n" + "st1b { z16.b }, p1, [x21, #3, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #4\n" + "bge 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<4, 4, true, VLType::SME>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_4VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<4, 4, true, VLType::SME>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_4VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_2x2.hpp new file mode 100644 index 0000000000..002a12479a --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_2x2.hpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_4VL_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 4 * roundup<size_t>(height, 2) * sme::get_vector_length<uint16_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x4\n" + "ptrue p2.b\n" + "blt 4f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "mov x21, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x20, x21\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z19.h }, p1/Z, [x26]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z18.h }, p0/Z, [x26, #1, MUL VL]\n" + "ld1h { z17.h }, p1/Z, [x25]\n" + "decw x21, ALL, MUL #4\n" + "cmp x21, #0x0\n" + "zip1 z24.h, z19.h, z17.h\n" + "ld1h { z16.h }, p0/Z, [x25, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "addvl x25, x25, #2\n" + "zip2 z23.h, z19.h, z17.h\n" + "ld1h { z17.h }, p1/Z, [x24]\n" + "zip1 z22.h, z18.h, z16.h\n" + "zip2 z21.h, z18.h, z16.h\n" + "ld1h { z20.h }, p0/Z, [x24, #1, MUL VL]\n" + "addvl x24, x24, #2\n" + "ld1h { z16.h }, p1/Z, [x23]\n" + "zip1 z19.h, z17.h, z16.h\n" + "zip2 z18.h, z17.h, z16.h\n" + "ld1h { z16.h }, p0/Z, [x23, #1, MUL VL]\n" + "addvl x23, x23, #2\n" + "zip1 z17.h, z20.h, z16.h\n" + "zip2 z16.h, z20.h, z16.h\n" + "st1h { z24.h }, p2, [x22]\n" + "st1h { z23.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z21.h }, p2, [x22, #3, MUL VL]\n" + "st1h { z19.h }, p2, [x22, #4, MUL VL]\n" + "st1h { z18.h }, p2, [x22, #5, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #6, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #8\n" + "bge 1b\n" + "cbz %x[height], 8f\n" + "4:" // Main loop skip + "5:" // Tail row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x22, %x[out]\n" + "csel x25, x25, %x[pad_row], GT\n" + "sub %x[height], %x[height], #0x2\n" + "mov x21, %x[width]\n" + "6:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z18.h }, p1/Z, [x26]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z20.h }, p0/Z, [x26, #1, MUL VL]\n" + "ld1h { z17.h }, p1/Z, [x25]\n" + "decw x21, ALL, MUL #4\n" + "cmp x21, #0x0\n" + "zip1 z19.h, z18.h, z17.h\n" + "ld1h { z16.h }, p0/Z, [x25, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "addvl x25, x25, #2\n" + "zip2 z18.h, z18.h, z17.h\n" + "zip1 z17.h, z20.h, z16.h\n" + "zip2 z16.h, z20.h, z16.h\n" + "st1h { z19.h }, p2, [x22]\n" + "st1h { z18.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 6b\n" + "7:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #4\n" + "bge 5b\n" + "8:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<4, 2, true, VLType::SME>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_4VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + +template<> +void Transform<4, 2, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_4VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_2x2_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_2x2_fp32bf16.hpp new file mode 100644 index 0000000000..2a43f34f71 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_4VL_2x2_fp32bf16.hpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2022-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_4VL_2x2_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 4 * roundup<size_t>(height, 2) * sme::get_vector_length<uint16_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x4\n" + "ptrue p4.b\n" + "blt 4f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "mov x21, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x20, x21\n" + "whilelt p3.s, XZR, x20\n" + "ld1w { z16.s }, p3/Z, [x26]\n" + ".inst 0x658ab218 // bfcvt z24.h, p4/M, z16.s\n" + "decw x20\n" + "whilelt p2.s, XZR, x20\n" + "ld1w { z16.s }, p2/Z, [x26, #1, MUL VL]\n" + ".inst 0x658ab217 // bfcvt z23.h, p4/M, z16.s\n" + "decw x20\n" + "whilelt p1.s, XZR, x20\n" + "ld1w { z16.s }, p1/Z, [x26, #2, MUL VL]\n" + ".inst 0x658ab216 // bfcvt z22.h, p4/M, z16.s\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z16.s }, p0/Z, [x26, #3, MUL VL]\n" + ".inst 0x658ab215 // bfcvt z21.h, p4/M, z16.s\n" + "ld1w { z16.s }, p3/Z, [x24]\n" + ".inst 0x658ab214 // bfcvt z20.h, p4/M, z16.s\n" + "decw x21, ALL, MUL #4\n" + "cmp x21, #0x0\n" + "ld1w { z16.s }, p2/Z, [x24, #1, MUL VL]\n" + ".inst 0x658ab213 // bfcvt z19.h, p4/M, z16.s\n" + "addvl x26, x26, #4\n" + "ld1w { z16.s }, p1/Z, [x24, #2, MUL VL]\n" + ".inst 0x658ab212 // bfcvt z18.h, p4/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x24, #3, MUL VL]\n" + ".inst 0x658ab211 // bfcvt z17.h, p4/M, z16.s\n" + "addvl x24, x24, #4\n" + "ld1w { z16.s }, p3/Z, [x25]\n" + ".inst 0x648ab218 // bfcvtnt z24.h, p4/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x25, #1, MUL VL]\n" + ".inst 0x648ab217 // bfcvtnt z23.h, p4/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x25, #2, MUL VL]\n" + ".inst 0x648ab216 // bfcvtnt z22.h, p4/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x25, #3, MUL VL]\n" + "addvl x25, x25, #4\n" + ".inst 0x648ab215 // bfcvtnt z21.h, p4/M, z16.s\n" + "ld1w { z16.s }, p3/Z, [x23]\n" + ".inst 0x648ab214 // bfcvtnt z20.h, p4/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x23, #1, MUL VL]\n" + ".inst 0x648ab213 // bfcvtnt z19.h, p4/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x23, #2, MUL VL]\n" + ".inst 0x648ab212 // bfcvtnt z18.h, p4/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x23, #3, MUL VL]\n" + "addvl x23, x23, #4\n" + ".inst 0x648ab211 // bfcvtnt z17.h, p4/M, z16.s\n" + "st1h { z24.h }, p4, [x22]\n" + "st1h { z23.h }, p4, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p4, [x22, #2, MUL VL]\n" + "st1h { z21.h }, p4, [x22, #3, MUL VL]\n" + "st1h { z20.h }, p4, [x22, #4, MUL VL]\n" + "st1h { z19.h }, p4, [x22, #5, MUL VL]\n" + "st1h { z18.h }, p4, [x22, #6, MUL VL]\n" + "st1h { z17.h }, p4, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #8\n" + "bge 1b\n" + "cbz %x[height], 8f\n" + "4:" // Main loop skip + "5:" // Tail row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x22, %x[out]\n" + "csel x25, x25, %x[pad_row], GT\n" + "sub %x[height], %x[height], #0x2\n" + "mov x21, %x[width]\n" + "6:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p3.s, XZR, x20\n" + "ld1w { z16.s }, p3/Z, [x26]\n" + ".inst 0x658ab214 // bfcvt z20.h, p4/M, z16.s\n" + "decw x20\n" + "whilelt p2.s, XZR, x20\n" + "ld1w { z16.s }, p2/Z, [x26, #1, MUL VL]\n" + ".inst 0x658ab213 // bfcvt z19.h, p4/M, z16.s\n" + "decw x20\n" + "whilelt p1.s, XZR, x20\n" + "ld1w { z16.s }, p1/Z, [x26, #2, MUL VL]\n" + ".inst 0x658ab212 // bfcvt z18.h, p4/M, z16.s\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z16.s }, p0/Z, [x26, #3, MUL VL]\n" + ".inst 0x658ab211 // bfcvt z17.h, p4/M, z16.s\n" + "ld1w { z16.s }, p3/Z, [x25]\n" + "decw x21, ALL, MUL #4\n" + "cmp x21, #0x0\n" + ".inst 0x648ab214 // bfcvtnt z20.h, p4/M, z16.s\n" + "ld1w { z16.s }, p2/Z, [x25, #1, MUL VL]\n" + "addvl x26, x26, #4\n" + ".inst 0x648ab213 // bfcvtnt z19.h, p4/M, z16.s\n" + "ld1w { z16.s }, p1/Z, [x25, #2, MUL VL]\n" + ".inst 0x648ab212 // bfcvtnt z18.h, p4/M, z16.s\n" + "ld1w { z16.s }, p0/Z, [x25, #3, MUL VL]\n" + "addvl x25, x25, #4\n" + ".inst 0x648ab211 // bfcvtnt z17.h, p4/M, z16.s\n" + "st1h { z20.h }, p4, [x22]\n" + "st1h { z19.h }, p4, [x22, #1, MUL VL]\n" + "st1h { z18.h }, p4, [x22, #2, MUL VL]\n" + "st1h { z17.h }, p4, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 6b\n" + "7:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #4\n" + "bge 5b\n" + "8:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace +template<> +void Transform<4, 2, true, VLType::SME>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_4VL_2x2_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_8VL.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_8VL.hpp new file mode 100644 index 0000000000..be9ad666a9 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_8VL.hpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_8VL(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 8 * height * sme::get_vector_length<uint8_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cmp %x[height], #0x2\n" + "ptrue p7.b\n" + "blt 4f\n" + "1:" // Main row loop: Head + "mov x25, %x[in]\n" + "add x24, x25, %x[in_stride]\n" + "add %x[in], x24, %x[in_stride]\n" + "mov x23, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "mov x22, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x21, x22\n" + "whilelt p0.h, XZR, x21\n" + "ld1h { z31.h }, p0/Z, [x25]\n" + "dech x21\n" + "whilelt p6.h, XZR, x21\n" + "ld1h { z30.h }, p6/Z, [x25, #1, MUL VL]\n" + "dech x21\n" + "whilelt p5.h, XZR, x21\n" + "ld1h { z29.h }, p5/Z, [x25, #2, MUL VL]\n" + "dech x21\n" + "whilelt p4.h, XZR, x21\n" + "ld1h { z28.h }, p4/Z, [x25, #3, MUL VL]\n" + "dech x21\n" + "whilelt p3.h, XZR, x21\n" + "ld1h { z27.h }, p3/Z, [x25, #4, MUL VL]\n" + "dech x21\n" + "whilelt p2.h, XZR, x21\n" + "ld1h { z26.h }, p2/Z, [x25, #5, MUL VL]\n" + "dech x21\n" + "whilelt p1.h, XZR, x21\n" + "ld1h { z25.h }, p1/Z, [x25, #6, MUL VL]\n" + "dech x21\n" + "mov x20, x23\n" + "ld1h { z24.h }, p0/Z, [x24]\n" + "whilelt p0.h, XZR, x21\n" + "dech x22, ALL, MUL #8\n" + "ld1h { z23.h }, p0/Z, [x25, #7, MUL VL]\n" + "ld1h { z22.h }, p6/Z, [x24, #1, MUL VL]\n" + "cmp x22, #0x0\n" + "addvl x25, x25, #8\n" + "ld1h { z21.h }, p5/Z, [x24, #2, MUL VL]\n" + "add x23, x23, %x[out_stride]\n" + "ld1h { z20.h }, p4/Z, [x24, #3, MUL VL]\n" + "ld1h { z19.h }, p3/Z, [x24, #4, MUL VL]\n" + "ld1h { z18.h }, p2/Z, [x24, #5, MUL VL]\n" + "ld1h { z17.h }, p1/Z, [x24, #6, MUL VL]\n" + "ld1h { z16.h }, p0/Z, [x24, #7, MUL VL]\n" + "st1h { z31.h }, p7, [x20]\n" + "addvl x24, x24, #8\n" + "st1h { z30.h }, p7, [x20, #1, MUL VL]\n" + "st1h { z29.h }, p7, [x20, #2, MUL VL]\n" + "st1h { z28.h }, p7, [x20, #3, MUL VL]\n" + "st1h { z27.h }, p7, [x20, #4, MUL VL]\n" + "st1h { z26.h }, p7, [x20, #5, MUL VL]\n" + "st1h { z25.h }, p7, [x20, #6, MUL VL]\n" + "st1h { z23.h }, p7, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1h { z24.h }, p7, [x20, #-8, MUL VL]\n" + "st1h { z22.h }, p7, [x20, #-7, MUL VL]\n" + "st1h { z21.h }, p7, [x20, #-6, MUL VL]\n" + "st1h { z20.h }, p7, [x20, #-5, MUL VL]\n" + "st1h { z19.h }, p7, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p7, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p7, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p7, [x20, #-1, MUL VL]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x2\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + "cbz %x[height], 8f\n" + "4:" // Main loop skip + "5:" // Tail row loop: Head + "mov x25, %x[in]\n" + "add %x[in], x25, %x[in_stride]\n" + "mov x23, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "mov x21, %x[width]\n" + "6:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z23.h }, p0/Z, [x25]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z22.h }, p0/Z, [x25, #1, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z21.h }, p0/Z, [x25, #2, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z20.h }, p0/Z, [x25, #3, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z19.h }, p0/Z, [x25, #4, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z18.h }, p0/Z, [x25, #5, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z17.h }, p0/Z, [x25, #6, MUL VL]\n" + "dech x20\n" + "dech x21, ALL, MUL #8\n" + "whilelt p0.h, XZR, x20\n" + "cmp x21, #0x0\n" + "ld1h { z16.h }, p0/Z, [x25, #7, MUL VL]\n" + "st1h { z23.h }, p7, [x23]\n" + "addvl x25, x25, #8\n" + "st1h { z22.h }, p7, [x23, #1, MUL VL]\n" + "st1h { z21.h }, p7, [x23, #2, MUL VL]\n" + "st1h { z20.h }, p7, [x23, #3, MUL VL]\n" + "st1h { z19.h }, p7, [x23, #4, MUL VL]\n" + "st1h { z18.h }, p7, [x23, #5, MUL VL]\n" + "st1h { z17.h }, p7, [x23, #6, MUL VL]\n" + "st1h { z16.h }, p7, [x23, #7, MUL VL]\n" + "add x23, x23, %x[out_stride]\n" + "bgt 6b\n" + "7:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #8\n" + "bge 5b\n" + "8:" // Done + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<8, 1, true, VLType::SME>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_8VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 2, + stride * sizeof(float), + (kmax-k0) + ); +} + +template<> +void Transform<8, 1, true, VLType::SME>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_8VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + +template<> +void Transform<8, 1, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_8VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_8VL_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_8VL_1x4.hpp new file mode 100644 index 0000000000..45d2e24258 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_8VL_1x4.hpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_8VL_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 8 * roundup<size_t>(height, 4) * sme::get_vector_length<uint32_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p2.b\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x23, %x[in_stride]\n" + "csel x23, x23, %x[pad_row], GT\n" + "csel x24, x24, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "mov x22, %x[out]\n" + "csel x25, x25, %x[pad_row], GT\n" + "sub %x[height], %x[height], #0x4\n" + "mov x21, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x20, x21\n" + "whilelt p1.b, XZR, x20\n" + "ld1b { z19.b }, p1/Z, [x26]\n" + "decb x20\n" + "whilelt p0.b, XZR, x20\n" + "ld1b { z17.b }, p0/Z, [x26, #1, MUL VL]\n" + "ld1b { z18.b }, p1/Z, [x25]\n" + "decw x21, ALL, MUL #8\n" + "cmp x21, #0x0\n" + "ld1b { z21.b }, p0/Z, [x25, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "addvl x25, x25, #2\n" + "ld1b { z16.b }, p1/Z, [x24]\n" + "zip1 z24.b, z19.b, z16.b\n" + "zip2 z20.b, z19.b, z16.b\n" + "ld1b { z16.b }, p0/Z, [x24, #1, MUL VL]\n" + "zip1 z23.b, z17.b, z16.b\n" + "zip2 z22.b, z17.b, z16.b\n" + "addvl x24, x24, #2\n" + "ld1b { z16.b }, p1/Z, [x23]\n" + "zip1 z17.b, z18.b, z16.b\n" + "zip2 z19.b, z18.b, z16.b\n" + "ld1b { z16.b }, p0/Z, [x23, #1, MUL VL]\n" + "zip1 z18.b, z21.b, z16.b\n" + "zip2 z21.b, z21.b, z16.b\n" + "addvl x23, x23, #2\n" + "zip1 z16.b, z24.b, z17.b\n" + "zip2 z17.b, z24.b, z17.b\n" + "st1b { z16.b }, p2, [x22]\n" + "zip1 z16.b, z20.b, z19.b\n" + "zip2 z20.b, z20.b, z19.b\n" + "st1b { z17.b }, p2, [x22, #1, MUL VL]\n" + "zip1 z19.b, z23.b, z18.b\n" + "zip2 z18.b, z23.b, z18.b\n" + "st1b { z16.b }, p2, [x22, #2, MUL VL]\n" + "zip1 z17.b, z22.b, z21.b\n" + "zip2 z16.b, z22.b, z21.b\n" + "st1b { z20.b }, p2, [x22, #3, MUL VL]\n" + "st1b { z19.b }, p2, [x22, #4, MUL VL]\n" + "st1b { z18.b }, p2, [x22, #5, MUL VL]\n" + "st1b { z17.b }, p2, [x22, #6, MUL VL]\n" + "st1b { z16.b }, p2, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #8\n" + "bge 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<8, 4, true, VLType::SME>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_8VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<8, 4, true, VLType::SME>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_8VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_8VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_8VL_2x2.hpp new file mode 100644 index 0000000000..ec7c415e27 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sme_transpose_interleave_8VL_2x2.hpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SME) + +namespace { + +void sme_transpose_interleave_8VL_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 8 * roundup<size_t>(height, 2) * sme::get_vector_length<uint16_t>(); + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p4.b\n" + "1:" // Main row loop: Head + "mov x24, %x[in]\n" + "add x23, x24, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "csel x23, x23, %x[pad_row], GT\n" + "sub %x[height], %x[height], #0x2\n" + "mov x21, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x20, x21\n" + "whilelt p3.h, XZR, x20\n" + "ld1h { z20.h }, p3/Z, [x24]\n" + "dech x20\n" + "whilelt p2.h, XZR, x20\n" + "ld1h { z19.h }, p2/Z, [x24, #1, MUL VL]\n" + "dech x20\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z18.h }, p1/Z, [x24, #2, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z24.h }, p0/Z, [x24, #3, MUL VL]\n" + "ld1h { z17.h }, p3/Z, [x23]\n" + "decw x21, ALL, MUL #8\n" + "cmp x21, #0x0\n" + "zip1 z23.h, z20.h, z17.h\n" + "ld1h { z16.h }, p2/Z, [x23, #1, MUL VL]\n" + "addvl x24, x24, #4\n" + "zip2 z22.h, z20.h, z17.h\n" + "zip1 z21.h, z19.h, z16.h\n" + "ld1h { z17.h }, p1/Z, [x23, #2, MUL VL]\n" + "zip2 z20.h, z19.h, z16.h\n" + "zip1 z19.h, z18.h, z17.h\n" + "ld1h { z16.h }, p0/Z, [x23, #3, MUL VL]\n" + "addvl x23, x23, #4\n" + "zip2 z18.h, z18.h, z17.h\n" + "zip1 z17.h, z24.h, z16.h\n" + "zip2 z16.h, z24.h, z16.h\n" + "st1h { z23.h }, p4, [x22]\n" + "st1h { z22.h }, p4, [x22, #1, MUL VL]\n" + "st1h { z21.h }, p4, [x22, #2, MUL VL]\n" + "st1h { z20.h }, p4, [x22, #3, MUL VL]\n" + "st1h { z19.h }, p4, [x22, #4, MUL VL]\n" + "st1h { z18.h }, p4, [x22, #5, MUL VL]\n" + "st1h { z17.h }, p4, [x22, #6, MUL VL]\n" + "st1h { z16.h }, p4, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #8\n" + "bge 1b\n" + ".inst 0xd503467f // SMSTOP\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x20", "x21", "x22", "x23", "x24", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<8, 2, true, VLType::SME>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_8VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + +template<> +void Transform<8, 2, true, VLType::SME>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sme_transpose_interleave_8VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SME) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_32bit.hpp deleted file mode 100644 index 881dc7bb72..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_32bit.hpp +++ /dev/null @@ -1,596 +0,0 @@ -/* - * Copyright (c) 2018-2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __ARM_FEATURE_SVE - -template<> -template<typename T> -inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - uint32_t *master_outptr = reinterpret_cast<uint32_t *>(out); - const uint32_t *inptr = reinterpret_cast<const uint32_t *>(in); - - for (int y=y0; y<ymax; y+=8) - { - const int height = ymax-y; - const long inwidth = (kmax - k0); - const long outwidth = inwidth * 8; - long inpos = 0; - long outpos = 0; - - uint32_t *outptr = master_outptr; - master_outptr += outwidth; - - const uint32_t *inptr0 = inptr + y * ldin + k0; - const uint32_t *inptr1 = inptr0 + ldin; - const uint32_t *inptr2 = inptr1 + ldin; - const uint32_t *inptr3 = inptr2 + ldin; - const uint32_t *inptr4 = inptr3 + ldin; - const uint32_t *inptr5 = inptr4 + ldin; - const uint32_t *inptr6 = inptr5 + ldin; - const uint32_t *inptr7 = inptr6 + ldin; - - switch(height) - { - case 1: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0], %[inpos], LSL #2]\n" - "incw %[inpos], all, mul #1\n" - "whilelt p0.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "zip2 z9.s, z0.s, z4.s\n" - "whilelt p1.s, %[outpos], %[outwidth]\n" - "zip1 z0.s, z8.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z1.s, z8.s, z4.s\n" - "zip1 z2.s, z9.s, z4.s\n" - "zip2 z3.s, z9.s, z4.s\n" - "whilelt p2.s, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z4.s\n" - "st1w z8.s, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z4.s\n" - "whilelt p3.s, %[outpos], %[outwidth]\n" - "zip1 z12.s, z2.s, z4.s\n" - "st1w z9.s, p1, [%[outptr], #1, MUL VL]\n" - "zip2 z13.s, z2.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z4.s\n" - "st1w z10.s, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z15.s, z3.s, z4.s\n" - "whilelt p4.s, %[outpos], %[outwidth]\n" - "st1w z11.s, p3, [%[outptr], #3, MUL VL]\n" - "incw %[outpos], all, mul #1\n" - "st1w z12.s, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p5.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z13.s, p5, [%[outptr], #5, MUL VL]\n" - "whilelt p6.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z14.s, p6, [%[outptr], #6, MUL VL]\n" - "whilelt p7.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z15.s, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 2: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.s, #0\n" - "mov z14.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0], %[inpos], LSL #2]\n" - "ld1w z1.s, p0/z, [%[inptr1], %[inpos], LSL #2]\n" - "incw %[inpos], all, mul #1\n" - "whilelt p0.s, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z4.s\n" - "zip2 z11.s, z1.s, z4.s\n" - "whilelt p1.s, %[outpos], %[outwidth]\n" - "zip1 z0.s, z8.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z1.s, z8.s, z4.s\n" - "zip1 z2.s, z9.s, z4.s\n" - "zip2 z3.s, z9.s, z4.s\n" - "whilelt p2.s, %[outpos], %[outwidth]\n" - "zip1 z4.s, z10.s, z14.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z5.s, z10.s, z14.s\n" - "zip1 z6.s, z11.s, z14.s\n" - "zip2 z7.s, z11.s, z14.s\n" - "whilelt p3.s, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1w z8.s, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "whilelt p4.s, %[outpos], %[outwidth]\n" - "zip1 z12.s, z2.s, z6.s\n" - "st1w z9.s, p1, [%[outptr], #1, MUL VL]\n" - "zip2 z13.s, z2.s, z6.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z7.s\n" - "st1w z10.s, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z15.s, z3.s, z7.s\n" - "whilelt p5.s, %[outpos], %[outwidth]\n" - "st1w z11.s, p3, [%[outptr], #3, MUL VL]\n" - "incw %[outpos], all, mul #1\n" - "st1w z12.s, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p6.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z13.s, p5, [%[outptr], #5, MUL VL]\n" - "whilelt p7.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z14.s, p6, [%[outptr], #6, MUL VL]\n" - "st1w z15.s, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 3: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.s, #0\n" - "mov z14.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0], %[inpos], LSL #2]\n" - "ld1w z1.s, p0/z, [%[inptr1], %[inpos], LSL #2]\n" - "ld1w z2.s, p0/z, [%[inptr2], %[inpos], LSL #2]\n" - "incw %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.s, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z4.s\n" - "zip2 z11.s, z1.s, z4.s\n" - "zip1 z12.s, z2.s, z4.s\n" - "whilelt p1.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z0.s, z8.s, z12.s\n" - "zip2 z1.s, z8.s, z12.s\n" - "zip1 z2.s, z9.s, z13.s\n" - "whilelt p2.s, %[outpos], %[outwidth]\n" - "zip2 z3.s, z9.s, z13.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z4.s, z10.s, z14.s\n" - "zip2 z5.s, z10.s, z14.s\n" - "zip1 z6.s, z11.s, z14.s\n" - "whilelt p3.s, %[outpos], %[outwidth]\n" - "zip2 z7.s, z11.s, z14.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z5.s\n" - "whilelt p4.s, %[outpos], %[outwidth]\n" - "zip2 z11.s, z1.s, z5.s\n" - "st1w z8.s, p0, [%[outptr]]\n" - "zip1 z12.s, z2.s, z6.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1w z9.s, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "zip2 z15.s, z3.s, z7.s\n" - "whilelt p5.s, %[outpos], %[outwidth]\n" - "st1w z10.s, p2, [%[outptr], #2, MUL VL]\n" - "incw %[outpos], all, mul #1\n" - "st1w z11.s, p3, [%[outptr], #3, MUL VL]\n" - "whilelt p6.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z12.s, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z13.s, p5, [%[outptr], #5, MUL VL]\n" - "st1w z14.s, p6, [%[outptr], #6, MUL VL]\n" - "st1w z15.s, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 4: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0], %[inpos], LSL #2]\n" - "ld1w z1.s, p0/z, [%[inptr1], %[inpos], LSL #2]\n" - "ld1w z2.s, p0/z, [%[inptr2], %[inpos], LSL #2]\n" - "ld1w z3.s, p0/z, [%[inptr3], %[inpos], LSL #2]\n" - "incw %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.s, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z4.s\n" - "zip2 z11.s, z1.s, z4.s\n" - "zip1 z12.s, z2.s, z4.s\n" - "whilelt p1.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z4.s\n" - "zip2 z15.s, z3.s, z4.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.s, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.s, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.s, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1w z8.s, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1w z9.s, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1w z10.s, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.s, %[outpos], %[outwidth]\n" - "st1w z11.s, p3, [%[outptr], #3, MUL VL]\n" - "incw %[outpos], all, mul #1\n" - "st1w z12.s, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z13.s, p5, [%[outptr], #5, MUL VL]\n" - "st1w z14.s, p6, [%[outptr], #6, MUL VL]\n" - "st1w z15.s, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 5: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z5.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0], %[inpos], LSL #2]\n" - "ld1w z1.s, p0/z, [%[inptr1], %[inpos], LSL #2]\n" - "ld1w z2.s, p0/z, [%[inptr2], %[inpos], LSL #2]\n" - "ld1w z3.s, p0/z, [%[inptr3], %[inpos], LSL #2]\n" - "ld1w z4.s, p0/z, [%[inptr4], %[inpos], LSL #2]\n" - "incw %[inpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "whilelt p0.s, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z5.s\n" - "whilelt p1.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z5.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z5.s\n" - "zip2 z15.s, z3.s, z5.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.s, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.s, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.s, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1w z8.s, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1w z9.s, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1w z10.s, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.s, %[outpos], %[outwidth]\n" - "st1w z11.s, p3, [%[outptr], #3, MUL VL]\n" - "incw %[outpos], all, mul #1\n" - "st1w z12.s, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z13.s, p5, [%[outptr], #5, MUL VL]\n" - "st1w z14.s, p6, [%[outptr], #6, MUL VL]\n" - "st1w z15.s, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 6: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z6.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0], %[inpos], LSL #2]\n" - "ld1w z1.s, p0/z, [%[inptr1], %[inpos], LSL #2]\n" - "ld1w z2.s, p0/z, [%[inptr2], %[inpos], LSL #2]\n" - "ld1w z3.s, p0/z, [%[inptr3], %[inpos], LSL #2]\n" - "ld1w z4.s, p0/z, [%[inptr4], %[inpos], LSL #2]\n" - "ld1w z5.s, p0/z, [%[inptr5], %[inpos], LSL #2]\n" - "incw %[inpos], all, mul #1\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p0.s, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z5.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "whilelt p1.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z6.s\n" - "zip2 z15.s, z3.s, z6.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.s, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.s, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.s, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1w z8.s, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1w z9.s, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1w z10.s, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.s, %[outpos], %[outwidth]\n" - "st1w z11.s, p3, [%[outptr], #3, MUL VL]\n" - "incw %[outpos], all, mul #1\n" - "st1w z12.s, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z13.s, p5, [%[outptr], #5, MUL VL]\n" - "st1w z14.s, p6, [%[outptr], #6, MUL VL]\n" - "st1w z15.s, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 7: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z7.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0], %[inpos], LSL #2]\n" - "ld1w z1.s, p0/z, [%[inptr1], %[inpos], LSL #2]\n" - "ld1w z2.s, p0/z, [%[inptr2], %[inpos], LSL #2]\n" - "ld1w z3.s, p0/z, [%[inptr3], %[inpos], LSL #2]\n" - "ld1w z4.s, p0/z, [%[inptr4], %[inpos], LSL #2]\n" - "ld1w z5.s, p0/z, [%[inptr5], %[inpos], LSL #2]\n" - "ld1w z6.s, p0/z, [%[inptr6], %[inpos], LSL #2]\n" - "incw %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.s, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p1.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z7.s\n" - "zip2 z15.s, z3.s, z7.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.s, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.s, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.s, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1w z8.s, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1w z9.s, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1w z10.s, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.s, %[outpos], %[outwidth]\n" - "st1w z11.s, p3, [%[outptr], #3, MUL VL]\n" - "incw %[outpos], all, mul #1\n" - "st1w z12.s, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z13.s, p5, [%[outptr], #5, MUL VL]\n" - "st1w z14.s, p6, [%[outptr], #6, MUL VL]\n" - "st1w z15.s, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - default: - case 8: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "ld1w z0.s, p0/z, [%[inptr0], %[inpos], LSL #2]\n" - "ld1w z1.s, p0/z, [%[inptr1], %[inpos], LSL #2]\n" - "ld1w z2.s, p0/z, [%[inptr2], %[inpos], LSL #2]\n" - "ld1w z3.s, p0/z, [%[inptr3], %[inpos], LSL #2]\n" - "ld1w z4.s, p0/z, [%[inptr4], %[inpos], LSL #2]\n" - "ld1w z5.s, p0/z, [%[inptr5], %[inpos], LSL #2]\n" - "ld1w z6.s, p0/z, [%[inptr6], %[inpos], LSL #2]\n" - "ld1w z7.s, p0/z, [%[inptr7], %[inpos], LSL #2]\n" - "incw %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.s, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p1.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z7.s\n" - "zip2 z15.s, z3.s, z7.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.s, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.s, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.s, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incw %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1w z8.s, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.s, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1w z9.s, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "incw %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1w z10.s, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.s, %[outpos], %[outwidth]\n" - "st1w z11.s, p3, [%[outptr], #3, MUL VL]\n" - "incw %[outpos], all, mul #1\n" - "st1w z12.s, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.s, %[outpos], %[outwidth]\n" - "incw %[outpos], all, mul #1\n" - "st1w z13.s, p5, [%[outptr], #5, MUL VL]\n" - "st1w z14.s, p6, [%[outptr], #6, MUL VL]\n" - "st1w z15.s, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - - } - } -} - -#endif // __ARM_FEATURE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block2_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block2_16bit.hpp deleted file mode 100644 index 234433a0f1..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block2_16bit.hpp +++ /dev/null @@ -1,596 +0,0 @@ -/* - * Copyright (c) 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __ARM_FEATURE_SVE - -template<> -template<typename T> -inline void TransformImpl<8, 2, false, 2, 2, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - uint16_t *master_outptr = reinterpret_cast<uint16_t *>(out); - const uint16_t *inptr = reinterpret_cast<const uint16_t *>(in); - - for (int y=y0; y<ymax; y+=8) - { - const int height = ymax-y; - const long inwidth = (kmax - k0); - const long outwidth = ((inwidth + 1) / 2) * 16; - long inpos = 0; - long outpos = 0; - - uint16_t *outptr = master_outptr; - master_outptr += outwidth; - - const uint16_t *inptr0 = inptr + y * ldin + k0; - const uint16_t *inptr1 = inptr0 + ldin; - const uint16_t *inptr2 = inptr1 + ldin; - const uint16_t *inptr3 = inptr2 + ldin; - const uint16_t *inptr4 = inptr3 + ldin; - const uint16_t *inptr5 = inptr4 + ldin; - const uint16_t *inptr6 = inptr5 + ldin; - const uint16_t *inptr7 = inptr6 + ldin; - - switch(height) - { - case 1: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "zip2 z9.s, z0.s, z4.s\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip1 z0.s, z8.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z1.s, z8.s, z4.s\n" - "zip1 z2.s, z9.s, z4.s\n" - "zip2 z3.s, z9.s, z4.s\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z4.s\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z4.s\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip1 z12.s, z2.s, z4.s\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip2 z13.s, z2.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z4.s\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z15.s, z3.s, z4.s\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 2: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.h, #0\n" - "mov z14.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z4.s\n" - "zip2 z11.s, z1.s, z4.s\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip1 z0.s, z8.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z1.s, z8.s, z4.s\n" - "zip1 z2.s, z9.s, z4.s\n" - "zip2 z3.s, z9.s, z4.s\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip1 z4.s, z10.s, z14.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z5.s, z10.s, z14.s\n" - "zip1 z6.s, z11.s, z14.s\n" - "zip2 z7.s, z11.s, z14.s\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip1 z12.s, z2.s, z6.s\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip2 z13.s, z2.s, z6.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z7.s\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z15.s, z3.s, z7.s\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 3: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.h, #0\n" - "mov z14.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z4.s\n" - "zip2 z11.s, z1.s, z4.s\n" - "zip1 z12.s, z2.s, z4.s\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z0.s, z8.s, z12.s\n" - "zip2 z1.s, z8.s, z12.s\n" - "zip1 z2.s, z9.s, z13.s\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z3.s, z9.s, z13.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z4.s, z10.s, z14.s\n" - "zip2 z5.s, z10.s, z14.s\n" - "zip1 z6.s, z11.s, z14.s\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z7.s, z11.s, z14.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z5.s\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z11.s, z1.s, z5.s\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip1 z12.s, z2.s, z6.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "zip2 z15.s, z3.s, z7.s\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 4: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "ld1h z3.h, p0/z, [%[inptr3], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z4.s\n" - "zip2 z11.s, z1.s, z4.s\n" - "zip1 z12.s, z2.s, z4.s\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z4.s\n" - "zip2 z15.s, z3.s, z4.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 5: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z5.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "ld1h z3.h, p0/z, [%[inptr3], %[inpos], LSL #1]\n" - "ld1h z4.h, p0/z, [%[inptr4], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z5.s\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z5.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z5.s\n" - "zip2 z15.s, z3.s, z5.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 6: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z6.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "ld1h z3.h, p0/z, [%[inptr3], %[inpos], LSL #1]\n" - "ld1h z4.h, p0/z, [%[inptr4], %[inpos], LSL #1]\n" - "ld1h z5.h, p0/z, [%[inptr5], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z5.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z6.s\n" - "zip2 z15.s, z3.s, z6.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 7: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z7.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "ld1h z3.h, p0/z, [%[inptr3], %[inpos], LSL #1]\n" - "ld1h z4.h, p0/z, [%[inptr4], %[inpos], LSL #1]\n" - "ld1h z5.h, p0/z, [%[inptr5], %[inpos], LSL #1]\n" - "ld1h z6.h, p0/z, [%[inptr6], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z7.s\n" - "zip2 z15.s, z3.s, z7.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - default: - case 8: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "ld1h z3.h, p0/z, [%[inptr3], %[inpos], LSL #1]\n" - "ld1h z4.h, p0/z, [%[inptr4], %[inpos], LSL #1]\n" - "ld1h z5.h, p0/z, [%[inptr5], %[inpos], LSL #1]\n" - "ld1h z6.h, p0/z, [%[inptr6], %[inpos], LSL #1]\n" - "ld1h z7.h, p0/z, [%[inptr7], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z7.s\n" - "zip2 z15.s, z3.s, z7.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "inch %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - - } - } -} - -#endif // __ARM_FEATURE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block2_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block2_32bit.hpp deleted file mode 100644 index 4cc4311cee..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block2_32bit.hpp +++ /dev/null @@ -1,632 +0,0 @@ -/* - * Copyright (c) 2018 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __ARM_FEATURE_SVE - -template<> -template<typename T> -inline void TransformImpl<8, 2, false, 4, 4, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - uint32_t *master_outptr = reinterpret_cast<uint32_t *>(out); - const uint32_t *inptr = reinterpret_cast<const uint32_t *>(in); - - for (int y=y0; y<ymax; y+=8) - { - const int height = ymax-y; - const long inwidth = (kmax - k0); - const long outwidth = (inwidth * 8 + 1) / 2; - long inpos = 0; - long outpos = 0; - - uint32_t *outptr = master_outptr; - master_outptr += (outwidth * 2); - - const uint32_t *inptr0 = inptr + y * ldin + k0; - const uint32_t *inptr1 = inptr0 + ldin; - const uint32_t *inptr2 = inptr1 + ldin; - const uint32_t *inptr3 = inptr2 + ldin; - const uint32_t *inptr4 = inptr3 + ldin; - const uint32_t *inptr5 = inptr4 + ldin; - const uint32_t *inptr6 = inptr5 + ldin; - const uint32_t *inptr7 = inptr6 + ldin; - - switch(height) - { - case 1: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0]]\n" - "zip1 z8.d, z0.d, z4.d\n" - "incw %[inpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "addvl %[inptr0], %[inptr0], #1\n" - "zip1 z0.d, z8.d, z4.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z4.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z4.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "zip2 z3.d, z9.d, z4.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "st1d z8.d, p0, [%[outptr]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "st1d z9.d, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z10.d, z1.d, z4.d\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "zip2 z11.d, z1.d, z4.d\n" - "st1d z10.d, p2, [%[outptr], #2, MUL VL]\n" - "zip1 z12.d, z2.d, z4.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z13.d, z2.d, z4.d\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "zip1 z14.d, z3.d, z4.d\n" - "st1d z11.d, p3, [%[outptr], #3, MUL VL]\n" - "zip2 z15.d, z3.d, z4.d\n" - "incd %[outpos], all, mul #1\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "st1d z12.d, p0, [%[outptr], #4, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "st1d z13.d, p1, [%[outptr], #5, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "st1d z14.d, p2, [%[outptr], #6, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "st1d z15.d, p3, [%[outptr], #7, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 2: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0]]\n" - "zip1 z8.d, z0.d, z4.d\n" - "ld1w z1.s, p0/z, [%[inptr1]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "incw %[inpos], all, mul #1\n" - "zip1 z10.d, z1.d, z4.d\n" - "addvl %[inptr0], %[inptr0], #1\n" - "zip2 z11.d, z1.d, z4.d\n" - "addvl %[inptr1], %[inptr1], #1\n" - "zip1 z0.d, z8.d, z4.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z4.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z4.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "zip2 z3.d, z9.d, z4.d\n" - "incd %[outpos], all, mul #1\n" - "mov z14.s, #0\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "zip1 z4.d, z10.d, z14.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z5.d, z10.d, z14.d\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "zip1 z6.d, z11.d, z14.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z7.d, z11.d, z14.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "st1d z8.d, p0, [%[outptr]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "st1d z9.d, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1d z10.d, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z11.d, z1.d, z5.d\n" - "st1d z11.d, p3, [%[outptr], #3, MUL VL]\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1d z12.d, p0, [%[outptr], #4, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "st1d z13.d, p1, [%[outptr], #5, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "st1d z14.d, p2, [%[outptr], #6, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "st1d z15.d, p3, [%[outptr], #7, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 3: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0]]\n" - "zip1 z8.d, z0.d, z4.d\n" - "ld1w z1.s, p0/z, [%[inptr1]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "ld1w z2.s, p0/z, [%[inptr2]]\n" - "zip1 z10.d, z1.d, z4.d\n" - "incw %[inpos], all, mul #1\n" - "zip2 z11.d, z1.d, z4.d\n" - "addvl %[inptr0], %[inptr0], #1\n" - "zip1 z12.d, z2.d, z4.d\n" - "addvl %[inptr1], %[inptr1], #1\n" - "zip2 z13.d, z2.d, z4.d\n" - "addvl %[inptr2], %[inptr2], #1\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "zip2 z3.d, z9.d, z13.d\n" - "incd %[outpos], all, mul #1\n" - "mov z14.s, #0\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "zip1 z4.d, z10.d, z14.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z5.d, z10.d, z14.d\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "zip1 z6.d, z11.d, z14.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z7.d, z11.d, z14.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "st1d z8.d, p0, [%[outptr]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "st1d z9.d, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1d z10.d, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z11.d, z1.d, z5.d\n" - "st1d z11.d, p3, [%[outptr], #3, MUL VL]\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1d z12.d, p0, [%[outptr], #4, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "st1d z13.d, p1, [%[outptr], #5, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "st1d z14.d, p2, [%[outptr], #6, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "st1d z15.d, p3, [%[outptr], #7, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 4: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0]]\n" - "zip1 z8.d, z0.d, z4.d\n" - "ld1w z1.s, p0/z, [%[inptr1]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "ld1w z2.s, p0/z, [%[inptr2]]\n" - "zip1 z10.d, z1.d, z4.d\n" - "ld1w z3.s, p0/z, [%[inptr3]]\n" - "zip2 z11.d, z1.d, z4.d\n" - "incw %[inpos], all, mul #1\n" - "zip1 z12.d, z2.d, z4.d\n" - "addvl %[inptr0], %[inptr0], #1\n" - "zip2 z13.d, z2.d, z4.d\n" - "addvl %[inptr1], %[inptr1], #1\n" - "zip1 z14.d, z3.d, z4.d\n" - "addvl %[inptr2], %[inptr2], #1\n" - "zip2 z15.d, z3.d, z4.d\n" - "addvl %[inptr3], %[inptr3], #1\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "zip2 z3.d, z9.d, z13.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "zip2 z7.d, z11.d, z15.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "st1d z8.d, p0, [%[outptr]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "st1d z9.d, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1d z10.d, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z11.d, z1.d, z5.d\n" - "st1d z11.d, p3, [%[outptr], #3, MUL VL]\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1d z12.d, p0, [%[outptr], #4, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "st1d z13.d, p1, [%[outptr], #5, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "st1d z14.d, p2, [%[outptr], #6, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "st1d z15.d, p3, [%[outptr], #7, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 5: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z5.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0]]\n" - "ld1w z1.s, p0/z, [%[inptr1]]\n" - "incw %[inpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "ld1w z2.s, p0/z, [%[inptr2]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "ld1w z3.s, p0/z, [%[inptr3]]\n" - "zip1 z12.d, z2.d, z5.d\n" - "ld1w z4.s, p0/z, [%[inptr4]]\n" - "zip1 z8.d, z0.d, z4.d\n" - "addvl %[inptr0], %[inptr0], #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "addvl %[inptr1], %[inptr1], #1\n" - "zip2 z13.d, z2.d, z5.d\n" - "addvl %[inptr2], %[inptr2], #1\n" - "zip1 z14.d, z3.d, z5.d\n" - "addvl %[inptr3], %[inptr3], #1\n" - "zip2 z15.d, z3.d, z5.d\n" - "addvl %[inptr4], %[inptr4], #1\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "zip2 z3.d, z9.d, z13.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "zip2 z7.d, z11.d, z15.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "st1d z8.d, p0, [%[outptr]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "st1d z9.d, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1d z10.d, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z11.d, z1.d, z5.d\n" - "st1d z11.d, p3, [%[outptr], #3, MUL VL]\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1d z12.d, p0, [%[outptr], #4, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "st1d z13.d, p1, [%[outptr], #5, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "st1d z14.d, p2, [%[outptr], #6, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "st1d z15.d, p3, [%[outptr], #7, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 6: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z6.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0]]\n" - "ld1w z1.s, p0/z, [%[inptr1]]\n" - "incw %[inpos], all, mul #1\n" - "ld1w z2.s, p0/z, [%[inptr2]]\n" - "addvl %[inptr0], %[inptr0], #1\n" - "zip1 z12.d, z2.d, z6.d\n" - "ld1w z3.s, p0/z, [%[inptr3]]\n" - "zip2 z13.d, z2.d, z6.d\n" - "ld1w z4.s, p0/z, [%[inptr4]]\n" - "zip1 z8.d, z0.d, z4.d\n" - "ld1w z5.s, p0/z, [%[inptr5]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "addvl %[inptr1], %[inptr1], #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "addvl %[inptr2], %[inptr2], #1\n" - "zip2 z11.d, z1.d, z5.d\n" - "addvl %[inptr3], %[inptr3], #1\n" - "zip1 z14.d, z3.d, z6.d\n" - "addvl %[inptr4], %[inptr4], #1\n" - "zip2 z15.d, z3.d, z6.d\n" - "addvl %[inptr5], %[inptr5], #1\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "zip2 z3.d, z9.d, z13.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "zip2 z7.d, z11.d, z15.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "st1d z8.d, p0, [%[outptr]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "st1d z9.d, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1d z10.d, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z11.d, z1.d, z5.d\n" - "st1d z11.d, p3, [%[outptr], #3, MUL VL]\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1d z12.d, p0, [%[outptr], #4, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "st1d z13.d, p1, [%[outptr], #5, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "st1d z14.d, p2, [%[outptr], #6, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "st1d z15.d, p3, [%[outptr], #7, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 7: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z7.s, #0\n" - "ld1w z0.s, p0/z, [%[inptr0]]\n" - "ld1w z1.s, p0/z, [%[inptr1]]\n" - "incw %[inpos], all, mul #1\n" - "ld1w z2.s, p0/z, [%[inptr2]]\n" - "addvl %[inptr0], %[inptr0], #1\n" - "ld1w z3.s, p0/z, [%[inptr3]]\n" - "addvl %[inptr1], %[inptr1], #1\n" - "zip1 z14.d, z3.d, z7.d\n" - "ld1w z4.s, p0/z, [%[inptr4]]\n" - "zip1 z8.d, z0.d, z4.d\n" - "ld1w z5.s, p0/z, [%[inptr5]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "ld1w z6.s, p0/z, [%[inptr6]]\n" - "zip1 z10.d, z1.d, z5.d\n" - "addvl %[inptr2], %[inptr2], #1\n" - "zip2 z11.d, z1.d, z5.d\n" - "addvl %[inptr3], %[inptr3], #1\n" - "zip1 z12.d, z2.d, z6.d\n" - "addvl %[inptr4], %[inptr4], #1\n" - "zip2 z13.d, z2.d, z6.d\n" - "addvl %[inptr5], %[inptr5], #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "addvl %[inptr6], %[inptr6], #1\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "zip2 z3.d, z9.d, z13.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "zip2 z7.d, z11.d, z15.d\n" - "incd %[outpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "st1d z8.d, p0, [%[outptr]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "st1d z9.d, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1d z10.d, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z11.d, z1.d, z5.d\n" - "st1d z11.d, p3, [%[outptr], #3, MUL VL]\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1d z12.d, p0, [%[outptr], #4, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "st1d z13.d, p1, [%[outptr], #5, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "st1d z14.d, p2, [%[outptr], #6, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "st1d z15.d, p3, [%[outptr], #7, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - default: - case 8: - __asm __volatile( - "1:\n" - "whilelt p0.s, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "ld1w z0.s, p0/z, [%[inptr0]]\n" - "incw %[inpos], all, mul #1\n" - "ld1w z1.s, p0/z, [%[inptr1]]\n" - "addvl %[inptr0], %[inptr0], #1\n" - "ld1w z2.s, p0/z, [%[inptr2]]\n" - "addvl %[inptr1], %[inptr1], #1\n" - "ld1w z3.s, p0/z, [%[inptr3]]\n" - "addvl %[inptr2], %[inptr2], #1\n" - "ld1w z4.s, p0/z, [%[inptr4]]\n" - "addvl %[inptr3], %[inptr3], #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "ld1w z5.s, p0/z, [%[inptr5]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "ld1w z6.s, p0/z, [%[inptr6]]\n" - "zip1 z10.d, z1.d, z5.d\n" - "ld1w z7.s, p0/z, [%[inptr7]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "addvl %[inptr4], %[inptr4], #1\n" - "zip1 z12.d, z2.d, z6.d\n" - "addvl %[inptr5], %[inptr5], #1\n" - "zip2 z13.d, z2.d, z6.d\n" - "addvl %[inptr6], %[inptr6], #1\n" - "zip1 z14.d, z3.d, z7.d\n" - "addvl %[inptr7], %[inptr7], #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip1 z0.d, z8.d, z12.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z1.d, z8.d, z12.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "zip1 z2.d, z9.d, z13.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z3.d, z9.d, z13.d\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "zip1 z4.d, z10.d, z14.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z5.d, z10.d, z14.d\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "zip1 z6.d, z11.d, z15.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "st1d z8.d, p0, [%[outptr]]\n" - "zip2 z9.d, z0.d, z4.d\n" - "st1d z9.d, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1d z10.d, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z11.d, z1.d, z5.d\n" - "st1d z11.d, p3, [%[outptr], #3, MUL VL]\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p0.d, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1d z12.d, p0, [%[outptr], #4, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incd %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p1.d, %[outpos], %[outwidth]\n" - "st1d z13.d, p1, [%[outptr], #5, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p2.d, %[outpos], %[outwidth]\n" - "st1d z14.d, p2, [%[outptr], #6, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "whilelt p3.d, %[outpos], %[outwidth]\n" - "st1d z15.d, p3, [%[outptr], #7, MUL VL]\n" - "incd %[outpos], all, mul #1\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - - } - } -} - -#endif // __ARM_FEATURE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block4_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block4_16bit.hpp deleted file mode 100644 index 26e10511a6..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block4_16bit.hpp +++ /dev/null @@ -1,596 +0,0 @@ -/* - * Copyright (c) 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __ARM_FEATURE_SVE - -template<> -template<typename T> -inline void TransformImpl<8, 4, false, 2, 2, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - uint16_t *master_outptr = reinterpret_cast<uint16_t *>(out); - const uint16_t *inptr = reinterpret_cast<const uint16_t *>(in); - - for (int y=y0; y<ymax; y+=8) - { - const int height = ymax-y; - const long inwidth = (kmax - k0); - const long outwidth = ((inwidth + 3) / 4) * 32; - long inpos = 0; - long outpos = 0; - - uint16_t *outptr = master_outptr; - master_outptr += outwidth; - - const uint16_t *inptr0 = inptr + y * ldin + k0; - const uint16_t *inptr1 = inptr0 + ldin; - const uint16_t *inptr2 = inptr1 + ldin; - const uint16_t *inptr3 = inptr2 + ldin; - const uint16_t *inptr4 = inptr3 + ldin; - const uint16_t *inptr5 = inptr4 + ldin; - const uint16_t *inptr6 = inptr5 + ldin; - const uint16_t *inptr7 = inptr6 + ldin; - - switch(height) - { - case 1: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "zip2 z9.d, z0.d, z4.d\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip1 z0.d, z8.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z1.d, z8.d, z4.d\n" - "zip1 z2.d, z9.d, z4.d\n" - "zip2 z3.d, z9.d, z4.d\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip1 z8.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip1 z10.d, z1.d, z4.d\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z4.d\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip1 z12.d, z2.d, z4.d\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip2 z13.d, z2.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z4.d\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z15.d, z3.d, z4.d\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 2: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.h, #0\n" - "mov z14.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip1 z8.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip1 z10.d, z1.d, z4.d\n" - "zip2 z11.d, z1.d, z4.d\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip1 z0.d, z8.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z1.d, z8.d, z4.d\n" - "zip1 z2.d, z9.d, z4.d\n" - "zip2 z3.d, z9.d, z4.d\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip1 z4.d, z10.d, z14.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z5.d, z10.d, z14.d\n" - "zip1 z6.d, z11.d, z14.d\n" - "zip2 z7.d, z11.d, z14.d\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip1 z8.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip1 z12.d, z2.d, z6.d\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip2 z13.d, z2.d, z6.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z7.d\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 3: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.h, #0\n" - "mov z14.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z4.d\n" - "zip2 z11.d, z1.d, z4.d\n" - "zip1 z12.d, z2.d, z4.d\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z0.d, z8.d, z12.d\n" - "zip2 z1.d, z8.d, z12.d\n" - "zip1 z2.d, z9.d, z13.d\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z3.d, z9.d, z13.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z4.d, z10.d, z14.d\n" - "zip2 z5.d, z10.d, z14.d\n" - "zip1 z6.d, z11.d, z14.d\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z7.d, z11.d, z14.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip1 z10.d, z1.d, z5.d\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z11.d, z1.d, z5.d\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip1 z12.d, z2.d, z6.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 4: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "ld1h z3.h, p0/z, [%[inptr3], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z4.d\n" - "zip2 z11.d, z1.d, z4.d\n" - "zip1 z12.d, z2.d, z4.d\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z4.d\n" - "zip2 z15.d, z3.d, z4.d\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "zip2 z3.d, z9.d, z13.d\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 5: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z5.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "ld1h z3.h, p0/z, [%[inptr3], %[inpos], LSL #1]\n" - "ld1h z4.h, p0/z, [%[inptr4], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip1 z8.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z5.d\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z5.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z5.d\n" - "zip2 z15.d, z3.d, z5.d\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "zip2 z3.d, z9.d, z13.d\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 6: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z6.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "ld1h z3.h, p0/z, [%[inptr3], %[inpos], LSL #1]\n" - "ld1h z4.h, p0/z, [%[inptr4], %[inpos], LSL #1]\n" - "ld1h z5.h, p0/z, [%[inptr5], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip1 z8.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip1 z10.d, z1.d, z5.d\n" - "zip2 z11.d, z1.d, z5.d\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z6.d\n" - "zip2 z15.d, z3.d, z6.d\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "zip2 z3.d, z9.d, z13.d\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 7: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z7.h, #0\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "ld1h z3.h, p0/z, [%[inptr3], %[inpos], LSL #1]\n" - "ld1h z4.h, p0/z, [%[inptr4], %[inpos], LSL #1]\n" - "ld1h z5.h, p0/z, [%[inptr5], %[inpos], LSL #1]\n" - "ld1h z6.h, p0/z, [%[inptr6], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z7.d\n" - "zip2 z15.d, z3.d, z7.d\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "zip2 z3.d, z9.d, z13.d\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - default: - case 8: - __asm __volatile( - "1:\n" - "whilelt p0.h, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "ld1h z0.h, p0/z, [%[inptr0], %[inpos], LSL #1]\n" - "ld1h z1.h, p0/z, [%[inptr1], %[inpos], LSL #1]\n" - "ld1h z2.h, p0/z, [%[inptr2], %[inpos], LSL #1]\n" - "ld1h z3.h, p0/z, [%[inptr3], %[inpos], LSL #1]\n" - "ld1h z4.h, p0/z, [%[inptr4], %[inpos], LSL #1]\n" - "ld1h z5.h, p0/z, [%[inptr5], %[inpos], LSL #1]\n" - "ld1h z6.h, p0/z, [%[inptr6], %[inpos], LSL #1]\n" - "ld1h z7.h, p0/z, [%[inptr7], %[inpos], LSL #1]\n" - "inch %[inpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p0.h, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p1.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z7.d\n" - "zip2 z15.d, z3.d, z7.d\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p2.h, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "zip2 z3.d, z9.d, z13.d\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p3.h, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p4.h, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "inch %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1h z8.h, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p5.h, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1h z9.h, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "inch %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "st1h z10.h, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.h, %[outpos], %[outwidth]\n" - "st1h z11.h, p3, [%[outptr], #3, MUL VL]\n" - "inch %[outpos], all, mul #1\n" - "st1h z12.h, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.h, %[outpos], %[outwidth]\n" - "inch %[outpos], all, mul #1\n" - "st1h z13.h, p5, [%[outptr], #5, MUL VL]\n" - "st1h z14.h, p6, [%[outptr], #6, MUL VL]\n" - "st1h z15.h, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - - } - } -} - -#endif // __ARM_FEATURE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block4_8bit.hpp deleted file mode 100644 index a96a43cbeb..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block4_8bit.hpp +++ /dev/null @@ -1,596 +0,0 @@ -/* - * Copyright (c) 2018 - 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __ARM_FEATURE_SVE - -template<> -template<typename T> -inline void TransformImpl<8, 4, false, 1, 1, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - uint8_t *master_outptr = reinterpret_cast<uint8_t *>(out); - const uint8_t *inptr = reinterpret_cast<const uint8_t *>(in); - - for (int y=y0; y<ymax; y+=8) - { - const int height = ymax-y; - const long inwidth = (kmax - k0); - const long outwidth = ((inwidth + 3) / 4) * 32; - long inpos = 0; - long outpos = 0; - - uint8_t *outptr = master_outptr; - master_outptr += outwidth; - - const uint8_t *inptr0 = inptr + y * ldin + k0; - const uint8_t *inptr1 = inptr0 + ldin; - const uint8_t *inptr2 = inptr1 + ldin; - const uint8_t *inptr3 = inptr2 + ldin; - const uint8_t *inptr4 = inptr3 + ldin; - const uint8_t *inptr5 = inptr4 + ldin; - const uint8_t *inptr6 = inptr5 + ldin; - const uint8_t *inptr7 = inptr6 + ldin; - - switch(height) - { - case 1: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "zip2 z9.s, z0.s, z4.s\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip1 z0.s, z8.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z1.s, z8.s, z4.s\n" - "zip1 z2.s, z9.s, z4.s\n" - "zip2 z3.s, z9.s, z4.s\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z4.s\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z4.s\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip1 z12.s, z2.s, z4.s\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip2 z13.s, z2.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z4.s\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z15.s, z3.s, z4.s\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 2: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.b, #0\n" - "mov z14.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z4.s\n" - "zip2 z11.s, z1.s, z4.s\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip1 z0.s, z8.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z1.s, z8.s, z4.s\n" - "zip1 z2.s, z9.s, z4.s\n" - "zip2 z3.s, z9.s, z4.s\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip1 z4.s, z10.s, z14.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z5.s, z10.s, z14.s\n" - "zip1 z6.s, z11.s, z14.s\n" - "zip2 z7.s, z11.s, z14.s\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip1 z12.s, z2.s, z6.s\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip2 z13.s, z2.s, z6.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z7.s\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z15.s, z3.s, z7.s\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 3: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.b, #0\n" - "mov z14.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z4.s\n" - "zip2 z11.s, z1.s, z4.s\n" - "zip1 z12.s, z2.s, z4.s\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z0.s, z8.s, z12.s\n" - "zip2 z1.s, z8.s, z12.s\n" - "zip1 z2.s, z9.s, z13.s\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z3.s, z9.s, z13.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z4.s, z10.s, z14.s\n" - "zip2 z5.s, z10.s, z14.s\n" - "zip1 z6.s, z11.s, z14.s\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z7.s, z11.s, z14.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z5.s\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z11.s, z1.s, z5.s\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip1 z12.s, z2.s, z6.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "zip2 z15.s, z3.s, z7.s\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 4: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "ld1b z3.b, p0/z, [%[inptr3], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z4.s\n" - "zip2 z11.s, z1.s, z4.s\n" - "zip1 z12.s, z2.s, z4.s\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z4.s\n" - "zip2 z15.s, z3.s, z4.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 5: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z5.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "ld1b z3.b, p0/z, [%[inptr3], %[inpos]]\n" - "ld1b z4.b, p0/z, [%[inptr4], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z5.s\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z5.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z5.s\n" - "zip2 z15.s, z3.s, z5.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 6: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z6.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "ld1b z3.b, p0/z, [%[inptr3], %[inpos]]\n" - "ld1b z4.b, p0/z, [%[inptr4], %[inpos]]\n" - "ld1b z5.b, p0/z, [%[inptr5], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip1 z8.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z9.s, z0.s, z4.s\n" - "zip1 z10.s, z1.s, z5.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z6.s\n" - "zip2 z15.s, z3.s, z6.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 7: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z7.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "ld1b z3.b, p0/z, [%[inptr3], %[inpos]]\n" - "ld1b z4.b, p0/z, [%[inptr4], %[inpos]]\n" - "ld1b z5.b, p0/z, [%[inptr5], %[inpos]]\n" - "ld1b z6.b, p0/z, [%[inptr6], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z7.s\n" - "zip2 z15.s, z3.s, z7.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - default: - case 8: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "ld1b z3.b, p0/z, [%[inptr3], %[inpos]]\n" - "ld1b z4.b, p0/z, [%[inptr4], %[inpos]]\n" - "ld1b z5.b, p0/z, [%[inptr5], %[inpos]]\n" - "ld1b z6.b, p0/z, [%[inptr6], %[inpos]]\n" - "ld1b z7.b, p0/z, [%[inptr7], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.s, z3.s, z7.s\n" - "zip2 z15.s, z3.s, z7.s\n" - "zip1 z0.s, z8.s, z12.s\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z1.s, z8.s, z12.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z2.s, z9.s, z13.s\n" - "zip2 z3.s, z9.s, z13.s\n" - "zip1 z4.s, z10.s, z14.s\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z5.s, z10.s, z14.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z6.s, z11.s, z15.s\n" - "zip2 z7.s, z11.s, z15.s\n" - "zip1 z8.s, z0.s, z4.s\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z9.s, z0.s, z4.s\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.s, z1.s, z5.s\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.s, z1.s, z5.s\n" - "zip1 z12.s, z2.s, z6.s\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "zip2 z13.s, z2.s, z6.s\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.s, z3.s, z7.s\n" - "incb %[outpos], all, mul #1\n" - "zip2 z15.s, z3.s, z7.s\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - - } - } -} - -#endif // __ARM_FEATURE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block8_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block8_8bit.hpp deleted file mode 100644 index b4935e6417..0000000000 --- a/src/core/NEON/kernels/arm_gemm/transforms/sve_interleave_8way_block8_8bit.hpp +++ /dev/null @@ -1,596 +0,0 @@ -/* - * Copyright (c) 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#pragma once - -#ifdef __ARM_FEATURE_SVE - -template<> -template<typename T> -inline void TransformImpl<8, 8, false, 1, 1, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) -{ - uint8_t *master_outptr = reinterpret_cast<uint8_t *>(out); - const uint8_t *inptr = reinterpret_cast<const uint8_t *>(in); - - for (int y=y0; y<ymax; y+=8) - { - const int height = ymax-y; - const long inwidth = (kmax - k0); - const long outwidth = ((inwidth + 7) / 8) * 64; - long inpos = 0; - long outpos = 0; - - uint8_t *outptr = master_outptr; - master_outptr += outwidth; - - const uint8_t *inptr0 = inptr + y * ldin + k0; - const uint8_t *inptr1 = inptr0 + ldin; - const uint8_t *inptr2 = inptr1 + ldin; - const uint8_t *inptr3 = inptr2 + ldin; - const uint8_t *inptr4 = inptr3 + ldin; - const uint8_t *inptr5 = inptr4 + ldin; - const uint8_t *inptr6 = inptr5 + ldin; - const uint8_t *inptr7 = inptr6 + ldin; - - switch(height) - { - case 1: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "zip2 z9.d, z0.d, z4.d\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip1 z0.d, z8.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z1.d, z8.d, z4.d\n" - "zip1 z2.d, z9.d, z4.d\n" - "zip2 z3.d, z9.d, z4.d\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip1 z8.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip1 z10.d, z1.d, z4.d\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z4.d\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip1 z12.d, z2.d, z4.d\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip2 z13.d, z2.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z4.d\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z15.d, z3.d, z4.d\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 2: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.b, #0\n" - "mov z14.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip1 z8.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip1 z10.d, z1.d, z4.d\n" - "zip2 z11.d, z1.d, z4.d\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip1 z0.d, z8.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z1.d, z8.d, z4.d\n" - "zip1 z2.d, z9.d, z4.d\n" - "zip2 z3.d, z9.d, z4.d\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip1 z4.d, z10.d, z14.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z5.d, z10.d, z14.d\n" - "zip1 z6.d, z11.d, z14.d\n" - "zip2 z7.d, z11.d, z14.d\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip1 z8.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip1 z12.d, z2.d, z6.d\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip2 z13.d, z2.d, z6.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z7.d\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 3: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.b, #0\n" - "mov z14.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z4.d\n" - "zip2 z11.d, z1.d, z4.d\n" - "zip1 z12.d, z2.d, z4.d\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z0.d, z8.d, z12.d\n" - "zip2 z1.d, z8.d, z12.d\n" - "zip1 z2.d, z9.d, z13.d\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z3.d, z9.d, z13.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z4.d, z10.d, z14.d\n" - "zip2 z5.d, z10.d, z14.d\n" - "zip1 z6.d, z11.d, z14.d\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z7.d, z11.d, z14.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip1 z10.d, z1.d, z5.d\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z11.d, z1.d, z5.d\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip1 z12.d, z2.d, z6.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "zip2 z15.d, z3.d, z7.d\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 4: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z4.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "ld1b z3.b, p0/z, [%[inptr3], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z4.d\n" - "zip2 z11.d, z1.d, z4.d\n" - "zip1 z12.d, z2.d, z4.d\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z4.d\n" - "zip2 z15.d, z3.d, z4.d\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "zip2 z3.d, z9.d, z13.d\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 5: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z5.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "ld1b z3.b, p0/z, [%[inptr3], %[inpos]]\n" - "ld1b z4.b, p0/z, [%[inptr4], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip1 z8.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z5.d\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z5.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z5.d\n" - "zip2 z15.d, z3.d, z5.d\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "zip2 z3.d, z9.d, z13.d\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 6: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z6.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "ld1b z3.b, p0/z, [%[inptr3], %[inpos]]\n" - "ld1b z4.b, p0/z, [%[inptr4], %[inpos]]\n" - "ld1b z5.b, p0/z, [%[inptr5], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip1 z8.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z9.d, z0.d, z4.d\n" - "zip1 z10.d, z1.d, z5.d\n" - "zip2 z11.d, z1.d, z5.d\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z6.d\n" - "zip2 z15.d, z3.d, z6.d\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "zip2 z3.d, z9.d, z13.d\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - case 7: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "mov z7.b, #0\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "ld1b z3.b, p0/z, [%[inptr3], %[inpos]]\n" - "ld1b z4.b, p0/z, [%[inptr4], %[inpos]]\n" - "ld1b z5.b, p0/z, [%[inptr5], %[inpos]]\n" - "ld1b z6.b, p0/z, [%[inptr6], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z7.d\n" - "zip2 z15.d, z3.d, z7.d\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "zip2 z3.d, z9.d, z13.d\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - default: - case 8: - __asm __volatile( - "1:\n" - "whilelt p0.b, %[inpos], %[inwidth]\n" - "b.none 2f\n" - "ld1b z0.b, p0/z, [%[inptr0], %[inpos]]\n" - "ld1b z1.b, p0/z, [%[inptr1], %[inpos]]\n" - "ld1b z2.b, p0/z, [%[inptr2], %[inpos]]\n" - "ld1b z3.b, p0/z, [%[inptr3], %[inpos]]\n" - "ld1b z4.b, p0/z, [%[inptr4], %[inpos]]\n" - "ld1b z5.b, p0/z, [%[inptr5], %[inpos]]\n" - "ld1b z6.b, p0/z, [%[inptr6], %[inpos]]\n" - "ld1b z7.b, p0/z, [%[inptr7], %[inpos]]\n" - "incb %[inpos], all, mul #1\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p0.b, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p1.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z14.d, z3.d, z7.d\n" - "zip2 z15.d, z3.d, z7.d\n" - "zip1 z0.d, z8.d, z12.d\n" - "whilelt p2.b, %[outpos], %[outwidth]\n" - "zip2 z1.d, z8.d, z12.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z2.d, z9.d, z13.d\n" - "zip2 z3.d, z9.d, z13.d\n" - "zip1 z4.d, z10.d, z14.d\n" - "whilelt p3.b, %[outpos], %[outwidth]\n" - "zip2 z5.d, z10.d, z14.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z6.d, z11.d, z15.d\n" - "zip2 z7.d, z11.d, z15.d\n" - "zip1 z8.d, z0.d, z4.d\n" - "whilelt p4.b, %[outpos], %[outwidth]\n" - "zip2 z9.d, z0.d, z4.d\n" - "incb %[outpos], all, mul #1\n" - "zip1 z10.d, z1.d, z5.d\n" - "st1b z8.b, p0, [%[outptr]]\n" - "zip2 z11.d, z1.d, z5.d\n" - "zip1 z12.d, z2.d, z6.d\n" - "whilelt p5.b, %[outpos], %[outwidth]\n" - "zip2 z13.d, z2.d, z6.d\n" - "st1b z9.b, p1, [%[outptr], #1, MUL VL]\n" - "zip1 z14.d, z3.d, z7.d\n" - "incb %[outpos], all, mul #1\n" - "zip2 z15.d, z3.d, z7.d\n" - "st1b z10.b, p2, [%[outptr], #2, MUL VL]\n" - "whilelt p6.b, %[outpos], %[outwidth]\n" - "st1b z11.b, p3, [%[outptr], #3, MUL VL]\n" - "incb %[outpos], all, mul #1\n" - "st1b z12.b, p4, [%[outptr], #4, MUL VL]\n" - "whilelt p7.b, %[outpos], %[outwidth]\n" - "incb %[outpos], all, mul #1\n" - "st1b z13.b, p5, [%[outptr], #5, MUL VL]\n" - "st1b z14.b, p6, [%[outptr], #6, MUL VL]\n" - "st1b z15.b, p7, [%[outptr], #7, MUL VL]\n" - "addvl %[outptr], %[outptr], #8\n" - "b 1b\n" - "2:\n" - : [inpos] "+r" (inpos), [outpos] "+r" (outpos), [outptr] "+r" (outptr), [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7) - : [outwidth] "r" (outwidth), [inwidth] "r" (inwidth) - : "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "cc", "memory" - ); - break; - - - } - } -} - -#endif // __ARM_FEATURE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_12VL_2x4_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_12VL_2x4_fp32bf16.hpp new file mode 100644 index 0000000000..f627fe575f --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_12VL_2x4_fp32bf16.hpp @@ -0,0 +1,376 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_12VL_2x4_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 12 * roundup<size_t>(height, 4) * get_vector_length<uint32_t>(); + + __asm__ __volatile__( + "ptrue p6.b\n" + "1:" // Main row loop: Head + "mov x28, %x[in]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "mov x25, %x[width]\n" + "cnth x24, ALL, MUL #6\n" + "add x23, x26, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x23, %x[in_stride]\n" + "csel x23, x23, %x[pad_row], GT\n" + "csel x26, x26, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x27, x27, %x[pad_row], GT\n" + "cmp x25, x24\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1w { z22.s }, p6/Z, [x28]\n" + "ld1w { z7.s }, p6/Z, [x28, #1, MUL VL]\n" + "mov x21, x22\n" + "add x22, x22, %x[out_stride]\n" + "ld1w { z19.s }, p6/Z, [x28, #2, MUL VL]\n" + "ld1w { z18.s }, p6/Z, [x28, #3, MUL VL]\n" + "mov x20, x22\n" + "sub x25, x25, x24\n" + "ld1w { z5.s }, p6/Z, [x28, #4, MUL VL]\n" + "ld1w { z25.s }, p6/Z, [x28, #5, MUL VL]\n" + "cmp x25, x24\n" + "add x22, x22, %x[out_stride]\n" + "ld1w { z20.s }, p6/Z, [x28, #6, MUL VL]\n" + "ld1w { z23.s }, p6/Z, [x28, #7, MUL VL]\n" + "addvl x28, x28, #12\n" + "ld1w { z4.s }, p6/Z, [x26]\n" + "ld1w { z10.s }, p6/Z, [x26, #1, MUL VL]\n" + "zip1 z14.s, z22.s, z4.s\n" + "zip2 z22.s, z22.s, z4.s\n" + "ld1w { z28.s }, p6/Z, [x26, #2, MUL VL]\n" + "ld1w { z27.s }, p6/Z, [x26, #3, MUL VL]\n" + "zip1 z24.s, z7.s, z10.s\n" + "zip2 z15.s, z7.s, z10.s\n" + "ld1w { z7.s }, p6/Z, [x26, #4, MUL VL]\n" + "ld1w { z2.s }, p6/Z, [x26, #5, MUL VL]\n" + "zip1 z9.s, z19.s, z28.s\n" + "zip2 z0.s, z19.s, z28.s\n" + "ld1w { z19.s }, p6/Z, [x26, #6, MUL VL]\n" + "ld1w { z16.s }, p6/Z, [x26, #7, MUL VL]\n" + "addvl x26, x26, #12\n" + "zip1 z1.s, z18.s, z27.s\n" + "ld1w { z30.s }, p6/Z, [x28, #-4, MUL VL]\n" + "ld1w { z29.s }, p6/Z, [x28, #-3, MUL VL]\n" + "zip2 z17.s, z18.s, z27.s\n" + ".inst 0x658ab9d5 // bfcvt z21.h, p6/M, z14.s\n" + "ld1w { z31.s }, p6/Z, [x27]\n" + "ld1w { z8.s }, p6/Z, [x27, #1, MUL VL]\n" + ".inst 0x658abacc // bfcvt z12.h, p6/M, z22.s\n" + ".inst 0x658abb0e // bfcvt z14.h, p6/M, z24.s\n" + "ld1w { z22.s }, p6/Z, [x27, #2, MUL VL]\n" + "ld1w { z28.s }, p6/Z, [x27, #3, MUL VL]\n" + ".inst 0x658ab9ea // bfcvt z10.h, p6/M, z15.s\n" + ".inst 0x658ab92f // bfcvt z15.h, p6/M, z9.s\n" + "ld1w { z27.s }, p6/Z, [x27, #4, MUL VL]\n" + "ld1w { z13.s }, p6/Z, [x27, #5, MUL VL]\n" + ".inst 0x658ab803 // bfcvt z3.h, p6/M, z0.s\n" + ".inst 0x658ab832 // bfcvt z18.h, p6/M, z1.s\n" + "ld1w { z26.s }, p6/Z, [x27, #6, MUL VL]\n" + "ld1w { z9.s }, p6/Z, [x27, #7, MUL VL]\n" + "addvl x27, x27, #12\n" + ".inst 0x658aba26 // bfcvt z6.h, p6/M, z17.s\n" + "ld1w { z1.s }, p6/Z, [x26, #-4, MUL VL]\n" + "ld1w { z0.s }, p6/Z, [x26, #-3, MUL VL]\n" + "zip1 z17.s, z5.s, z7.s\n" + "zip2 z5.s, z5.s, z7.s\n" + "ld1w { z24.s }, p6/Z, [x23]\n" + "ld1w { z11.s }, p6/Z, [x23, #1, MUL VL]\n" + "zip1 z7.s, z31.s, z24.s\n" + "zip2 z31.s, z31.s, z24.s\n" + "ld1w { z4.s }, p6/Z, [x23, #2, MUL VL]\n" + "ld1w { z24.s }, p6/Z, [x23, #3, MUL VL]\n" + ".inst 0x648ab8f5 // bfcvtnt z21.h, p6/M, z7.s\n" + "zip1 z7.s, z8.s, z11.s\n" + "zip2 z11.s, z8.s, z11.s\n" + "ld1w { z8.s }, p6/Z, [x23, #4, MUL VL]\n" + ".inst 0x648abbec // bfcvtnt z12.h, p6/M, z31.s\n" + "ld1w { z31.s }, p6/Z, [x23, #5, MUL VL]\n" + ".inst 0x648ab8ee // bfcvtnt z14.h, p6/M, z7.s\n" + "ld1w { z7.s }, p6/Z, [x23, #6, MUL VL]\n" + ".inst 0x648ab96a // bfcvtnt z10.h, p6/M, z11.s\n" + "zip1 z11.s, z22.s, z4.s\n" + "zip2 z4.s, z22.s, z4.s\n" + "ld1w { z22.s }, p6/Z, [x23, #7, MUL VL]\n" + "addvl x23, x23, #12\n" + ".inst 0x648ab96f // bfcvtnt z15.h, p6/M, z11.s\n" + "ld1w { z11.s }, p6/Z, [x28, #-2, MUL VL]\n" + ".inst 0x648ab883 // bfcvtnt z3.h, p6/M, z4.s\n" + "zip1 z4.s, z28.s, z24.s\n" + "zip2 z24.s, z28.s, z24.s\n" + "ld1w { z28.s }, p6/Z, [x28, #-1, MUL VL]\n" + ".inst 0x648ab892 // bfcvtnt z18.h, p6/M, z4.s\n" + "ld1w { z4.s }, p6/Z, [x27, #-4, MUL VL]\n" + ".inst 0x648abb06 // bfcvtnt z6.h, p6/M, z24.s\n" + "zip1 z24.s, z25.s, z2.s\n" + "zip2 z25.s, z25.s, z2.s\n" + "zip1 z2.s, z20.s, z19.s\n" + "zip2 z20.s, z20.s, z19.s\n" + "zip1 z19.s, z23.s, z16.s\n" + "zip2 z16.s, z23.s, z16.s\n" + "zip1 z23.s, z30.s, z1.s\n" + "zip2 z30.s, z30.s, z1.s\n" + "zip1 z1.s, z29.s, z0.s\n" + "zip2 z0.s, z29.s, z0.s\n" + ".inst 0x658aba31 // bfcvt z17.h, p6/M, z17.s\n" + "zip1 z29.s, z27.s, z8.s\n" + ".inst 0x658ab8a5 // bfcvt z5.h, p6/M, z5.s\n" + "zip2 z27.s, z27.s, z8.s\n" + "ld1w { z8.s }, p6/Z, [x27, #-3, MUL VL]\n" + ".inst 0x658abb18 // bfcvt z24.h, p6/M, z24.s\n" + ".inst 0x658abb39 // bfcvt z25.h, p6/M, z25.s\n" + ".inst 0x658ab842 // bfcvt z2.h, p6/M, z2.s\n" + ".inst 0x658aba94 // bfcvt z20.h, p6/M, z20.s\n" + ".inst 0x658aba73 // bfcvt z19.h, p6/M, z19.s\n" + ".inst 0x658aba10 // bfcvt z16.h, p6/M, z16.s\n" + ".inst 0x658abaf7 // bfcvt z23.h, p6/M, z23.s\n" + ".inst 0x658abbde // bfcvt z30.h, p6/M, z30.s\n" + ".inst 0x658ab821 // bfcvt z1.h, p6/M, z1.s\n" + ".inst 0x658ab800 // bfcvt z0.h, p6/M, z0.s\n" + ".inst 0x648abbb1 // bfcvtnt z17.h, p6/M, z29.s\n" + "ld1w { z29.s }, p6/Z, [x26, #-2, MUL VL]\n" + ".inst 0x648abb65 // bfcvtnt z5.h, p6/M, z27.s\n" + "zip1 z27.s, z13.s, z31.s\n" + "zip2 z31.s, z13.s, z31.s\n" + "ld1w { z13.s }, p6/Z, [x26, #-1, MUL VL]\n" + ".inst 0x648abb78 // bfcvtnt z24.h, p6/M, z27.s\n" + "ld1w { z27.s }, p6/Z, [x23, #-4, MUL VL]\n" + ".inst 0x648abbf9 // bfcvtnt z25.h, p6/M, z31.s\n" + "zip1 z31.s, z26.s, z7.s\n" + "zip2 z26.s, z26.s, z7.s\n" + "ld1w { z7.s }, p6/Z, [x23, #-3, MUL VL]\n" + ".inst 0x648abbe2 // bfcvtnt z2.h, p6/M, z31.s\n" + "ld1w { z31.s }, p6/Z, [x27, #-2, MUL VL]\n" + ".inst 0x648abb54 // bfcvtnt z20.h, p6/M, z26.s\n" + "zip1 z26.s, z9.s, z22.s\n" + "zip2 z9.s, z9.s, z22.s\n" + "ld1w { z22.s }, p6/Z, [x27, #-1, MUL VL]\n" + ".inst 0x648abb53 // bfcvtnt z19.h, p6/M, z26.s\n" + "ld1w { z26.s }, p6/Z, [x23, #-2, MUL VL]\n" + ".inst 0x648ab930 // bfcvtnt z16.h, p6/M, z9.s\n" + "ld1w { z9.s }, p6/Z, [x23, #-1, MUL VL]\n" + "st1h { z21.h }, p6, [x21]\n" + "zip1 z21.s, z4.s, z27.s\n" + "zip2 z27.s, z4.s, z27.s\n" + "zip1 z4.s, z8.s, z7.s\n" + "zip2 z8.s, z8.s, z7.s\n" + "st1h { z12.h }, p6, [x21, #1, MUL VL]\n" + "zip1 z7.s, z11.s, z29.s\n" + "zip2 z11.s, z11.s, z29.s\n" + "st1h { z14.h }, p6, [x21, #2, MUL VL]\n" + "zip1 z29.s, z28.s, z13.s\n" + "zip2 z12.s, z28.s, z13.s\n" + "st1h { z10.h }, p6, [x21, #3, MUL VL]\n" + "st1h { z15.h }, p6, [x21, #4, MUL VL]\n" + ".inst 0x648abab7 // bfcvtnt z23.h, p6/M, z21.s\n" + ".inst 0x648abb7e // bfcvtnt z30.h, p6/M, z27.s\n" + "st1h { z3.h }, p6, [x21, #5, MUL VL]\n" + ".inst 0x648ab881 // bfcvtnt z1.h, p6/M, z4.s\n" + ".inst 0x648ab900 // bfcvtnt z0.h, p6/M, z8.s\n" + "st1h { z18.h }, p6, [x21, #6, MUL VL]\n" + ".inst 0x658ab8e8 // bfcvt z8.h, p6/M, z7.s\n" + "zip1 z27.s, z31.s, z26.s\n" + "st1h { z6.h }, p6, [x21, #7, MUL VL]\n" + "addvl x21, x21, #12\n" + ".inst 0x658ab96e // bfcvt z14.h, p6/M, z11.s\n" + "zip2 z28.s, z31.s, z26.s\n" + ".inst 0x658abbbd // bfcvt z29.h, p6/M, z29.s\n" + "zip1 z21.s, z22.s, z9.s\n" + "st1h { z17.h }, p6, [x21, #-4, MUL VL]\n" + ".inst 0x658ab992 // bfcvt z18.h, p6/M, z12.s\n" + "zip2 z17.s, z22.s, z9.s\n" + "st1h { z5.h }, p6, [x21, #-3, MUL VL]\n" + "st1h { z24.h }, p6, [x21, #-2, MUL VL]\n" + ".inst 0x648abb68 // bfcvtnt z8.h, p6/M, z27.s\n" + ".inst 0x648abb8e // bfcvtnt z14.h, p6/M, z28.s\n" + "st1h { z25.h }, p6, [x21, #-1, MUL VL]\n" + ".inst 0x648ababd // bfcvtnt z29.h, p6/M, z21.s\n" + ".inst 0x648aba32 // bfcvtnt z18.h, p6/M, z17.s\n" + "st1h { z2.h }, p6, [x20]\n" + "st1h { z20.h }, p6, [x20, #1, MUL VL]\n" + "st1h { z19.h }, p6, [x20, #2, MUL VL]\n" + "st1h { z16.h }, p6, [x20, #3, MUL VL]\n" + "st1h { z23.h }, p6, [x20, #4, MUL VL]\n" + "st1h { z30.h }, p6, [x20, #5, MUL VL]\n" + "st1h { z1.h }, p6, [x20, #6, MUL VL]\n" + "st1h { z0.h }, p6, [x20, #7, MUL VL]\n" + "addvl x20, x20, #12\n" + "st1h { z8.h }, p6, [x20, #-4, MUL VL]\n" + "st1h { z14.h }, p6, [x20, #-3, MUL VL]\n" + "st1h { z29.h }, p6, [x20, #-2, MUL VL]\n" + "st1h { z18.h }, p6, [x20, #-1, MUL VL]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x25, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x25\n" + "whilelt p5.s, XZR, x20\n" + "ld1w { z22.s }, p5/Z, [x28]\n" + "ld1w { z21.s }, p5/Z, [x26]\n" + "decw x20\n" + "whilelt p4.s, XZR, x20\n" + "ld1w { z20.s }, p4/Z, [x28, #1, MUL VL]\n" + "ld1w { z19.s }, p4/Z, [x26, #1, MUL VL]\n" + "decw x20\n" + "whilelt p3.s, XZR, x20\n" + "ld1w { z18.s }, p3/Z, [x28, #2, MUL VL]\n" + "ld1w { z17.s }, p3/Z, [x26, #2, MUL VL]\n" + "decw x20\n" + "whilelt p2.s, XZR, x20\n" + "ld1w { z30.s }, p2/Z, [x28, #3, MUL VL]\n" + "ld1w { z16.s }, p2/Z, [x26, #3, MUL VL]\n" + "decw x20\n" + "whilelt p1.s, XZR, x20\n" + "ld1w { z13.s }, p1/Z, [x28, #4, MUL VL]\n" + "ld1w { z29.s }, p5/Z, [x27]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z12.s }, p0/Z, [x28, #5, MUL VL]\n" + "ld1w { z28.s }, p4/Z, [x27, #1, MUL VL]\n" + "ld1w { z11.s }, p3/Z, [x27, #2, MUL VL]\n" + "ld1w { z10.s }, p2/Z, [x27, #3, MUL VL]\n" + "zip1 z27.s, z22.s, z21.s\n" + "zip2 z26.s, z22.s, z21.s\n" + "ld1w { z9.s }, p1/Z, [x26, #4, MUL VL]\n" + "ld1w { z8.s }, p0/Z, [x26, #5, MUL VL]\n" + "zip1 z25.s, z20.s, z19.s\n" + "zip2 z24.s, z20.s, z19.s\n" + "ld1w { z23.s }, p5/Z, [x23]\n" + "ld1w { z22.s }, p4/Z, [x23, #1, MUL VL]\n" + "zip1 z21.s, z18.s, z17.s\n" + "zip2 z20.s, z18.s, z17.s\n" + "ld1w { z19.s }, p3/Z, [x23, #2, MUL VL]\n" + "ld1w { z18.s }, p2/Z, [x23, #3, MUL VL]\n" + "zip1 z17.s, z30.s, z16.s\n" + "zip2 z16.s, z30.s, z16.s\n" + "ld1w { z7.s }, p1/Z, [x27, #4, MUL VL]\n" + "ld1w { z6.s }, p0/Z, [x27, #5, MUL VL]\n" + ".inst 0x658abb65 // bfcvt z5.h, p6/M, z27.s\n" + "zip1 z4.s, z29.s, z23.s\n" + "ld1w { z3.s }, p1/Z, [x23, #4, MUL VL]\n" + "ld1w { z2.s }, p0/Z, [x23, #5, MUL VL]\n" + ".inst 0x658abb41 // bfcvt z1.h, p6/M, z26.s\n" + "zip2 z0.s, z29.s, z23.s\n" + ".inst 0x658abb3f // bfcvt z31.h, p6/M, z25.s\n" + "zip1 z30.s, z28.s, z22.s\n" + "mov x20, x22\n" + "decd x25, ALL, MUL #12\n" + ".inst 0x658abb1d // bfcvt z29.h, p6/M, z24.s\n" + "zip2 z28.s, z28.s, z22.s\n" + "cmp x25, #0x0\n" + "addvl x28, x28, #6\n" + ".inst 0x658ababb // bfcvt z27.h, p6/M, z21.s\n" + "zip1 z23.s, z11.s, z19.s\n" + "addvl x27, x27, #6\n" + "addvl x26, x26, #6\n" + ".inst 0x658aba9a // bfcvt z26.h, p6/M, z20.s\n" + "zip2 z22.s, z11.s, z19.s\n" + "addvl x23, x23, #6\n" + "add x22, x22, %x[out_stride]\n" + ".inst 0x658aba39 // bfcvt z25.h, p6/M, z17.s\n" + "zip1 z21.s, z10.s, z18.s\n" + ".inst 0x658aba18 // bfcvt z24.h, p6/M, z16.s\n" + "zip2 z20.s, z10.s, z18.s\n" + "zip1 z19.s, z13.s, z9.s\n" + "zip2 z18.s, z13.s, z9.s\n" + "zip1 z17.s, z12.s, z8.s\n" + "zip2 z16.s, z12.s, z8.s\n" + ".inst 0x648ab885 // bfcvtnt z5.h, p6/M, z4.s\n" + ".inst 0x648ab801 // bfcvtnt z1.h, p6/M, z0.s\n" + "st1h { z5.h }, p6, [x20]\n" + ".inst 0x648abbdf // bfcvtnt z31.h, p6/M, z30.s\n" + ".inst 0x648abb9d // bfcvtnt z29.h, p6/M, z28.s\n" + "st1h { z1.h }, p6, [x20, #1, MUL VL]\n" + ".inst 0x648abafb // bfcvtnt z27.h, p6/M, z23.s\n" + ".inst 0x648abada // bfcvtnt z26.h, p6/M, z22.s\n" + "st1h { z31.h }, p6, [x20, #2, MUL VL]\n" + ".inst 0x648abab9 // bfcvtnt z25.h, p6/M, z21.s\n" + ".inst 0x648aba98 // bfcvtnt z24.h, p6/M, z20.s\n" + "st1h { z29.h }, p6, [x20, #3, MUL VL]\n" + ".inst 0x658aba77 // bfcvt z23.h, p6/M, z19.s\n" + "zip1 z22.s, z7.s, z3.s\n" + "st1h { z27.h }, p6, [x20, #4, MUL VL]\n" + ".inst 0x658aba55 // bfcvt z21.h, p6/M, z18.s\n" + "zip2 z20.s, z7.s, z3.s\n" + "st1h { z26.h }, p6, [x20, #5, MUL VL]\n" + ".inst 0x658aba33 // bfcvt z19.h, p6/M, z17.s\n" + "zip1 z18.s, z6.s, z2.s\n" + "st1h { z25.h }, p6, [x20, #6, MUL VL]\n" + ".inst 0x658aba11 // bfcvt z17.h, p6/M, z16.s\n" + "zip2 z16.s, z6.s, z2.s\n" + "st1h { z24.h }, p6, [x20, #7, MUL VL]\n" + "addvl x20, x20, #12\n" + ".inst 0x648abad7 // bfcvtnt z23.h, p6/M, z22.s\n" + ".inst 0x648aba95 // bfcvtnt z21.h, p6/M, z20.s\n" + "st1h { z23.h }, p6, [x20, #-4, MUL VL]\n" + ".inst 0x648aba53 // bfcvtnt z19.h, p6/M, z18.s\n" + ".inst 0x648aba11 // bfcvtnt z17.h, p6/M, z16.s\n" + "st1h { z21.h }, p6, [x20, #-3, MUL VL]\n" + "st1h { z19.h }, p6, [x20, #-2, MUL VL]\n" + "st1h { z17.h }, p6, [x20, #-1, MUL VL]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #12\n" + "bge 1b\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace +template<> +void Transform<12, 4, true, VLType::SVE>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_12VL_2x4_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_1VL.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_1VL.hpp new file mode 100644 index 0000000000..b33c4f6c2d --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_1VL.hpp @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_1VL(uint32_t *out, const uint32_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 1 * height * get_vector_length<uint8_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "ptrue p1.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "mov x25, %x[width]\n" + "cntw x24, ALL, MUL #2\n" + "add x23, x26, %x[in_stride]\n" + "add x21, x23, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "cmp x25, x24\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "sub x25, x25, x24\n" + "ld1w { z23.s }, p1/Z, [x26]\n" + "ld1w { z22.s }, p1/Z, [x26, #1, MUL VL]\n" + "cmp x25, x24\n" + "ld1w { z21.s }, p1/Z, [x23]\n" + "ld1w { z20.s }, p1/Z, [x23, #1, MUL VL]\n" + "addvl x26, x26, #2\n" + "addvl x23, x23, #2\n" + "ld1w { z19.s }, p1/Z, [x21]\n" + "ld1w { z18.s }, p1/Z, [x21, #1, MUL VL]\n" + "addvl x21, x21, #2\n" + "ld1w { z17.s }, p1/Z, [x20]\n" + "ld1w { z16.s }, p1/Z, [x20, #1, MUL VL]\n" + "st1w { z23.s }, p1, [x22]\n" + "addvl x20, x20, #2\n" + "st1w { z21.s }, p1, [x22, #1, MUL VL]\n" + "st1w { z19.s }, p1, [x22, #2, MUL VL]\n" + "st1w { z17.s }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1w { z22.s }, p1, [x22]\n" + "st1w { z20.s }, p1, [x22, #1, MUL VL]\n" + "st1w { z18.s }, p1, [x22, #2, MUL VL]\n" + "st1w { z16.s }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x25, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.s, XZR, x25\n" + "decw x25\n" + "ld1w { z19.s }, p0/Z, [x26]\n" + "ld1w { z18.s }, p0/Z, [x23]\n" + "cmp x25, #0x0\n" + "addvl x26, x26, #1\n" + "ld1w { z17.s }, p0/Z, [x21]\n" + "ld1w { z16.s }, p0/Z, [x20]\n" + "addvl x23, x23, #1\n" + "addvl x21, x21, #1\n" + "st1w { z19.s }, p1, [x22]\n" + "addvl x20, x20, #1\n" + "st1w { z18.s }, p1, [x22, #1, MUL VL]\n" + "st1w { z17.s }, p1, [x22, #2, MUL VL]\n" + "st1w { z16.s }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #4\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x21, %x[width]\n" + "cntw x20, ALL, MUL #2\n" + "mov x26, %x[in]\n" + "cmp x21, x20\n" + "add %x[in], x26, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "sub x21, x21, x20\n" + "ld1w { z17.s }, p1/Z, [x26]\n" + "ld1w { z16.s }, p1/Z, [x26, #1, MUL VL]\n" + "st1w { z17.s }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "cmp x21, x20\n" + "st1w { z16.s }, p1, [x22]\n" + "addvl x26, x26, #2\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "whilelt p0.s, XZR, x21\n" + "decw x21\n" + "ld1w { z16.s }, p0/Z, [x26]\n" + "st1w { z16.s }, p1, [x22]\n" + "cmp x21, #0x0\n" + "addvl x26, x26, #1\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #1\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "p0", "p1", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23" + ); +} + +} // anonymous namespace + +template<> +void Transform<1, 1, true, VLType::SVE>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_1VL( + reinterpret_cast<uint32_t *>(out), + reinterpret_cast<const uint32_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 4, + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_1VL_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_1VL_1x4.hpp new file mode 100644 index 0000000000..e468787815 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_1VL_1x4.hpp @@ -0,0 +1,308 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_1VL_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 1 * roundup<size_t>(height, 4) * get_vector_length<uint32_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "ptrue p1.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x10, %x[in]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "mov x25, %x[width]\n" + "cntb x24, ALL, MUL #2\n" + "add x23, x26, %x[in_stride]\n" + "add x21, x23, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "cmp x25, x24\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1b { z20.b }, p1/Z, [x10]\n" + "ld1b { z18.b }, p1/Z, [x9]\n" + "sub x25, x25, x24\n" + "cmp x25, x24\n" + "ld1b { z17.b }, p1/Z, [x28]\n" + "ld1b { z16.b }, p1/Z, [x27]\n" + "zip1 z25.b, z20.b, z17.b\n" + "zip1 z24.b, z18.b, z16.b\n" + "ld1b { z21.b }, p1/Z, [x26]\n" + "ld1b { z19.b }, p1/Z, [x23]\n" + "zip2 z2.b, z20.b, z17.b\n" + "zip2 z1.b, z18.b, z16.b\n" + "ld1b { z18.b }, p1/Z, [x21]\n" + "ld1b { z17.b }, p1/Z, [x20]\n" + "zip1 z20.b, z21.b, z18.b\n" + "zip1 z16.b, z19.b, z17.b\n" + "ld1b { z0.b }, p1/Z, [x10, #1, MUL VL]\n" + "ld1b { z31.b }, p1/Z, [x9, #1, MUL VL]\n" + "zip2 z30.b, z21.b, z18.b\n" + "zip2 z29.b, z19.b, z17.b\n" + "ld1b { z23.b }, p1/Z, [x28, #1, MUL VL]\n" + "ld1b { z22.b }, p1/Z, [x27, #1, MUL VL]\n" + "zip1 z19.b, z25.b, z24.b\n" + "zip1 z18.b, z20.b, z16.b\n" + "ld1b { z28.b }, p1/Z, [x26, #1, MUL VL]\n" + "ld1b { z27.b }, p1/Z, [x23, #1, MUL VL]\n" + "zip2 z17.b, z25.b, z24.b\n" + "zip2 z16.b, z20.b, z16.b\n" + "ld1b { z21.b }, p1/Z, [x21, #1, MUL VL]\n" + "ld1b { z20.b }, p1/Z, [x20, #1, MUL VL]\n" + "st1b { z19.b }, p1, [x22]\n" + "zip1 z26.b, z0.b, z23.b\n" + "st1b { z18.b }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z25.b, z31.b, z22.b\n" + "zip1 z24.b, z28.b, z21.b\n" + "st1b { z17.b }, p1, [x22]\n" + "zip1 z19.b, z27.b, z20.b\n" + "zip1 z17.b, z2.b, z1.b\n" + "addvl x10, x10, #2\n" + "st1b { z16.b }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z18.b, z30.b, z29.b\n" + "zip2 z16.b, z2.b, z1.b\n" + "st1b { z17.b }, p1, [x22]\n" + "zip2 z17.b, z30.b, z29.b\n" + "zip2 z23.b, z0.b, z23.b\n" + "addvl x9, x9, #2\n" + "st1b { z18.b }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z22.b, z31.b, z22.b\n" + "zip2 z21.b, z28.b, z21.b\n" + "st1b { z16.b }, p1, [x22]\n" + "zip2 z20.b, z27.b, z20.b\n" + "zip1 z16.b, z26.b, z25.b\n" + "addvl x28, x28, #2\n" + "st1b { z17.b }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z18.b, z24.b, z19.b\n" + "zip2 z17.b, z26.b, z25.b\n" + "st1b { z16.b }, p1, [x22]\n" + "zip2 z16.b, z24.b, z19.b\n" + "zip1 z19.b, z23.b, z22.b\n" + "addvl x27, x27, #2\n" + "st1b { z18.b }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z18.b, z21.b, z20.b\n" + "addvl x26, x26, #2\n" + "st1b { z17.b }, p1, [x22]\n" + "addvl x23, x23, #2\n" + "addvl x21, x21, #2\n" + "zip2 z17.b, z23.b, z22.b\n" + "st1b { z16.b }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "addvl x20, x20, #2\n" + "zip2 z16.b, z21.b, z20.b\n" + "st1b { z19.b }, p1, [x22]\n" + "st1b { z18.b }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1b { z17.b }, p1, [x22]\n" + "st1b { z16.b }, p1, [x22, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x25, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.b, XZR, x25\n" + "ld1b { z19.b }, p0/Z, [x10]\n" + "ld1b { z18.b }, p0/Z, [x9]\n" + "decw x25\n" + "ld1b { z17.b }, p0/Z, [x28]\n" + "ld1b { z16.b }, p0/Z, [x27]\n" + "zip1 z21.b, z19.b, z17.b\n" + "zip1 z20.b, z18.b, z16.b\n" + "ld1b { z18.b }, p0/Z, [x26]\n" + "ld1b { z19.b }, p0/Z, [x23]\n" + "cmp x25, #0x0\n" + "incd x10, ALL, MUL #2\n" + "ld1b { z17.b }, p0/Z, [x21]\n" + "ld1b { z16.b }, p0/Z, [x20]\n" + "zip1 z18.b, z18.b, z17.b\n" + "zip1 z16.b, z19.b, z16.b\n" + "incd x9, ALL, MUL #2\n" + "incd x28, ALL, MUL #2\n" + "zip1 z17.b, z21.b, z20.b\n" + "zip1 z16.b, z18.b, z16.b\n" + "incd x27, ALL, MUL #2\n" + "incd x26, ALL, MUL #2\n" + "st1b { z17.b }, p1, [x22]\n" + "incd x23, ALL, MUL #2\n" + "incd x21, ALL, MUL #2\n" + "st1b { z16.b }, p1, [x22, #1, MUL VL]\n" + "incd x20, ALL, MUL #2\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x8\n" + "addvl %x[out], %x[out], #2\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x10, %x[in]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "mov x21, %x[width]\n" + "cntb x20, ALL, MUL #2\n" + "add x27, x28, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x27, %x[in_stride]\n" + "csel x27, x27, %x[pad_row], GT\n" + "csel x28, x28, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x9, x9, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1b { z21.b }, p1/Z, [x10]\n" + "ld1b { z18.b }, p1/Z, [x9]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1b { z17.b }, p1/Z, [x28]\n" + "ld1b { z16.b }, p1/Z, [x27]\n" + "zip1 z20.b, z21.b, z17.b\n" + "zip1 z19.b, z18.b, z16.b\n" + "ld1b { z24.b }, p1/Z, [x10, #1, MUL VL]\n" + "ld1b { z23.b }, p1/Z, [x9, #1, MUL VL]\n" + "zip2 z22.b, z21.b, z17.b\n" + "zip2 z21.b, z18.b, z16.b\n" + "ld1b { z18.b }, p1/Z, [x28, #1, MUL VL]\n" + "ld1b { z17.b }, p1/Z, [x27, #1, MUL VL]\n" + "zip1 z16.b, z20.b, z19.b\n" + "st1b { z16.b }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z16.b, z20.b, z19.b\n" + "st1b { z16.b }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z20.b, z24.b, z18.b\n" + "zip1 z19.b, z23.b, z17.b\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "zip1 z16.b, z22.b, z21.b\n" + "st1b { z16.b }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z16.b, z22.b, z21.b\n" + "st1b { z16.b }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z18.b, z24.b, z18.b\n" + "zip2 z17.b, z23.b, z17.b\n" + "zip1 z16.b, z20.b, z19.b\n" + "st1b { z16.b }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z16.b, z20.b, z19.b\n" + "st1b { z16.b }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z16.b, z18.b, z17.b\n" + "addvl x28, x28, #2\n" + "st1b { z16.b }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "addvl x27, x27, #2\n" + "zip2 z16.b, z18.b, z17.b\n" + "st1b { z16.b }, p1, [x22]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "whilelt p0.b, XZR, x21\n" + "ld1b { z19.b }, p0/Z, [x10]\n" + "ld1b { z18.b }, p0/Z, [x9]\n" + "decw x21\n" + "ld1b { z17.b }, p0/Z, [x28]\n" + "ld1b { z16.b }, p0/Z, [x27]\n" + "zip1 z17.b, z19.b, z17.b\n" + "zip1 z16.b, z18.b, z16.b\n" + "cmp x21, #0x0\n" + "incd x10, ALL, MUL #2\n" + "zip1 z16.b, z17.b, z16.b\n" + "st1b { z16.b }, p1, [x22]\n" + "incd x9, ALL, MUL #2\n" + "incd x28, ALL, MUL #2\n" + "incd x27, ALL, MUL #2\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #1\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<1, 4, true, VLType::SVE>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_1VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<1, 4, true, VLType::SVE>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_1VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_2VL_2x4_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_2VL_2x4_fp32bf16.hpp new file mode 100644 index 0000000000..f66fcdc994 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_2VL_2x4_fp32bf16.hpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_2VL_2x4_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 2 * roundup<size_t>(height, 4) * get_vector_length<uint32_t>(); + + __asm__ __volatile__( + "ptrue p1.b\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "mov x25, %x[width]\n" + "cnth x24\n" + "cmp %x[height], #0x3\n" + "mov x23, %x[out]\n" + "add x22, x26, %x[in_stride]\n" + "add x21, x22, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "add %x[in], x20, %x[in_stride]\n" + "csel x20, x20, %x[pad_row], GT\n" + "csel x21, x21, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x22, x22, %x[pad_row], GT\n" + "cmp x25, x24\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1w { z18.s }, p1/Z, [x26]\n" + "ld1w { z17.s }, p1/Z, [x21]\n" + "sub x25, x25, x24\n" + "ld1w { z21.s }, p1/Z, [x26, #1, MUL VL]\n" + "ld1w { z16.s }, p1/Z, [x21, #1, MUL VL]\n" + "cmp x25, x24\n" + "addvl x26, x26, #2\n" + "ld1w { z26.s }, p1/Z, [x22]\n" + "ld1w { z20.s }, p1/Z, [x20]\n" + "addvl x21, x21, #2\n" + "zip1 z19.s, z18.s, z17.s\n" + "zip2 z18.s, z18.s, z17.s\n" + "ld1w { z25.s }, p1/Z, [x22, #1, MUL VL]\n" + "ld1w { z24.s }, p1/Z, [x20, #1, MUL VL]\n" + "addvl x22, x22, #2\n" + "zip1 z17.s, z21.s, z16.s\n" + "zip2 z16.s, z21.s, z16.s\n" + "addvl x20, x20, #2\n" + ".inst 0x658aa677 // bfcvt z23.h, p1/M, z19.s\n" + "zip1 z22.s, z26.s, z20.s\n" + ".inst 0x658aa655 // bfcvt z21.h, p1/M, z18.s\n" + "zip2 z20.s, z26.s, z20.s\n" + ".inst 0x658aa633 // bfcvt z19.h, p1/M, z17.s\n" + "zip1 z18.s, z25.s, z24.s\n" + ".inst 0x658aa611 // bfcvt z17.h, p1/M, z16.s\n" + "zip2 z16.s, z25.s, z24.s\n" + ".inst 0x648aa6d7 // bfcvtnt z23.h, p1/M, z22.s\n" + ".inst 0x648aa695 // bfcvtnt z21.h, p1/M, z20.s\n" + ".inst 0x648aa653 // bfcvtnt z19.h, p1/M, z18.s\n" + ".inst 0x648aa611 // bfcvtnt z17.h, p1/M, z16.s\n" + "st1h { z23.h }, p1, [x23]\n" + "st1h { z21.h }, p1, [x23, #1, MUL VL]\n" + "add x23, x23, %x[out_stride]\n" + "st1h { z19.h }, p1, [x23]\n" + "st1h { z17.h }, p1, [x23, #1, MUL VL]\n" + "add x23, x23, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x25, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.s, XZR, x25\n" + "decd x25, ALL, MUL #2\n" + "ld1w { z19.s }, p0/Z, [x26]\n" + "addvl x26, x26, #1\n" + "ld1w { z16.s }, p0/Z, [x21]\n" + "addvl x21, x21, #1\n" + "ld1w { z20.s }, p0/Z, [x22]\n" + "addvl x22, x22, #1\n" + "ld1w { z18.s }, p0/Z, [x20]\n" + "addvl x20, x20, #1\n" + "cmp x25, #0x0\n" + "zip1 z17.s, z19.s, z16.s\n" + "zip2 z16.s, z19.s, z16.s\n" + "zip1 z19.s, z20.s, z18.s\n" + "zip2 z18.s, z20.s, z18.s\n" + ".inst 0x658aa631 // bfcvt z17.h, p1/M, z17.s\n" + ".inst 0x658aa610 // bfcvt z16.h, p1/M, z16.s\n" + ".inst 0x648aa671 // bfcvtnt z17.h, p1/M, z19.s\n" + ".inst 0x648aa650 // bfcvtnt z16.h, p1/M, z18.s\n" + "st1h { z17.h }, p1, [x23]\n" + "st1h { z16.h }, p1, [x23, #1, MUL VL]\n" + "add x23, x23, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #2\n" + "bge 1b\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26" + ); +} + +} // anonymous namespace +template<> +void Transform<2, 4, true, VLType::SVE>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_2VL_2x4_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_3VL.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_3VL.hpp new file mode 100644 index 0000000000..546800fa69 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_3VL.hpp @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_3VL(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 3 * height * get_vector_length<uint8_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "ptrue p3.b\n" + "blt 4f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "mov x21, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x20, x21\n" + "whilelt p2.h, XZR, x20\n" + "ld1h { z27.h }, p2/Z, [x26]\n" + "ld1h { z26.h }, p2/Z, [x25]\n" + "dech x20\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z25.h }, p1/Z, [x26, #1, MUL VL]\n" + "ld1h { z24.h }, p1/Z, [x25, #1, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z23.h }, p0/Z, [x26, #2, MUL VL]\n" + "ld1h { z22.h }, p0/Z, [x25, #2, MUL VL]\n" + "mov x20, x22\n" + "dech x21, ALL, MUL #3\n" + "ld1h { z21.h }, p2/Z, [x24]\n" + "ld1h { z20.h }, p1/Z, [x24, #1, MUL VL]\n" + "ld1h { z19.h }, p0/Z, [x24, #2, MUL VL]\n" + "ld1h { z18.h }, p2/Z, [x23]\n" + "cmp x21, #0x0\n" + "addvl x26, x26, #3\n" + "ld1h { z17.h }, p1/Z, [x23, #1, MUL VL]\n" + "ld1h { z16.h }, p0/Z, [x23, #2, MUL VL]\n" + "st1h { z27.h }, p3, [x20]\n" + "addvl x25, x25, #3\n" + "st1h { z25.h }, p3, [x20, #1, MUL VL]\n" + "addvl x24, x24, #3\n" + "addvl x23, x23, #3\n" + "st1h { z23.h }, p3, [x20, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z26.h }, p3, [x20, #3, MUL VL]\n" + "st1h { z24.h }, p3, [x20, #4, MUL VL]\n" + "st1h { z22.h }, p3, [x20, #5, MUL VL]\n" + "st1h { z21.h }, p3, [x20, #6, MUL VL]\n" + "st1h { z20.h }, p3, [x20, #7, MUL VL]\n" + "addvl x20, x20, #12\n" + "st1h { z19.h }, p3, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p3, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p3, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p3, [x20, #-1, MUL VL]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #12\n" + "bge 1b\n" + "cbz %x[height], 8f\n" + "4:" // Main loop skip + "5:" // Tail row loop: Head + "mov x26, %x[in]\n" + "add %x[in], x26, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "mov x21, %x[width]\n" + "6:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z18.h }, p0/Z, [x26]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z17.h }, p0/Z, [x26, #1, MUL VL]\n" + "dech x20\n" + "dech x21, ALL, MUL #3\n" + "whilelt p0.h, XZR, x20\n" + "cmp x21, #0x0\n" + "ld1h { z16.h }, p0/Z, [x26, #2, MUL VL]\n" + "st1h { z18.h }, p3, [x22]\n" + "addvl x26, x26, #3\n" + "st1h { z17.h }, p3, [x22, #1, MUL VL]\n" + "st1h { z16.h }, p3, [x22, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 6b\n" + "7:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #3\n" + "bge 5b\n" + "8:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27" + ); +} + +} // anonymous namespace + +template<> +void Transform<3, 1, true, VLType::SVE>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_3VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 2, + stride * sizeof(float), + (kmax-k0) + ); +} + +template<> +void Transform<3, 1, true, VLType::SVE>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_3VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + +template<> +void Transform<3, 1, true, VLType::SVE>( + double *out, const double *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_3VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(double) / 2, + stride * sizeof(double), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_3VL_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_3VL_1x4.hpp new file mode 100644 index 0000000000..a44141c109 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_3VL_1x4.hpp @@ -0,0 +1,366 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_3VL_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 3 * roundup<size_t>(height, 4) * get_vector_length<uint32_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "ptrue p1.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x10, %x[in]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "mov x25, %x[width]\n" + "cntb x24, ALL, MUL #3\n" + "add x23, x26, %x[in_stride]\n" + "add x21, x23, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "cmp x25, x24\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1b { z21.b }, p1/Z, [x10]\n" + "ld1b { z20.b }, p1/Z, [x9]\n" + "sub x25, x25, x24\n" + "cmp x25, x24\n" + "ld1b { z17.b }, p1/Z, [x28]\n" + "ld1b { z16.b }, p1/Z, [x27]\n" + "zip1 z31.b, z21.b, z17.b\n" + "zip1 z22.b, z20.b, z16.b\n" + "ld1b { z19.b }, p1/Z, [x26]\n" + "ld1b { z18.b }, p1/Z, [x23]\n" + "zip2 z14.b, z21.b, z17.b\n" + "zip2 z13.b, z20.b, z16.b\n" + "ld1b { z17.b }, p1/Z, [x21]\n" + "ld1b { z16.b }, p1/Z, [x20]\n" + "zip1 z30.b, z19.b, z17.b\n" + "zip1 z29.b, z18.b, z16.b\n" + "ld1b { z21.b }, p1/Z, [x10, #1, MUL VL]\n" + "ld1b { z20.b }, p1/Z, [x9, #1, MUL VL]\n" + "zip2 z12.b, z19.b, z17.b\n" + "zip2 z11.b, z18.b, z16.b\n" + "ld1b { z17.b }, p1/Z, [x28, #1, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x27, #1, MUL VL]\n" + "zip1 z10.b, z21.b, z17.b\n" + "zip1 z9.b, z20.b, z16.b\n" + "ld1b { z19.b }, p1/Z, [x26, #1, MUL VL]\n" + "ld1b { z18.b }, p1/Z, [x23, #1, MUL VL]\n" + "zip2 z8.b, z21.b, z17.b\n" + "zip2 z7.b, z20.b, z16.b\n" + "ld1b { z17.b }, p1/Z, [x21, #1, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x20, #1, MUL VL]\n" + "zip1 z6.b, z19.b, z17.b\n" + "zip1 z5.b, z18.b, z16.b\n" + "ld1b { z28.b }, p1/Z, [x10, #2, MUL VL]\n" + "ld1b { z27.b }, p1/Z, [x9, #2, MUL VL]\n" + "zip2 z4.b, z19.b, z17.b\n" + "zip2 z3.b, z18.b, z16.b\n" + "ld1b { z26.b }, p1/Z, [x28, #2, MUL VL]\n" + "ld1b { z25.b }, p1/Z, [x27, #2, MUL VL]\n" + "zip1 z2.b, z28.b, z26.b\n" + "zip1 z1.b, z27.b, z25.b\n" + "ld1b { z24.b }, p1/Z, [x26, #2, MUL VL]\n" + "ld1b { z23.b }, p1/Z, [x23, #2, MUL VL]\n" + "zip1 z16.b, z31.b, z22.b\n" + "zip2 z22.b, z31.b, z22.b\n" + "ld1b { z21.b }, p1/Z, [x21, #2, MUL VL]\n" + "ld1b { z20.b }, p1/Z, [x20, #2, MUL VL]\n" + "zip1 z0.b, z24.b, z21.b\n" + "zip1 z31.b, z23.b, z20.b\n" + "zip1 z19.b, z14.b, z13.b\n" + "zip1 z18.b, z30.b, z29.b\n" + "st1b { z16.b }, p1, [x22]\n" + "addvl x10, x10, #3\n" + "zip2 z16.b, z30.b, z29.b\n" + "zip1 z17.b, z12.b, z11.b\n" + "st1b { z22.b }, p1, [x22, #1, MUL VL]\n" + "addvl x9, x9, #3\n" + "st1b { z19.b }, p1, [x22, #2, MUL VL]\n" + "zip2 z30.b, z28.b, z26.b\n" + "zip2 z29.b, z27.b, z25.b\n" + "addvl x28, x28, #3\n" + "st1b { z18.b }, p1, [x22, #3, MUL VL]\n" + "zip2 z28.b, z24.b, z21.b\n" + "zip2 z27.b, z23.b, z20.b\n" + "addvl x27, x27, #3\n" + "st1b { z16.b }, p1, [x22, #4, MUL VL]\n" + "zip2 z21.b, z14.b, z13.b\n" + "zip1 z16.b, z10.b, z9.b\n" + "addvl x26, x26, #3\n" + "st1b { z17.b }, p1, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z20.b, z10.b, z9.b\n" + "zip2 z19.b, z12.b, z11.b\n" + "zip1 z18.b, z6.b, z5.b\n" + "zip2 z17.b, z6.b, z5.b\n" + "st1b { z21.b }, p1, [x22]\n" + "addvl x23, x23, #3\n" + "st1b { z16.b }, p1, [x22, #1, MUL VL]\n" + "zip1 z16.b, z8.b, z7.b\n" + "zip2 z26.b, z8.b, z7.b\n" + "addvl x21, x21, #3\n" + "st1b { z20.b }, p1, [x22, #2, MUL VL]\n" + "zip1 z25.b, z2.b, z1.b\n" + "zip1 z24.b, z4.b, z3.b\n" + "addvl x20, x20, #3\n" + "st1b { z19.b }, p1, [x22, #3, MUL VL]\n" + "zip2 z23.b, z4.b, z3.b\n" + "zip1 z22.b, z0.b, z31.b\n" + "st1b { z18.b }, p1, [x22, #4, MUL VL]\n" + "zip2 z21.b, z2.b, z1.b\n" + "zip1 z20.b, z30.b, z29.b\n" + "st1b { z17.b }, p1, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z19.b, z30.b, z29.b\n" + "zip2 z18.b, z0.b, z31.b\n" + "st1b { z16.b }, p1, [x22]\n" + "zip1 z17.b, z28.b, z27.b\n" + "zip2 z16.b, z28.b, z27.b\n" + "st1b { z26.b }, p1, [x22, #1, MUL VL]\n" + "st1b { z25.b }, p1, [x22, #2, MUL VL]\n" + "st1b { z24.b }, p1, [x22, #3, MUL VL]\n" + "st1b { z23.b }, p1, [x22, #4, MUL VL]\n" + "st1b { z22.b }, p1, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1b { z21.b }, p1, [x22]\n" + "st1b { z20.b }, p1, [x22, #1, MUL VL]\n" + "st1b { z19.b }, p1, [x22, #2, MUL VL]\n" + "st1b { z18.b }, p1, [x22, #3, MUL VL]\n" + "st1b { z17.b }, p1, [x22, #4, MUL VL]\n" + "st1b { z16.b }, p1, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x25, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.b, XZR, x25\n" + "ld1b { z19.b }, p0/Z, [x10]\n" + "ld1b { z18.b }, p0/Z, [x9]\n" + "decw x25, ALL, MUL #3\n" + "ld1b { z17.b }, p0/Z, [x28]\n" + "ld1b { z16.b }, p0/Z, [x27]\n" + "zip1 z26.b, z19.b, z17.b\n" + "zip1 z25.b, z18.b, z16.b\n" + "ld1b { z21.b }, p0/Z, [x26]\n" + "ld1b { z20.b }, p0/Z, [x23]\n" + "zip2 z24.b, z19.b, z17.b\n" + "zip2 z19.b, z18.b, z16.b\n" + "ld1b { z18.b }, p0/Z, [x21]\n" + "ld1b { z16.b }, p0/Z, [x20]\n" + "zip1 z23.b, z21.b, z18.b\n" + "zip1 z17.b, z20.b, z16.b\n" + "zip2 z22.b, z21.b, z18.b\n" + "zip2 z16.b, z20.b, z16.b\n" + "cmp x25, #0x0\n" + "incd x10, ALL, MUL #6\n" + "incd x9, ALL, MUL #6\n" + "incd x28, ALL, MUL #6\n" + "zip1 z21.b, z26.b, z25.b\n" + "zip2 z20.b, z26.b, z25.b\n" + "incd x27, ALL, MUL #6\n" + "incd x26, ALL, MUL #6\n" + "zip1 z19.b, z24.b, z19.b\n" + "zip1 z18.b, z23.b, z17.b\n" + "incd x23, ALL, MUL #6\n" + "incd x21, ALL, MUL #6\n" + "zip2 z17.b, z23.b, z17.b\n" + "zip1 z16.b, z22.b, z16.b\n" + "incd x20, ALL, MUL #6\n" + "st1b { z21.b }, p1, [x22]\n" + "st1b { z20.b }, p1, [x22, #1, MUL VL]\n" + "st1b { z19.b }, p1, [x22, #2, MUL VL]\n" + "st1b { z18.b }, p1, [x22, #3, MUL VL]\n" + "st1b { z17.b }, p1, [x22, #4, MUL VL]\n" + "st1b { z16.b }, p1, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x8\n" + "addvl %x[out], %x[out], #6\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x10, %x[in]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "mov x21, %x[width]\n" + "cntb x20, ALL, MUL #3\n" + "add x27, x28, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x27, %x[in_stride]\n" + "csel x27, x27, %x[pad_row], GT\n" + "csel x28, x28, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x9, x9, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1b { z21.b }, p1/Z, [x10]\n" + "ld1b { z20.b }, p1/Z, [x9]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1b { z17.b }, p1/Z, [x28]\n" + "ld1b { z16.b }, p1/Z, [x27]\n" + "zip1 z31.b, z21.b, z17.b\n" + "zip1 z30.b, z20.b, z16.b\n" + "ld1b { z19.b }, p1/Z, [x10, #1, MUL VL]\n" + "ld1b { z18.b }, p1/Z, [x9, #1, MUL VL]\n" + "zip2 z29.b, z21.b, z17.b\n" + "zip2 z28.b, z20.b, z16.b\n" + "ld1b { z17.b }, p1/Z, [x28, #1, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x27, #1, MUL VL]\n" + "zip1 z27.b, z19.b, z17.b\n" + "zip1 z26.b, z18.b, z16.b\n" + "ld1b { z22.b }, p1/Z, [x10, #2, MUL VL]\n" + "ld1b { z21.b }, p1/Z, [x9, #2, MUL VL]\n" + "zip2 z25.b, z19.b, z17.b\n" + "zip2 z20.b, z18.b, z16.b\n" + "ld1b { z19.b }, p1/Z, [x28, #2, MUL VL]\n" + "ld1b { z18.b }, p1/Z, [x27, #2, MUL VL]\n" + "zip1 z24.b, z22.b, z19.b\n" + "zip1 z23.b, z21.b, z18.b\n" + "zip1 z16.b, z31.b, z30.b\n" + "zip2 z17.b, z31.b, z30.b\n" + "st1b { z16.b }, p1, [x22]\n" + "addvl x10, x10, #3\n" + "zip1 z16.b, z29.b, z28.b\n" + "st1b { z17.b }, p1, [x22, #1, MUL VL]\n" + "zip2 z22.b, z22.b, z19.b\n" + "addvl x9, x9, #3\n" + "st1b { z16.b }, p1, [x22, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z21.b, z21.b, z18.b\n" + "zip2 z18.b, z29.b, z28.b\n" + "zip1 z16.b, z27.b, z26.b\n" + "zip2 z17.b, z27.b, z26.b\n" + "st1b { z18.b }, p1, [x22]\n" + "addvl x28, x28, #3\n" + "st1b { z16.b }, p1, [x22, #1, MUL VL]\n" + "zip1 z16.b, z25.b, z20.b\n" + "zip2 z20.b, z25.b, z20.b\n" + "addvl x27, x27, #3\n" + "st1b { z17.b }, p1, [x22, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z19.b, z24.b, z23.b\n" + "zip2 z18.b, z24.b, z23.b\n" + "st1b { z16.b }, p1, [x22]\n" + "zip1 z17.b, z22.b, z21.b\n" + "zip2 z16.b, z22.b, z21.b\n" + "st1b { z20.b }, p1, [x22, #1, MUL VL]\n" + "st1b { z19.b }, p1, [x22, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1b { z18.b }, p1, [x22]\n" + "st1b { z17.b }, p1, [x22, #1, MUL VL]\n" + "st1b { z16.b }, p1, [x22, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "whilelt p0.b, XZR, x21\n" + "ld1b { z19.b }, p0/Z, [x10]\n" + "ld1b { z21.b }, p0/Z, [x9]\n" + "decw x21, ALL, MUL #3\n" + "ld1b { z18.b }, p0/Z, [x28]\n" + "ld1b { z16.b }, p0/Z, [x27]\n" + "zip1 z20.b, z19.b, z18.b\n" + "zip1 z17.b, z21.b, z16.b\n" + "zip2 z19.b, z19.b, z18.b\n" + "zip2 z16.b, z21.b, z16.b\n" + "cmp x21, #0x0\n" + "incd x10, ALL, MUL #6\n" + "incd x9, ALL, MUL #6\n" + "incd x28, ALL, MUL #6\n" + "zip1 z18.b, z20.b, z17.b\n" + "zip2 z17.b, z20.b, z17.b\n" + "incd x27, ALL, MUL #6\n" + "zip1 z16.b, z19.b, z16.b\n" + "st1b { z18.b }, p1, [x22]\n" + "st1b { z17.b }, p1, [x22, #1, MUL VL]\n" + "st1b { z16.b }, p1, [x22, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #3\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<3, 4, true, VLType::SVE>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_3VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<3, 4, true, VLType::SVE>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_3VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_3VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_3VL_2x2.hpp new file mode 100644 index 0000000000..36a15a16b3 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_3VL_2x2.hpp @@ -0,0 +1,316 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_3VL_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 3 * roundup<size_t>(height, 2) * get_vector_length<uint16_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "ptrue p2.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x12, %x[in]\n" + "add x11, x12, %x[in_stride]\n" + "add x10, x11, %x[in_stride]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "mov x27, %x[width]\n" + "cnth x26, ALL, MUL #3\n" + "add x25, x28, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp x27, x26\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1h { z17.h }, p2/Z, [x12]\n" + "ld1h { z23.h }, p2/Z, [x12, #1, MUL VL]\n" + "mov x21, x22\n" + "add x22, x22, %x[out_stride]\n" + "ld1h { z16.h }, p2/Z, [x11]\n" + "ld1h { z20.h }, p2/Z, [x11, #1, MUL VL]\n" + "zip1 z9.h, z17.h, z16.h\n" + "zip2 z8.h, z17.h, z16.h\n" + "ld1h { z17.h }, p2/Z, [x10]\n" + "ld1h { z22.h }, p2/Z, [x10, #1, MUL VL]\n" + "zip1 z7.h, z23.h, z20.h\n" + "mov x20, x22\n" + "ld1h { z16.h }, p2/Z, [x9]\n" + "ld1h { z21.h }, p2/Z, [x9, #1, MUL VL]\n" + "zip1 z6.h, z17.h, z16.h\n" + "zip2 z5.h, z17.h, z16.h\n" + "ld1h { z18.h }, p2/Z, [x28]\n" + "ld1h { z17.h }, p2/Z, [x25]\n" + "zip1 z4.h, z22.h, z21.h\n" + "zip1 z3.h, z18.h, z17.h\n" + "ld1h { z19.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x11, #2, MUL VL]\n" + "zip2 z2.h, z18.h, z17.h\n" + "zip2 z1.h, z23.h, z20.h\n" + "ld1h { z18.h }, p2/Z, [x10, #2, MUL VL]\n" + "ld1h { z17.h }, p2/Z, [x9, #2, MUL VL]\n" + "zip1 z0.h, z19.h, z16.h\n" + "zip2 z31.h, z19.h, z16.h\n" + "ld1h { z20.h }, p2/Z, [x28, #1, MUL VL]\n" + "ld1h { z30.h }, p2/Z, [x28, #2, MUL VL]\n" + "zip2 z29.h, z22.h, z21.h\n" + "zip1 z28.h, z18.h, z17.h\n" + "ld1h { z16.h }, p2/Z, [x25, #1, MUL VL]\n" + "ld1h { z19.h }, p2/Z, [x25, #2, MUL VL]\n" + "zip1 z27.h, z20.h, z16.h\n" + "zip2 z26.h, z18.h, z17.h\n" + "ld1h { z17.h }, p2/Z, [x24]\n" + "ld1h { z18.h }, p2/Z, [x24, #1, MUL VL]\n" + "zip2 z25.h, z20.h, z16.h\n" + "zip1 z24.h, z30.h, z19.h\n" + "ld1h { z23.h }, p2/Z, [x24, #2, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x23]\n" + "zip1 z22.h, z17.h, z16.h\n" + "zip2 z21.h, z17.h, z16.h\n" + "ld1h { z17.h }, p2/Z, [x23, #1, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x23, #2, MUL VL]\n" + "st1h { z9.h }, p2, [x21]\n" + "zip1 z20.h, z18.h, z17.h\n" + "st1h { z8.h }, p2, [x21, #1, MUL VL]\n" + "sub x27, x27, x26\n" + "cmp x27, x26\n" + "zip2 z19.h, z30.h, z19.h\n" + "st1h { z7.h }, p2, [x21, #2, MUL VL]\n" + "addvl x12, x12, #3\n" + "addvl x11, x11, #3\n" + "zip2 z18.h, z18.h, z17.h\n" + "st1h { z6.h }, p2, [x21, #3, MUL VL]\n" + "addvl x10, x10, #3\n" + "addvl x9, x9, #3\n" + "zip1 z17.h, z23.h, z16.h\n" + "st1h { z5.h }, p2, [x21, #4, MUL VL]\n" + "addvl x28, x28, #3\n" + "addvl x25, x25, #3\n" + "zip2 z16.h, z23.h, z16.h\n" + "st1h { z4.h }, p2, [x21, #5, MUL VL]\n" + "addvl x24, x24, #3\n" + "addvl x23, x23, #3\n" + "st1h { z3.h }, p2, [x21, #6, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z2.h }, p2, [x21, #7, MUL VL]\n" + "addvl x21, x21, #12\n" + "st1h { z27.h }, p2, [x21, #-4, MUL VL]\n" + "st1h { z22.h }, p2, [x21, #-3, MUL VL]\n" + "st1h { z21.h }, p2, [x21, #-2, MUL VL]\n" + "st1h { z20.h }, p2, [x21, #-1, MUL VL]\n" + "st1h { z1.h }, p2, [x20]\n" + "st1h { z0.h }, p2, [x20, #1, MUL VL]\n" + "st1h { z31.h }, p2, [x20, #2, MUL VL]\n" + "st1h { z29.h }, p2, [x20, #3, MUL VL]\n" + "st1h { z28.h }, p2, [x20, #4, MUL VL]\n" + "st1h { z26.h }, p2, [x20, #5, MUL VL]\n" + "st1h { z25.h }, p2, [x20, #6, MUL VL]\n" + "st1h { z24.h }, p2, [x20, #7, MUL VL]\n" + "addvl x20, x20, #12\n" + "st1h { z19.h }, p2, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p2, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p2, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p2, [x20, #-1, MUL VL]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x27, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x27\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z0.h }, p1/Z, [x12]\n" + "ld1h { z16.h }, p1/Z, [x11]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z21.h }, p0/Z, [x12, #1, MUL VL]\n" + "ld1h { z19.h }, p0/Z, [x11, #1, MUL VL]\n" + "ld1h { z31.h }, p1/Z, [x10]\n" + "ld1h { z30.h }, p0/Z, [x10, #1, MUL VL]\n" + "mov x20, x22\n" + "decw x27, ALL, MUL #3\n" + "ld1h { z18.h }, p1/Z, [x9]\n" + "ld1h { z29.h }, p0/Z, [x9, #1, MUL VL]\n" + "addvl x12, x12, #1\n" + "addvl x11, x11, #1\n" + "ld1h { z28.h }, p1/Z, [x28]\n" + "ld1h { z20.h }, p1/Z, [x25]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "ld1h { z27.h }, p0/Z, [x28, #1, MUL VL]\n" + "addvl x28, x28, #1\n" + "ld1h { z26.h }, p0/Z, [x25, #1, MUL VL]\n" + "addvl x25, x25, #1\n" + "ld1h { z25.h }, p1/Z, [x24]\n" + "ld1h { z24.h }, p0/Z, [x24, #1, MUL VL]\n" + "addvl x24, x24, #1\n" + "zip1 z17.h, z0.h, z16.h\n" + "ld1h { z23.h }, p1/Z, [x23]\n" + "ld1h { z22.h }, p0/Z, [x23, #1, MUL VL]\n" + "addvl x23, x23, #1\n" + "zip2 z16.h, z0.h, z16.h\n" + "zip1 z21.h, z21.h, z19.h\n" + "zip1 z19.h, z31.h, z18.h\n" + "st1h { z17.h }, p2, [x20]\n" + "cmp x27, #0x0\n" + "zip2 z18.h, z31.h, z18.h\n" + "zip1 z17.h, z30.h, z29.h\n" + "st1h { z16.h }, p2, [x20, #1, MUL VL]\n" + "incd x12, ALL, MUL #4\n" + "zip1 z16.h, z28.h, z20.h\n" + "zip2 z20.h, z28.h, z20.h\n" + "st1h { z21.h }, p2, [x20, #2, MUL VL]\n" + "incd x11, ALL, MUL #4\n" + "st1h { z19.h }, p2, [x20, #3, MUL VL]\n" + "incd x10, ALL, MUL #4\n" + "incd x9, ALL, MUL #4\n" + "zip1 z19.h, z27.h, z26.h\n" + "st1h { z18.h }, p2, [x20, #4, MUL VL]\n" + "incd x28, ALL, MUL #4\n" + "incd x25, ALL, MUL #4\n" + "zip1 z18.h, z25.h, z23.h\n" + "st1h { z17.h }, p2, [x20, #5, MUL VL]\n" + "incd x24, ALL, MUL #4\n" + "incd x23, ALL, MUL #4\n" + "zip2 z17.h, z25.h, z23.h\n" + "st1h { z16.h }, p2, [x20, #6, MUL VL]\n" + "zip1 z16.h, z24.h, z22.h\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z20.h }, p2, [x20, #7, MUL VL]\n" + "addvl x20, x20, #12\n" + "st1h { z19.h }, p2, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p2, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p2, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p2, [x20, #-1, MUL VL]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x8\n" + "addvl %x[out], %x[out], #12\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x12, %x[in]\n" + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #3\n" + "add x11, x12, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x11, %x[in_stride]\n" + "csel x11, x11, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1h { z17.h }, p2/Z, [x12]\n" + "ld1h { z22.h }, p2/Z, [x12, #1, MUL VL]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1h { z16.h }, p2/Z, [x11]\n" + "ld1h { z21.h }, p2/Z, [x11, #1, MUL VL]\n" + "zip1 z18.h, z17.h, z16.h\n" + "zip2 z17.h, z17.h, z16.h\n" + "ld1h { z20.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z19.h }, p2/Z, [x11, #2, MUL VL]\n" + "zip1 z16.h, z22.h, z21.h\n" + "st1h { z18.h }, p2, [x22]\n" + "st1h { z17.h }, p2, [x22, #1, MUL VL]\n" + "addvl x12, x12, #3\n" + "addvl x11, x11, #3\n" + "zip2 z18.h, z22.h, z21.h\n" + "st1h { z16.h }, p2, [x22, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z17.h, z20.h, z19.h\n" + "zip2 z16.h, z20.h, z19.h\n" + "st1h { z18.h }, p2, [x22]\n" + "st1h { z17.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z20.h }, p0/Z, [x12]\n" + "ld1h { z17.h }, p0/Z, [x11]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z19.h }, p0/Z, [x12, #1, MUL VL]\n" + "ld1h { z16.h }, p0/Z, [x11, #1, MUL VL]\n" + "decw x21, ALL, MUL #3\n" + "addvl x12, x12, #1\n" + "zip1 z18.h, z20.h, z17.h\n" + "zip2 z17.h, z20.h, z17.h\n" + "addvl x11, x11, #1\n" + "cmp x21, #0x0\n" + "zip1 z16.h, z19.h, z16.h\n" + "st1h { z18.h }, p2, [x22]\n" + "incd x12, ALL, MUL #4\n" + "incd x11, ALL, MUL #4\n" + "st1h { z17.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #2, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #3\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<3, 2, true, VLType::SVE>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_3VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_4VL.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_4VL.hpp new file mode 100644 index 0000000000..e661e2698a --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_4VL.hpp @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_4VL(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 4 * height * get_vector_length<uint8_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "ptrue p4.b\n" + "blt 4f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "mov x21, %x[width]\n" + "2:" // Main row loop: Column loop + "mov x20, x21\n" + "whilelt p3.h, XZR, x20\n" + "ld1h { z31.h }, p3/Z, [x26]\n" + "ld1h { z30.h }, p3/Z, [x25]\n" + "dech x20\n" + "whilelt p2.h, XZR, x20\n" + "ld1h { z29.h }, p2/Z, [x26, #1, MUL VL]\n" + "ld1h { z28.h }, p2/Z, [x25, #1, MUL VL]\n" + "dech x20\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z27.h }, p1/Z, [x26, #2, MUL VL]\n" + "ld1h { z26.h }, p1/Z, [x25, #2, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z25.h }, p0/Z, [x26, #3, MUL VL]\n" + "ld1h { z24.h }, p0/Z, [x25, #3, MUL VL]\n" + "mov x20, x22\n" + "dech x21, ALL, MUL #4\n" + "ld1h { z23.h }, p3/Z, [x24]\n" + "ld1h { z22.h }, p2/Z, [x24, #1, MUL VL]\n" + "ld1h { z21.h }, p1/Z, [x24, #2, MUL VL]\n" + "ld1h { z20.h }, p0/Z, [x24, #3, MUL VL]\n" + "cmp x21, #0x0\n" + "addvl x26, x26, #4\n" + "ld1h { z19.h }, p3/Z, [x23]\n" + "ld1h { z18.h }, p2/Z, [x23, #1, MUL VL]\n" + "addvl x25, x25, #4\n" + "addvl x24, x24, #4\n" + "ld1h { z17.h }, p1/Z, [x23, #2, MUL VL]\n" + "ld1h { z16.h }, p0/Z, [x23, #3, MUL VL]\n" + "st1h { z31.h }, p4, [x20]\n" + "addvl x23, x23, #4\n" + "st1h { z29.h }, p4, [x20, #1, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z27.h }, p4, [x20, #2, MUL VL]\n" + "st1h { z25.h }, p4, [x20, #3, MUL VL]\n" + "st1h { z30.h }, p4, [x20, #4, MUL VL]\n" + "st1h { z28.h }, p4, [x20, #5, MUL VL]\n" + "st1h { z26.h }, p4, [x20, #6, MUL VL]\n" + "st1h { z24.h }, p4, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1h { z23.h }, p4, [x20, #-8, MUL VL]\n" + "st1h { z22.h }, p4, [x20, #-7, MUL VL]\n" + "st1h { z21.h }, p4, [x20, #-6, MUL VL]\n" + "st1h { z20.h }, p4, [x20, #-5, MUL VL]\n" + "st1h { z19.h }, p4, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p4, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p4, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p4, [x20, #-1, MUL VL]\n" + "bgt 2b\n" + "3:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + "cbz %x[height], 8f\n" + "4:" // Main loop skip + "5:" // Tail row loop: Head + "mov x26, %x[in]\n" + "add %x[in], x26, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "mov x21, %x[width]\n" + "6:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z19.h }, p0/Z, [x26]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z18.h }, p0/Z, [x26, #1, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z17.h }, p0/Z, [x26, #2, MUL VL]\n" + "dech x20\n" + "dech x21, ALL, MUL #4\n" + "whilelt p0.h, XZR, x20\n" + "cmp x21, #0x0\n" + "ld1h { z16.h }, p0/Z, [x26, #3, MUL VL]\n" + "st1h { z19.h }, p4, [x22]\n" + "addvl x26, x26, #4\n" + "st1h { z18.h }, p4, [x22, #1, MUL VL]\n" + "st1h { z17.h }, p4, [x22, #2, MUL VL]\n" + "st1h { z16.h }, p4, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 6b\n" + "7:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #4\n" + "bge 5b\n" + "8:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<4, 1, true, VLType::SVE>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_4VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 2, + stride * sizeof(float), + (kmax-k0) + ); +} + +template<> +void Transform<4, 1, true, VLType::SVE>( + __fp16 *out, const __fp16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_4VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(__fp16) / 2, + stride * sizeof(__fp16), + (kmax-k0) + ); +} + +template<> +void Transform<4, 1, true, VLType::SVE>( + double *out, const double *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_4VL( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(double) / 2, + stride * sizeof(double), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_4VL_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_4VL_1x4.hpp new file mode 100644 index 0000000000..03a78f72f1 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_4VL_1x4.hpp @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_4VL_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 4 * roundup<size_t>(height, 4) * get_vector_length<uint32_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "ptrue p1.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x10, %x[in]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "mov x25, %x[width]\n" + "cntb x24, ALL, MUL #2\n" + "add x23, x26, %x[in_stride]\n" + "add x21, x23, %x[in_stride]\n" + "add x20, x21, %x[in_stride]\n" + "cmp x25, x24\n" + "add %x[in], x20, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1b { z21.b }, p1/Z, [x10]\n" + "ld1b { z20.b }, p1/Z, [x9]\n" + "sub x25, x25, x24\n" + "cmp x25, x24\n" + "ld1b { z17.b }, p1/Z, [x28]\n" + "ld1b { z16.b }, p1/Z, [x27]\n" + "zip1 z4.b, z21.b, z17.b\n" + "zip1 z3.b, z20.b, z16.b\n" + "ld1b { z19.b }, p1/Z, [x26]\n" + "ld1b { z18.b }, p1/Z, [x23]\n" + "zip2 z2.b, z21.b, z17.b\n" + "zip2 z1.b, z20.b, z16.b\n" + "ld1b { z17.b }, p1/Z, [x21]\n" + "ld1b { z16.b }, p1/Z, [x20]\n" + "zip1 z0.b, z19.b, z17.b\n" + "zip1 z31.b, z18.b, z16.b\n" + "ld1b { z24.b }, p1/Z, [x10, #1, MUL VL]\n" + "ld1b { z20.b }, p1/Z, [x9, #1, MUL VL]\n" + "zip2 z30.b, z19.b, z17.b\n" + "zip2 z23.b, z18.b, z16.b\n" + "ld1b { z17.b }, p1/Z, [x28, #1, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x27, #1, MUL VL]\n" + "zip1 z22.b, z24.b, z17.b\n" + "zip1 z21.b, z20.b, z16.b\n" + "ld1b { z19.b }, p1/Z, [x26, #1, MUL VL]\n" + "ld1b { z18.b }, p1/Z, [x23, #1, MUL VL]\n" + "zip2 z29.b, z24.b, z17.b\n" + "zip2 z28.b, z20.b, z16.b\n" + "ld1b { z17.b }, p1/Z, [x21, #1, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x20, #1, MUL VL]\n" + "zip1 z27.b, z19.b, z17.b\n" + "zip1 z26.b, z18.b, z16.b\n" + "zip2 z25.b, z19.b, z17.b\n" + "zip2 z24.b, z18.b, z16.b\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "zip1 z16.b, z4.b, z3.b\n" + "zip2 z17.b, z4.b, z3.b\n" + "st1b { z16.b }, p1, [x22]\n" + "addvl x28, x28, #2\n" + "zip1 z16.b, z2.b, z1.b\n" + "zip2 z20.b, z2.b, z1.b\n" + "st1b { z17.b }, p1, [x22, #1, MUL VL]\n" + "addvl x27, x27, #2\n" + "zip1 z19.b, z0.b, z31.b\n" + "zip2 z18.b, z0.b, z31.b\n" + "st1b { z16.b }, p1, [x22, #2, MUL VL]\n" + "addvl x26, x26, #2\n" + "zip1 z17.b, z30.b, z23.b\n" + "zip2 z16.b, z30.b, z23.b\n" + "st1b { z20.b }, p1, [x22, #3, MUL VL]\n" + "addvl x23, x23, #2\n" + "st1b { z19.b }, p1, [x22, #4, MUL VL]\n" + "addvl x21, x21, #2\n" + "addvl x20, x20, #2\n" + "zip1 z23.b, z22.b, z21.b\n" + "st1b { z18.b }, p1, [x22, #5, MUL VL]\n" + "zip2 z22.b, z22.b, z21.b\n" + "zip1 z21.b, z29.b, z28.b\n" + "st1b { z17.b }, p1, [x22, #6, MUL VL]\n" + "zip2 z20.b, z29.b, z28.b\n" + "zip1 z19.b, z27.b, z26.b\n" + "st1b { z16.b }, p1, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z18.b, z27.b, z26.b\n" + "zip1 z17.b, z25.b, z24.b\n" + "zip2 z16.b, z25.b, z24.b\n" + "st1b { z23.b }, p1, [x22]\n" + "st1b { z22.b }, p1, [x22, #1, MUL VL]\n" + "st1b { z21.b }, p1, [x22, #2, MUL VL]\n" + "st1b { z20.b }, p1, [x22, #3, MUL VL]\n" + "st1b { z19.b }, p1, [x22, #4, MUL VL]\n" + "st1b { z18.b }, p1, [x22, #5, MUL VL]\n" + "st1b { z17.b }, p1, [x22, #6, MUL VL]\n" + "st1b { z16.b }, p1, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x25, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.b, XZR, x25\n" + "ld1b { z19.b }, p0/Z, [x10]\n" + "ld1b { z18.b }, p0/Z, [x9]\n" + "decw x25, ALL, MUL #4\n" + "ld1b { z17.b }, p0/Z, [x28]\n" + "ld1b { z16.b }, p0/Z, [x27]\n" + "zip1 z27.b, z19.b, z17.b\n" + "zip1 z26.b, z18.b, z16.b\n" + "ld1b { z22.b }, p0/Z, [x26]\n" + "ld1b { z21.b }, p0/Z, [x23]\n" + "zip2 z25.b, z19.b, z17.b\n" + "zip2 z20.b, z18.b, z16.b\n" + "ld1b { z19.b }, p0/Z, [x21]\n" + "ld1b { z16.b }, p0/Z, [x20]\n" + "zip1 z18.b, z22.b, z19.b\n" + "zip1 z17.b, z21.b, z16.b\n" + "zip2 z24.b, z22.b, z19.b\n" + "zip2 z16.b, z21.b, z16.b\n" + "cmp x25, #0x0\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "zip1 z23.b, z27.b, z26.b\n" + "zip2 z22.b, z27.b, z26.b\n" + "addvl x27, x27, #1\n" + "addvl x26, x26, #1\n" + "zip1 z21.b, z25.b, z20.b\n" + "zip2 z20.b, z25.b, z20.b\n" + "addvl x23, x23, #1\n" + "addvl x21, x21, #1\n" + "zip1 z19.b, z18.b, z17.b\n" + "zip2 z18.b, z18.b, z17.b\n" + "addvl x20, x20, #1\n" + "zip1 z17.b, z24.b, z16.b\n" + "zip2 z16.b, z24.b, z16.b\n" + "st1b { z23.b }, p1, [x22]\n" + "st1b { z22.b }, p1, [x22, #1, MUL VL]\n" + "st1b { z21.b }, p1, [x22, #2, MUL VL]\n" + "st1b { z20.b }, p1, [x22, #3, MUL VL]\n" + "st1b { z19.b }, p1, [x22, #4, MUL VL]\n" + "st1b { z18.b }, p1, [x22, #5, MUL VL]\n" + "st1b { z17.b }, p1, [x22, #6, MUL VL]\n" + "st1b { z16.b }, p1, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x8\n" + "addvl %x[out], %x[out], #8\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x10, %x[in]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "mov x21, %x[width]\n" + "cntb x20, ALL, MUL #2\n" + "add x27, x28, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x27, %x[in_stride]\n" + "csel x27, x27, %x[pad_row], GT\n" + "csel x28, x28, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x9, x9, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1b { z21.b }, p1/Z, [x10]\n" + "ld1b { z19.b }, p1/Z, [x9]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1b { z17.b }, p1/Z, [x28]\n" + "ld1b { z16.b }, p1/Z, [x27]\n" + "zip1 z26.b, z21.b, z17.b\n" + "zip1 z25.b, z19.b, z16.b\n" + "ld1b { z20.b }, p1/Z, [x10, #1, MUL VL]\n" + "ld1b { z18.b }, p1/Z, [x9, #1, MUL VL]\n" + "zip2 z24.b, z21.b, z17.b\n" + "zip2 z19.b, z19.b, z16.b\n" + "ld1b { z17.b }, p1/Z, [x28, #1, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x27, #1, MUL VL]\n" + "zip1 z23.b, z20.b, z17.b\n" + "zip1 z22.b, z18.b, z16.b\n" + "zip2 z21.b, z20.b, z17.b\n" + "zip2 z20.b, z18.b, z16.b\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "zip1 z16.b, z26.b, z25.b\n" + "zip2 z18.b, z26.b, z25.b\n" + "st1b { z16.b }, p1, [x22]\n" + "addvl x28, x28, #2\n" + "zip1 z17.b, z24.b, z19.b\n" + "zip2 z16.b, z24.b, z19.b\n" + "st1b { z18.b }, p1, [x22, #1, MUL VL]\n" + "addvl x27, x27, #2\n" + "st1b { z17.b }, p1, [x22, #2, MUL VL]\n" + "zip1 z19.b, z23.b, z22.b\n" + "zip2 z18.b, z23.b, z22.b\n" + "st1b { z16.b }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z17.b, z21.b, z20.b\n" + "zip2 z16.b, z21.b, z20.b\n" + "st1b { z19.b }, p1, [x22]\n" + "st1b { z18.b }, p1, [x22, #1, MUL VL]\n" + "st1b { z17.b }, p1, [x22, #2, MUL VL]\n" + "st1b { z16.b }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "whilelt p0.b, XZR, x21\n" + "ld1b { z20.b }, p0/Z, [x10]\n" + "ld1b { z21.b }, p0/Z, [x9]\n" + "decw x21, ALL, MUL #4\n" + "ld1b { z19.b }, p0/Z, [x28]\n" + "ld1b { z16.b }, p0/Z, [x27]\n" + "zip1 z18.b, z20.b, z19.b\n" + "zip1 z17.b, z21.b, z16.b\n" + "zip2 z20.b, z20.b, z19.b\n" + "zip2 z16.b, z21.b, z16.b\n" + "cmp x21, #0x0\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "zip1 z19.b, z18.b, z17.b\n" + "zip2 z18.b, z18.b, z17.b\n" + "addvl x27, x27, #1\n" + "zip1 z17.b, z20.b, z16.b\n" + "zip2 z16.b, z20.b, z16.b\n" + "st1b { z19.b }, p1, [x22]\n" + "st1b { z18.b }, p1, [x22, #1, MUL VL]\n" + "st1b { z17.b }, p1, [x22, #2, MUL VL]\n" + "st1b { z16.b }, p1, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #4\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<4, 4, true, VLType::SVE>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_4VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<4, 4, true, VLType::SVE>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_4VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_4VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_4VL_2x2.hpp new file mode 100644 index 0000000000..b196799cfe --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_4VL_2x2.hpp @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_4VL_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 4 * roundup<size_t>(height, 2) * get_vector_length<uint16_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "ptrue p2.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x12, %x[in]\n" + "add x11, x12, %x[in_stride]\n" + "add x10, x11, %x[in_stride]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "mov x27, %x[width]\n" + "cnth x26, ALL, MUL #4\n" + "add x25, x28, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp x27, x26\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1h { z18.h }, p2/Z, [x12]\n" + "ld1h { z20.h }, p2/Z, [x12, #1, MUL VL]\n" + "mov x21, x22\n" + "add x22, x22, %x[out_stride]\n" + "ld1h { z17.h }, p2/Z, [x11]\n" + "ld1h { z16.h }, p2/Z, [x11, #1, MUL VL]\n" + "zip1 z25.h, z18.h, z17.h\n" + "zip2 z24.h, z18.h, z17.h\n" + "ld1h { z19.h }, p2/Z, [x10]\n" + "ld1h { z18.h }, p2/Z, [x10, #1, MUL VL]\n" + "zip1 z23.h, z20.h, z16.h\n" + "zip2 z15.h, z20.h, z16.h\n" + "ld1h { z17.h }, p2/Z, [x9]\n" + "ld1h { z16.h }, p2/Z, [x9, #1, MUL VL]\n" + "zip1 z14.h, z19.h, z17.h\n" + "zip2 z13.h, z19.h, z17.h\n" + "ld1h { z17.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z19.h }, p2/Z, [x12, #3, MUL VL]\n" + "zip1 z12.h, z18.h, z16.h\n" + "zip2 z11.h, z18.h, z16.h\n" + "ld1h { z16.h }, p2/Z, [x11, #2, MUL VL]\n" + "ld1h { z18.h }, p2/Z, [x11, #3, MUL VL]\n" + "mov x20, x22\n" + "zip1 z10.h, z17.h, z16.h\n" + "ld1h { z21.h }, p2/Z, [x10, #2, MUL VL]\n" + "ld1h { z20.h }, p2/Z, [x10, #3, MUL VL]\n" + "zip2 z9.h, z17.h, z16.h\n" + "zip1 z8.h, z19.h, z18.h\n" + "ld1h { z17.h }, p2/Z, [x9, #2, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x9, #3, MUL VL]\n" + "zip2 z7.h, z19.h, z18.h\n" + "zip1 z6.h, z21.h, z17.h\n" + "ld1h { z19.h }, p2/Z, [x28]\n" + "ld1h { z18.h }, p2/Z, [x28, #1, MUL VL]\n" + "zip2 z5.h, z21.h, z17.h\n" + "zip1 z4.h, z20.h, z16.h\n" + "ld1h { z22.h }, p2/Z, [x28, #2, MUL VL]\n" + "ld1h { z3.h }, p2/Z, [x28, #3, MUL VL]\n" + "zip2 z2.h, z20.h, z16.h\n" + "sub x27, x27, x26\n" + "ld1h { z17.h }, p2/Z, [x25]\n" + "ld1h { z16.h }, p2/Z, [x25, #1, MUL VL]\n" + "zip1 z1.h, z19.h, z17.h\n" + "zip2 z0.h, z19.h, z17.h\n" + "ld1h { z21.h }, p2/Z, [x25, #2, MUL VL]\n" + "ld1h { z20.h }, p2/Z, [x25, #3, MUL VL]\n" + "zip1 z31.h, z18.h, z16.h\n" + "zip2 z30.h, z18.h, z16.h\n" + "ld1h { z17.h }, p2/Z, [x24]\n" + "ld1h { z19.h }, p2/Z, [x24, #1, MUL VL]\n" + "cmp x27, x26\n" + "addvl x12, x12, #4\n" + "ld1h { z29.h }, p2/Z, [x24, #2, MUL VL]\n" + "ld1h { z28.h }, p2/Z, [x24, #3, MUL VL]\n" + "addvl x11, x11, #4\n" + "addvl x10, x10, #4\n" + "ld1h { z16.h }, p2/Z, [x23]\n" + "ld1h { z18.h }, p2/Z, [x23, #1, MUL VL]\n" + "zip1 z27.h, z17.h, z16.h\n" + "zip2 z26.h, z17.h, z16.h\n" + "ld1h { z17.h }, p2/Z, [x23, #2, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x23, #3, MUL VL]\n" + "st1h { z25.h }, p2, [x21]\n" + "zip1 z25.h, z19.h, z18.h\n" + "st1h { z24.h }, p2, [x21, #1, MUL VL]\n" + "zip2 z24.h, z19.h, z18.h\n" + "addvl x9, x9, #4\n" + "addvl x28, x28, #4\n" + "st1h { z23.h }, p2, [x21, #2, MUL VL]\n" + "addvl x25, x25, #4\n" + "addvl x24, x24, #4\n" + "zip1 z23.h, z22.h, z21.h\n" + "st1h { z15.h }, p2, [x21, #3, MUL VL]\n" + "addvl x23, x23, #4\n" + "zip2 z22.h, z22.h, z21.h\n" + "zip1 z21.h, z3.h, z20.h\n" + "st1h { z14.h }, p2, [x21, #4, MUL VL]\n" + "zip2 z20.h, z3.h, z20.h\n" + "zip1 z19.h, z29.h, z17.h\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z13.h }, p2, [x21, #5, MUL VL]\n" + "zip2 z18.h, z29.h, z17.h\n" + "zip1 z17.h, z28.h, z16.h\n" + "st1h { z12.h }, p2, [x21, #6, MUL VL]\n" + "zip2 z16.h, z28.h, z16.h\n" + "st1h { z11.h }, p2, [x21, #7, MUL VL]\n" + "addvl x21, x21, #16\n" + "st1h { z1.h }, p2, [x21, #-8, MUL VL]\n" + "st1h { z0.h }, p2, [x21, #-7, MUL VL]\n" + "st1h { z31.h }, p2, [x21, #-6, MUL VL]\n" + "st1h { z30.h }, p2, [x21, #-5, MUL VL]\n" + "st1h { z27.h }, p2, [x21, #-4, MUL VL]\n" + "st1h { z26.h }, p2, [x21, #-3, MUL VL]\n" + "st1h { z25.h }, p2, [x21, #-2, MUL VL]\n" + "st1h { z24.h }, p2, [x21, #-1, MUL VL]\n" + "st1h { z10.h }, p2, [x20]\n" + "st1h { z9.h }, p2, [x20, #1, MUL VL]\n" + "st1h { z8.h }, p2, [x20, #2, MUL VL]\n" + "st1h { z7.h }, p2, [x20, #3, MUL VL]\n" + "st1h { z6.h }, p2, [x20, #4, MUL VL]\n" + "st1h { z5.h }, p2, [x20, #5, MUL VL]\n" + "st1h { z4.h }, p2, [x20, #6, MUL VL]\n" + "st1h { z2.h }, p2, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1h { z23.h }, p2, [x20, #-8, MUL VL]\n" + "st1h { z22.h }, p2, [x20, #-7, MUL VL]\n" + "st1h { z21.h }, p2, [x20, #-6, MUL VL]\n" + "st1h { z20.h }, p2, [x20, #-5, MUL VL]\n" + "st1h { z19.h }, p2, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p2, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p2, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p2, [x20, #-1, MUL VL]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x27, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x27\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z22.h }, p1/Z, [x12]\n" + "ld1h { z21.h }, p1/Z, [x11]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z20.h }, p0/Z, [x12, #1, MUL VL]\n" + "ld1h { z19.h }, p0/Z, [x11, #1, MUL VL]\n" + "ld1h { z18.h }, p1/Z, [x10]\n" + "ld1h { z24.h }, p0/Z, [x10, #1, MUL VL]\n" + "mov x20, x22\n" + "decw x27, ALL, MUL #4\n" + "ld1h { z17.h }, p1/Z, [x9]\n" + "ld1h { z16.h }, p0/Z, [x9, #1, MUL VL]\n" + "zip1 z31.h, z22.h, z21.h\n" + "zip2 z23.h, z22.h, z21.h\n" + "ld1h { z30.h }, p1/Z, [x28]\n" + "ld1h { z29.h }, p0/Z, [x28, #1, MUL VL]\n" + "zip1 z22.h, z20.h, z19.h\n" + "zip2 z28.h, z20.h, z19.h\n" + "ld1h { z21.h }, p1/Z, [x25]\n" + "ld1h { z27.h }, p0/Z, [x25, #1, MUL VL]\n" + "zip1 z20.h, z18.h, z17.h\n" + "zip2 z19.h, z18.h, z17.h\n" + "ld1h { z18.h }, p1/Z, [x24]\n" + "ld1h { z26.h }, p0/Z, [x24, #1, MUL VL]\n" + "zip1 z25.h, z24.h, z16.h\n" + "zip2 z24.h, z24.h, z16.h\n" + "ld1h { z17.h }, p1/Z, [x23]\n" + "ld1h { z16.h }, p0/Z, [x23, #1, MUL VL]\n" + "st1h { z31.h }, p2, [x20]\n" + "cmp x27, #0x0\n" + "st1h { z23.h }, p2, [x20, #1, MUL VL]\n" + "addvl x12, x12, #2\n" + "addvl x11, x11, #2\n" + "zip1 z23.h, z30.h, z21.h\n" + "st1h { z22.h }, p2, [x20, #2, MUL VL]\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "zip2 z22.h, z30.h, z21.h\n" + "st1h { z28.h }, p2, [x20, #3, MUL VL]\n" + "addvl x28, x28, #2\n" + "addvl x25, x25, #2\n" + "zip1 z21.h, z29.h, z27.h\n" + "st1h { z20.h }, p2, [x20, #4, MUL VL]\n" + "addvl x24, x24, #2\n" + "addvl x23, x23, #2\n" + "zip2 z20.h, z29.h, z27.h\n" + "st1h { z19.h }, p2, [x20, #5, MUL VL]\n" + "zip1 z19.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z25.h }, p2, [x20, #6, MUL VL]\n" + "zip1 z17.h, z26.h, z16.h\n" + "zip2 z16.h, z26.h, z16.h\n" + "st1h { z24.h }, p2, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1h { z23.h }, p2, [x20, #-8, MUL VL]\n" + "st1h { z22.h }, p2, [x20, #-7, MUL VL]\n" + "st1h { z21.h }, p2, [x20, #-6, MUL VL]\n" + "st1h { z20.h }, p2, [x20, #-5, MUL VL]\n" + "st1h { z19.h }, p2, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p2, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p2, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p2, [x20, #-1, MUL VL]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x8\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x12, %x[in]\n" + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #4\n" + "add x11, x12, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x11, %x[in_stride]\n" + "csel x11, x11, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1h { z18.h }, p2/Z, [x12]\n" + "ld1h { z20.h }, p2/Z, [x12, #1, MUL VL]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1h { z17.h }, p2/Z, [x11]\n" + "ld1h { z16.h }, p2/Z, [x11, #1, MUL VL]\n" + "zip1 z23.h, z18.h, z17.h\n" + "zip2 z19.h, z18.h, z17.h\n" + "ld1h { z18.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z22.h }, p2/Z, [x12, #3, MUL VL]\n" + "zip1 z21.h, z20.h, z16.h\n" + "zip2 z20.h, z20.h, z16.h\n" + "ld1h { z17.h }, p2/Z, [x11, #2, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x11, #3, MUL VL]\n" + "st1h { z23.h }, p2, [x22]\n" + "addvl x12, x12, #4\n" + "st1h { z19.h }, p2, [x22, #1, MUL VL]\n" + "addvl x11, x11, #4\n" + "zip1 z19.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "st1h { z21.h }, p2, [x22, #2, MUL VL]\n" + "zip1 z17.h, z22.h, z16.h\n" + "zip2 z16.h, z22.h, z16.h\n" + "st1h { z20.h }, p2, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z19.h }, p2, [x22]\n" + "st1h { z18.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z18.h }, p0/Z, [x12]\n" + "ld1h { z17.h }, p0/Z, [x11]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z20.h }, p0/Z, [x12, #1, MUL VL]\n" + "ld1h { z16.h }, p0/Z, [x11, #1, MUL VL]\n" + "decw x21, ALL, MUL #4\n" + "cmp x21, #0x0\n" + "zip1 z19.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "addvl x12, x12, #2\n" + "addvl x11, x11, #2\n" + "zip1 z17.h, z20.h, z16.h\n" + "zip2 z16.h, z20.h, z16.h\n" + "st1h { z19.h }, p2, [x22]\n" + "st1h { z18.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #3, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #4\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<4, 2, true, VLType::SVE>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_4VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_1x8.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_1x8.hpp new file mode 100644 index 0000000000..68fe2d0cbe --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_1x8.hpp @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_6VL_1x8(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 8) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 6 * roundup<size_t>(height, 8) * get_vector_length<uint64_t>(); + + __asm__ __volatile__( + "ptrue p1.b\n" + "1:" // Main row loop: Head + "mov x10, %x[in]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp %x[height], #0x7\n" + "add %x[in], x23, %x[in_stride]\n" + "csel x23, x23, %x[pad_row], GT\n" + "csel x24, x24, %x[pad_row], GE\n" + "cmp %x[height], #0x5\n" + "mov x22, %x[width]\n" + "cntb x21, ALL, MUL #3\n" + "csel x25, x25, %x[pad_row], GT\n" + "csel x26, x26, %x[pad_row], GE\n" + "cmp %x[height], #0x3\n" + "csel x27, x27, %x[pad_row], GT\n" + "csel x28, x28, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x9, x9, %x[pad_row], GT\n" + "cmp x22, x21\n" + "mov x20, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1b { z21.b }, p1/Z, [x10]\n" + "ld1b { z25.b }, p1/Z, [x9]\n" + "sub x22, x22, x21\n" + "cmp x22, x21\n" + "ld1b { z20.b }, p1/Z, [x28]\n" + "ld1b { z24.b }, p1/Z, [x27]\n" + "ld1b { z19.b }, p1/Z, [x26]\n" + "ld1b { z18.b }, p1/Z, [x25]\n" + "zip1 z7.b, z21.b, z19.b\n" + "zip1 z6.b, z25.b, z18.b\n" + "ld1b { z17.b }, p1/Z, [x24]\n" + "ld1b { z16.b }, p1/Z, [x23]\n" + "zip1 z28.b, z20.b, z17.b\n" + "zip1 z27.b, z24.b, z16.b\n" + "ld1b { z23.b }, p1/Z, [x10, #1, MUL VL]\n" + "ld1b { z22.b }, p1/Z, [x9, #1, MUL VL]\n" + "zip2 z5.b, z21.b, z19.b\n" + "zip2 z4.b, z20.b, z17.b\n" + "ld1b { z21.b }, p1/Z, [x28, #1, MUL VL]\n" + "ld1b { z20.b }, p1/Z, [x27, #1, MUL VL]\n" + "zip2 z3.b, z25.b, z18.b\n" + "zip2 z2.b, z24.b, z16.b\n" + "ld1b { z19.b }, p1/Z, [x26, #1, MUL VL]\n" + "ld1b { z18.b }, p1/Z, [x25, #1, MUL VL]\n" + "zip1 z1.b, z23.b, z19.b\n" + "zip1 z15.b, z22.b, z18.b\n" + "ld1b { z17.b }, p1/Z, [x24, #1, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x23, #1, MUL VL]\n" + "zip1 z0.b, z21.b, z17.b\n" + "zip1 z31.b, z20.b, z16.b\n" + "ld1b { z26.b }, p1/Z, [x10, #2, MUL VL]\n" + "ld1b { z30.b }, p1/Z, [x9, #2, MUL VL]\n" + "zip2 z14.b, z23.b, z19.b\n" + "zip2 z13.b, z21.b, z17.b\n" + "ld1b { z25.b }, p1/Z, [x28, #2, MUL VL]\n" + "ld1b { z24.b }, p1/Z, [x27, #2, MUL VL]\n" + "zip2 z12.b, z22.b, z18.b\n" + "zip2 z11.b, z20.b, z16.b\n" + "ld1b { z23.b }, p1/Z, [x26, #2, MUL VL]\n" + "ld1b { z22.b }, p1/Z, [x25, #2, MUL VL]\n" + "zip1 z10.b, z26.b, z23.b\n" + "zip1 z9.b, z30.b, z22.b\n" + "ld1b { z21.b }, p1/Z, [x24, #2, MUL VL]\n" + "ld1b { z17.b }, p1/Z, [x23, #2, MUL VL]\n" + "zip1 z29.b, z25.b, z21.b\n" + "zip1 z8.b, z24.b, z17.b\n" + "zip1 z19.b, z7.b, z28.b\n" + "zip1 z16.b, z6.b, z27.b\n" + "addvl x10, x10, #3\n" + "addvl x9, x9, #3\n" + "zip2 z28.b, z7.b, z28.b\n" + "zip2 z18.b, z6.b, z27.b\n" + "addvl x28, x28, #3\n" + "addvl x27, x27, #3\n" + "zip1 z27.b, z5.b, z4.b\n" + "zip1 z20.b, z3.b, z2.b\n" + "addvl x26, x26, #3\n" + "addvl x25, x25, #3\n" + "zip2 z7.b, z26.b, z23.b\n" + "zip2 z26.b, z25.b, z21.b\n" + "addvl x24, x24, #3\n" + "addvl x23, x23, #3\n" + "zip2 z6.b, z30.b, z22.b\n" + "zip2 z25.b, z24.b, z17.b\n" + "zip2 z5.b, z5.b, z4.b\n" + "zip2 z4.b, z3.b, z2.b\n" + "zip1 z3.b, z1.b, z0.b\n" + "zip1 z2.b, z15.b, z31.b\n" + "zip2 z1.b, z1.b, z0.b\n" + "zip2 z0.b, z15.b, z31.b\n" + "zip1 z31.b, z14.b, z13.b\n" + "zip1 z30.b, z12.b, z11.b\n" + "zip2 z24.b, z14.b, z13.b\n" + "zip2 z23.b, z12.b, z11.b\n" + "zip1 z22.b, z10.b, z29.b\n" + "zip1 z21.b, z9.b, z8.b\n" + "zip1 z17.b, z19.b, z16.b\n" + "zip2 z16.b, z19.b, z16.b\n" + "st1b { z17.b }, p1, [x20]\n" + "zip1 z19.b, z28.b, z18.b\n" + "zip2 z18.b, z28.b, z18.b\n" + "st1b { z16.b }, p1, [x20, #1, MUL VL]\n" + "zip1 z17.b, z27.b, z20.b\n" + "zip2 z16.b, z27.b, z20.b\n" + "st1b { z19.b }, p1, [x20, #2, MUL VL]\n" + "st1b { z18.b }, p1, [x20, #3, MUL VL]\n" + "zip2 z29.b, z10.b, z29.b\n" + "zip2 z20.b, z9.b, z8.b\n" + "st1b { z17.b }, p1, [x20, #4, MUL VL]\n" + "zip1 z28.b, z7.b, z26.b\n" + "zip1 z27.b, z6.b, z25.b\n" + "st1b { z16.b }, p1, [x20, #5, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "zip2 z26.b, z7.b, z26.b\n" + "zip2 z25.b, z6.b, z25.b\n" + "zip1 z17.b, z5.b, z4.b\n" + "zip2 z16.b, z5.b, z4.b\n" + "st1b { z17.b }, p1, [x20]\n" + "zip1 z18.b, z3.b, z2.b\n" + "zip2 z17.b, z3.b, z2.b\n" + "st1b { z16.b }, p1, [x20, #1, MUL VL]\n" + "zip1 z16.b, z1.b, z0.b\n" + "zip2 z19.b, z1.b, z0.b\n" + "st1b { z18.b }, p1, [x20, #2, MUL VL]\n" + "st1b { z17.b }, p1, [x20, #3, MUL VL]\n" + "zip1 z18.b, z31.b, z30.b\n" + "zip2 z17.b, z31.b, z30.b\n" + "st1b { z16.b }, p1, [x20, #4, MUL VL]\n" + "zip1 z16.b, z24.b, z23.b\n" + "zip2 z24.b, z24.b, z23.b\n" + "st1b { z19.b }, p1, [x20, #5, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "zip1 z23.b, z22.b, z21.b\n" + "zip2 z22.b, z22.b, z21.b\n" + "st1b { z18.b }, p1, [x20]\n" + "zip1 z21.b, z29.b, z20.b\n" + "zip2 z20.b, z29.b, z20.b\n" + "st1b { z17.b }, p1, [x20, #1, MUL VL]\n" + "zip1 z19.b, z28.b, z27.b\n" + "zip2 z18.b, z28.b, z27.b\n" + "st1b { z16.b }, p1, [x20, #2, MUL VL]\n" + "zip1 z17.b, z26.b, z25.b\n" + "zip2 z16.b, z26.b, z25.b\n" + "st1b { z24.b }, p1, [x20, #3, MUL VL]\n" + "st1b { z23.b }, p1, [x20, #4, MUL VL]\n" + "st1b { z22.b }, p1, [x20, #5, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "st1b { z21.b }, p1, [x20]\n" + "st1b { z20.b }, p1, [x20, #1, MUL VL]\n" + "st1b { z19.b }, p1, [x20, #2, MUL VL]\n" + "st1b { z18.b }, p1, [x20, #3, MUL VL]\n" + "st1b { z17.b }, p1, [x20, #4, MUL VL]\n" + "st1b { z16.b }, p1, [x20, #5, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x22, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.b, XZR, x22\n" + "ld1b { z23.b }, p0/Z, [x10]\n" + "ld1b { z27.b }, p0/Z, [x9]\n" + "decd x22, ALL, MUL #6\n" + "ld1b { z21.b }, p0/Z, [x28]\n" + "ld1b { z26.b }, p0/Z, [x27]\n" + "cmp x22, #0x0\n" + "incd x10, ALL, MUL #6\n" + "ld1b { z20.b }, p0/Z, [x26]\n" + "ld1b { z19.b }, p0/Z, [x25]\n" + "zip1 z25.b, z23.b, z20.b\n" + "zip1 z24.b, z27.b, z19.b\n" + "ld1b { z17.b }, p0/Z, [x24]\n" + "ld1b { z16.b }, p0/Z, [x23]\n" + "zip1 z22.b, z21.b, z17.b\n" + "zip1 z18.b, z26.b, z16.b\n" + "zip2 z23.b, z23.b, z20.b\n" + "zip2 z21.b, z21.b, z17.b\n" + "incd x9, ALL, MUL #6\n" + "incd x28, ALL, MUL #6\n" + "zip2 z20.b, z27.b, z19.b\n" + "zip2 z17.b, z26.b, z16.b\n" + "incd x27, ALL, MUL #6\n" + "incd x26, ALL, MUL #6\n" + "zip1 z19.b, z25.b, z22.b\n" + "zip1 z16.b, z24.b, z18.b\n" + "incd x25, ALL, MUL #6\n" + "incd x24, ALL, MUL #6\n" + "zip2 z22.b, z25.b, z22.b\n" + "zip2 z18.b, z24.b, z18.b\n" + "incd x23, ALL, MUL #6\n" + "zip1 z21.b, z23.b, z21.b\n" + "zip1 z20.b, z20.b, z17.b\n" + "zip1 z17.b, z19.b, z16.b\n" + "zip2 z16.b, z19.b, z16.b\n" + "st1b { z17.b }, p1, [x20]\n" + "zip1 z19.b, z22.b, z18.b\n" + "zip2 z18.b, z22.b, z18.b\n" + "st1b { z16.b }, p1, [x20, #1, MUL VL]\n" + "zip1 z17.b, z21.b, z20.b\n" + "zip2 z16.b, z21.b, z20.b\n" + "st1b { z19.b }, p1, [x20, #2, MUL VL]\n" + "st1b { z18.b }, p1, [x20, #3, MUL VL]\n" + "st1b { z17.b }, p1, [x20, #4, MUL VL]\n" + "st1b { z16.b }, p1, [x20, #5, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #6\n" + "bge 1b\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<6, 8, true, VLType::SVE>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_6VL_1x8( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<6, 8, true, VLType::SVE>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_6VL_1x8( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_2x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_2x4.hpp new file mode 100644 index 0000000000..910fc6cb02 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_2x4.hpp @@ -0,0 +1,409 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_6VL_2x4(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 6 * roundup<size_t>(height, 4) * get_vector_length<uint32_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "ptrue p2.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x12, %x[in]\n" + "add x11, x12, %x[in_stride]\n" + "add x10, x11, %x[in_stride]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "mov x27, %x[width]\n" + "cnth x26, ALL, MUL #3\n" + "add x25, x28, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp x27, x26\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1h { z18.h }, p2/Z, [x12]\n" + "ld1h { z13.h }, p2/Z, [x12, #1, MUL VL]\n" + "mov x21, x22\n" + "add x22, x22, %x[out_stride]\n" + "ld1h { z17.h }, p2/Z, [x11]\n" + "ld1h { z12.h }, p2/Z, [x11, #1, MUL VL]\n" + "mov x20, x22\n" + "sub x27, x27, x26\n" + "ld1h { z16.h }, p2/Z, [x10]\n" + "ld1h { z11.h }, p2/Z, [x10, #1, MUL VL]\n" + "zip1 z23.h, z18.h, z16.h\n" + "zip2 z29.h, z18.h, z16.h\n" + "ld1h { z16.h }, p2/Z, [x9]\n" + "ld1h { z10.h }, p2/Z, [x9, #1, MUL VL]\n" + "zip1 z22.h, z17.h, z16.h\n" + "zip2 z28.h, z17.h, z16.h\n" + "ld1h { z27.h }, p2/Z, [x28]\n" + "ld1h { z26.h }, p2/Z, [x25]\n" + "zip1 z21.h, z13.h, z11.h\n" + "zip1 z20.h, z12.h, z10.h\n" + "ld1h { z18.h }, p2/Z, [x24]\n" + "ld1h { z19.h }, p2/Z, [x23]\n" + "zip1 z17.h, z27.h, z18.h\n" + "zip1 z16.h, z26.h, z19.h\n" + "ld1h { z9.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z8.h }, p2/Z, [x11, #2, MUL VL]\n" + "zip1 z25.h, z23.h, z22.h\n" + "zip2 z24.h, z23.h, z22.h\n" + "ld1h { z23.h }, p2/Z, [x10, #2, MUL VL]\n" + "ld1h { z7.h }, p2/Z, [x9, #2, MUL VL]\n" + "zip1 z22.h, z29.h, z28.h\n" + "zip2 z6.h, z29.h, z28.h\n" + "ld1h { z28.h }, p2/Z, [x28, #1, MUL VL]\n" + "ld1h { z5.h }, p2/Z, [x25, #1, MUL VL]\n" + "zip1 z4.h, z21.h, z20.h\n" + "zip2 z3.h, z21.h, z20.h\n" + "ld1h { z21.h }, p2/Z, [x24, #1, MUL VL]\n" + "ld1h { z20.h }, p2/Z, [x23, #1, MUL VL]\n" + "zip1 z2.h, z17.h, z16.h\n" + "zip2 z1.h, z17.h, z16.h\n" + "ld1h { z0.h }, p2/Z, [x28, #2, MUL VL]\n" + "ld1h { z31.h }, p2/Z, [x25, #2, MUL VL]\n" + "zip2 z18.h, z27.h, z18.h\n" + "zip2 z17.h, z26.h, z19.h\n" + "ld1h { z30.h }, p2/Z, [x24, #2, MUL VL]\n" + "ld1h { z29.h }, p2/Z, [x23, #2, MUL VL]\n" + "zip1 z19.h, z28.h, z21.h\n" + "zip1 z16.h, z5.h, z20.h\n" + "st1h { z25.h }, p2, [x21]\n" + "zip2 z27.h, z13.h, z11.h\n" + "zip2 z26.h, z12.h, z10.h\n" + "cmp x27, x26\n" + "st1h { z24.h }, p2, [x21, #1, MUL VL]\n" + "zip1 z25.h, z9.h, z23.h\n" + "zip1 z24.h, z8.h, z7.h\n" + "addvl x12, x12, #3\n" + "st1h { z22.h }, p2, [x21, #2, MUL VL]\n" + "zip2 z23.h, z9.h, z23.h\n" + "zip2 z22.h, z8.h, z7.h\n" + "addvl x11, x11, #3\n" + "st1h { z6.h }, p2, [x21, #3, MUL VL]\n" + "zip2 z28.h, z28.h, z21.h\n" + "zip2 z21.h, z5.h, z20.h\n" + "addvl x10, x10, #3\n" + "st1h { z4.h }, p2, [x21, #4, MUL VL]\n" + "zip1 z20.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "addvl x9, x9, #3\n" + "st1h { z3.h }, p2, [x21, #5, MUL VL]\n" + "zip1 z17.h, z19.h, z16.h\n" + "zip2 z16.h, z19.h, z16.h\n" + "addvl x28, x28, #3\n" + "st1h { z2.h }, p2, [x21, #6, MUL VL]\n" + "zip1 z19.h, z27.h, z26.h\n" + "zip2 z27.h, z27.h, z26.h\n" + "addvl x25, x25, #3\n" + "st1h { z1.h }, p2, [x21, #7, MUL VL]\n" + "addvl x21, x21, #12\n" + "zip1 z26.h, z25.h, z24.h\n" + "zip2 z25.h, z25.h, z24.h\n" + "st1h { z20.h }, p2, [x21, #-4, MUL VL]\n" + "zip1 z24.h, z23.h, z22.h\n" + "zip2 z23.h, z23.h, z22.h\n" + "addvl x24, x24, #3\n" + "st1h { z18.h }, p2, [x21, #-3, MUL VL]\n" + "zip1 z22.h, z28.h, z21.h\n" + "zip2 z21.h, z28.h, z21.h\n" + "addvl x23, x23, #3\n" + "st1h { z17.h }, p2, [x21, #-2, MUL VL]\n" + "zip1 z18.h, z0.h, z30.h\n" + "zip1 z17.h, z31.h, z29.h\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z16.h }, p2, [x21, #-1, MUL VL]\n" + "zip2 z20.h, z0.h, z30.h\n" + "zip2 z16.h, z31.h, z29.h\n" + "st1h { z19.h }, p2, [x20]\n" + "zip1 z19.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "st1h { z27.h }, p2, [x20, #1, MUL VL]\n" + "zip1 z17.h, z20.h, z16.h\n" + "zip2 z16.h, z20.h, z16.h\n" + "st1h { z26.h }, p2, [x20, #2, MUL VL]\n" + "st1h { z25.h }, p2, [x20, #3, MUL VL]\n" + "st1h { z24.h }, p2, [x20, #4, MUL VL]\n" + "st1h { z23.h }, p2, [x20, #5, MUL VL]\n" + "st1h { z22.h }, p2, [x20, #6, MUL VL]\n" + "st1h { z21.h }, p2, [x20, #7, MUL VL]\n" + "addvl x20, x20, #12\n" + "st1h { z19.h }, p2, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p2, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p2, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p2, [x20, #-1, MUL VL]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x27, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x27\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z17.h }, p1/Z, [x12]\n" + "ld1h { z19.h }, p1/Z, [x11]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z22.h }, p0/Z, [x12, #1, MUL VL]\n" + "ld1h { z21.h }, p0/Z, [x11, #1, MUL VL]\n" + "ld1h { z16.h }, p1/Z, [x10]\n" + "ld1h { z20.h }, p0/Z, [x10, #1, MUL VL]\n" + "zip1 z25.h, z17.h, z16.h\n" + "zip2 z24.h, z17.h, z16.h\n" + "ld1h { z18.h }, p1/Z, [x9]\n" + "ld1h { z17.h }, p0/Z, [x9, #1, MUL VL]\n" + "zip1 z16.h, z19.h, z18.h\n" + "zip2 z19.h, z19.h, z18.h\n" + "ld1h { z0.h }, p1/Z, [x28]\n" + "ld1h { z31.h }, p1/Z, [x25]\n" + "zip1 z23.h, z22.h, z20.h\n" + "zip1 z22.h, z21.h, z17.h\n" + "ld1h { z30.h }, p1/Z, [x24]\n" + "ld1h { z29.h }, p1/Z, [x23]\n" + "zip1 z21.h, z0.h, z30.h\n" + "zip1 z18.h, z31.h, z29.h\n" + "ld1h { z28.h }, p0/Z, [x28, #1, MUL VL]\n" + "ld1h { z27.h }, p0/Z, [x25, #1, MUL VL]\n" + "mov x20, x22\n" + "decd x27, ALL, MUL #6\n" + "ld1h { z20.h }, p0/Z, [x24, #1, MUL VL]\n" + "ld1h { z26.h }, p0/Z, [x23, #1, MUL VL]\n" + "addvl x12, x12, #1\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "zip1 z17.h, z25.h, z16.h\n" + "zip2 z16.h, z25.h, z16.h\n" + "addvl x28, x28, #1\n" + "addvl x25, x25, #1\n" + "zip1 z25.h, z24.h, z19.h\n" + "zip2 z19.h, z24.h, z19.h\n" + "addvl x24, x24, #1\n" + "addvl x23, x23, #1\n" + "zip1 z24.h, z23.h, z22.h\n" + "zip2 z23.h, z23.h, z22.h\n" + "zip1 z22.h, z21.h, z18.h\n" + "zip2 z21.h, z21.h, z18.h\n" + "st1h { z17.h }, p2, [x20]\n" + "cmp x27, #0x0\n" + "zip2 z18.h, z0.h, z30.h\n" + "zip2 z17.h, z31.h, z29.h\n" + "st1h { z16.h }, p2, [x20, #1, MUL VL]\n" + "incd x12, ALL, MUL #4\n" + "zip1 z20.h, z28.h, z20.h\n" + "zip1 z16.h, z27.h, z26.h\n" + "st1h { z25.h }, p2, [x20, #2, MUL VL]\n" + "incd x11, ALL, MUL #4\n" + "st1h { z19.h }, p2, [x20, #3, MUL VL]\n" + "incd x10, ALL, MUL #4\n" + "incd x9, ALL, MUL #4\n" + "zip1 z19.h, z18.h, z17.h\n" + "st1h { z24.h }, p2, [x20, #4, MUL VL]\n" + "incd x28, ALL, MUL #4\n" + "incd x25, ALL, MUL #4\n" + "zip2 z18.h, z18.h, z17.h\n" + "st1h { z23.h }, p2, [x20, #5, MUL VL]\n" + "incd x24, ALL, MUL #4\n" + "incd x23, ALL, MUL #4\n" + "zip1 z17.h, z20.h, z16.h\n" + "st1h { z22.h }, p2, [x20, #6, MUL VL]\n" + "zip2 z16.h, z20.h, z16.h\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z21.h }, p2, [x20, #7, MUL VL]\n" + "addvl x20, x20, #12\n" + "st1h { z19.h }, p2, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p2, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p2, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p2, [x20, #-1, MUL VL]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x8\n" + "addvl %x[out], %x[out], #12\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x12, %x[in]\n" + "add x11, x12, %x[in_stride]\n" + "add x10, x11, %x[in_stride]\n" + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x9, %x[in_stride]\n" + "csel x9, x9, %x[pad_row], GT\n" + "csel x10, x10, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x11, x11, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1h { z18.h }, p2/Z, [x12]\n" + "ld1h { z24.h }, p2/Z, [x12, #1, MUL VL]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1h { z17.h }, p2/Z, [x11]\n" + "ld1h { z23.h }, p2/Z, [x11, #1, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x10]\n" + "ld1h { z22.h }, p2/Z, [x10, #1, MUL VL]\n" + "zip1 z31.h, z18.h, z16.h\n" + "zip2 z30.h, z18.h, z16.h\n" + "ld1h { z16.h }, p2/Z, [x9]\n" + "ld1h { z20.h }, p2/Z, [x9, #1, MUL VL]\n" + "zip1 z29.h, z17.h, z16.h\n" + "zip2 z28.h, z17.h, z16.h\n" + "ld1h { z19.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z18.h }, p2/Z, [x11, #2, MUL VL]\n" + "zip1 z27.h, z24.h, z22.h\n" + "zip1 z21.h, z23.h, z20.h\n" + "ld1h { z17.h }, p2/Z, [x10, #2, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x9, #2, MUL VL]\n" + "zip2 z26.h, z24.h, z22.h\n" + "zip2 z20.h, z23.h, z20.h\n" + "zip1 z25.h, z19.h, z17.h\n" + "zip1 z24.h, z18.h, z16.h\n" + "addvl x12, x12, #3\n" + "addvl x11, x11, #3\n" + "zip2 z23.h, z19.h, z17.h\n" + "zip2 z22.h, z18.h, z16.h\n" + "addvl x10, x10, #3\n" + "addvl x9, x9, #3\n" + "zip1 z17.h, z31.h, z29.h\n" + "zip2 z16.h, z31.h, z29.h\n" + "st1h { z17.h }, p2, [x22]\n" + "zip1 z19.h, z30.h, z28.h\n" + "zip2 z18.h, z30.h, z28.h\n" + "st1h { z16.h }, p2, [x22, #1, MUL VL]\n" + "zip1 z17.h, z27.h, z21.h\n" + "zip2 z16.h, z27.h, z21.h\n" + "st1h { z19.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z18.h }, p2, [x22, #3, MUL VL]\n" + "zip1 z21.h, z26.h, z20.h\n" + "zip2 z20.h, z26.h, z20.h\n" + "st1h { z17.h }, p2, [x22, #4, MUL VL]\n" + "zip1 z19.h, z25.h, z24.h\n" + "zip2 z18.h, z25.h, z24.h\n" + "st1h { z16.h }, p2, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z17.h, z23.h, z22.h\n" + "zip2 z16.h, z23.h, z22.h\n" + "st1h { z21.h }, p2, [x22]\n" + "st1h { z20.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z19.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z18.h }, p2, [x22, #3, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #4, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z22.h }, p1/Z, [x12]\n" + "ld1h { z25.h }, p1/Z, [x11]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z24.h }, p0/Z, [x12, #1, MUL VL]\n" + "ld1h { z23.h }, p0/Z, [x11, #1, MUL VL]\n" + "ld1h { z21.h }, p1/Z, [x10]\n" + "ld1h { z20.h }, p0/Z, [x10, #1, MUL VL]\n" + "decd x21, ALL, MUL #6\n" + "addvl x12, x12, #1\n" + "ld1h { z18.h }, p1/Z, [x9]\n" + "ld1h { z17.h }, p0/Z, [x9, #1, MUL VL]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "zip1 z19.h, z22.h, z21.h\n" + "zip1 z16.h, z25.h, z18.h\n" + "cmp x21, #0x0\n" + "zip2 z22.h, z22.h, z21.h\n" + "zip2 z18.h, z25.h, z18.h\n" + "incd x12, ALL, MUL #4\n" + "incd x11, ALL, MUL #4\n" + "zip1 z21.h, z24.h, z20.h\n" + "zip1 z20.h, z23.h, z17.h\n" + "incd x10, ALL, MUL #4\n" + "incd x9, ALL, MUL #4\n" + "zip1 z17.h, z19.h, z16.h\n" + "zip2 z16.h, z19.h, z16.h\n" + "st1h { z17.h }, p2, [x22]\n" + "zip1 z19.h, z22.h, z18.h\n" + "zip2 z18.h, z22.h, z18.h\n" + "st1h { z16.h }, p2, [x22, #1, MUL VL]\n" + "zip1 z17.h, z21.h, z20.h\n" + "zip2 z16.h, z21.h, z20.h\n" + "st1h { z19.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z18.h }, p2, [x22, #3, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #4, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #6\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<6, 4, true, VLType::SVE>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_6VL_2x4( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_2x4_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_2x4_fp32bf16.hpp new file mode 100644 index 0000000000..f0f10d2f43 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_2x4_fp32bf16.hpp @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_6VL_2x4_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 6 * roundup<size_t>(height, 4) * get_vector_length<uint32_t>(); + + __asm__ __volatile__( + "ptrue p3.b\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[width]\n" + "cnth x20, ALL, MUL #3\n" + "add x22, x24, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x22, %x[in_stride]\n" + "csel x22, x22, %x[pad_row], GT\n" + "csel x24, x24, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x23, x20\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1w { z17.s }, p3/Z, [x26]\n" + "ld1w { z18.s }, p3/Z, [x26, #1, MUL VL]\n" + "sub x23, x23, x20\n" + "cmp x23, x20\n" + "ld1w { z19.s }, p3/Z, [x26, #2, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x24]\n" + "zip1 z21.s, z17.s, z16.s\n" + "zip2 z20.s, z17.s, z16.s\n" + "ld1w { z17.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x24, #2, MUL VL]\n" + "zip1 z29.s, z18.s, z17.s\n" + "zip2 z28.s, z18.s, z17.s\n" + "ld1w { z17.s }, p3/Z, [x26, #3, MUL VL]\n" + "ld1w { z18.s }, p3/Z, [x26, #4, MUL VL]\n" + "zip1 z27.s, z19.s, z16.s\n" + "zip2 z26.s, z19.s, z16.s\n" + "ld1w { z19.s }, p3/Z, [x26, #5, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x24, #3, MUL VL]\n" + "zip1 z25.s, z17.s, z16.s\n" + "zip2 z24.s, z17.s, z16.s\n" + "ld1w { z17.s }, p3/Z, [x24, #4, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x24, #5, MUL VL]\n" + "zip1 z12.s, z18.s, z17.s\n" + "zip2 z11.s, z18.s, z17.s\n" + "ld1w { z18.s }, p3/Z, [x25]\n" + "ld1w { z23.s }, p3/Z, [x25, #1, MUL VL]\n" + "zip1 z10.s, z19.s, z16.s\n" + "zip2 z9.s, z19.s, z16.s\n" + "ld1w { z22.s }, p3/Z, [x25, #2, MUL VL]\n" + "ld1w { z17.s }, p3/Z, [x22]\n" + ".inst 0x658aaea8 // bfcvt z8.h, p3/M, z21.s\n" + "zip1 z7.s, z18.s, z17.s\n" + "ld1w { z16.s }, p3/Z, [x22, #1, MUL VL]\n" + "ld1w { z21.s }, p3/Z, [x22, #2, MUL VL]\n" + ".inst 0x658aae86 // bfcvt z6.h, p3/M, z20.s\n" + "zip2 z5.s, z18.s, z17.s\n" + "ld1w { z20.s }, p3/Z, [x25, #3, MUL VL]\n" + "ld1w { z19.s }, p3/Z, [x25, #4, MUL VL]\n" + ".inst 0x658aafa4 // bfcvt z4.h, p3/M, z29.s\n" + "zip1 z3.s, z23.s, z16.s\n" + "ld1w { z2.s }, p3/Z, [x25, #5, MUL VL]\n" + "ld1w { z18.s }, p3/Z, [x22, #3, MUL VL]\n" + ".inst 0x658aaf81 // bfcvt z1.h, p3/M, z28.s\n" + "zip2 z0.s, z23.s, z16.s\n" + "ld1w { z17.s }, p3/Z, [x22, #4, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x22, #5, MUL VL]\n" + ".inst 0x658aaf7f // bfcvt z31.h, p3/M, z27.s\n" + "zip1 z30.s, z22.s, z21.s\n" + ".inst 0x658aaf5d // bfcvt z29.h, p3/M, z26.s\n" + "zip2 z28.s, z22.s, z21.s\n" + "addvl x26, x26, #6\n" + "addvl x25, x25, #6\n" + ".inst 0x658aaf3b // bfcvt z27.h, p3/M, z25.s\n" + "zip1 z26.s, z20.s, z18.s\n" + "addvl x24, x24, #6\n" + "addvl x22, x22, #6\n" + ".inst 0x658aaf19 // bfcvt z25.h, p3/M, z24.s\n" + "zip2 z24.s, z20.s, z18.s\n" + ".inst 0x658aad97 // bfcvt z23.h, p3/M, z12.s\n" + "zip1 z22.s, z19.s, z17.s\n" + ".inst 0x658aad75 // bfcvt z21.h, p3/M, z11.s\n" + "zip2 z20.s, z19.s, z17.s\n" + ".inst 0x658aad53 // bfcvt z19.h, p3/M, z10.s\n" + "zip1 z18.s, z2.s, z16.s\n" + ".inst 0x658aad31 // bfcvt z17.h, p3/M, z9.s\n" + "zip2 z16.s, z2.s, z16.s\n" + ".inst 0x648aace8 // bfcvtnt z8.h, p3/M, z7.s\n" + ".inst 0x648aaca6 // bfcvtnt z6.h, p3/M, z5.s\n" + "st1h { z8.h }, p3, [x21]\n" + ".inst 0x648aac64 // bfcvtnt z4.h, p3/M, z3.s\n" + ".inst 0x648aac01 // bfcvtnt z1.h, p3/M, z0.s\n" + "st1h { z6.h }, p3, [x21, #1, MUL VL]\n" + ".inst 0x648aafdf // bfcvtnt z31.h, p3/M, z30.s\n" + ".inst 0x648aaf9d // bfcvtnt z29.h, p3/M, z28.s\n" + "st1h { z4.h }, p3, [x21, #2, MUL VL]\n" + "st1h { z1.h }, p3, [x21, #3, MUL VL]\n" + ".inst 0x648aaf5b // bfcvtnt z27.h, p3/M, z26.s\n" + ".inst 0x648aaf19 // bfcvtnt z25.h, p3/M, z24.s\n" + "st1h { z31.h }, p3, [x21, #4, MUL VL]\n" + ".inst 0x648aaed7 // bfcvtnt z23.h, p3/M, z22.s\n" + ".inst 0x648aae95 // bfcvtnt z21.h, p3/M, z20.s\n" + "st1h { z29.h }, p3, [x21, #5, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + ".inst 0x648aae53 // bfcvtnt z19.h, p3/M, z18.s\n" + ".inst 0x648aae11 // bfcvtnt z17.h, p3/M, z16.s\n" + "st1h { z27.h }, p3, [x21]\n" + "st1h { z25.h }, p3, [x21, #1, MUL VL]\n" + "st1h { z23.h }, p3, [x21, #2, MUL VL]\n" + "st1h { z21.h }, p3, [x21, #3, MUL VL]\n" + "st1h { z19.h }, p3, [x21, #4, MUL VL]\n" + "st1h { z17.h }, p3, [x21, #5, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x23, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x23\n" + "whilelt p2.s, XZR, x20\n" + "ld1w { z20.s }, p2/Z, [x26]\n" + "ld1w { z19.s }, p2/Z, [x24]\n" + "decw x20\n" + "whilelt p1.s, XZR, x20\n" + "ld1w { z18.s }, p1/Z, [x26, #1, MUL VL]\n" + "ld1w { z17.s }, p1/Z, [x24, #1, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z25.s }, p0/Z, [x26, #2, MUL VL]\n" + "ld1w { z16.s }, p0/Z, [x24, #2, MUL VL]\n" + "ld1w { z24.s }, p2/Z, [x25]\n" + "ld1w { z30.s }, p1/Z, [x25, #1, MUL VL]\n" + "zip1 z23.s, z20.s, z19.s\n" + "zip2 z22.s, z20.s, z19.s\n" + "ld1w { z29.s }, p0/Z, [x25, #2, MUL VL]\n" + "ld1w { z21.s }, p2/Z, [x22]\n" + "zip1 z20.s, z18.s, z17.s\n" + "zip2 z19.s, z18.s, z17.s\n" + "ld1w { z18.s }, p1/Z, [x22, #1, MUL VL]\n" + "ld1w { z28.s }, p0/Z, [x22, #2, MUL VL]\n" + "zip1 z17.s, z25.s, z16.s\n" + "zip2 z16.s, z25.s, z16.s\n" + "decd x23, ALL, MUL #6\n" + ".inst 0x658aaefb // bfcvt z27.h, p3/M, z23.s\n" + "zip1 z26.s, z24.s, z21.s\n" + "cmp x23, #0x0\n" + ".inst 0x658aaed9 // bfcvt z25.h, p3/M, z22.s\n" + "zip2 z24.s, z24.s, z21.s\n" + "addvl x26, x26, #3\n" + "addvl x25, x25, #3\n" + ".inst 0x658aae97 // bfcvt z23.h, p3/M, z20.s\n" + "zip1 z22.s, z30.s, z18.s\n" + "addvl x24, x24, #3\n" + "addvl x22, x22, #3\n" + ".inst 0x658aae75 // bfcvt z21.h, p3/M, z19.s\n" + "zip2 z20.s, z30.s, z18.s\n" + ".inst 0x658aae33 // bfcvt z19.h, p3/M, z17.s\n" + "zip1 z18.s, z29.s, z28.s\n" + ".inst 0x658aae11 // bfcvt z17.h, p3/M, z16.s\n" + "zip2 z16.s, z29.s, z28.s\n" + ".inst 0x648aaf5b // bfcvtnt z27.h, p3/M, z26.s\n" + ".inst 0x648aaf19 // bfcvtnt z25.h, p3/M, z24.s\n" + "st1h { z27.h }, p3, [x21]\n" + ".inst 0x648aaed7 // bfcvtnt z23.h, p3/M, z22.s\n" + ".inst 0x648aae95 // bfcvtnt z21.h, p3/M, z20.s\n" + "st1h { z25.h }, p3, [x21, #1, MUL VL]\n" + ".inst 0x648aae53 // bfcvtnt z19.h, p3/M, z18.s\n" + ".inst 0x648aae11 // bfcvtnt z17.h, p3/M, z16.s\n" + "st1h { z23.h }, p3, [x21, #2, MUL VL]\n" + "st1h { z21.h }, p3, [x21, #3, MUL VL]\n" + "st1h { z19.h }, p3, [x21, #4, MUL VL]\n" + "st1h { z17.h }, p3, [x21, #5, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #6\n" + "bge 1b\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace +template<> +void Transform<6, 4, true, VLType::SVE>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_6VL_2x4_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_4x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_4x2.hpp new file mode 100644 index 0000000000..c638eaacde --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_6VL_4x2.hpp @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_6VL_4x2(uint32_t *out, const uint32_t *in, size_t width, size_t in_stride, size_t height) +{ + uint32_t *pad_row = reinterpret_cast<uint32_t *>(alloca(width * sizeof(uint32_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint32_t)); + } + + size_t out_stride = 6 * roundup<size_t>(height, 2) * get_vector_length<uint16_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "ptrue p3.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x28, %x[in]\n" + "mov x27, %x[width]\n" + "cntw x26, ALL, MUL #6\n" + "add x25, x28, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp x27, x26\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1w { z18.s }, p3/Z, [x28]\n" + "ld1w { z17.s }, p3/Z, [x28, #1, MUL VL]\n" + "mov x21, x22\n" + "add x22, x22, %x[out_stride]\n" + "ld1w { z19.s }, p3/Z, [x28, #2, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x25]\n" + "zip1 z9.s, z18.s, z16.s\n" + "zip2 z8.s, z18.s, z16.s\n" + "ld1w { z16.s }, p3/Z, [x25, #1, MUL VL]\n" + "ld1w { z18.s }, p3/Z, [x25, #2, MUL VL]\n" + "zip1 z7.s, z17.s, z16.s\n" + "zip2 z6.s, z17.s, z16.s\n" + "ld1w { z17.s }, p3/Z, [x24]\n" + "ld1w { z16.s }, p3/Z, [x23]\n" + "zip1 z5.s, z19.s, z18.s\n" + "zip2 z4.s, z19.s, z18.s\n" + "ld1w { z18.s }, p3/Z, [x28, #3, MUL VL]\n" + "ld1w { z21.s }, p3/Z, [x28, #4, MUL VL]\n" + "zip1 z3.s, z17.s, z16.s\n" + "zip2 z2.s, z17.s, z16.s\n" + "ld1w { z20.s }, p3/Z, [x28, #5, MUL VL]\n" + "ld1w { z17.s }, p3/Z, [x25, #3, MUL VL]\n" + "mov x20, x22\n" + "zip1 z1.s, z18.s, z17.s\n" + "ld1w { z19.s }, p3/Z, [x25, #4, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x25, #5, MUL VL]\n" + "zip2 z0.s, z18.s, z17.s\n" + "zip1 z31.s, z21.s, z19.s\n" + "ld1w { z18.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z17.s }, p3/Z, [x24, #2, MUL VL]\n" + "zip2 z30.s, z21.s, z19.s\n" + "zip1 z29.s, z20.s, z16.s\n" + "ld1w { z19.s }, p3/Z, [x24, #3, MUL VL]\n" + "ld1w { z28.s }, p3/Z, [x24, #4, MUL VL]\n" + "zip2 z27.s, z20.s, z16.s\n" + "sub x27, x27, x26\n" + "ld1w { z26.s }, p3/Z, [x24, #5, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x23, #1, MUL VL]\n" + "zip1 z25.s, z18.s, z16.s\n" + "zip2 z24.s, z18.s, z16.s\n" + "ld1w { z16.s }, p3/Z, [x23, #2, MUL VL]\n" + "ld1w { z18.s }, p3/Z, [x23, #3, MUL VL]\n" + "zip1 z23.s, z17.s, z16.s\n" + "zip2 z22.s, z17.s, z16.s\n" + "ld1w { z17.s }, p3/Z, [x23, #4, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x23, #5, MUL VL]\n" + "st1w { z9.s }, p3, [x21]\n" + "zip1 z21.s, z19.s, z18.s\n" + "st1w { z8.s }, p3, [x21, #1, MUL VL]\n" + "zip2 z20.s, z19.s, z18.s\n" + "cmp x27, x26\n" + "addvl x28, x28, #6\n" + "st1w { z7.s }, p3, [x21, #2, MUL VL]\n" + "addvl x25, x25, #6\n" + "addvl x24, x24, #6\n" + "zip1 z19.s, z28.s, z17.s\n" + "st1w { z6.s }, p3, [x21, #3, MUL VL]\n" + "addvl x23, x23, #6\n" + "zip2 z18.s, z28.s, z17.s\n" + "zip1 z17.s, z26.s, z16.s\n" + "st1w { z5.s }, p3, [x21, #4, MUL VL]\n" + "zip2 z16.s, z26.s, z16.s\n" + "add x22, x22, %x[out_stride]\n" + "st1w { z4.s }, p3, [x21, #5, MUL VL]\n" + "st1w { z3.s }, p3, [x21, #6, MUL VL]\n" + "st1w { z2.s }, p3, [x21, #7, MUL VL]\n" + "addvl x21, x21, #12\n" + "st1w { z25.s }, p3, [x21, #-4, MUL VL]\n" + "st1w { z24.s }, p3, [x21, #-3, MUL VL]\n" + "st1w { z23.s }, p3, [x21, #-2, MUL VL]\n" + "st1w { z22.s }, p3, [x21, #-1, MUL VL]\n" + "st1w { z1.s }, p3, [x20]\n" + "st1w { z0.s }, p3, [x20, #1, MUL VL]\n" + "st1w { z31.s }, p3, [x20, #2, MUL VL]\n" + "st1w { z30.s }, p3, [x20, #3, MUL VL]\n" + "st1w { z29.s }, p3, [x20, #4, MUL VL]\n" + "st1w { z27.s }, p3, [x20, #5, MUL VL]\n" + "st1w { z21.s }, p3, [x20, #6, MUL VL]\n" + "st1w { z20.s }, p3, [x20, #7, MUL VL]\n" + "addvl x20, x20, #12\n" + "st1w { z19.s }, p3, [x20, #-4, MUL VL]\n" + "st1w { z18.s }, p3, [x20, #-3, MUL VL]\n" + "st1w { z17.s }, p3, [x20, #-2, MUL VL]\n" + "st1w { z16.s }, p3, [x20, #-1, MUL VL]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x27, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x27\n" + "whilelt p2.s, XZR, x20\n" + "ld1w { z19.s }, p2/Z, [x28]\n" + "ld1w { z18.s }, p2/Z, [x25]\n" + "decw x20\n" + "whilelt p1.s, XZR, x20\n" + "ld1w { z17.s }, p1/Z, [x28, #1, MUL VL]\n" + "ld1w { z16.s }, p1/Z, [x25, #1, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z22.s }, p0/Z, [x28, #2, MUL VL]\n" + "ld1w { z21.s }, p0/Z, [x25, #2, MUL VL]\n" + "ld1w { z28.s }, p2/Z, [x24]\n" + "ld1w { z27.s }, p2/Z, [x23]\n" + "mov x20, x22\n" + "decd x27, ALL, MUL #6\n" + "ld1w { z26.s }, p1/Z, [x24, #1, MUL VL]\n" + "ld1w { z25.s }, p0/Z, [x24, #2, MUL VL]\n" + "zip1 z20.s, z19.s, z18.s\n" + "zip2 z19.s, z19.s, z18.s\n" + "ld1w { z24.s }, p1/Z, [x23, #1, MUL VL]\n" + "ld1w { z23.s }, p0/Z, [x23, #2, MUL VL]\n" + "zip1 z18.s, z17.s, z16.s\n" + "zip2 z17.s, z17.s, z16.s\n" + "zip1 z16.s, z22.s, z21.s\n" + "zip2 z22.s, z22.s, z21.s\n" + "st1w { z20.s }, p3, [x20]\n" + "cmp x27, #0x0\n" + "zip1 z21.s, z28.s, z27.s\n" + "zip2 z20.s, z28.s, z27.s\n" + "st1w { z19.s }, p3, [x20, #1, MUL VL]\n" + "addvl x28, x28, #3\n" + "st1w { z18.s }, p3, [x20, #2, MUL VL]\n" + "addvl x25, x25, #3\n" + "addvl x24, x24, #3\n" + "zip1 z19.s, z26.s, z24.s\n" + "st1w { z17.s }, p3, [x20, #3, MUL VL]\n" + "addvl x23, x23, #3\n" + "zip2 z18.s, z26.s, z24.s\n" + "zip1 z17.s, z25.s, z23.s\n" + "st1w { z16.s }, p3, [x20, #4, MUL VL]\n" + "zip2 z16.s, z25.s, z23.s\n" + "add x22, x22, %x[out_stride]\n" + "st1w { z22.s }, p3, [x20, #5, MUL VL]\n" + "st1w { z21.s }, p3, [x20, #6, MUL VL]\n" + "st1w { z20.s }, p3, [x20, #7, MUL VL]\n" + "addvl x20, x20, #12\n" + "st1w { z19.s }, p3, [x20, #-4, MUL VL]\n" + "st1w { z18.s }, p3, [x20, #-3, MUL VL]\n" + "st1w { z17.s }, p3, [x20, #-2, MUL VL]\n" + "st1w { z16.s }, p3, [x20, #-1, MUL VL]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #12\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x28, %x[in]\n" + "mov x21, %x[width]\n" + "cntw x20, ALL, MUL #6\n" + "add x25, x28, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1w { z17.s }, p3/Z, [x28]\n" + "ld1w { z19.s }, p3/Z, [x28, #1, MUL VL]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1w { z18.s }, p3/Z, [x28, #2, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x25]\n" + "zip1 z28.s, z17.s, z16.s\n" + "zip2 z20.s, z17.s, z16.s\n" + "ld1w { z17.s }, p3/Z, [x25, #1, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x25, #2, MUL VL]\n" + "zip1 z27.s, z19.s, z17.s\n" + "zip2 z26.s, z19.s, z17.s\n" + "ld1w { z19.s }, p3/Z, [x28, #3, MUL VL]\n" + "ld1w { z25.s }, p3/Z, [x28, #4, MUL VL]\n" + "zip1 z24.s, z18.s, z16.s\n" + "zip2 z23.s, z18.s, z16.s\n" + "ld1w { z22.s }, p3/Z, [x28, #5, MUL VL]\n" + "ld1w { z18.s }, p3/Z, [x25, #3, MUL VL]\n" + "addvl x28, x28, #6\n" + "zip1 z21.s, z19.s, z18.s\n" + "ld1w { z17.s }, p3/Z, [x25, #4, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x25, #5, MUL VL]\n" + "st1w { z28.s }, p3, [x22]\n" + "addvl x25, x25, #6\n" + "st1w { z20.s }, p3, [x22, #1, MUL VL]\n" + "zip2 z20.s, z19.s, z18.s\n" + "zip1 z19.s, z25.s, z17.s\n" + "st1w { z27.s }, p3, [x22, #2, MUL VL]\n" + "zip2 z18.s, z25.s, z17.s\n" + "zip1 z17.s, z22.s, z16.s\n" + "st1w { z26.s }, p3, [x22, #3, MUL VL]\n" + "zip2 z16.s, z22.s, z16.s\n" + "st1w { z24.s }, p3, [x22, #4, MUL VL]\n" + "st1w { z23.s }, p3, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1w { z21.s }, p3, [x22]\n" + "st1w { z20.s }, p3, [x22, #1, MUL VL]\n" + "st1w { z19.s }, p3, [x22, #2, MUL VL]\n" + "st1w { z18.s }, p3, [x22, #3, MUL VL]\n" + "st1w { z17.s }, p3, [x22, #4, MUL VL]\n" + "st1w { z16.s }, p3, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z20.s }, p0/Z, [x28]\n" + "ld1w { z19.s }, p0/Z, [x25]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z18.s }, p0/Z, [x28, #1, MUL VL]\n" + "ld1w { z17.s }, p0/Z, [x25, #1, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z22.s }, p0/Z, [x28, #2, MUL VL]\n" + "ld1w { z16.s }, p0/Z, [x25, #2, MUL VL]\n" + "decd x21, ALL, MUL #6\n" + "cmp x21, #0x0\n" + "zip1 z21.s, z20.s, z19.s\n" + "zip2 z20.s, z20.s, z19.s\n" + "addvl x28, x28, #3\n" + "addvl x25, x25, #3\n" + "zip1 z19.s, z18.s, z17.s\n" + "zip2 z18.s, z18.s, z17.s\n" + "zip1 z17.s, z22.s, z16.s\n" + "zip2 z16.s, z22.s, z16.s\n" + "st1w { z21.s }, p3, [x22]\n" + "st1w { z20.s }, p3, [x22, #1, MUL VL]\n" + "st1w { z19.s }, p3, [x22, #2, MUL VL]\n" + "st1w { z18.s }, p3, [x22, #3, MUL VL]\n" + "st1w { z17.s }, p3, [x22, #4, MUL VL]\n" + "st1w { z16.s }, p3, [x22, #5, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #6\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<6, 2, true, VLType::SVE>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_6VL_4x2( + reinterpret_cast<uint32_t *>(out), + reinterpret_cast<const uint32_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 4, + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL.hpp new file mode 100644 index 0000000000..0526bd0596 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL.hpp @@ -0,0 +1,305 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_8VL(uint32_t *out, const uint32_t *in, size_t width, size_t in_stride, size_t height) +{ + size_t out_stride = 8 * height * get_vector_length<uint8_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x2\n" + "ptrue p1.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "mov x25, %x[width]\n" + "cntw x24, ALL, MUL #16\n" + "add x23, x26, %x[in_stride]\n" + "cmp x25, x24\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1w { z15.s }, p1/Z, [x26]\n" + "ld1w { z14.s }, p1/Z, [x26, #1, MUL VL]\n" + "mov x21, x22\n" + "add x22, x22, %x[out_stride]\n" + "ld1w { z13.s }, p1/Z, [x26, #2, MUL VL]\n" + "ld1w { z12.s }, p1/Z, [x26, #3, MUL VL]\n" + "mov x20, x22\n" + "sub x25, x25, x24\n" + "ld1w { z11.s }, p1/Z, [x26, #4, MUL VL]\n" + "ld1w { z10.s }, p1/Z, [x26, #5, MUL VL]\n" + "cmp x25, x24\n" + "add x22, x22, %x[out_stride]\n" + "ld1w { z9.s }, p1/Z, [x26, #6, MUL VL]\n" + "ld1w { z8.s }, p1/Z, [x26, #7, MUL VL]\n" + "addvl x26, x26, #16\n" + "ld1w { z7.s }, p1/Z, [x23]\n" + "ld1w { z6.s }, p1/Z, [x23, #1, MUL VL]\n" + "ld1w { z5.s }, p1/Z, [x23, #2, MUL VL]\n" + "ld1w { z4.s }, p1/Z, [x23, #3, MUL VL]\n" + "ld1w { z3.s }, p1/Z, [x23, #4, MUL VL]\n" + "ld1w { z2.s }, p1/Z, [x23, #5, MUL VL]\n" + "ld1w { z1.s }, p1/Z, [x23, #6, MUL VL]\n" + "ld1w { z0.s }, p1/Z, [x23, #7, MUL VL]\n" + "addvl x23, x23, #16\n" + "ld1w { z31.s }, p1/Z, [x26, #-8, MUL VL]\n" + "ld1w { z30.s }, p1/Z, [x26, #-7, MUL VL]\n" + "ld1w { z29.s }, p1/Z, [x26, #-6, MUL VL]\n" + "ld1w { z28.s }, p1/Z, [x26, #-5, MUL VL]\n" + "ld1w { z27.s }, p1/Z, [x26, #-4, MUL VL]\n" + "ld1w { z26.s }, p1/Z, [x26, #-3, MUL VL]\n" + "ld1w { z25.s }, p1/Z, [x26, #-2, MUL VL]\n" + "ld1w { z24.s }, p1/Z, [x26, #-1, MUL VL]\n" + "ld1w { z23.s }, p1/Z, [x23, #-8, MUL VL]\n" + "ld1w { z22.s }, p1/Z, [x23, #-7, MUL VL]\n" + "ld1w { z21.s }, p1/Z, [x23, #-6, MUL VL]\n" + "ld1w { z20.s }, p1/Z, [x23, #-5, MUL VL]\n" + "ld1w { z19.s }, p1/Z, [x23, #-4, MUL VL]\n" + "ld1w { z18.s }, p1/Z, [x23, #-3, MUL VL]\n" + "ld1w { z17.s }, p1/Z, [x23, #-2, MUL VL]\n" + "ld1w { z16.s }, p1/Z, [x23, #-1, MUL VL]\n" + "st1w { z15.s }, p1, [x21]\n" + "st1w { z14.s }, p1, [x21, #1, MUL VL]\n" + "st1w { z13.s }, p1, [x21, #2, MUL VL]\n" + "st1w { z12.s }, p1, [x21, #3, MUL VL]\n" + "st1w { z11.s }, p1, [x21, #4, MUL VL]\n" + "st1w { z10.s }, p1, [x21, #5, MUL VL]\n" + "st1w { z9.s }, p1, [x21, #6, MUL VL]\n" + "st1w { z8.s }, p1, [x21, #7, MUL VL]\n" + "addvl x21, x21, #16\n" + "st1w { z7.s }, p1, [x21, #-8, MUL VL]\n" + "st1w { z6.s }, p1, [x21, #-7, MUL VL]\n" + "st1w { z5.s }, p1, [x21, #-6, MUL VL]\n" + "st1w { z4.s }, p1, [x21, #-5, MUL VL]\n" + "st1w { z3.s }, p1, [x21, #-4, MUL VL]\n" + "st1w { z2.s }, p1, [x21, #-3, MUL VL]\n" + "st1w { z1.s }, p1, [x21, #-2, MUL VL]\n" + "st1w { z0.s }, p1, [x21, #-1, MUL VL]\n" + "st1w { z31.s }, p1, [x20]\n" + "st1w { z30.s }, p1, [x20, #1, MUL VL]\n" + "st1w { z29.s }, p1, [x20, #2, MUL VL]\n" + "st1w { z28.s }, p1, [x20, #3, MUL VL]\n" + "st1w { z27.s }, p1, [x20, #4, MUL VL]\n" + "st1w { z26.s }, p1, [x20, #5, MUL VL]\n" + "st1w { z25.s }, p1, [x20, #6, MUL VL]\n" + "st1w { z24.s }, p1, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1w { z23.s }, p1, [x20, #-8, MUL VL]\n" + "st1w { z22.s }, p1, [x20, #-7, MUL VL]\n" + "st1w { z21.s }, p1, [x20, #-6, MUL VL]\n" + "st1w { z20.s }, p1, [x20, #-5, MUL VL]\n" + "st1w { z19.s }, p1, [x20, #-4, MUL VL]\n" + "st1w { z18.s }, p1, [x20, #-3, MUL VL]\n" + "st1w { z17.s }, p1, [x20, #-2, MUL VL]\n" + "st1w { z16.s }, p1, [x20, #-1, MUL VL]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x25, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x25\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z31.s }, p0/Z, [x26]\n" + "ld1w { z30.s }, p0/Z, [x23]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z29.s }, p0/Z, [x26, #1, MUL VL]\n" + "ld1w { z28.s }, p0/Z, [x23, #1, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z27.s }, p0/Z, [x26, #2, MUL VL]\n" + "ld1w { z26.s }, p0/Z, [x23, #2, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z25.s }, p0/Z, [x26, #3, MUL VL]\n" + "ld1w { z24.s }, p0/Z, [x23, #3, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z23.s }, p0/Z, [x26, #4, MUL VL]\n" + "ld1w { z22.s }, p0/Z, [x23, #4, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z21.s }, p0/Z, [x26, #5, MUL VL]\n" + "ld1w { z20.s }, p0/Z, [x23, #5, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z19.s }, p0/Z, [x26, #6, MUL VL]\n" + "ld1w { z18.s }, p0/Z, [x23, #6, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z17.s }, p0/Z, [x26, #7, MUL VL]\n" + "ld1w { z16.s }, p0/Z, [x23, #7, MUL VL]\n" + "mov x20, x22\n" + "decw x25, ALL, MUL #8\n" + "st1w { z31.s }, p1, [x20]\n" + "st1w { z29.s }, p1, [x20, #1, MUL VL]\n" + "cmp x25, #0x0\n" + "addvl x26, x26, #8\n" + "st1w { z27.s }, p1, [x20, #2, MUL VL]\n" + "addvl x23, x23, #8\n" + "add x22, x22, %x[out_stride]\n" + "st1w { z25.s }, p1, [x20, #3, MUL VL]\n" + "st1w { z23.s }, p1, [x20, #4, MUL VL]\n" + "st1w { z21.s }, p1, [x20, #5, MUL VL]\n" + "st1w { z19.s }, p1, [x20, #6, MUL VL]\n" + "st1w { z17.s }, p1, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1w { z30.s }, p1, [x20, #-8, MUL VL]\n" + "st1w { z28.s }, p1, [x20, #-7, MUL VL]\n" + "st1w { z26.s }, p1, [x20, #-6, MUL VL]\n" + "st1w { z24.s }, p1, [x20, #-5, MUL VL]\n" + "st1w { z22.s }, p1, [x20, #-4, MUL VL]\n" + "st1w { z20.s }, p1, [x20, #-3, MUL VL]\n" + "st1w { z18.s }, p1, [x20, #-2, MUL VL]\n" + "st1w { z16.s }, p1, [x20, #-1, MUL VL]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x2\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x21, %x[width]\n" + "cntw x20, ALL, MUL #16\n" + "mov x26, %x[in]\n" + "cmp x21, x20\n" + "add %x[in], x26, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x1\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1w { z31.s }, p1/Z, [x26]\n" + "ld1w { z30.s }, p1/Z, [x26, #1, MUL VL]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1w { z29.s }, p1/Z, [x26, #2, MUL VL]\n" + "ld1w { z28.s }, p1/Z, [x26, #3, MUL VL]\n" + "ld1w { z27.s }, p1/Z, [x26, #4, MUL VL]\n" + "ld1w { z26.s }, p1/Z, [x26, #5, MUL VL]\n" + "ld1w { z25.s }, p1/Z, [x26, #6, MUL VL]\n" + "ld1w { z24.s }, p1/Z, [x26, #7, MUL VL]\n" + "addvl x26, x26, #16\n" + "ld1w { z23.s }, p1/Z, [x26, #-8, MUL VL]\n" + "ld1w { z22.s }, p1/Z, [x26, #-7, MUL VL]\n" + "ld1w { z21.s }, p1/Z, [x26, #-6, MUL VL]\n" + "ld1w { z20.s }, p1/Z, [x26, #-5, MUL VL]\n" + "ld1w { z19.s }, p1/Z, [x26, #-4, MUL VL]\n" + "ld1w { z18.s }, p1/Z, [x26, #-3, MUL VL]\n" + "ld1w { z17.s }, p1/Z, [x26, #-2, MUL VL]\n" + "ld1w { z16.s }, p1/Z, [x26, #-1, MUL VL]\n" + "st1w { z31.s }, p1, [x22]\n" + "st1w { z30.s }, p1, [x22, #1, MUL VL]\n" + "st1w { z29.s }, p1, [x22, #2, MUL VL]\n" + "st1w { z28.s }, p1, [x22, #3, MUL VL]\n" + "st1w { z27.s }, p1, [x22, #4, MUL VL]\n" + "st1w { z26.s }, p1, [x22, #5, MUL VL]\n" + "st1w { z25.s }, p1, [x22, #6, MUL VL]\n" + "st1w { z24.s }, p1, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1w { z23.s }, p1, [x22]\n" + "st1w { z22.s }, p1, [x22, #1, MUL VL]\n" + "st1w { z21.s }, p1, [x22, #2, MUL VL]\n" + "st1w { z20.s }, p1, [x22, #3, MUL VL]\n" + "st1w { z19.s }, p1, [x22, #4, MUL VL]\n" + "st1w { z18.s }, p1, [x22, #5, MUL VL]\n" + "st1w { z17.s }, p1, [x22, #6, MUL VL]\n" + "st1w { z16.s }, p1, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z23.s }, p0/Z, [x26]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z22.s }, p0/Z, [x26, #1, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z21.s }, p0/Z, [x26, #2, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z20.s }, p0/Z, [x26, #3, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z19.s }, p0/Z, [x26, #4, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z18.s }, p0/Z, [x26, #5, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z17.s }, p0/Z, [x26, #6, MUL VL]\n" + "decw x20\n" + "decw x21, ALL, MUL #8\n" + "whilelt p0.s, XZR, x20\n" + "cmp x21, #0x0\n" + "ld1w { z16.s }, p0/Z, [x26, #7, MUL VL]\n" + "st1w { z23.s }, p1, [x22]\n" + "addvl x26, x26, #8\n" + "st1w { z22.s }, p1, [x22, #1, MUL VL]\n" + "st1w { z21.s }, p1, [x22, #2, MUL VL]\n" + "st1w { z20.s }, p1, [x22, #3, MUL VL]\n" + "st1w { z19.s }, p1, [x22, #4, MUL VL]\n" + "st1w { z18.s }, p1, [x22, #5, MUL VL]\n" + "st1w { z17.s }, p1, [x22, #6, MUL VL]\n" + "st1w { z16.s }, p1, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #8\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [width] "r" (width) + : "cc", "memory", "p0", "p1", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<8, 1, true, VLType::SVE>( + float *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_8VL( + reinterpret_cast<uint32_t *>(out), + reinterpret_cast<const uint32_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(float) / 4, + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_1x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_1x4.hpp new file mode 100644 index 0000000000..98f0770d77 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_1x4.hpp @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_8VL_1x4(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 8 * roundup<size_t>(height, 4) * get_vector_length<uint32_t>(); + + __asm__ __volatile__( + "ptrue p2.b\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[width]\n" + "cntb x20, ALL, MUL #8\n" + "add x22, x24, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x22, %x[in_stride]\n" + "csel x22, x22, %x[pad_row], GT\n" + "csel x24, x24, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x23, x20\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1b { z7.b }, p2/Z, [x26]\n" + "ld1b { z24.b }, p2/Z, [x26, #1, MUL VL]\n" + "sub x23, x23, x20\n" + "cmp x23, x20\n" + "ld1b { z31.b }, p2/Z, [x25]\n" + "ld1b { z18.b }, p2/Z, [x25, #1, MUL VL]\n" + "ld1b { z19.b }, p2/Z, [x24]\n" + "ld1b { z25.b }, p2/Z, [x24, #1, MUL VL]\n" + "zip1 z23.b, z7.b, z19.b\n" + "zip2 z20.b, z7.b, z19.b\n" + "ld1b { z30.b }, p2/Z, [x22]\n" + "ld1b { z3.b }, p2/Z, [x22, #1, MUL VL]\n" + "zip1 z21.b, z31.b, z30.b\n" + "zip2 z19.b, z31.b, z30.b\n" + "ld1b { z16.b }, p2/Z, [x26, #2, MUL VL]\n" + "ld1b { z30.b }, p2/Z, [x26, #3, MUL VL]\n" + "zip1 z2.b, z24.b, z25.b\n" + "zip1 z17.b, z18.b, z3.b\n" + "ld1b { z29.b }, p2/Z, [x25, #2, MUL VL]\n" + "ld1b { z8.b }, p2/Z, [x25, #3, MUL VL]\n" + "zip2 z22.b, z24.b, z25.b\n" + "zip2 z4.b, z18.b, z3.b\n" + "ld1b { z0.b }, p2/Z, [x24, #2, MUL VL]\n" + "ld1b { z3.b }, p2/Z, [x24, #3, MUL VL]\n" + "zip1 z9.b, z16.b, z0.b\n" + "zip2 z14.b, z16.b, z0.b\n" + "ld1b { z18.b }, p2/Z, [x22, #2, MUL VL]\n" + "ld1b { z16.b }, p2/Z, [x22, #3, MUL VL]\n" + "zip1 z24.b, z29.b, z18.b\n" + "zip2 z11.b, z29.b, z18.b\n" + "ld1b { z1.b }, p2/Z, [x26, #4, MUL VL]\n" + "ld1b { z12.b }, p2/Z, [x26, #5, MUL VL]\n" + "zip1 z13.b, z30.b, z3.b\n" + "zip1 z15.b, z8.b, z16.b\n" + "ld1b { z5.b }, p2/Z, [x25, #4, MUL VL]\n" + "ld1b { z29.b }, p2/Z, [x25, #5, MUL VL]\n" + "zip2 z31.b, z30.b, z3.b\n" + "zip2 z30.b, z8.b, z16.b\n" + "ld1b { z16.b }, p2/Z, [x24, #4, MUL VL]\n" + "ld1b { z18.b }, p2/Z, [x24, #5, MUL VL]\n" + "zip1 z27.b, z1.b, z16.b\n" + "zip2 z10.b, z1.b, z16.b\n" + "ld1b { z7.b }, p2/Z, [x22, #4, MUL VL]\n" + "ld1b { z16.b }, p2/Z, [x22, #5, MUL VL]\n" + "zip1 z8.b, z5.b, z7.b\n" + "zip2 z26.b, z5.b, z7.b\n" + "ld1b { z3.b }, p2/Z, [x26, #6, MUL VL]\n" + "ld1b { z25.b }, p2/Z, [x26, #7, MUL VL]\n" + "zip1 z6.b, z12.b, z18.b\n" + "zip1 z5.b, z29.b, z16.b\n" + "ld1b { z0.b }, p2/Z, [x25, #6, MUL VL]\n" + "ld1b { z28.b }, p2/Z, [x25, #7, MUL VL]\n" + "zip2 z12.b, z12.b, z18.b\n" + "zip2 z7.b, z29.b, z16.b\n" + "ld1b { z1.b }, p2/Z, [x24, #6, MUL VL]\n" + "ld1b { z29.b }, p2/Z, [x24, #7, MUL VL]\n" + "zip1 z16.b, z23.b, z21.b\n" + "zip2 z18.b, z23.b, z21.b\n" + "ld1b { z23.b }, p2/Z, [x22, #6, MUL VL]\n" + "ld1b { z21.b }, p2/Z, [x22, #7, MUL VL]\n" + "st1b { z16.b }, p2, [x21]\n" + "zip1 z16.b, z20.b, z19.b\n" + "zip2 z20.b, z20.b, z19.b\n" + "zip1 z19.b, z2.b, z17.b\n" + "st1b { z18.b }, p2, [x21, #1, MUL VL]\n" + "addvl x26, x26, #8\n" + "zip2 z18.b, z2.b, z17.b\n" + "zip1 z17.b, z22.b, z4.b\n" + "st1b { z16.b }, p2, [x21, #2, MUL VL]\n" + "addvl x25, x25, #8\n" + "zip2 z16.b, z22.b, z4.b\n" + "st1b { z20.b }, p2, [x21, #3, MUL VL]\n" + "zip1 z4.b, z3.b, z1.b\n" + "addvl x24, x24, #8\n" + "st1b { z19.b }, p2, [x21, #4, MUL VL]\n" + "zip1 z22.b, z0.b, z23.b\n" + "zip2 z3.b, z3.b, z1.b\n" + "addvl x22, x22, #8\n" + "st1b { z18.b }, p2, [x21, #5, MUL VL]\n" + "zip2 z2.b, z0.b, z23.b\n" + "zip1 z1.b, z25.b, z29.b\n" + "st1b { z17.b }, p2, [x21, #6, MUL VL]\n" + "zip1 z0.b, z28.b, z21.b\n" + "zip2 z29.b, z25.b, z29.b\n" + "st1b { z16.b }, p2, [x21, #7, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + "zip2 z28.b, z28.b, z21.b\n" + "zip1 z17.b, z9.b, z24.b\n" + "zip2 z16.b, z9.b, z24.b\n" + "zip1 z19.b, z14.b, z11.b\n" + "st1b { z17.b }, p2, [x21]\n" + "zip2 z18.b, z14.b, z11.b\n" + "zip1 z17.b, z13.b, z15.b\n" + "st1b { z16.b }, p2, [x21, #1, MUL VL]\n" + "zip2 z16.b, z13.b, z15.b\n" + "zip1 z21.b, z31.b, z30.b\n" + "st1b { z19.b }, p2, [x21, #2, MUL VL]\n" + "zip2 z20.b, z31.b, z30.b\n" + "st1b { z18.b }, p2, [x21, #3, MUL VL]\n" + "zip1 z19.b, z27.b, z8.b\n" + "st1b { z17.b }, p2, [x21, #4, MUL VL]\n" + "zip2 z18.b, z27.b, z8.b\n" + "zip1 z17.b, z10.b, z26.b\n" + "st1b { z16.b }, p2, [x21, #5, MUL VL]\n" + "zip2 z16.b, z10.b, z26.b\n" + "zip1 z27.b, z6.b, z5.b\n" + "st1b { z21.b }, p2, [x21, #6, MUL VL]\n" + "zip2 z26.b, z6.b, z5.b\n" + "zip1 z25.b, z12.b, z7.b\n" + "st1b { z20.b }, p2, [x21, #7, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + "zip2 z24.b, z12.b, z7.b\n" + "zip1 z23.b, z4.b, z22.b\n" + "st1b { z19.b }, p2, [x21]\n" + "zip2 z22.b, z4.b, z22.b\n" + "zip1 z21.b, z3.b, z2.b\n" + "st1b { z18.b }, p2, [x21, #1, MUL VL]\n" + "zip2 z20.b, z3.b, z2.b\n" + "zip1 z19.b, z1.b, z0.b\n" + "st1b { z17.b }, p2, [x21, #2, MUL VL]\n" + "zip2 z18.b, z1.b, z0.b\n" + "zip1 z17.b, z29.b, z28.b\n" + "st1b { z16.b }, p2, [x21, #3, MUL VL]\n" + "zip2 z16.b, z29.b, z28.b\n" + "st1b { z27.b }, p2, [x21, #4, MUL VL]\n" + "st1b { z26.b }, p2, [x21, #5, MUL VL]\n" + "st1b { z25.b }, p2, [x21, #6, MUL VL]\n" + "st1b { z24.b }, p2, [x21, #7, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + "st1b { z23.b }, p2, [x21]\n" + "st1b { z22.b }, p2, [x21, #1, MUL VL]\n" + "st1b { z21.b }, p2, [x21, #2, MUL VL]\n" + "st1b { z20.b }, p2, [x21, #3, MUL VL]\n" + "st1b { z19.b }, p2, [x21, #4, MUL VL]\n" + "st1b { z18.b }, p2, [x21, #5, MUL VL]\n" + "st1b { z17.b }, p2, [x21, #6, MUL VL]\n" + "st1b { z16.b }, p2, [x21, #7, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x23, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x23\n" + "whilelt p1.b, XZR, x20\n" + "ld1b { z23.b }, p1/Z, [x26]\n" + "ld1b { z22.b }, p1/Z, [x25]\n" + "decb x20\n" + "whilelt p0.b, XZR, x20\n" + "ld1b { z21.b }, p0/Z, [x26, #1, MUL VL]\n" + "ld1b { z25.b }, p0/Z, [x25, #1, MUL VL]\n" + "ld1b { z19.b }, p1/Z, [x24]\n" + "ld1b { z20.b }, p0/Z, [x24, #1, MUL VL]\n" + "decw x23, ALL, MUL #8\n" + "zip1 z24.b, z23.b, z19.b\n" + "ld1b { z18.b }, p1/Z, [x22]\n" + "ld1b { z16.b }, p0/Z, [x22, #1, MUL VL]\n" + "zip1 z17.b, z22.b, z18.b\n" + "zip2 z23.b, z23.b, z19.b\n" + "zip2 z19.b, z22.b, z18.b\n" + "zip1 z22.b, z21.b, z20.b\n" + "cmp x23, #0x0\n" + "addvl x26, x26, #2\n" + "zip1 z18.b, z25.b, z16.b\n" + "zip2 z21.b, z21.b, z20.b\n" + "addvl x25, x25, #2\n" + "addvl x24, x24, #2\n" + "zip2 z20.b, z25.b, z16.b\n" + "addvl x22, x22, #2\n" + "zip1 z16.b, z24.b, z17.b\n" + "st1b { z16.b }, p2, [x21]\n" + "zip2 z16.b, z24.b, z17.b\n" + "zip1 z17.b, z23.b, z19.b\n" + "st1b { z16.b }, p2, [x21, #1, MUL VL]\n" + "zip2 z16.b, z23.b, z19.b\n" + "zip1 z19.b, z22.b, z18.b\n" + "st1b { z17.b }, p2, [x21, #2, MUL VL]\n" + "zip2 z18.b, z22.b, z18.b\n" + "zip1 z17.b, z21.b, z20.b\n" + "st1b { z16.b }, p2, [x21, #3, MUL VL]\n" + "zip2 z16.b, z21.b, z20.b\n" + "st1b { z19.b }, p2, [x21, #4, MUL VL]\n" + "st1b { z18.b }, p2, [x21, #5, MUL VL]\n" + "st1b { z17.b }, p2, [x21, #6, MUL VL]\n" + "st1b { z16.b }, p2, [x21, #7, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #8\n" + "bge 1b\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<8, 4, true, VLType::SVE>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_8VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<8, 4, true, VLType::SVE>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_8VL_1x4( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_1x8.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_1x8.hpp new file mode 100644 index 0000000000..3fa5292143 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_1x8.hpp @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_8VL_1x8(uint8_t *out, const uint8_t *in, size_t width, size_t in_stride, size_t height) +{ + uint8_t *pad_row = reinterpret_cast<uint8_t *>(alloca(width * sizeof(uint8_t))); + + if (height % 8) { + memset(pad_row, 0, width * sizeof(uint8_t)); + } + + size_t out_stride = 8 * roundup<size_t>(height, 8) * get_vector_length<uint64_t>(); + + __asm__ __volatile__( + "ptrue p1.b\n" + "1:" // Main row loop: Head + "mov x10, %x[in]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "add x27, x28, %x[in_stride]\n" + "add x26, x27, %x[in_stride]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp %x[height], #0x7\n" + "add %x[in], x23, %x[in_stride]\n" + "csel x23, x23, %x[pad_row], GT\n" + "csel x24, x24, %x[pad_row], GE\n" + "cmp %x[height], #0x5\n" + "mov x22, %x[width]\n" + "cntb x21, ALL, MUL #2\n" + "csel x25, x25, %x[pad_row], GT\n" + "csel x26, x26, %x[pad_row], GE\n" + "cmp %x[height], #0x3\n" + "csel x27, x27, %x[pad_row], GT\n" + "csel x28, x28, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x9, x9, %x[pad_row], GT\n" + "cmp x22, x21\n" + "mov x20, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1b { z23.b }, p1/Z, [x10]\n" + "ld1b { z22.b }, p1/Z, [x9]\n" + "sub x22, x22, x21\n" + "cmp x22, x21\n" + "ld1b { z20.b }, p1/Z, [x28]\n" + "ld1b { z21.b }, p1/Z, [x27]\n" + "ld1b { z19.b }, p1/Z, [x26]\n" + "ld1b { z18.b }, p1/Z, [x25]\n" + "zip1 z5.b, z23.b, z19.b\n" + "zip1 z4.b, z22.b, z18.b\n" + "ld1b { z17.b }, p1/Z, [x24]\n" + "ld1b { z16.b }, p1/Z, [x23]\n" + "zip1 z3.b, z20.b, z17.b\n" + "zip1 z31.b, z21.b, z16.b\n" + "ld1b { z25.b }, p1/Z, [x10, #1, MUL VL]\n" + "ld1b { z24.b }, p1/Z, [x9, #1, MUL VL]\n" + "zip2 z2.b, z23.b, z19.b\n" + "zip2 z30.b, z20.b, z17.b\n" + "ld1b { z23.b }, p1/Z, [x28, #1, MUL VL]\n" + "ld1b { z20.b }, p1/Z, [x27, #1, MUL VL]\n" + "zip2 z22.b, z22.b, z18.b\n" + "zip2 z21.b, z21.b, z16.b\n" + "ld1b { z19.b }, p1/Z, [x26, #1, MUL VL]\n" + "ld1b { z18.b }, p1/Z, [x25, #1, MUL VL]\n" + "zip1 z29.b, z25.b, z19.b\n" + "zip1 z28.b, z24.b, z18.b\n" + "ld1b { z17.b }, p1/Z, [x24, #1, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x23, #1, MUL VL]\n" + "zip1 z27.b, z23.b, z17.b\n" + "zip1 z26.b, z20.b, z16.b\n" + "zip2 z1.b, z25.b, z19.b\n" + "zip2 z25.b, z23.b, z17.b\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "zip2 z24.b, z24.b, z18.b\n" + "zip2 z16.b, z20.b, z16.b\n" + "addvl x28, x28, #2\n" + "addvl x27, x27, #2\n" + "zip1 z0.b, z5.b, z3.b\n" + "zip1 z17.b, z4.b, z31.b\n" + "addvl x26, x26, #2\n" + "addvl x25, x25, #2\n" + "zip2 z20.b, z5.b, z3.b\n" + "zip2 z19.b, z4.b, z31.b\n" + "addvl x24, x24, #2\n" + "addvl x23, x23, #2\n" + "zip1 z31.b, z2.b, z30.b\n" + "zip1 z18.b, z22.b, z21.b\n" + "zip2 z30.b, z2.b, z30.b\n" + "zip2 z23.b, z22.b, z21.b\n" + "zip1 z22.b, z29.b, z27.b\n" + "zip1 z21.b, z28.b, z26.b\n" + "zip2 z29.b, z29.b, z27.b\n" + "zip2 z28.b, z28.b, z26.b\n" + "zip1 z27.b, z1.b, z25.b\n" + "zip1 z26.b, z24.b, z16.b\n" + "zip2 z25.b, z1.b, z25.b\n" + "zip2 z24.b, z24.b, z16.b\n" + "zip1 z16.b, z0.b, z17.b\n" + "zip2 z17.b, z0.b, z17.b\n" + "st1b { z16.b }, p1, [x20]\n" + "zip1 z16.b, z20.b, z19.b\n" + "zip2 z20.b, z20.b, z19.b\n" + "st1b { z17.b }, p1, [x20, #1, MUL VL]\n" + "zip1 z19.b, z31.b, z18.b\n" + "zip2 z18.b, z31.b, z18.b\n" + "st1b { z16.b }, p1, [x20, #2, MUL VL]\n" + "zip1 z17.b, z30.b, z23.b\n" + "zip2 z16.b, z30.b, z23.b\n" + "st1b { z20.b }, p1, [x20, #3, MUL VL]\n" + "st1b { z19.b }, p1, [x20, #4, MUL VL]\n" + "zip1 z23.b, z22.b, z21.b\n" + "zip2 z22.b, z22.b, z21.b\n" + "st1b { z18.b }, p1, [x20, #5, MUL VL]\n" + "zip1 z21.b, z29.b, z28.b\n" + "zip2 z20.b, z29.b, z28.b\n" + "st1b { z17.b }, p1, [x20, #6, MUL VL]\n" + "zip1 z19.b, z27.b, z26.b\n" + "zip2 z18.b, z27.b, z26.b\n" + "st1b { z16.b }, p1, [x20, #7, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "zip1 z17.b, z25.b, z24.b\n" + "zip2 z16.b, z25.b, z24.b\n" + "st1b { z23.b }, p1, [x20]\n" + "st1b { z22.b }, p1, [x20, #1, MUL VL]\n" + "st1b { z21.b }, p1, [x20, #2, MUL VL]\n" + "st1b { z20.b }, p1, [x20, #3, MUL VL]\n" + "st1b { z19.b }, p1, [x20, #4, MUL VL]\n" + "st1b { z18.b }, p1, [x20, #5, MUL VL]\n" + "st1b { z17.b }, p1, [x20, #6, MUL VL]\n" + "st1b { z16.b }, p1, [x20, #7, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x22, 5f\n" + "4:" // Main row loop: Column loop + "whilelt p0.b, XZR, x22\n" + "ld1b { z25.b }, p0/Z, [x10]\n" + "ld1b { z27.b }, p0/Z, [x9]\n" + "decd x22, ALL, MUL #8\n" + "ld1b { z26.b }, p0/Z, [x28]\n" + "ld1b { z24.b }, p0/Z, [x27]\n" + "cmp x22, #0x0\n" + "addvl x10, x10, #1\n" + "ld1b { z22.b }, p0/Z, [x26]\n" + "ld1b { z21.b }, p0/Z, [x25]\n" + "zip1 z20.b, z25.b, z22.b\n" + "zip1 z23.b, z27.b, z21.b\n" + "ld1b { z17.b }, p0/Z, [x24]\n" + "ld1b { z16.b }, p0/Z, [x23]\n" + "zip1 z19.b, z26.b, z17.b\n" + "zip1 z18.b, z24.b, z16.b\n" + "zip2 z25.b, z25.b, z22.b\n" + "zip2 z22.b, z26.b, z17.b\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "zip2 z21.b, z27.b, z21.b\n" + "zip2 z16.b, z24.b, z16.b\n" + "addvl x27, x27, #1\n" + "addvl x26, x26, #1\n" + "zip1 z24.b, z20.b, z19.b\n" + "zip1 z17.b, z23.b, z18.b\n" + "addvl x25, x25, #1\n" + "addvl x24, x24, #1\n" + "zip2 z20.b, z20.b, z19.b\n" + "zip2 z19.b, z23.b, z18.b\n" + "addvl x23, x23, #1\n" + "zip1 z23.b, z25.b, z22.b\n" + "zip1 z18.b, z21.b, z16.b\n" + "zip2 z22.b, z25.b, z22.b\n" + "zip2 z21.b, z21.b, z16.b\n" + "zip1 z16.b, z24.b, z17.b\n" + "zip2 z17.b, z24.b, z17.b\n" + "st1b { z16.b }, p1, [x20]\n" + "zip1 z16.b, z20.b, z19.b\n" + "zip2 z20.b, z20.b, z19.b\n" + "st1b { z17.b }, p1, [x20, #1, MUL VL]\n" + "zip1 z19.b, z23.b, z18.b\n" + "zip2 z18.b, z23.b, z18.b\n" + "st1b { z16.b }, p1, [x20, #2, MUL VL]\n" + "zip1 z17.b, z22.b, z21.b\n" + "zip2 z16.b, z22.b, z21.b\n" + "st1b { z20.b }, p1, [x20, #3, MUL VL]\n" + "st1b { z19.b }, p1, [x20, #4, MUL VL]\n" + "st1b { z18.b }, p1, [x20, #5, MUL VL]\n" + "st1b { z17.b }, p1, [x20, #6, MUL VL]\n" + "st1b { z16.b }, p1, [x20, #7, MUL VL]\n" + "add x20, x20, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #8\n" + "bge 1b\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<8, 8, true, VLType::SVE>( + uint8_t *out, const uint8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_8VL_1x8( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(uint8_t) / 1, + stride * sizeof(uint8_t), + (kmax-k0) + ); +} + +template<> +void Transform<8, 8, true, VLType::SVE>( + int8_t *out, const int8_t *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_8VL_1x8( + reinterpret_cast<uint8_t *>(out), + reinterpret_cast<const uint8_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(int8_t) / 1, + stride * sizeof(int8_t), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_2x2.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_2x2.hpp new file mode 100644 index 0000000000..02977ecf1e --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_2x2.hpp @@ -0,0 +1,378 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_8VL_2x2(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 2) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 8 * roundup<size_t>(height, 2) * get_vector_length<uint16_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x4\n" + "ptrue p4.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x28, %x[in]\n" + "mov x27, %x[width]\n" + "cnth x26, ALL, MUL #8\n" + "add x25, x28, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp x27, x26\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1h { z30.h }, p4/Z, [x28]\n" + "ld1h { z12.h }, p4/Z, [x28, #1, MUL VL]\n" + "mov x21, x22\n" + "add x22, x22, %x[out_stride]\n" + "ld1h { z31.h }, p4/Z, [x28, #2, MUL VL]\n" + "ld1h { z18.h }, p4/Z, [x28, #3, MUL VL]\n" + "mov x20, x22\n" + "sub x27, x27, x26\n" + "ld1h { z20.h }, p4/Z, [x25]\n" + "ld1h { z17.h }, p4/Z, [x25, #1, MUL VL]\n" + "zip1 z3.h, z30.h, z20.h\n" + "zip2 z21.h, z30.h, z20.h\n" + "ld1h { z26.h }, p4/Z, [x25, #2, MUL VL]\n" + "ld1h { z23.h }, p4/Z, [x25, #3, MUL VL]\n" + "zip1 z13.h, z12.h, z17.h\n" + "zip2 z0.h, z12.h, z17.h\n" + "ld1h { z2.h }, p4/Z, [x28, #4, MUL VL]\n" + "ld1h { z24.h }, p4/Z, [x28, #5, MUL VL]\n" + "zip1 z12.h, z31.h, z26.h\n" + "zip2 z14.h, z31.h, z26.h\n" + "ld1h { z17.h }, p4/Z, [x28, #6, MUL VL]\n" + "ld1h { z29.h }, p4/Z, [x28, #7, MUL VL]\n" + "zip1 z16.h, z18.h, z23.h\n" + "zip2 z15.h, z18.h, z23.h\n" + "ld1h { z9.h }, p4/Z, [x25, #4, MUL VL]\n" + "ld1h { z18.h }, p4/Z, [x25, #5, MUL VL]\n" + "zip1 z11.h, z2.h, z9.h\n" + "zip2 z5.h, z2.h, z9.h\n" + "ld1h { z7.h }, p4/Z, [x25, #6, MUL VL]\n" + "ld1h { z2.h }, p4/Z, [x25, #7, MUL VL]\n" + "zip1 z10.h, z24.h, z18.h\n" + "zip2 z6.h, z24.h, z18.h\n" + "ld1h { z19.h }, p4/Z, [x24]\n" + "ld1h { z18.h }, p4/Z, [x24, #1, MUL VL]\n" + "zip1 z9.h, z17.h, z7.h\n" + "zip2 z4.h, z17.h, z7.h\n" + "ld1h { z24.h }, p4/Z, [x24, #2, MUL VL]\n" + "ld1h { z22.h }, p4/Z, [x24, #3, MUL VL]\n" + "zip1 z7.h, z29.h, z2.h\n" + "zip2 z8.h, z29.h, z2.h\n" + "ld1h { z25.h }, p4/Z, [x24, #4, MUL VL]\n" + "ld1h { z17.h }, p4/Z, [x24, #5, MUL VL]\n" + "cmp x27, x26\n" + "addvl x28, x28, #8\n" + "ld1h { z2.h }, p4/Z, [x24, #6, MUL VL]\n" + "ld1h { z30.h }, p4/Z, [x24, #7, MUL VL]\n" + "addvl x25, x25, #8\n" + "addvl x24, x24, #8\n" + "ld1h { z20.h }, p4/Z, [x23]\n" + "ld1h { z27.h }, p4/Z, [x23, #1, MUL VL]\n" + "zip1 z31.h, z19.h, z20.h\n" + "zip2 z29.h, z19.h, z20.h\n" + "ld1h { z26.h }, p4/Z, [x23, #2, MUL VL]\n" + "ld1h { z23.h }, p4/Z, [x23, #3, MUL VL]\n" + "zip1 z28.h, z18.h, z27.h\n" + "zip2 z1.h, z18.h, z27.h\n" + "ld1h { z20.h }, p4/Z, [x23, #4, MUL VL]\n" + "ld1h { z19.h }, p4/Z, [x23, #5, MUL VL]\n" + "zip1 z27.h, z24.h, z26.h\n" + "zip2 z26.h, z24.h, z26.h\n" + "ld1h { z18.h }, p4/Z, [x23, #6, MUL VL]\n" + "ld1h { z24.h }, p4/Z, [x23, #7, MUL VL]\n" + "st1h { z3.h }, p4, [x21]\n" + "zip1 z3.h, z22.h, z23.h\n" + "st1h { z21.h }, p4, [x21, #1, MUL VL]\n" + "zip2 z22.h, z22.h, z23.h\n" + "addvl x23, x23, #8\n" + "zip1 z23.h, z25.h, z20.h\n" + "st1h { z13.h }, p4, [x21, #2, MUL VL]\n" + "zip2 z25.h, z25.h, z20.h\n" + "zip1 z21.h, z17.h, z19.h\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z0.h }, p4, [x21, #3, MUL VL]\n" + "zip2 z20.h, z17.h, z19.h\n" + "zip1 z19.h, z2.h, z18.h\n" + "st1h { z12.h }, p4, [x21, #4, MUL VL]\n" + "zip2 z18.h, z2.h, z18.h\n" + "zip1 z17.h, z30.h, z24.h\n" + "st1h { z14.h }, p4, [x21, #5, MUL VL]\n" + "zip2 z13.h, z30.h, z24.h\n" + "st1h { z16.h }, p4, [x21, #6, MUL VL]\n" + "st1h { z15.h }, p4, [x21, #7, MUL VL]\n" + "addvl x21, x21, #16\n" + "st1h { z31.h }, p4, [x21, #-8, MUL VL]\n" + "st1h { z29.h }, p4, [x21, #-7, MUL VL]\n" + "st1h { z28.h }, p4, [x21, #-6, MUL VL]\n" + "st1h { z1.h }, p4, [x21, #-5, MUL VL]\n" + "st1h { z27.h }, p4, [x21, #-4, MUL VL]\n" + "st1h { z26.h }, p4, [x21, #-3, MUL VL]\n" + "st1h { z3.h }, p4, [x21, #-2, MUL VL]\n" + "st1h { z22.h }, p4, [x21, #-1, MUL VL]\n" + "st1h { z11.h }, p4, [x20]\n" + "st1h { z5.h }, p4, [x20, #1, MUL VL]\n" + "st1h { z10.h }, p4, [x20, #2, MUL VL]\n" + "st1h { z6.h }, p4, [x20, #3, MUL VL]\n" + "st1h { z9.h }, p4, [x20, #4, MUL VL]\n" + "st1h { z4.h }, p4, [x20, #5, MUL VL]\n" + "st1h { z7.h }, p4, [x20, #6, MUL VL]\n" + "st1h { z8.h }, p4, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1h { z23.h }, p4, [x20, #-8, MUL VL]\n" + "st1h { z25.h }, p4, [x20, #-7, MUL VL]\n" + "st1h { z21.h }, p4, [x20, #-6, MUL VL]\n" + "st1h { z20.h }, p4, [x20, #-5, MUL VL]\n" + "st1h { z19.h }, p4, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p4, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p4, [x20, #-2, MUL VL]\n" + "st1h { z13.h }, p4, [x20, #-1, MUL VL]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x27, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x27\n" + "whilelt p3.h, XZR, x20\n" + "ld1h { z20.h }, p3/Z, [x28]\n" + "ld1h { z19.h }, p3/Z, [x25]\n" + "dech x20\n" + "whilelt p2.h, XZR, x20\n" + "ld1h { z18.h }, p2/Z, [x28, #1, MUL VL]\n" + "ld1h { z17.h }, p2/Z, [x25, #1, MUL VL]\n" + "dech x20\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z25.h }, p1/Z, [x28, #2, MUL VL]\n" + "ld1h { z16.h }, p1/Z, [x25, #2, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z0.h }, p0/Z, [x28, #3, MUL VL]\n" + "ld1h { z24.h }, p0/Z, [x25, #3, MUL VL]\n" + "mov x20, x22\n" + "decw x27, ALL, MUL #8\n" + "ld1h { z31.h }, p3/Z, [x24]\n" + "ld1h { z30.h }, p2/Z, [x24, #1, MUL VL]\n" + "ld1h { z29.h }, p1/Z, [x24, #2, MUL VL]\n" + "ld1h { z28.h }, p0/Z, [x24, #3, MUL VL]\n" + "zip1 z23.h, z20.h, z19.h\n" + "zip2 z22.h, z20.h, z19.h\n" + "ld1h { z21.h }, p3/Z, [x23]\n" + "ld1h { z27.h }, p2/Z, [x23, #1, MUL VL]\n" + "zip1 z20.h, z18.h, z17.h\n" + "zip2 z19.h, z18.h, z17.h\n" + "ld1h { z18.h }, p1/Z, [x23, #2, MUL VL]\n" + "ld1h { z26.h }, p0/Z, [x23, #3, MUL VL]\n" + "zip1 z17.h, z25.h, z16.h\n" + "zip2 z16.h, z25.h, z16.h\n" + "zip1 z25.h, z0.h, z24.h\n" + "zip2 z24.h, z0.h, z24.h\n" + "st1h { z23.h }, p4, [x20]\n" + "cmp x27, #0x0\n" + "st1h { z22.h }, p4, [x20, #1, MUL VL]\n" + "addvl x28, x28, #4\n" + "addvl x25, x25, #4\n" + "zip1 z23.h, z31.h, z21.h\n" + "st1h { z20.h }, p4, [x20, #2, MUL VL]\n" + "addvl x24, x24, #4\n" + "addvl x23, x23, #4\n" + "zip2 z22.h, z31.h, z21.h\n" + "st1h { z19.h }, p4, [x20, #3, MUL VL]\n" + "zip1 z21.h, z30.h, z27.h\n" + "zip2 z20.h, z30.h, z27.h\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z17.h }, p4, [x20, #4, MUL VL]\n" + "zip1 z19.h, z29.h, z18.h\n" + "zip2 z18.h, z29.h, z18.h\n" + "st1h { z16.h }, p4, [x20, #5, MUL VL]\n" + "zip1 z17.h, z28.h, z26.h\n" + "zip2 z16.h, z28.h, z26.h\n" + "st1h { z25.h }, p4, [x20, #6, MUL VL]\n" + "st1h { z24.h }, p4, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1h { z23.h }, p4, [x20, #-8, MUL VL]\n" + "st1h { z22.h }, p4, [x20, #-7, MUL VL]\n" + "st1h { z21.h }, p4, [x20, #-6, MUL VL]\n" + "st1h { z20.h }, p4, [x20, #-5, MUL VL]\n" + "st1h { z19.h }, p4, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p4, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p4, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p4, [x20, #-1, MUL VL]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x4\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x28, %x[in]\n" + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #8\n" + "add x25, x28, %x[in_stride]\n" + "cmp %x[height], #0x1\n" + "add %x[in], x25, %x[in_stride]\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x2\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1h { z17.h }, p4/Z, [x28]\n" + "ld1h { z20.h }, p4/Z, [x28, #1, MUL VL]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1h { z23.h }, p4/Z, [x28, #2, MUL VL]\n" + "ld1h { z19.h }, p4/Z, [x28, #3, MUL VL]\n" + "ld1h { z16.h }, p4/Z, [x25]\n" + "ld1h { z18.h }, p4/Z, [x25, #1, MUL VL]\n" + "zip1 z0.h, z17.h, z16.h\n" + "zip2 z22.h, z17.h, z16.h\n" + "ld1h { z17.h }, p4/Z, [x25, #2, MUL VL]\n" + "ld1h { z16.h }, p4/Z, [x25, #3, MUL VL]\n" + "zip1 z31.h, z20.h, z18.h\n" + "zip2 z30.h, z20.h, z18.h\n" + "ld1h { z21.h }, p4/Z, [x28, #4, MUL VL]\n" + "ld1h { z20.h }, p4/Z, [x28, #5, MUL VL]\n" + "zip1 z29.h, z23.h, z17.h\n" + "zip2 z28.h, z23.h, z17.h\n" + "ld1h { z27.h }, p4/Z, [x28, #6, MUL VL]\n" + "ld1h { z26.h }, p4/Z, [x28, #7, MUL VL]\n" + "zip1 z25.h, z19.h, z16.h\n" + "zip2 z24.h, z19.h, z16.h\n" + "ld1h { z19.h }, p4/Z, [x25, #4, MUL VL]\n" + "ld1h { z18.h }, p4/Z, [x25, #5, MUL VL]\n" + "addvl x28, x28, #8\n" + "zip1 z23.h, z21.h, z19.h\n" + "ld1h { z17.h }, p4/Z, [x25, #6, MUL VL]\n" + "ld1h { z16.h }, p4/Z, [x25, #7, MUL VL]\n" + "st1h { z0.h }, p4, [x22]\n" + "addvl x25, x25, #8\n" + "st1h { z22.h }, p4, [x22, #1, MUL VL]\n" + "zip2 z22.h, z21.h, z19.h\n" + "zip1 z21.h, z20.h, z18.h\n" + "st1h { z31.h }, p4, [x22, #2, MUL VL]\n" + "zip2 z20.h, z20.h, z18.h\n" + "zip1 z19.h, z27.h, z17.h\n" + "st1h { z30.h }, p4, [x22, #3, MUL VL]\n" + "zip2 z18.h, z27.h, z17.h\n" + "zip1 z17.h, z26.h, z16.h\n" + "st1h { z29.h }, p4, [x22, #4, MUL VL]\n" + "zip2 z16.h, z26.h, z16.h\n" + "st1h { z28.h }, p4, [x22, #5, MUL VL]\n" + "st1h { z25.h }, p4, [x22, #6, MUL VL]\n" + "st1h { z24.h }, p4, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z23.h }, p4, [x22]\n" + "st1h { z22.h }, p4, [x22, #1, MUL VL]\n" + "st1h { z21.h }, p4, [x22, #2, MUL VL]\n" + "st1h { z20.h }, p4, [x22, #3, MUL VL]\n" + "st1h { z19.h }, p4, [x22, #4, MUL VL]\n" + "st1h { z18.h }, p4, [x22, #5, MUL VL]\n" + "st1h { z17.h }, p4, [x22, #6, MUL VL]\n" + "st1h { z16.h }, p4, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z22.h }, p0/Z, [x28]\n" + "ld1h { z21.h }, p0/Z, [x25]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z20.h }, p0/Z, [x28, #1, MUL VL]\n" + "ld1h { z19.h }, p0/Z, [x25, #1, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z18.h }, p0/Z, [x28, #2, MUL VL]\n" + "ld1h { z17.h }, p0/Z, [x25, #2, MUL VL]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z24.h }, p0/Z, [x28, #3, MUL VL]\n" + "ld1h { z23.h }, p0/Z, [x25, #3, MUL VL]\n" + "decw x21, ALL, MUL #8\n" + "cmp x21, #0x0\n" + "zip1 z16.h, z22.h, z21.h\n" + "zip2 z22.h, z22.h, z21.h\n" + "addvl x28, x28, #4\n" + "addvl x25, x25, #4\n" + "zip1 z21.h, z20.h, z19.h\n" + "zip2 z20.h, z20.h, z19.h\n" + "zip1 z19.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "st1h { z16.h }, p4, [x22]\n" + "zip1 z17.h, z24.h, z23.h\n" + "zip2 z16.h, z24.h, z23.h\n" + "st1h { z22.h }, p4, [x22, #1, MUL VL]\n" + "st1h { z21.h }, p4, [x22, #2, MUL VL]\n" + "st1h { z20.h }, p4, [x22, #3, MUL VL]\n" + "st1h { z19.h }, p4, [x22, #4, MUL VL]\n" + "st1h { z18.h }, p4, [x22, #5, MUL VL]\n" + "st1h { z17.h }, p4, [x22, #6, MUL VL]\n" + "st1h { z16.h }, p4, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #8\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<8, 2, true, VLType::SVE>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_8VL_2x2( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_2x4.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_2x4.hpp new file mode 100644 index 0000000000..34799c60a6 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_2x4.hpp @@ -0,0 +1,463 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_8VL_2x4(uint16_t *out, const uint16_t *in, size_t width, size_t in_stride, size_t height) +{ + uint16_t *pad_row = reinterpret_cast<uint16_t *>(alloca(width * sizeof(uint16_t))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(uint16_t)); + } + + size_t out_stride = 8 * roundup<size_t>(height, 4) * get_vector_length<uint32_t>(); + + __asm__ __volatile__( + "cmp %x[height], #0x8\n" + "ptrue p2.b\n" + "blt 6f\n" + "1:" // Main row loop: Head + "mov x12, %x[in]\n" + "add x11, x12, %x[in_stride]\n" + "add x10, x11, %x[in_stride]\n" + "add x9, x10, %x[in_stride]\n" + "add x28, x9, %x[in_stride]\n" + "mov x27, %x[width]\n" + "cnth x26, ALL, MUL #4\n" + "add x25, x28, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "add x23, x24, %x[in_stride]\n" + "cmp x27, x26\n" + "add %x[in], x23, %x[in_stride]\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x8\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1h { z21.h }, p2/Z, [x12]\n" + "ld1h { z17.h }, p2/Z, [x12, #1, MUL VL]\n" + "mov x21, x22\n" + "add x22, x22, %x[out_stride]\n" + "ld1h { z31.h }, p2/Z, [x11]\n" + "ld1h { z5.h }, p2/Z, [x11, #1, MUL VL]\n" + "mov x20, x22\n" + "sub x27, x27, x26\n" + "ld1h { z15.h }, p2/Z, [x10]\n" + "ld1h { z28.h }, p2/Z, [x10, #1, MUL VL]\n" + "zip1 z24.h, z21.h, z15.h\n" + "zip2 z29.h, z21.h, z15.h\n" + "ld1h { z6.h }, p2/Z, [x9]\n" + "ld1h { z4.h }, p2/Z, [x9, #1, MUL VL]\n" + "zip1 z16.h, z31.h, z6.h\n" + "zip2 z18.h, z31.h, z6.h\n" + "ld1h { z3.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z25.h }, p2/Z, [x12, #3, MUL VL]\n" + "zip1 z20.h, z17.h, z28.h\n" + "zip1 z7.h, z5.h, z4.h\n" + "ld1h { z27.h }, p2/Z, [x11, #2, MUL VL]\n" + "ld1h { z22.h }, p2/Z, [x11, #3, MUL VL]\n" + "zip2 z2.h, z17.h, z28.h\n" + "zip2 z19.h, z5.h, z4.h\n" + "ld1h { z28.h }, p2/Z, [x10, #2, MUL VL]\n" + "ld1h { z17.h }, p2/Z, [x10, #3, MUL VL]\n" + "zip1 z21.h, z24.h, z16.h\n" + "zip2 z24.h, z24.h, z16.h\n" + "ld1h { z5.h }, p2/Z, [x9, #2, MUL VL]\n" + "ld1h { z1.h }, p2/Z, [x9, #3, MUL VL]\n" + "zip1 z14.h, z29.h, z18.h\n" + "zip2 z12.h, z29.h, z18.h\n" + "ld1h { z18.h }, p2/Z, [x28]\n" + "ld1h { z31.h }, p2/Z, [x28, #1, MUL VL]\n" + "zip1 z11.h, z20.h, z7.h\n" + "zip2 z13.h, z20.h, z7.h\n" + "ld1h { z4.h }, p2/Z, [x25]\n" + "ld1h { z26.h }, p2/Z, [x25, #1, MUL VL]\n" + "zip1 z15.h, z2.h, z19.h\n" + "zip2 z10.h, z2.h, z19.h\n" + "ld1h { z16.h }, p2/Z, [x24]\n" + "ld1h { z30.h }, p2/Z, [x24, #1, MUL VL]\n" + "zip1 z19.h, z18.h, z16.h\n" + "zip2 z18.h, z18.h, z16.h\n" + "ld1h { z8.h }, p2/Z, [x23]\n" + "ld1h { z29.h }, p2/Z, [x23, #1, MUL VL]\n" + "zip1 z20.h, z4.h, z8.h\n" + "zip2 z0.h, z4.h, z8.h\n" + "ld1h { z6.h }, p2/Z, [x28, #2, MUL VL]\n" + "ld1h { z8.h }, p2/Z, [x28, #3, MUL VL]\n" + "zip1 z23.h, z31.h, z30.h\n" + "zip1 z16.h, z26.h, z29.h\n" + "ld1h { z9.h }, p2/Z, [x25, #2, MUL VL]\n" + "ld1h { z7.h }, p2/Z, [x25, #3, MUL VL]\n" + "zip2 z31.h, z31.h, z30.h\n" + "zip2 z30.h, z26.h, z29.h\n" + "ld1h { z2.h }, p2/Z, [x24, #2, MUL VL]\n" + "ld1h { z26.h }, p2/Z, [x24, #3, MUL VL]\n" + "zip1 z29.h, z3.h, z28.h\n" + "zip1 z4.h, z27.h, z5.h\n" + "zip2 z28.h, z3.h, z28.h\n" + "ld1h { z3.h }, p2/Z, [x23, #2, MUL VL]\n" + "zip2 z27.h, z27.h, z5.h\n" + "ld1h { z5.h }, p2/Z, [x23, #3, MUL VL]\n" + "st1h { z21.h }, p2, [x21]\n" + "zip1 z21.h, z25.h, z17.h\n" + "zip2 z25.h, z25.h, z17.h\n" + "cmp x27, x26\n" + "st1h { z24.h }, p2, [x21, #1, MUL VL]\n" + "zip1 z24.h, z22.h, z1.h\n" + "zip2 z22.h, z22.h, z1.h\n" + "addvl x12, x12, #4\n" + "st1h { z14.h }, p2, [x21, #2, MUL VL]\n" + "zip1 z17.h, z19.h, z20.h\n" + "zip2 z20.h, z19.h, z20.h\n" + "addvl x11, x11, #4\n" + "st1h { z12.h }, p2, [x21, #3, MUL VL]\n" + "zip1 z19.h, z18.h, z0.h\n" + "zip2 z18.h, z18.h, z0.h\n" + "addvl x10, x10, #4\n" + "st1h { z11.h }, p2, [x21, #4, MUL VL]\n" + "zip1 z14.h, z23.h, z16.h\n" + "zip2 z16.h, z23.h, z16.h\n" + "addvl x9, x9, #4\n" + "st1h { z13.h }, p2, [x21, #5, MUL VL]\n" + "zip1 z23.h, z31.h, z30.h\n" + "zip2 z1.h, z31.h, z30.h\n" + "addvl x28, x28, #4\n" + "st1h { z15.h }, p2, [x21, #6, MUL VL]\n" + "zip1 z0.h, z29.h, z4.h\n" + "zip2 z31.h, z29.h, z4.h\n" + "addvl x25, x25, #4\n" + "st1h { z10.h }, p2, [x21, #7, MUL VL]\n" + "addvl x21, x21, #16\n" + "zip1 z30.h, z28.h, z27.h\n" + "zip2 z29.h, z28.h, z27.h\n" + "st1h { z17.h }, p2, [x21, #-8, MUL VL]\n" + "zip1 z13.h, z21.h, z24.h\n" + "zip2 z27.h, z21.h, z24.h\n" + "addvl x24, x24, #4\n" + "st1h { z20.h }, p2, [x21, #-7, MUL VL]\n" + "zip1 z28.h, z25.h, z22.h\n" + "zip2 z25.h, z25.h, z22.h\n" + "addvl x23, x23, #4\n" + "st1h { z19.h }, p2, [x21, #-6, MUL VL]\n" + "zip1 z22.h, z6.h, z2.h\n" + "zip1 z21.h, z9.h, z3.h\n" + "add x22, x22, %x[out_stride]\n" + "st1h { z18.h }, p2, [x21, #-5, MUL VL]\n" + "zip2 z20.h, z6.h, z2.h\n" + "zip2 z19.h, z9.h, z3.h\n" + "st1h { z14.h }, p2, [x21, #-4, MUL VL]\n" + "zip1 z18.h, z8.h, z26.h\n" + "zip1 z17.h, z7.h, z5.h\n" + "st1h { z16.h }, p2, [x21, #-3, MUL VL]\n" + "zip2 z24.h, z8.h, z26.h\n" + "zip2 z16.h, z7.h, z5.h\n" + "st1h { z23.h }, p2, [x21, #-2, MUL VL]\n" + "zip1 z23.h, z22.h, z21.h\n" + "zip2 z22.h, z22.h, z21.h\n" + "st1h { z1.h }, p2, [x21, #-1, MUL VL]\n" + "zip1 z21.h, z20.h, z19.h\n" + "zip2 z20.h, z20.h, z19.h\n" + "st1h { z0.h }, p2, [x20]\n" + "zip1 z19.h, z18.h, z17.h\n" + "zip2 z18.h, z18.h, z17.h\n" + "st1h { z31.h }, p2, [x20, #1, MUL VL]\n" + "zip1 z17.h, z24.h, z16.h\n" + "zip2 z16.h, z24.h, z16.h\n" + "st1h { z30.h }, p2, [x20, #2, MUL VL]\n" + "st1h { z29.h }, p2, [x20, #3, MUL VL]\n" + "st1h { z13.h }, p2, [x20, #4, MUL VL]\n" + "st1h { z27.h }, p2, [x20, #5, MUL VL]\n" + "st1h { z28.h }, p2, [x20, #6, MUL VL]\n" + "st1h { z25.h }, p2, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "st1h { z23.h }, p2, [x20, #-8, MUL VL]\n" + "st1h { z22.h }, p2, [x20, #-7, MUL VL]\n" + "st1h { z21.h }, p2, [x20, #-6, MUL VL]\n" + "st1h { z20.h }, p2, [x20, #-5, MUL VL]\n" + "st1h { z19.h }, p2, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p2, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p2, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p2, [x20, #-1, MUL VL]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x27, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x27\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z17.h }, p1/Z, [x12]\n" + "ld1h { z19.h }, p1/Z, [x11]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z24.h }, p0/Z, [x12, #1, MUL VL]\n" + "ld1h { z23.h }, p0/Z, [x11, #1, MUL VL]\n" + "ld1h { z16.h }, p1/Z, [x10]\n" + "ld1h { z20.h }, p0/Z, [x10, #1, MUL VL]\n" + "zip1 z1.h, z17.h, z16.h\n" + "zip2 z22.h, z17.h, z16.h\n" + "ld1h { z18.h }, p1/Z, [x9]\n" + "ld1h { z17.h }, p0/Z, [x9, #1, MUL VL]\n" + "zip1 z16.h, z19.h, z18.h\n" + "zip2 z19.h, z19.h, z18.h\n" + "ld1h { z0.h }, p1/Z, [x28]\n" + "ld1h { z31.h }, p0/Z, [x28, #1, MUL VL]\n" + "zip1 z25.h, z24.h, z20.h\n" + "zip1 z21.h, z23.h, z17.h\n" + "ld1h { z30.h }, p1/Z, [x25]\n" + "ld1h { z29.h }, p0/Z, [x25, #1, MUL VL]\n" + "zip2 z28.h, z24.h, z20.h\n" + "zip2 z24.h, z23.h, z17.h\n" + "ld1h { z20.h }, p1/Z, [x24]\n" + "ld1h { z27.h }, p0/Z, [x24, #1, MUL VL]\n" + "mov x20, x22\n" + "decd x27, ALL, MUL #8\n" + "ld1h { z23.h }, p1/Z, [x23]\n" + "ld1h { z26.h }, p0/Z, [x23, #1, MUL VL]\n" + "zip1 z18.h, z1.h, z16.h\n" + "zip2 z17.h, z1.h, z16.h\n" + "zip1 z16.h, z22.h, z19.h\n" + "zip2 z19.h, z22.h, z19.h\n" + "st1h { z18.h }, p2, [x20]\n" + "cmp x27, #0x0\n" + "zip1 z22.h, z25.h, z21.h\n" + "zip2 z21.h, z25.h, z21.h\n" + "st1h { z17.h }, p2, [x20, #1, MUL VL]\n" + "addvl x12, x12, #2\n" + "zip1 z25.h, z28.h, z24.h\n" + "zip2 z18.h, z28.h, z24.h\n" + "st1h { z16.h }, p2, [x20, #2, MUL VL]\n" + "addvl x11, x11, #2\n" + "zip1 z17.h, z0.h, z20.h\n" + "zip1 z16.h, z30.h, z23.h\n" + "st1h { z19.h }, p2, [x20, #3, MUL VL]\n" + "addvl x10, x10, #2\n" + "zip2 z20.h, z0.h, z20.h\n" + "zip2 z19.h, z30.h, z23.h\n" + "st1h { z22.h }, p2, [x20, #4, MUL VL]\n" + "addvl x9, x9, #2\n" + "zip1 z24.h, z31.h, z27.h\n" + "zip1 z23.h, z29.h, z26.h\n" + "st1h { z21.h }, p2, [x20, #5, MUL VL]\n" + "addvl x28, x28, #2\n" + "zip2 z22.h, z31.h, z27.h\n" + "zip2 z21.h, z29.h, z26.h\n" + "st1h { z25.h }, p2, [x20, #6, MUL VL]\n" + "addvl x25, x25, #2\n" + "st1h { z18.h }, p2, [x20, #7, MUL VL]\n" + "addvl x20, x20, #16\n" + "addvl x24, x24, #2\n" + "zip1 z18.h, z17.h, z16.h\n" + "addvl x23, x23, #2\n" + "zip2 z17.h, z17.h, z16.h\n" + "zip1 z16.h, z20.h, z19.h\n" + "st1h { z18.h }, p2, [x20, #-8, MUL VL]\n" + "zip2 z20.h, z20.h, z19.h\n" + "zip1 z19.h, z24.h, z23.h\n" + "st1h { z17.h }, p2, [x20, #-7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip2 z18.h, z24.h, z23.h\n" + "zip1 z17.h, z22.h, z21.h\n" + "st1h { z16.h }, p2, [x20, #-6, MUL VL]\n" + "zip2 z16.h, z22.h, z21.h\n" + "st1h { z20.h }, p2, [x20, #-5, MUL VL]\n" + "st1h { z19.h }, p2, [x20, #-4, MUL VL]\n" + "st1h { z18.h }, p2, [x20, #-3, MUL VL]\n" + "st1h { z17.h }, p2, [x20, #-2, MUL VL]\n" + "st1h { z16.h }, p2, [x20, #-1, MUL VL]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x8\n" + "addvl %x[out], %x[out], #16\n" + "bge 1b\n" + "cbz %x[height], 12f\n" + "6:" // Main loop skip + "7:" // Tail row loop: Head + "mov x12, %x[in]\n" + "add x11, x12, %x[in_stride]\n" + "add x10, x11, %x[in_stride]\n" + "mov x21, %x[width]\n" + "cnth x20, ALL, MUL #4\n" + "add x9, x10, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x9, %x[in_stride]\n" + "csel x9, x9, %x[pad_row], GT\n" + "csel x10, x10, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x11, x11, %x[pad_row], GT\n" + "cmp x21, x20\n" + "mov x22, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 9f\n" + "8:" // Tail row loop: Unroll column loop + "ld1h { z17.h }, p2/Z, [x12]\n" + "ld1h { z22.h }, p2/Z, [x12, #1, MUL VL]\n" + "sub x21, x21, x20\n" + "cmp x21, x20\n" + "ld1h { z19.h }, p2/Z, [x11]\n" + "ld1h { z21.h }, p2/Z, [x11, #1, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x10]\n" + "ld1h { z18.h }, p2/Z, [x10, #1, MUL VL]\n" + "zip1 z4.h, z17.h, z16.h\n" + "zip2 z3.h, z17.h, z16.h\n" + "ld1h { z17.h }, p2/Z, [x9]\n" + "ld1h { z16.h }, p2/Z, [x9, #1, MUL VL]\n" + "zip1 z2.h, z19.h, z17.h\n" + "zip2 z1.h, z19.h, z17.h\n" + "ld1h { z17.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z24.h }, p2/Z, [x12, #3, MUL VL]\n" + "zip1 z0.h, z22.h, z18.h\n" + "zip1 z31.h, z21.h, z16.h\n" + "ld1h { z20.h }, p2/Z, [x11, #2, MUL VL]\n" + "ld1h { z19.h }, p2/Z, [x11, #3, MUL VL]\n" + "zip2 z30.h, z22.h, z18.h\n" + "zip2 z23.h, z21.h, z16.h\n" + "ld1h { z16.h }, p2/Z, [x10, #2, MUL VL]\n" + "ld1h { z18.h }, p2/Z, [x10, #3, MUL VL]\n" + "zip1 z22.h, z17.h, z16.h\n" + "zip2 z29.h, z17.h, z16.h\n" + "ld1h { z17.h }, p2/Z, [x9, #2, MUL VL]\n" + "ld1h { z16.h }, p2/Z, [x9, #3, MUL VL]\n" + "zip1 z21.h, z20.h, z17.h\n" + "zip2 z28.h, z20.h, z17.h\n" + "zip1 z27.h, z24.h, z18.h\n" + "zip1 z26.h, z19.h, z16.h\n" + "addvl x12, x12, #4\n" + "addvl x11, x11, #4\n" + "zip2 z25.h, z24.h, z18.h\n" + "zip2 z24.h, z19.h, z16.h\n" + "addvl x10, x10, #4\n" + "addvl x9, x9, #4\n" + "zip1 z16.h, z4.h, z2.h\n" + "zip2 z17.h, z4.h, z2.h\n" + "st1h { z16.h }, p2, [x22]\n" + "zip1 z16.h, z3.h, z1.h\n" + "zip2 z20.h, z3.h, z1.h\n" + "st1h { z17.h }, p2, [x22, #1, MUL VL]\n" + "zip1 z19.h, z0.h, z31.h\n" + "zip2 z18.h, z0.h, z31.h\n" + "st1h { z16.h }, p2, [x22, #2, MUL VL]\n" + "zip1 z17.h, z30.h, z23.h\n" + "zip2 z16.h, z30.h, z23.h\n" + "st1h { z20.h }, p2, [x22, #3, MUL VL]\n" + "st1h { z19.h }, p2, [x22, #4, MUL VL]\n" + "zip1 z23.h, z22.h, z21.h\n" + "zip2 z22.h, z22.h, z21.h\n" + "st1h { z18.h }, p2, [x22, #5, MUL VL]\n" + "zip1 z21.h, z29.h, z28.h\n" + "zip2 z20.h, z29.h, z28.h\n" + "st1h { z17.h }, p2, [x22, #6, MUL VL]\n" + "zip1 z19.h, z27.h, z26.h\n" + "zip2 z18.h, z27.h, z26.h\n" + "st1h { z16.h }, p2, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "zip1 z17.h, z25.h, z24.h\n" + "zip2 z16.h, z25.h, z24.h\n" + "st1h { z23.h }, p2, [x22]\n" + "st1h { z22.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z21.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z20.h }, p2, [x22, #3, MUL VL]\n" + "st1h { z19.h }, p2, [x22, #4, MUL VL]\n" + "st1h { z18.h }, p2, [x22, #5, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #6, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bge 8b\n" + "9:" // Tail row loop: Unroll column loop skip + "cbz x21, 11f\n" + "10:" // Tail row loop: Column loop + "mov x20, x21\n" + "whilelt p1.h, XZR, x20\n" + "ld1h { z23.h }, p1/Z, [x12]\n" + "ld1h { z22.h }, p1/Z, [x11]\n" + "dech x20\n" + "whilelt p0.h, XZR, x20\n" + "ld1h { z21.h }, p0/Z, [x12, #1, MUL VL]\n" + "ld1h { z25.h }, p0/Z, [x11, #1, MUL VL]\n" + "ld1h { z19.h }, p1/Z, [x10]\n" + "ld1h { z20.h }, p0/Z, [x10, #1, MUL VL]\n" + "decd x21, ALL, MUL #8\n" + "zip1 z24.h, z23.h, z19.h\n" + "ld1h { z18.h }, p1/Z, [x9]\n" + "ld1h { z16.h }, p0/Z, [x9, #1, MUL VL]\n" + "zip1 z17.h, z22.h, z18.h\n" + "zip2 z23.h, z23.h, z19.h\n" + "zip2 z19.h, z22.h, z18.h\n" + "zip1 z22.h, z21.h, z20.h\n" + "cmp x21, #0x0\n" + "addvl x12, x12, #2\n" + "zip1 z18.h, z25.h, z16.h\n" + "zip2 z21.h, z21.h, z20.h\n" + "addvl x11, x11, #2\n" + "addvl x10, x10, #2\n" + "zip2 z20.h, z25.h, z16.h\n" + "addvl x9, x9, #2\n" + "zip1 z16.h, z24.h, z17.h\n" + "st1h { z16.h }, p2, [x22]\n" + "zip2 z16.h, z24.h, z17.h\n" + "zip1 z17.h, z23.h, z19.h\n" + "st1h { z16.h }, p2, [x22, #1, MUL VL]\n" + "zip2 z16.h, z23.h, z19.h\n" + "zip1 z19.h, z22.h, z18.h\n" + "st1h { z17.h }, p2, [x22, #2, MUL VL]\n" + "zip2 z18.h, z22.h, z18.h\n" + "zip1 z17.h, z21.h, z20.h\n" + "st1h { z16.h }, p2, [x22, #3, MUL VL]\n" + "zip2 z16.h, z21.h, z20.h\n" + "st1h { z19.h }, p2, [x22, #4, MUL VL]\n" + "st1h { z18.h }, p2, [x22, #5, MUL VL]\n" + "st1h { z17.h }, p2, [x22, #6, MUL VL]\n" + "st1h { z16.h }, p2, [x22, #7, MUL VL]\n" + "add x22, x22, %x[out_stride]\n" + "bgt 10b\n" + "11:" // Tail row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #8\n" + "bge 7b\n" + "12:" // Done + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "x9", "x10", "x11", "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace + +template<> +void Transform<8, 4, true, VLType::SVE>( + bfloat16 *out, const bfloat16 *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_8VL_2x4( + reinterpret_cast<uint16_t *>(out), + reinterpret_cast<const uint16_t *>(in + k0 * stride + x0), + (xmax-x0) * sizeof(bfloat16) / 2, + stride * sizeof(bfloat16), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_2x4_fp32bf16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_2x4_fp32bf16.hpp new file mode 100644 index 0000000000..5a48e579ae --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/sve_transpose_interleave_8VL_2x4_fp32bf16.hpp @@ -0,0 +1,282 @@ +/* + * Copyright (c) 2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#if defined(ARM_COMPUTE_ENABLE_SVE) + +namespace { + +void sve_transpose_interleave_8VL_2x4_fp32bf16(bfloat16 *out, const float *in, size_t width, size_t in_stride, size_t height) +{ + float *pad_row = reinterpret_cast<float *>(alloca(width * sizeof(float))); + + if (height % 4) { + memset(pad_row, 0, width * sizeof(float)); + } + + size_t out_stride = 8 * roundup<size_t>(height, 4) * get_vector_length<uint32_t>(); + + __asm__ __volatile__( + "ptrue p4.b\n" + "1:" // Main row loop: Head + "mov x26, %x[in]\n" + "add x25, x26, %x[in_stride]\n" + "add x24, x25, %x[in_stride]\n" + "mov x23, %x[width]\n" + "cnth x20, ALL, MUL #4\n" + "add x22, x24, %x[in_stride]\n" + "cmp %x[height], #0x3\n" + "add %x[in], x22, %x[in_stride]\n" + "csel x22, x22, %x[pad_row], GT\n" + "csel x24, x24, %x[pad_row], GE\n" + "cmp %x[height], #0x1\n" + "csel x25, x25, %x[pad_row], GT\n" + "cmp x23, x20\n" + "mov x21, %x[out]\n" + "sub %x[height], %x[height], #0x4\n" + "blt 3f\n" + "2:" // Main row loop: Unroll column loop + "ld1w { z19.s }, p4/Z, [x26]\n" + "ld1w { z18.s }, p4/Z, [x26, #1, MUL VL]\n" + "sub x23, x23, x20\n" + "cmp x23, x20\n" + "ld1w { z20.s }, p4/Z, [x26, #2, MUL VL]\n" + "ld1w { z24.s }, p4/Z, [x26, #3, MUL VL]\n" + "ld1w { z23.s }, p4/Z, [x24]\n" + "ld1w { z17.s }, p4/Z, [x24, #1, MUL VL]\n" + "zip1 z22.s, z19.s, z23.s\n" + "zip2 z21.s, z19.s, z23.s\n" + "ld1w { z31.s }, p4/Z, [x24, #2, MUL VL]\n" + "ld1w { z16.s }, p4/Z, [x24, #3, MUL VL]\n" + "zip1 z9.s, z18.s, z17.s\n" + "zip2 z7.s, z18.s, z17.s\n" + "ld1w { z19.s }, p4/Z, [x26, #4, MUL VL]\n" + "ld1w { z18.s }, p4/Z, [x26, #5, MUL VL]\n" + "zip1 z6.s, z20.s, z31.s\n" + "zip2 z5.s, z20.s, z31.s\n" + "ld1w { z15.s }, p4/Z, [x26, #6, MUL VL]\n" + "ld1w { z20.s }, p4/Z, [x26, #7, MUL VL]\n" + "zip1 z3.s, z24.s, z16.s\n" + "zip2 z2.s, z24.s, z16.s\n" + "ld1w { z16.s }, p4/Z, [x24, #4, MUL VL]\n" + "ld1w { z17.s }, p4/Z, [x24, #5, MUL VL]\n" + "zip1 z1.s, z19.s, z16.s\n" + "zip2 z0.s, z19.s, z16.s\n" + "ld1w { z16.s }, p4/Z, [x24, #6, MUL VL]\n" + "ld1w { z19.s }, p4/Z, [x24, #7, MUL VL]\n" + "zip1 z31.s, z18.s, z17.s\n" + "zip2 z30.s, z18.s, z17.s\n" + "ld1w { z18.s }, p4/Z, [x25]\n" + "ld1w { z17.s }, p4/Z, [x25, #1, MUL VL]\n" + "zip1 z29.s, z15.s, z16.s\n" + "zip2 z28.s, z15.s, z16.s\n" + "ld1w { z16.s }, p4/Z, [x25, #2, MUL VL]\n" + "ld1w { z23.s }, p4/Z, [x25, #3, MUL VL]\n" + "zip1 z27.s, z20.s, z19.s\n" + "zip2 z26.s, z20.s, z19.s\n" + "ld1w { z11.s }, p4/Z, [x22]\n" + "ld1w { z8.s }, p4/Z, [x22, #1, MUL VL]\n" + ".inst 0x658ab2d8 // bfcvt z24.h, p4/M, z22.s\n" + "zip1 z25.s, z18.s, z11.s\n" + "ld1w { z4.s }, p4/Z, [x22, #2, MUL VL]\n" + "ld1w { z22.s }, p4/Z, [x22, #3, MUL VL]\n" + ".inst 0x658ab2af // bfcvt z15.h, p4/M, z21.s\n" + "zip2 z14.s, z18.s, z11.s\n" + "ld1w { z21.s }, p4/Z, [x25, #4, MUL VL]\n" + "ld1w { z20.s }, p4/Z, [x25, #5, MUL VL]\n" + ".inst 0x658ab12d // bfcvt z13.h, p4/M, z9.s\n" + "zip1 z12.s, z17.s, z8.s\n" + "ld1w { z11.s }, p4/Z, [x25, #6, MUL VL]\n" + "ld1w { z10.s }, p4/Z, [x25, #7, MUL VL]\n" + ".inst 0x658ab0e9 // bfcvt z9.h, p4/M, z7.s\n" + "zip2 z8.s, z17.s, z8.s\n" + "ld1w { z19.s }, p4/Z, [x22, #4, MUL VL]\n" + "ld1w { z18.s }, p4/Z, [x22, #5, MUL VL]\n" + ".inst 0x658ab0c7 // bfcvt z7.h, p4/M, z6.s\n" + "zip1 z6.s, z16.s, z4.s\n" + "ld1w { z17.s }, p4/Z, [x22, #6, MUL VL]\n" + ".inst 0x658ab0a5 // bfcvt z5.h, p4/M, z5.s\n" + "zip2 z4.s, z16.s, z4.s\n" + "ld1w { z16.s }, p4/Z, [x22, #7, MUL VL]\n" + ".inst 0x658ab063 // bfcvt z3.h, p4/M, z3.s\n" + ".inst 0x658ab042 // bfcvt z2.h, p4/M, z2.s\n" + "addvl x26, x26, #8\n" + "addvl x25, x25, #8\n" + ".inst 0x658ab021 // bfcvt z1.h, p4/M, z1.s\n" + ".inst 0x658ab000 // bfcvt z0.h, p4/M, z0.s\n" + "addvl x24, x24, #8\n" + "addvl x22, x22, #8\n" + ".inst 0x658ab3ff // bfcvt z31.h, p4/M, z31.s\n" + ".inst 0x658ab3de // bfcvt z30.h, p4/M, z30.s\n" + ".inst 0x658ab3bd // bfcvt z29.h, p4/M, z29.s\n" + ".inst 0x658ab39c // bfcvt z28.h, p4/M, z28.s\n" + ".inst 0x658ab37b // bfcvt z27.h, p4/M, z27.s\n" + ".inst 0x658ab35a // bfcvt z26.h, p4/M, z26.s\n" + ".inst 0x648ab338 // bfcvtnt z24.h, p4/M, z25.s\n" + "zip1 z25.s, z23.s, z22.s\n" + "st1h { z24.h }, p4, [x21]\n" + "zip2 z24.s, z23.s, z22.s\n" + "zip1 z23.s, z21.s, z19.s\n" + "zip2 z22.s, z21.s, z19.s\n" + "zip1 z21.s, z20.s, z18.s\n" + "zip2 z20.s, z20.s, z18.s\n" + "zip1 z19.s, z11.s, z17.s\n" + "zip2 z18.s, z11.s, z17.s\n" + "zip1 z17.s, z10.s, z16.s\n" + "zip2 z16.s, z10.s, z16.s\n" + ".inst 0x648ab1cf // bfcvtnt z15.h, p4/M, z14.s\n" + "st1h { z15.h }, p4, [x21, #1, MUL VL]\n" + ".inst 0x648ab18d // bfcvtnt z13.h, p4/M, z12.s\n" + ".inst 0x648ab109 // bfcvtnt z9.h, p4/M, z8.s\n" + "st1h { z13.h }, p4, [x21, #2, MUL VL]\n" + ".inst 0x648ab0c7 // bfcvtnt z7.h, p4/M, z6.s\n" + ".inst 0x648ab085 // bfcvtnt z5.h, p4/M, z4.s\n" + "st1h { z9.h }, p4, [x21, #3, MUL VL]\n" + ".inst 0x648ab323 // bfcvtnt z3.h, p4/M, z25.s\n" + ".inst 0x648ab302 // bfcvtnt z2.h, p4/M, z24.s\n" + "st1h { z7.h }, p4, [x21, #4, MUL VL]\n" + "st1h { z5.h }, p4, [x21, #5, MUL VL]\n" + ".inst 0x648ab2e1 // bfcvtnt z1.h, p4/M, z23.s\n" + ".inst 0x648ab2c0 // bfcvtnt z0.h, p4/M, z22.s\n" + "st1h { z3.h }, p4, [x21, #6, MUL VL]\n" + ".inst 0x648ab2bf // bfcvtnt z31.h, p4/M, z21.s\n" + ".inst 0x648ab29e // bfcvtnt z30.h, p4/M, z20.s\n" + "st1h { z2.h }, p4, [x21, #7, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + ".inst 0x648ab27d // bfcvtnt z29.h, p4/M, z19.s\n" + ".inst 0x648ab25c // bfcvtnt z28.h, p4/M, z18.s\n" + ".inst 0x648ab23b // bfcvtnt z27.h, p4/M, z17.s\n" + ".inst 0x648ab21a // bfcvtnt z26.h, p4/M, z16.s\n" + "st1h { z1.h }, p4, [x21]\n" + "st1h { z0.h }, p4, [x21, #1, MUL VL]\n" + "st1h { z31.h }, p4, [x21, #2, MUL VL]\n" + "st1h { z30.h }, p4, [x21, #3, MUL VL]\n" + "st1h { z29.h }, p4, [x21, #4, MUL VL]\n" + "st1h { z28.h }, p4, [x21, #5, MUL VL]\n" + "st1h { z27.h }, p4, [x21, #6, MUL VL]\n" + "st1h { z26.h }, p4, [x21, #7, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + "bge 2b\n" + "3:" // Main row loop: Unroll column loop skip + "cbz x23, 5f\n" + "4:" // Main row loop: Column loop + "mov x20, x23\n" + "whilelt p3.s, XZR, x20\n" + "ld1w { z22.s }, p3/Z, [x26]\n" + "ld1w { z21.s }, p3/Z, [x24]\n" + "decw x20\n" + "whilelt p2.s, XZR, x20\n" + "ld1w { z20.s }, p2/Z, [x26, #1, MUL VL]\n" + "ld1w { z19.s }, p2/Z, [x24, #1, MUL VL]\n" + "decw x20\n" + "whilelt p1.s, XZR, x20\n" + "ld1w { z18.s }, p1/Z, [x26, #2, MUL VL]\n" + "ld1w { z17.s }, p1/Z, [x24, #2, MUL VL]\n" + "decw x20\n" + "whilelt p0.s, XZR, x20\n" + "ld1w { z28.s }, p0/Z, [x26, #3, MUL VL]\n" + "ld1w { z16.s }, p0/Z, [x24, #3, MUL VL]\n" + "ld1w { z27.s }, p3/Z, [x25]\n" + "ld1w { z3.s }, p2/Z, [x25, #1, MUL VL]\n" + "zip1 z26.s, z22.s, z21.s\n" + "zip2 z25.s, z22.s, z21.s\n" + "ld1w { z2.s }, p1/Z, [x25, #2, MUL VL]\n" + "ld1w { z1.s }, p0/Z, [x25, #3, MUL VL]\n" + "zip1 z24.s, z20.s, z19.s\n" + "zip2 z23.s, z20.s, z19.s\n" + "ld1w { z22.s }, p3/Z, [x22]\n" + "ld1w { z21.s }, p2/Z, [x22, #1, MUL VL]\n" + "zip1 z20.s, z18.s, z17.s\n" + "zip2 z19.s, z18.s, z17.s\n" + "ld1w { z18.s }, p1/Z, [x22, #2, MUL VL]\n" + "ld1w { z0.s }, p0/Z, [x22, #3, MUL VL]\n" + "zip1 z17.s, z28.s, z16.s\n" + "zip2 z16.s, z28.s, z16.s\n" + "decd x23, ALL, MUL #8\n" + ".inst 0x658ab35f // bfcvt z31.h, p4/M, z26.s\n" + "zip1 z30.s, z27.s, z22.s\n" + "cmp x23, #0x0\n" + ".inst 0x658ab33d // bfcvt z29.h, p4/M, z25.s\n" + "zip2 z28.s, z27.s, z22.s\n" + "addvl x26, x26, #4\n" + "addvl x25, x25, #4\n" + ".inst 0x658ab31b // bfcvt z27.h, p4/M, z24.s\n" + "zip1 z26.s, z3.s, z21.s\n" + "addvl x24, x24, #4\n" + "addvl x22, x22, #4\n" + ".inst 0x658ab2f9 // bfcvt z25.h, p4/M, z23.s\n" + "zip2 z24.s, z3.s, z21.s\n" + ".inst 0x658ab297 // bfcvt z23.h, p4/M, z20.s\n" + "zip1 z22.s, z2.s, z18.s\n" + ".inst 0x658ab275 // bfcvt z21.h, p4/M, z19.s\n" + "zip2 z20.s, z2.s, z18.s\n" + ".inst 0x658ab233 // bfcvt z19.h, p4/M, z17.s\n" + "zip1 z18.s, z1.s, z0.s\n" + ".inst 0x658ab211 // bfcvt z17.h, p4/M, z16.s\n" + "zip2 z16.s, z1.s, z0.s\n" + ".inst 0x648ab3df // bfcvtnt z31.h, p4/M, z30.s\n" + ".inst 0x648ab39d // bfcvtnt z29.h, p4/M, z28.s\n" + "st1h { z31.h }, p4, [x21]\n" + ".inst 0x648ab35b // bfcvtnt z27.h, p4/M, z26.s\n" + ".inst 0x648ab319 // bfcvtnt z25.h, p4/M, z24.s\n" + "st1h { z29.h }, p4, [x21, #1, MUL VL]\n" + ".inst 0x648ab2d7 // bfcvtnt z23.h, p4/M, z22.s\n" + ".inst 0x648ab295 // bfcvtnt z21.h, p4/M, z20.s\n" + "st1h { z27.h }, p4, [x21, #2, MUL VL]\n" + ".inst 0x648ab253 // bfcvtnt z19.h, p4/M, z18.s\n" + ".inst 0x648ab211 // bfcvtnt z17.h, p4/M, z16.s\n" + "st1h { z25.h }, p4, [x21, #3, MUL VL]\n" + "st1h { z23.h }, p4, [x21, #4, MUL VL]\n" + "st1h { z21.h }, p4, [x21, #5, MUL VL]\n" + "st1h { z19.h }, p4, [x21, #6, MUL VL]\n" + "st1h { z17.h }, p4, [x21, #7, MUL VL]\n" + "add x21, x21, %x[out_stride]\n" + "bgt 4b\n" + "5:" // Main row loop: Column loop skip + "cmp %x[height], #0x1\n" + "addvl %x[out], %x[out], #8\n" + "bge 1b\n" + : [height] "+&r" (height), [in] "+&r" (in), [out] "+&r" (out) + : [in_stride] "r" (in_stride), [out_stride] "r" (out_stride), [pad_row] "r" (pad_row), [width] "r" (width) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // anonymous namespace +template<> +void Transform<8, 4, true, VLType::SVE>( + bfloat16 *out, const float *in, int stride, int x0, int xmax, int k0, int kmax) +{ + sve_transpose_interleave_8VL_2x4_fp32bf16( + out, + in + k0 * stride + x0, + (xmax-x0), + stride * sizeof(float), + (kmax-k0) + ); +} + + +#endif // defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp b/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp index 63e85c155a..02367bd7e7 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/transpose_interleave_common.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,6 +23,8 @@ */ #pragma once +#include "../asmlib.hpp" + template <unsigned int IntBy, typename TIn, typename TOut> struct TransposeInterleaveCommon { // Override the moveblock_1xY methods to improve performance @@ -56,7 +58,7 @@ struct TransposeInterleaveCommon { } } - static inline void Transform(TOut *out, const TIn *in, const int stride, const int x0, const int xmax, const int k0, const int kmax) { + static void Transform(TOut *out, const TIn *in, const int stride, const int x0, const int xmax, const int k0, const int kmax) { const auto ldin = stride; TOut *outarray = out; |