diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index aa206e3f42..ddb438f06c 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -120,13 +120,13 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = [](const GemmArgs &args) { return (args._Nsize < 12); }, [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_4x8, float, float>(args); } }, -{ +GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "hybrid_fp32_mla_16x4", [](const GemmArgs &args) { return (args._Ksize >= 4); }, - [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || (args._Msize < 16) || (args._nmulti > 1); }, + [](const GemmArgs &args) { return GemmHybrid<hybrid_fp32_mla_16x4, float, float>::estimate_cycles(args, hybrid_fp32_mla_16x4::get_performance_parameters(args._ci)); }, [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_16x4, float, float>(args); } -}, +), #ifdef __ARM_FEATURE_SVE { @@ -138,21 +138,21 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = }, #endif // __ARM_FEATURE_SVE // Pretranposed, 2D split -{ +GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED_2D, "sgemm_12x8_2d", nullptr, - [](const GemmArgs &args) { return args._maxthreads >= 8; }, + [](const GemmArgs &args) { return GemmInterleavedPretransposed2d<sgemm_12x8, float, float>::estimate_cycles(args, sgemm_12x8::get_performance_parameters(args._ci)); }, [](const GemmArgs &args) { return new GemmInterleavedPretransposed2d<sgemm_12x8, float, float>(args); } -}, +), // 1D split (with pretransposed or not) -{ +GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "sgemm_12x8_1d", nullptr, - nullptr, + [](const GemmArgs &args) { return GemmInterleaved<sgemm_12x8, float, float>::estimate_cycles(args, sgemm_12x8::get_performance_parameters(args._ci)); }, [](const GemmArgs &args) { return new GemmInterleaved<sgemm_12x8, float, float>(args); } -}, +), #endif // __aarch64__ #ifdef __arm__ |