diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index af0d38ec37..0c1d3a387b 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -141,7 +141,7 @@ GemmImplementation<float, float>::with_estimate( "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); } }, #endif // ARM_COMPUTE_ENABLE_BF16 @@ -150,7 +150,7 @@ GemmImplementation<float, float>::with_estimate( "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); } }, #ifdef ARM_COMPUTE_ENABLE_BF16 @@ -199,14 +199,14 @@ GemmImplementation<float, float>::with_estimate( GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32bf16fp32_mmla_6x4VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>(args); } ), GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32bf16fp32_mmla_4x6VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); } ), |