From 99ef8407cd5b27fdec6f8dfaf8b55f820b6dea71 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Tue, 20 Mar 2018 16:46:55 +0000 Subject: COMPMID-881: Updated arm_gemm to the lastest Change-Id: Iba2664f33320e79bd15ca9c1399e65e4cc165be6 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125265 Tested-by: Jenkins Reviewed-by: Georgios Pinitas --- src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 30 ++- src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 10 +- .../kernels/arm_gemm/gemv_native_transposed.hpp | 8 +- .../kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp | 4 +- .../arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp | 6 +- .../arm_gemm/kernels/a64_hgemm_24x8/generic.cpp | 6 +- .../kernels/a64_sgemm_native_16x4/generic.cpp | 147 ++++++++++++- .../arm_gemm/kernels/a64_sgemv_trans/generic.cpp | 239 +++++++++++---------- .../merges/a64_merge_float_to_half_12x8.hpp | 5 +- src/core/NEON/kernels/arm_gemm/merges/list.hpp | 3 +- .../transforms/a32_interleave_6way_32bit.hpp | 52 ++--- .../a64_interleave_8way_half_to_float.hpp | 4 +- ...64_transpose_interleave_12way_half_to_float.hpp | 4 +- 13 files changed, 337 insertions(+), 181 deletions(-) diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index b9729d4c5c..484892dc81 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -21,7 +21,9 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + +// This can only be built if the target/compiler supports FP16 arguments. +#ifdef __ARM_FP16_ARGS #include "arm_gemm.hpp" @@ -40,26 +42,38 @@ UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, c const int maxthreads, const bool pretransposed_hint) { #ifdef __aarch64__ - /* If FP16 is supported, use it */ - if(ci.has_fp16()) + + // Only consider the native FP16 kernel if it will get built. +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + // If the compiler is configured to enable this feature always, then assume it is available at runtime too. + const bool use_fp16 = true; +#else + // Otherwise, detect at runtime via CPUInfo. + const bool use_fp16 = ci.has_fp16(); +#endif + + // If FP16 is supported, use it. + if(use_fp16) { return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); } +#endif - /* Fallback to using the blocked SGEMM kernel. */ + // Fallback to using the blocked SGEMM kernel. return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); #else - /* For AArch32, only support the SGEMM route. */ + // For AArch32, only support the SGEMM route for now. return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); #endif } -// Instantiate static class members -#ifdef __aarch64__ +// Instantiate static class members if necessary. +#if defined(__aarch64__) && (defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS)) const int hgemm_24x8::out_width; const int hgemm_24x8::out_height; #endif } // namespace arm_gemm -#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC +#endif // __ARM_FP16_ARGS diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 1baa21fd1b..a5b41cac2f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -49,16 +49,16 @@ UniqueGemmCommon gemm(const CPUInfo &ci, const unsig return UniqueGemmCommon(new GemvPretransposed(&ci, N, K, trB, beta)); } - /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle beta */ - if(M == 1 && beta == 1.0f && !trA && !trB) + /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */ + if(M == 1 && alpha == 1.0f && !trA && !trB) { - return UniqueGemmCommon(new GemvNativeTransposed(&ci, N, K, alpha)); + return UniqueGemmCommon(new GemvNativeTransposed(&ci, N, K, beta)); } - /* Native GEMM: requires M to be a multiple of 4, K a multiple of 4, N a + /* Native GEMM: requires M to be a multiple of 4, K at least 4, N a * multiple of 16, doesn't handle alpha and only makes sense for small * sizes. */ - if(N <= 128 && K <= 128 && ((M % 4) == 0) && ((K % 4) == 0) && ((N % 16) == 0) && alpha == 1.0f) + if(N <= 128 && K <= 128 && ((M % 4) == 0) && (K >= 4) && ((N % 16) == 0) && alpha == 1.0f) { return UniqueGemmCommon(new GemmNative(&ci, M, N, K, beta)); } diff --git a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp index c0b886266d..29c71f2511 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp @@ -49,7 +49,7 @@ class GemvNativeTransposed : public GemmCommon const unsigned int _Nsize; const unsigned int _Ksize; - const Tr _alpha; + const Tr _beta; const CPUInfo *const _ci; @@ -60,8 +60,8 @@ public: GemvNativeTransposed(GemvNativeTransposed &) = delete; GemvNativeTransposed &operator=(GemvNativeTransposed &) = delete; - GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const Tr alpha) - : _Nsize(N), _Ksize(K), _alpha(alpha), _ci(ci) + GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const Tr beta) + : _Nsize(N), _Ksize(K), _beta(beta), _ci(ci) { /* For now don't do any blocking. TODO: figure out if we should. */ m_block = K; @@ -97,7 +97,7 @@ public: prof(PROFILE_KERNEL, ((mmax - m0) * (nmax - n0)), [&](void) { strat.kernel(this->_Bptr + (m0 * this->_ldb) + n0, this->_Aptr + m0, this->_Cptr + n0, - _alpha, this->_ldb, (mmax - m0), (nmax - n0)); + _beta, this->_ldb, (mmax - m0), (nmax - n0)); }); } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp index 77ec59aa35..5fc0a7b707 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) #include "arm_gemm.hpp" @@ -71,4 +71,4 @@ public: } // namespace arm_gemm -#endif // __aarch64__ +#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp index d59618dd54..2186117536 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/a55r1.cpp @@ -21,7 +21,9 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) + +// Build on AArch64 where either FP16_KERNELS is set or FP16 is explicitly supported. +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) #include @@ -357,4 +359,4 @@ void a64_hgemm_asimd_24x8_a55r1(const __fp16 *Apanel, const __fp16 *Bpanel, __fp } // namespace arm_gemm -#endif // __aarch64__ && __ARM_FEATURE_FP16_SCALAR_ARITHMETIC +#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp index 468d603484..65a5d43d1d 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8/generic.cpp @@ -21,7 +21,9 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) + +// Build on AArch64 where either FP16_KERNELS is set or FP16 is explicitly supported. +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) #include @@ -334,4 +336,4 @@ void a64_hgemm_asimd_24x8(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cp } // namespace arm_gemm -#endif // __aarch64__ && __ARM_FEATURE_FP16_SCALAR_ARITHMETIC +#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp index 1b5787ce7c..8d4a38c36d 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4/generic.cpp @@ -31,8 +31,9 @@ namespace arm_gemm { void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, float *C, int ldc, float beta, int M, int N, int K) { - int oddk = (K % 8) ? 1 : 0; - int beta0 = (beta == 0.0f) ? 1 : 0; + const int oddk = ((K % 8) >= 4) ? 1 : 0; + const int beta0 = (beta == 0.0f) ? 1 : 0; + const int oddones = (K % 4); /* For now, very naive with no blocking */ for(int y = 0; y < M; y += 4) @@ -52,6 +53,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo float *c_ptr3 = c_ptr2 + ldc; int loops = ((K + 4) / 8) - 1; + int odds = oddones; size_t ldbb = ldb * sizeof(float); @@ -434,14 +436,17 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "ldr b0aq, [%[b_ptr]]\n" "fmla v17.4s, b1a.4s, a0.s[1]\n" + "add %[a_ptr0], %[a_ptr0], #32\n" "fmla v21.4s, b1a.4s, a1.s[1]\n" - "subs %w[loops], %w[loops], #1\n" + "add %[a_ptr1], %[a_ptr1], #32\n" "fmla v25.4s, b1a.4s, a2.s[1]\n" + "add %[a_ptr2], %[a_ptr2], #32\n" "fmla v29.4s, b1a.4s, a3.s[1]\n" "ldr b1aq, [%[b_ptr], #16]\n" "fmla v18.4s, b2a.4s, a0.s[1]\n" "fmla v22.4s, b2a.4s, a1.s[1]\n" + "add %[a_ptr3], %[a_ptr3], #32\n" "fmla v26.4s, b2a.4s, a2.s[1]\n" "fmla v30.4s, b2a.4s, a3.s[1]\n" "ldr b2aq, [%[b_ptr], #32]\n" @@ -488,7 +493,6 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "fmla v17.4s, b1a.4s, a0.s[3]\n" "fmla v21.4s, b1a.4s, a1.s[3]\n" - "ldr a3aq, [%[a_ptr3], #16]\n" "fmla v25.4s, b1a.4s, a2.s[3]\n" "fmla v29.4s, b1a.4s, a3.s[3]\n" "ldr b1aq, [%[b_ptr], #16]\n" @@ -560,6 +564,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo // Unroll 6 "fmla v16.4s, bb0.4s, a0a.s[2]\n" "fmla v20.4s, bb0.4s, a1a.s[2]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" "fmla v24.4s, bb0.4s, a2a.s[2]\n" "fmla v28.4s, bb0.4s, a3a.s[2]\n" @@ -583,6 +588,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "fmla v17.4s, b1a.4s, a0a.s[3]\n" "fmla v18.4s, b2a.4s, a0a.s[3]\n" "fmla v19.4s, b3a.4s, a0a.s[3]\n" + "cbnz %w[odds], 6f\n" "fmla v20.4s, b0a.4s, a1a.s[3]\n" "str q16, [%[c_ptr0]]\n" @@ -615,12 +621,16 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo // Odd K case: Just do 4 more. "2:\n" "fmla v21.4s, bb1.4s, a1.s[0]\n" + "add %[a_ptr0], %[a_ptr0], #16\n" "fmla v25.4s, bb1.4s, a2.s[0]\n" + "add %[a_ptr1], %[a_ptr1], #16\n" "fmla v29.4s, bb1.4s, a3.s[0]\n" "ldr b1q, [%[b_ptr], #16]\n" "fmla v18.4s, bb2.4s, a0.s[0]\n" + "add %[a_ptr2], %[a_ptr2], #16\n" "fmla v22.4s, bb2.4s, a1.s[0]\n" + "add %[a_ptr3], %[a_ptr3], #16\n" "fmla v26.4s, bb2.4s, a2.s[0]\n" "fmla v30.4s, bb2.4s, a3.s[0]\n" "ldr b2q, [%[b_ptr], #32]\n" @@ -641,7 +651,6 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "fmla v17.4s, b1a.4s, a0.s[1]\n" "fmla v21.4s, b1a.4s, a1.s[1]\n" - "subs %w[loops], %w[loops], #1\n" "fmla v25.4s, b1a.4s, a2.s[1]\n" "fmla v29.4s, b1a.4s, a3.s[1]\n" "ldr b1aq, [%[b_ptr], #16]\n" @@ -660,6 +669,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo // Unroll 2 "fmla v16.4s, bb0.4s, a0.s[2]\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" "fmla v20.4s, bb0.4s, a1.s[2]\n" "fmla v24.4s, bb0.4s, a2.s[2]\n" "fmla v28.4s, bb0.4s, a3.s[2]\n" @@ -684,6 +694,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "fmla v17.4s, b1a.4s, a0.s[3]\n" "fmla v18.4s, b2a.4s, a0.s[3]\n" "fmla v19.4s, b3a.4s, a0.s[3]\n" + "cbnz %w[odds], 7f\n" "fmla v20.4s, b0a.4s, a1.s[3]\n" "str q16, [%[c_ptr0]]\n" @@ -711,6 +722,130 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "str q26, [%[c_ptr2], #32]\n" "fmla v31.4s, b3a.4s, a3.s[3]\n" "str q27, [%[c_ptr2], #48]\n" + "b 3f\n" + + // "Odd ones" - lead in from even + "6:\n" + "fmla v20.4s, b0a.4s, a1a.s[3]\n" + "fmla v21.4s, b1a.4s, a1a.s[3]\n" + "ldr b0q, [%[b_ptr]]\n" + "fmla v22.4s, b2a.4s, a1a.s[3]\n" + "subs %w[odds], %w[odds], #1\n" + "fmla v23.4s, b3a.4s, a1a.s[3]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v24.4s, b0a.4s, a2a.s[3]\n" + "fmla v25.4s, b1a.4s, a2a.s[3]\n" + "ldr b2q, [%[b_ptr], #32]\n" + "fmla v26.4s, b2a.4s, a2a.s[3]\n" + "fmla v27.4s, b3a.4s, a2a.s[3]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + "fmla v28.4s, b0a.4s, a3a.s[3]\n" + "ld1r {a0.4s}, [%[a_ptr0]], #4\n" + "fmla v29.4s, b1a.4s, a3a.s[3]\n" + "fmla v30.4s, b2a.4s, a3a.s[3]\n" + "ld1r {a1.4s}, [%[a_ptr1]], #4\n" + "fmla v31.4s, b3a.4s, a3a.s[3]\n" + + "fmla v16.4s, bb0.4s, a0.4s\n" + "beq 9f\n" + "b 8f\n" + + // "Odd ones" - lead in from odd + "7:\n" + "fmla v20.4s, b0a.4s, a1.s[3]\n" + "subs %w[odds], %w[odds], #1\n" + "fmla v21.4s, b1a.4s, a1.s[3]\n" + "ldr b0q, [%[b_ptr]]\n" + "fmla v22.4s, b2a.4s, a1.s[3]\n" + "fmla v23.4s, b3a.4s, a1.s[3]\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v24.4s, b0a.4s, a2.s[3]\n" + "fmla v25.4s, b1a.4s, a2.s[3]\n" + "ldr b2q, [%[b_ptr], #32]\n" + "fmla v26.4s, b2a.4s, a2.s[3]\n" + "fmla v27.4s, b3a.4s, a2.s[3]\n" + "ldr b3q, [%[b_ptr], #48]\n" + + "fmla v28.4s, b0a.4s, a3.s[3]\n" + "ld1r {a0.4s}, [%[a_ptr0]], #4\n" + "fmla v29.4s, b1a.4s, a3.s[3]\n" + "fmla v30.4s, b2a.4s, a3.s[3]\n" + "ld1r {a1.4s}, [%[a_ptr1]], #4\n" + "fmla v31.4s, b3a.4s, a3.s[3]\n" + + "fmla v16.4s, bb0.4s, a0.4s\n" + "beq 9f\n" + + // "Odd ones" - loop + "8:\n" + "fmla v17.4s, bb1.4s, a0.4s\n" + "ld1r {a2.4s}, [%[a_ptr2]], #4\n" + "fmla v18.4s, bb2.4s, a0.4s\n" + "add %[b_ptr], %[b_ptr], %[ldb]\n" + "fmla v19.4s, bb3.4s, a0.4s\n" + "ld1r {a3.4s}, [%[a_ptr3]], #4\n" + + "fmla v20.4s, bb0.4s, a1.4s\n" + "subs %w[odds], %w[odds], #1\n" + "fmla v21.4s, bb1.4s, a1.4s\n" + "ld1r {a0.4s}, [%[a_ptr0]], #4\n" + "fmla v22.4s, bb2.4s, a1.4s\n" + "fmla v23.4s, bb3.4s, a1.4s\n" + "ld1r {a1.4s}, [%[a_ptr1]], #4\n" + + "fmla v24.4s, bb0.4s, a2.4s\n" + "fmla v28.4s, bb0.4s, a3.4s\n" + "ldr b0q, [%[b_ptr]]\n" + "fmla v25.4s, bb1.4s, a2.4s\n" + "fmla v29.4s, bb1.4s, a3.4s\n" + "ldr b1q, [%[b_ptr], #16]\n" + + "fmla v26.4s, bb2.4s, a2.4s\n" + "fmla v30.4s, bb2.4s, a3.4s\n" + "ldr b2q, [%[b_ptr], #32]\n" + "fmla v27.4s, bb3.4s, a2.4s\n" + "fmla v31.4s, bb3.4s, a3.4s\n" + "ldr b3q, [%[b_ptr], #48]\n" + "fmla v16.4s, bb0.4s, a0.4s\n" + "bne 8b\n" + + // "Odd ones" - detached final iteration + "9:\n" + "fmla v17.4s, bb1.4s, a0.4s\n" + "ld1r {a2.4s}, [%[a_ptr2]], #4\n" + "fmla v18.4s, bb2.4s, a0.4s\n" + "fmla v19.4s, bb3.4s, a0.4s\n" + "ld1r {a3.4s}, [%[a_ptr3]], #4\n" + + "fmla v20.4s, bb0.4s, a1.4s\n" + "str q16, [%[c_ptr0]]\n" + "fmla v21.4s, bb1.4s, a1.4s\n" + "str q17, [%[c_ptr0], #16]\n" + "fmla v22.4s, bb2.4s, a1.4s\n" + "str q18, [%[c_ptr0], #32]\n" + "fmla v23.4s, bb3.4s, a1.4s\n" + "str q19, [%[c_ptr0], #48]\n" + + "fmla v24.4s, bb0.4s, a2.4s\n" + "str q20, [%[c_ptr1]]\n" + "fmla v25.4s, bb1.4s, a2.4s\n" + "str q21, [%[c_ptr1], #16]\n" + "fmla v26.4s, bb2.4s, a2.4s\n" + "str q22, [%[c_ptr1], #32]\n" + "fmla v27.4s, bb3.4s, a2.4s\n" + "str q23, [%[c_ptr1], #48]\n" + + "fmla v28.4s, bb0.4s, a3.4s\n" + "str q24, [%[c_ptr2]]\n" + "fmla v29.4s, bb1.4s, a3.4s\n" + "str q25, [%[c_ptr2], #16]\n" + "fmla v30.4s, bb2.4s, a3.4s\n" + "str q26, [%[c_ptr2], #32]\n" + "fmla v31.4s, bb3.4s, a3.4s\n" + "str q27, [%[c_ptr2], #48]\n" "3:\n" "str q28, [%[c_ptr3]]\n" @@ -719,7 +854,7 @@ void a64_sgemm_native_16x4(const float *A, int lda, const float *B, int ldb, flo "str q31, [%[c_ptr3], #48]\n" : [a_ptr0] "+r"(a_ptr0), [a_ptr1] "+r"(a_ptr1), [a_ptr2] "+r"(a_ptr2), [a_ptr3] "+r"(a_ptr3), - [b_ptr] "+r"(b_ptr), [loops] "+r"(loops) + [b_ptr] "+r"(b_ptr), [loops] "+r"(loops), [odds] "+r"(odds) : [ldb] "r"(ldbb), [oddk] "r"(oddk), [beta0] "r"(beta0), [betaptr] "r"(&beta), [c_ptr0] "r"(c_ptr0), [c_ptr1] "r"(c_ptr1), [c_ptr2] "r"(c_ptr2), [c_ptr3] "r"(c_ptr3) : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp index 3309baff3a..8fa403bf02 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_trans/generic.cpp @@ -50,12 +50,12 @@ namespace arm_gemm { -void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float alpha, int lda, int M, int N) +void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, float beta, int lda, int M, int N) { const float *a_ptr_base = Astart; float *y_ptr = Ystart; - register const float32x4_t va asm("v1") = vdupq_n_f32(alpha); + register const float32x4_t vb asm("v1") = vdupq_n_f32(beta); int firstpfd = FIRST_PFD; if(firstpfd > M) @@ -344,6 +344,7 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl "fmla v17.4s, v5.4s, v0.4s\n" "ldr q5, [%[a_ptr], #0xf0]\n" "fmla v18.4s, v6.4s, v0.4s\n" + "ldr q6, [%[a_ptr], #0x100]\n" "fmla v19.4s, v7.4s, v0.4s\n" "ldr q7, [%[a_ptr], #0x110]\n" @@ -372,75 +373,75 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl "fmla v31.4s, v7.4s, v0.4s\n" "ldr q7, [%[y_ptr], #0x50]\n" - "fmla v2.4s, v8.4s, %[va].4s\n" - "ldr q8, [%[y_ptr], #0x60]\n" - "fmla v3.4s, v9.4s, %[va].4s\n" - "ldr q9, [%[y_ptr], #0x70]\n" - "fmla v4.4s, v10.4s, %[va].4s\n" - "ldr q10, [%[y_ptr], #0x80]\n" - "fmla v5.4s, v11.4s, %[va].4s\n" - "ldr q11, [%[y_ptr], #0x90]\n" - "fmla v6.4s, v12.4s, %[va].4s\n" - "ldr q12, [%[y_ptr], #0xa0]\n" - "str q2, [%[y_ptr], #0x00]\n" - "fmla v7.4s, v13.4s, %[va].4s\n" - "ldr q13, [%[y_ptr], #0xb0]\n" - "str q3, [%[y_ptr], #0x10]\n" - "fmla v8.4s, v14.4s, %[va].4s\n" - "ldr q14, [%[y_ptr], #0xc0]\n" - "str q4, [%[y_ptr], #0x20]\n" - "fmla v9.4s, v15.4s, %[va].4s\n" - "ldr q15, [%[y_ptr], #0xd0]\n" - "str q5, [%[y_ptr], #0x30]\n" - "fmla v10.4s, v16.4s, %[va].4s\n" - "ldr q16, [%[y_ptr], #0xe0]\n" - "str q6, [%[y_ptr], #0x40]\n" - "fmla v11.4s, v17.4s, %[va].4s\n" - "ldr q17, [%[y_ptr], #0xf0]\n" - "str q7, [%[y_ptr], #0x50]\n" - "fmla v12.4s, v18.4s, %[va].4s\n" - "ldr q18, [%[y_ptr], #0x100]\n" - "str q8, [%[y_ptr], #0x60]\n" - "fmla v13.4s, v19.4s, %[va].4s\n" - "ldr q19, [%[y_ptr], #0x110]\n" - "str q9, [%[y_ptr], #0x70]\n" - "fmla v14.4s, v20.4s, %[va].4s\n" - "ldr q20, [%[y_ptr], #0x120]\n" - "str q10, [%[y_ptr], #0x80]\n" - "fmla v15.4s, v21.4s, %[va].4s\n" - "ldr q21, [%[y_ptr], #0x130]\n" - "str q11, [%[y_ptr], #0x90]\n" - "fmla v16.4s, v22.4s, %[va].4s\n" - "ldr q22, [%[y_ptr], #0x140]\n" - "str q12, [%[y_ptr], #0xa0]\n" - "fmla v17.4s, v23.4s, %[va].4s\n" - "ldr q23, [%[y_ptr], #0x150]\n" - "str q13, [%[y_ptr], #0xb0]\n" - "fmla v18.4s, v24.4s, %[va].4s\n" - "ldr q24, [%[y_ptr], #0x160]\n" - "str q14, [%[y_ptr], #0xc0]\n" - "fmla v19.4s, v25.4s, %[va].4s\n" - "ldr q25, [%[y_ptr], #0x170]\n" - "str q15, [%[y_ptr], #0xd0]\n" - "fmla v20.4s, v26.4s, %[va].4s\n" - "str q16, [%[y_ptr], #0xe0]\n" - "fmla v21.4s, v27.4s, %[va].4s\n" - "str q17, [%[y_ptr], #0xf0]\n" - "fmla v22.4s, v28.4s, %[va].4s\n" - "str q18, [%[y_ptr], #0x100]\n" - "fmla v23.4s, v29.4s, %[va].4s\n" - "str q19, [%[y_ptr], #0x110]\n" - "fmla v24.4s, v30.4s, %[va].4s\n" - "str q20, [%[y_ptr], #0x120]\n" - "fmla v25.4s, v31.4s, %[va].4s\n" - "str q21, [%[y_ptr], #0x130]\n" - - "stp q22, q23, [%[y_ptr], #0x140]\n" - "stp q24, q25, [%[y_ptr], #0x160]\n" + "fmla v8.4s, v2.4s, %[vb].4s\n" + "ldr q2, [%[y_ptr], #0x60]\n" + "fmla v9.4s, v3.4s, %[vb].4s\n" + "ldr q3, [%[y_ptr], #0x70]\n" + "fmla v10.4s, v4.4s, %[vb].4s\n" + "ldr q4, [%[y_ptr], #0x80]\n" + "fmla v11.4s, v5.4s, %[vb].4s\n" + "ldr q5, [%[y_ptr], #0x90]\n" + "fmla v12.4s, v6.4s, %[vb].4s\n" + "ldr q6, [%[y_ptr], #0xa0]\n" + "str q8, [%[y_ptr], #0x00]\n" + "fmla v13.4s, v7.4s, %[vb].4s\n" + "ldr q7, [%[y_ptr], #0xb0]\n" + "str q9, [%[y_ptr], #0x10]\n" + "fmla v14.4s, v2.4s, %[vb].4s\n" + "ldr q2, [%[y_ptr], #0xc0]\n" + "str q10, [%[y_ptr], #0x20]\n" + "fmla v15.4s, v3.4s, %[vb].4s\n" + "ldr q3, [%[y_ptr], #0xd0]\n" + "str q11, [%[y_ptr], #0x30]\n" + "fmla v16.4s, v4.4s, %[vb].4s\n" + "ldr q4, [%[y_ptr], #0xe0]\n" + "str q12, [%[y_ptr], #0x40]\n" + "fmla v17.4s, v5.4s, %[vb].4s\n" + "ldr q5, [%[y_ptr], #0xf0]\n" + "str q13, [%[y_ptr], #0x50]\n" + "fmla v18.4s, v6.4s, %[vb].4s\n" + "ldr q6, [%[y_ptr], #0x100]\n" + "str q14, [%[y_ptr], #0x60]\n" + "fmla v19.4s, v7.4s, %[vb].4s\n" + "ldr q7, [%[y_ptr], #0x110]\n" + "str q15, [%[y_ptr], #0x70]\n" + "fmla v20.4s, v2.4s, %[vb].4s\n" + "ldr q2, [%[y_ptr], #0x120]\n" + "str q16, [%[y_ptr], #0x80]\n" + "fmla v21.4s, v3.4s, %[vb].4s\n" + "ldr q3, [%[y_ptr], #0x130]\n" + "str q17, [%[y_ptr], #0x90]\n" + "fmla v22.4s, v4.4s, %[vb].4s\n" + "ldr q4, [%[y_ptr], #0x140]\n" + "str q18, [%[y_ptr], #0xa0]\n" + "fmla v23.4s, v5.4s, %[vb].4s\n" + "ldr q5, [%[y_ptr], #0x150]\n" + "str q19, [%[y_ptr], #0xb0]\n" + "fmla v24.4s, v6.4s, %[vb].4s\n" + "ldr q6, [%[y_ptr], #0x160]\n" + "str q20, [%[y_ptr], #0xc0]\n" + "fmla v25.4s, v7.4s, %[vb].4s\n" + "ldr q7, [%[y_ptr], #0x170]\n" + "str q21, [%[y_ptr], #0xd0]\n" + "fmla v26.4s, v2.4s, %[vb].4s\n" + "str q22, [%[y_ptr], #0xe0]\n" + "fmla v27.4s, v3.4s, %[vb].4s\n" + "str q23, [%[y_ptr], #0xf0]\n" + "fmla v28.4s, v4.4s, %[vb].4s\n" + "str q24, [%[y_ptr], #0x100]\n" + "fmla v29.4s, v5.4s, %[vb].4s\n" + "str q25, [%[y_ptr], #0x110]\n" + "fmla v30.4s, v6.4s, %[vb].4s\n" + "str q26, [%[y_ptr], #0x120]\n" + "fmla v31.4s, v7.4s, %[vb].4s\n" + "str q27, [%[y_ptr], #0x130]\n" + + "stp q28, q29, [%[y_ptr], #0x140]\n" + "stp q30, q31, [%[y_ptr], #0x160]\n" "add %[y_ptr], %[y_ptr], #0x180\n" : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k), [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr) - : [jump] "r"(jump), [va] "w"(va), [pf_limit] "r"(pf_limit) + : [jump] "r"(jump), [vb] "w"(vb), [pf_limit] "r"(pf_limit) : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); @@ -747,161 +748,161 @@ void a64_sgemv_trans(const float *Astart, const float *Xstart, float *Ystart, fl // Vector 0 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v8.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v8.4s, v7.4s, %[vb].4s\n" + "str q8, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 1 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v9.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v9.4s, v7.4s, %[vb].4s\n" + "str q9, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 2 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v10.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v10.4s, v7.4s, %[vb].4s\n" + "str q10, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 3 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v11.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v11.4s, v7.4s, %[vb].4s\n" + "str q11, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 4 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v12.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v12.4s, v7.4s, %[vb].4s\n" + "str q12, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 5 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v13.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v13.4s, v7.4s, %[vb].4s\n" + "str q13, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 6 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v14.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v14.4s, v7.4s, %[vb].4s\n" + "str q14, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 7 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v15.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v15.4s, v7.4s, %[vb].4s\n" + "str q15, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 8 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v16.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v16.4s, v7.4s, %[vb].4s\n" + "str q16, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 9 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v17.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v17.4s, v7.4s, %[vb].4s\n" + "str q17, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 10 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v18.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v18.4s, v7.4s, %[vb].4s\n" + "str q18, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 11 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v19.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v19.4s, v7.4s, %[vb].4s\n" + "str q19, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 12 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v20.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v20.4s, v7.4s, %[vb].4s\n" + "str q20, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 13 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v21.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v21.4s, v7.4s, %[vb].4s\n" + "str q21, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 14 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v22.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v22.4s, v7.4s, %[vb].4s\n" + "str q22, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 15 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v23.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v23.4s, v7.4s, %[vb].4s\n" + "str q23, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 16 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v24.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v24.4s, v7.4s, %[vb].4s\n" + "str q24, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 17 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v25.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v25.4s, v7.4s, %[vb].4s\n" + "str q25, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 18 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v26.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v26.4s, v7.4s, %[vb].4s\n" + "str q26, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 19 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v27.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v27.4s, v7.4s, %[vb].4s\n" + "str q27, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 20 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v28.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v28.4s, v7.4s, %[vb].4s\n" + "str q28, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 21 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v29.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v29.4s, v7.4s, %[vb].4s\n" + "str q29, [%[y_ptr]], #0x10\n" "beq 12f\n" // Vector 22 "subs %w[vecs], %w[vecs], #1\n" "ldr q7, [%[y_ptr]]\n" - "fmla v7.4s, v30.4s, %[va].4s\n" - "str q7, [%[y_ptr]], #0x10\n" + "fmla v30.4s, v7.4s, %[vb].4s\n" + "str q30, [%[y_ptr]], #0x10\n" // Odd 2 "12:\n" "cbz %[odd2_aptr], 13f\n" "ldr d7, [%[y_ptr]]\n" - "fmla v7.2s, v6.2s, %[va].2s\n" - "str d7, [%[y_ptr]], #0x8\n" + "fmla v6.2s, v7.2s, %[vb].2s\n" + "str d6, [%[y_ptr]], #0x8\n" // Odd 1 "13:\n" "cbz %[odd1_aptr], 14f\n" "ldr s7, [%[y_ptr]]\n" - "fmla v7.2s, v5.2s, %[va].2s\n" - "str s7, [%[y_ptr]]\n" + "fmla v5.2s, v7.2s, %[vb].2s\n" + "str s5, [%[y_ptr]]\n" "14:\n" : [a_ptr] "+r"(a_ptr), [x_ptr] "+r"(x_ptr), [y_ptr] "+r"(y_ptr), [k] "+r"(k), [pf_ptr] "+r"(pf_ptr), [firstpf_ptr] "+r"(firstpf_ptr), [odd1_aptr] "+r"(odd1_aptr), [odd2_aptr] "+r"(odd2_aptr), [dopf] "+r"(dopf), [vecs] "+r"(vecs) - : [jump] "r"(jump), [va] "w"(va), [pf_limit] "r"(pf_limit), [numvecs] "r"(numvecs) + : [jump] "r"(jump), [vb] "w"(vb), [pf_limit] "r"(pf_limit), [numvecs] "r"(numvecs) : "w0", "v0", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"); diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp index 12a090112d..9708fe189d 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp @@ -23,7 +23,8 @@ */ #pragma once -#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) +// This should be possible on any AArch64 target, but some old compilers don't support __fp16 arguments. +#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) #include @@ -270,4 +271,4 @@ inline void MergeResults<12, 8>(__fp16 *out, const float *in, int ldout, int y0, } } -#endif // __aarch64__ && __ARM_FEATURE_FP16_SCALAR_ARITHMETIC +#endif // __aarch64__ && __ARM_FP16_ARGS diff --git a/src/core/NEON/kernels/arm_gemm/merges/list.hpp b/src/core/NEON/kernels/arm_gemm/merges/list.hpp index 7d56e58f44..d93f1b0e6e 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/list.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/list.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -21,7 +21,6 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ - #include "a32_merge_float_8x6.hpp" #include "a64_merge_float_12x8.hpp" #include "a64_merge_float_to_half_12x8.hpp" diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp index f09e5a0e78..501d6bf075 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp @@ -73,63 +73,65 @@ inline void TransformImpl<6, 1, false, 4, 4>::Transform(T *out, const T *in, int inptr4 = zerobuff; case 0: inptr5 = zerobuff; - default: break; + + default: + UNREACHABLE("Impossible."); } } __asm __volatile( // Load up 8 elements (2 vectors) from each of 8 sources. - "VLD1.32 {d0-d3}, [%[inptr0]]!\n" // q0=A0A1A2A3 - "VLD1.32 {d4-d7}, [%[inptr1]]!\n" // q2=B0B1B2B3 - "VLD1.32 {d8-d11}, [%[inptr2]]!\n" // q4=C0C1C2C3 - "VZIP.32 q0, q4\n" // q0=A0C0A1C1, q4 = A2C2A3C3 - "VLD1.32 {d12-d15}, [%[inptr3]]!\n" // q6=D0D1D2D3 - "VZIP.32 q2, q6\n" // q2=B0D0B1D1, q6 = B2D2B3D3 - "VLD1.32 {d16-d19}, [%[inptr4]]!\n" - "VLD1.32 {d20-d23}, [%[inptr5]]!\n" + "VLD1.32 {d0-d3}, [%[inptr0]]!\n" // q0=A0A1A2A3 + "VLD1.32 {d4-d7}, [%[inptr1]]!\n" // q2=B0B1B2B3 + "VLD1.32 {d8-d11}, [%[inptr2]]!\n" // q4=C0C1C2C3 + "VZIP.32 q0, q4\n" // q0=A0C0A1C1, q4 = A2C2A3C3 + "VLD1.32 {d12-d15}, [%[inptr3]]!\n" // q6=D0D1D2D3 + "VZIP.32 q2, q6\n" // q2=B0D0B1D1, q6 = B2D2B3D3 + "VLD1.32 {d16-d19}, [%[inptr4]]!\n" + "VLD1.32 {d20-d23}, [%[inptr5]]!\n" "VZIP.32 q8, q10\n" // q8=E0F0E1F1, q10 = E2F2E3F3 ASM_PREFETCH("[%[inptr0], #128]") "VZIP.32 q0, q2\n" // q0 = A0B0C0D0, q2 = A1B1C1D1 // Store first elements - "VST1.32 {d0-d1}, [%[outptr]]!\n" - "VST1.32 {d16}, [%[outptr]]!\n" + "VST1.32 {d0-d1}, [%[outptr]]!\n" + "VST1.32 {d16}, [%[outptr]]!\n" "VZIP.32 q4, q6\n" // q4 = A2B2C2D2, q6 = A3B3C3D3 // Store second elements - "VST1.32 {d4-d5}, [%[outptr]]!\n" + "VST1.32 {d4-d5}, [%[outptr]]!\n" "VZIP.32 q1, q5\n" ASM_PREFETCH("[%[inptr1], #128]") - "VST1.32 {d17}, [%[outptr]]!\n" + "VST1.32 {d17}, [%[outptr]]!\n" "VZIP.32 q3, q7\n" // Store third elements "VZIP.32 q9, q11\n" - "VST1.32 {d8-d9}, [%[outptr]]!\n" + "VST1.32 {d8-d9}, [%[outptr]]!\n" "VZIP.32 q1, q3\n" ASM_PREFETCH("[%[inptr2], #128]") - "VST1.32 {d20}, [%[outptr]]!\n" + "VST1.32 {d20}, [%[outptr]]!\n" // Store fourth elements "VZIP.32 q5, q7\n" - "VST1.32 {d12-d13}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr3], #128]") - "VST1.32 {d21}, [%[outptr]]!\n" + "VST1.32 {d12-d13}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr3], #128]") + "VST1.32 {d21}, [%[outptr]]!\n" // Fifth - "VST1.32 {d2-d3}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr4], #128]") - "VST1.32 {d18}, [%[outptr]]!\n" + "VST1.32 {d2-d3}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr4], #128]") + "VST1.32 {d18}, [%[outptr]]!\n" // Sixth - "VST1.32 {d6-d7}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr5], #128]") - "VST1.32 {d19}, [%[outptr]]!\n" + "VST1.32 {d6-d7}, [%[outptr]]!\n" ASM_PREFETCH("[%[inptr5], #128]") + "VST1.32 {d19}, [%[outptr]]!\n" // Seventh - "VST1.32 {d10-d11}, [%[outptr]]!\n" - "VST1.32 {d22}, [%[outptr]]!\n" + "VST1.32 {d10-d11}, [%[outptr]]!\n" + "VST1.32 {d22}, [%[outptr]]!\n" // Eighth - "VST1.32 {d14-d15}, [%[outptr]]!\n" - "VST1.32 {d23}, [%[outptr]]!\n" + "VST1.32 {d14-d15}, [%[outptr]]!\n" + "VST1.32 {d23}, [%[outptr]]!\n" : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [outptr] "+r"(outptr) diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp index 85ffdc2d4f..1d2d4969f6 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) +#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) #include @@ -189,4 +189,4 @@ inline void TransformImpl<8, 1, false, 4, 2>::Transform(float *out, const __fp16 } } -#endif // __aarch64__ +#endif // __aarch64__ && __ARM_FP16_ARGS diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp index ff1cbfb5f5..b79f32fb8b 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) +#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) #include "transpose_interleave_common.hpp" @@ -110,4 +110,4 @@ inline void TransformImpl<12, 1, true, 4, 2>::Transform( TransposeInterleaveCommon<12, __fp16, float>::Transform(out, in, stride, x0, xmax, k0, kmax); } -#endif // __aarch64__ +#endif // __aarch64__ && __ARM_FP16_ARGS -- cgit v1.2.1