diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 60 |
1 files changed, 12 insertions, 48 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 1d5b97b41a..aa206e3f42 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -26,28 +26,22 @@ #include "gemm_hybrid.hpp" #include "gemm_implementation.hpp" #include "gemm_interleaved.hpp" -#include "gemm_interleaved_2d.hpp" #include "gemm_interleaved_pretransposed_2d.hpp" -#include "gemm_native.hpp" #include "gemv_batched.hpp" -#include "gemv_native_transposed.hpp" #include "gemv_pretransposed.hpp" #include "kernels/a32_sgemm_8x6.hpp" #include "kernels/a64_hybrid_fp32_mla_16x4.hpp" #include "kernels/a64_hybrid_fp32_mla_4x8.hpp" -#include "kernels/a64_native_fp32_mla_16x4.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_4x6.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_4x8.hpp" #include "kernels/a64_sgemm_12x8.hpp" #include "kernels/a64_sgemv_pretransposed.hpp" -#include "kernels/a64_sgemv_trans.hpp" #include "kernels/sve_hybrid_fp32_mla_4VLx4.hpp" #include "kernels/sve_hybrid_fp32_mmla_4VLx4.hpp" #include "kernels/sve_interleaved_fp32_mla_3VLx8.hpp" #include "kernels/sve_interleaved_fp32_mmla_3VLx8.hpp" -#include "kernels/sve_native_fp32_mla_4VLx4.hpp" #include "kernels/sve_smallK_hybrid_fp32_mla_1VLx8.hpp" namespace arm_gemm { @@ -65,23 +59,15 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = { GemmMethod::GEMV_PRETRANSPOSED, "sgemv_pretransposed", - [](const GemmArgs &args) { return (args._Msize==1 && args._pretransposed_hint && args._nbatches==1); }, + [](const GemmArgs &args) { return (args._Msize==1 && args._nbatches==1); }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed<sgemv_pretransposed, float, float>(args); } }, -{ - GemmMethod::GEMV_NATIVE_TRANSPOSED, - "sgemv_trans", - [](const GemmArgs &args) { return (args._Msize==1 && !args._trA && !args._trB && args._nbatches==1); }, - nullptr, - [](const GemmArgs &args) { return new GemvNativeTransposed<sgemv_trans, float, float>(args); } -}, - #if defined(__ARM_FEATURE_SVE) && defined(MMLA_FP32) { GemmMethod::GEMM_HYBRID, "hybrid_fp32_mmla_4VLx4", - [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return (args._Ksize >= 4); }, [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mmla_4VLx4, float, float>(args); } }, @@ -95,66 +81,52 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = #endif // __ARM_FEATURE_SVE && MMLA_FP32 #ifdef __ARM_FEATURE_SVE -// SVE smallk / native / hybrid methods +// SVE smallk / hybrid methods { GemmMethod::GEMM_HYBRID, "smallK_hybrid_fp32_mla_1VLx8", - [](const GemmArgs &args) { return (args._Ksize <= 24) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return (args._Ksize <= 24); }, nullptr, [](const GemmArgs &args) { return new GemmHybrid<smallK_hybrid_fp32_mla_1VLx8, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "hybrid_fp32_mla_4VLx4", - [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return (args._Ksize >= 4); }, [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_4VLx4, float, float>(args); } }, -{ - GemmMethod::GEMM_NATIVE, - "native_fp32_mla_4VLx4", - [](const GemmArgs &args) { return (args._Ksize>4 && !args._trA && !args._trB); }, - [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, - [](const GemmArgs &args) { return new GemmNative<native_fp32_mla_4VLx4, float, float>(args); } -}, #endif // __ARM_FEATURE_SVE -// NEON native / hybrid methods +// NEON hybrid methods { GemmMethod::GEMM_HYBRID, "smallK_hybrid_fp32_mla_4x8", - [](const GemmArgs &args) { return (args._Ksize <= 8) && (args._Nsize % 4)==0 && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return (args._Ksize <= 8) && (args._Nsize % 4)==0; }, nullptr, [](const GemmArgs &args) { return new GemmHybrid<smallK_hybrid_fp32_mla_4x8, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "smallK_hybrid_fp32_mla_4x6", - [](const GemmArgs &args) { return (args._Ksize > 8) && (args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return (args._Ksize > 8) && (args._Ksize <= 16) && (args._Nsize % 4)==0; }, nullptr, [](const GemmArgs &args) { return new GemmHybrid<smallK_hybrid_fp32_mla_4x6, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "hybrid_fp32_mla_4x8_normal", - [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return (args._Ksize >= 4); }, [](const GemmArgs &args) { return (args._Nsize < 12); }, [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_4x8, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "hybrid_fp32_mla_16x4", - [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return (args._Ksize >= 4); }, [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || (args._Msize < 16) || (args._nmulti > 1); }, [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_16x4, float, float>(args); } }, -{ - GemmMethod::GEMM_NATIVE, - "native_fp32_mla_16x4", - [](const GemmArgs &args) { return (args._Ksize>4 && (args._Nsize % 16)==0 && !args._trA && !args._trB); }, - [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, - [](const GemmArgs &args) { return new GemmNative<native_fp32_mla_16x4, float, float>(args); } -}, #ifdef __ARM_FEATURE_SVE { @@ -168,18 +140,10 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = // Pretranposed, 2D split { GemmMethod::GEMM_INTERLEAVED_2D, - "sgemm_12x8_pretranspose_2d", - [](const GemmArgs &args) { return args._pretransposed_hint; }, - [](const GemmArgs &args) { return args._maxthreads >= 8; }, - [](const GemmArgs &args) { return new GemmInterleavedPretransposed2d<sgemm_12x8, float, float>(args); } -}, -// Non-pretransposed, 2D split (no buffer manager) -{ - GemmMethod::GEMM_INTERLEAVED_2D, "sgemm_12x8_2d", nullptr, - [](const GemmArgs &args) { return (!args._pretransposed_hint) && (args._maxthreads >= 8); }, - [](const GemmArgs &args) { return new GemmInterleaved2d<sgemm_12x8, float, float>(args); } + [](const GemmArgs &args) { return args._maxthreads >= 8; }, + [](const GemmArgs &args) { return new GemmInterleavedPretransposed2d<sgemm_12x8, float, float>(args); } }, // 1D split (with pretransposed or not) { |