aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp18
1 files changed, 9 insertions, 9 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index aa206e3f42..ddb438f06c 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -120,13 +120,13 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] =
[](const GemmArgs &args) { return (args._Nsize < 12); },
[](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_4x8, float, float>(args); }
},
-{
+GemmImplementation<float, float>::with_estimate(
GemmMethod::GEMM_HYBRID,
"hybrid_fp32_mla_16x4",
[](const GemmArgs &args) { return (args._Ksize >= 4); },
- [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || (args._Msize < 16) || (args._nmulti > 1); },
+ [](const GemmArgs &args) { return GemmHybrid<hybrid_fp32_mla_16x4, float, float>::estimate_cycles(args, hybrid_fp32_mla_16x4::get_performance_parameters(args._ci)); },
[](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_16x4, float, float>(args); }
-},
+),
#ifdef __ARM_FEATURE_SVE
{
@@ -138,21 +138,21 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] =
},
#endif // __ARM_FEATURE_SVE
// Pretranposed, 2D split
-{
+GemmImplementation<float, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED_2D,
"sgemm_12x8_2d",
nullptr,
- [](const GemmArgs &args) { return args._maxthreads >= 8; },
+ [](const GemmArgs &args) { return GemmInterleavedPretransposed2d<sgemm_12x8, float, float>::estimate_cycles(args, sgemm_12x8::get_performance_parameters(args._ci)); },
[](const GemmArgs &args) { return new GemmInterleavedPretransposed2d<sgemm_12x8, float, float>(args); }
-},
+),
// 1D split (with pretransposed or not)
-{
+GemmImplementation<float, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"sgemm_12x8_1d",
nullptr,
- nullptr,
+ [](const GemmArgs &args) { return GemmInterleaved<sgemm_12x8, float, float>::estimate_cycles(args, sgemm_12x8::get_performance_parameters(args._ci)); },
[](const GemmArgs &args) { return new GemmInterleaved<sgemm_12x8, float, float>(args); }
-},
+),
#endif // __aarch64__
#ifdef __arm__