diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index e840e90eec..7d14971b70 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -36,10 +36,12 @@ #include "kernels/a64_sgemv_pretransposed.hpp" #include "kernels/a64_sgemm_native_16x4.hpp" +#include "kernels/sve_interleaved_fp32_mla_3VLx8.hpp" + namespace arm_gemm { -#ifdef __aarch64__ -// SGEMM implementations for AArch64 +#if defined(__aarch64__) && !defined(__ARM_FEATURE_SVE) +// SGEMM implementations for AArch64 without SVE // Pretransposed GEMV class GemmImpl_sgemm_gemv_pretransposed : public GemmImplementation<float, float> { @@ -92,7 +94,9 @@ public: class GemmImpl_sgemm_gemm_interleaved : public GemmImplementation<float, float> { public: UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override { -#ifdef __aarch64__ +#ifdef __ARM_FEATURE_SVE + return UniqueGemmCommon<float, float> (new GemmInterleaved<interleaved_fp32_mla_3VLx8, float, float>(args)); +#elif defined(__aarch64__) return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(args)); #elif defined(__arm__) return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(args)); @@ -105,7 +109,7 @@ public: }; static GemmImpl_gemv_batched<float, float> gemv_batched_impl{}; -#ifdef __aarch64__ +#if defined(__aarch64__) && !defined(__ARM_FEATURE_SVE) static GemmImpl_sgemm_gemv_pretransposed sgemm_gemv_pretransposed_impl{}; static GemmImpl_sgemm_gemv_native_transposed sgemm_gemv_native_transposed_impl{}; static GemmImpl_sgemm_gemm_native sgemm_gemm_native_impl{}; @@ -115,7 +119,7 @@ static GemmImpl_sgemm_gemm_interleaved sgemm_gemm_interleaved_impl{}; /* List of implementations (order matters) */ static std::vector<GemmImplementation<float, float> *> SGemmMethods = { &gemv_batched_impl, -#ifdef __aarch64__ +#if defined(__aarch64__) && !defined(__ARM_FEATURE_SVE) &sgemm_gemv_pretransposed_impl, &sgemm_gemv_native_transposed_impl, &sgemm_gemm_native_impl, |