aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp24
1 files changed, 12 insertions, 12 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 290fe87230..0c1d3a387b 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -124,14 +124,14 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"sme2_gemv_fp32bf16fp32_dot_16VL",
- [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32bf16fp32_dot_16VL, float, float>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"sme2_gemv_fp32_mla_16VL",
- [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32_mla_16VL, float, float>(args); }
},
@@ -139,25 +139,25 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL",
- [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; },
[](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
- return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+ return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); }
},
#endif // ARM_COMPUTE_ENABLE_BF16
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_fp32_mopa_1VLx4VL",
- [](const GemmArgs &args) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; },
[](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
- return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+ return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); }
},
#ifdef ARM_COMPUTE_ENABLE_BF16
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL",
- [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; },
[](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, float, float>(args); }
@@ -166,7 +166,7 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_fp32_mopa_4VLx1VL",
- [](const GemmArgs &args) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; },
[](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_4VLx1VL, float, float>(args); }
@@ -175,7 +175,7 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL",
- [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, float, float>(args); }
},
@@ -183,7 +183,7 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_fp32_mopa_2VLx2VL",
- [](const GemmArgs &args) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_2VLx2VL, float, float>(args); }
},
@@ -199,14 +199,14 @@ GemmImplementation<float, float>::with_estimate(
GemmImplementation<float, float>::with_estimate(
GemmMethod::GEMM_HYBRID,
"sve_hybrid_fp32bf16fp32_mmla_6x4VL",
- [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); },
[](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>(args); }
),
GemmImplementation<float, float>::with_estimate(
GemmMethod::GEMM_HYBRID,
"sve_hybrid_fp32bf16fp32_mmla_4x6VL",
- [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); },
[](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); }
),