diff options
Diffstat (limited to 'src/core')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 290fe87230..af0d38ec37 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -124,14 +124,14 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_HYBRID, "sme2_gemv_fp32bf16fp32_dot_16VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32bf16fp32_dot_16VL, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "sme2_gemv_fp32_mla_16VL", - [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32_mla_16VL, float, float>(args); } }, @@ -139,7 +139,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); } @@ -148,7 +148,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); } @@ -157,7 +157,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, float, float>(args); } @@ -166,7 +166,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_4VLx1VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_4VLx1VL, float, float>(args); } @@ -175,7 +175,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, float, float>(args); } }, @@ -183,7 +183,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_2VLx2VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_2VLx2VL, float, float>(args); } }, |