From 499b5bca1a897461d4105ba52e4c766ddb5f564a Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Fri, 26 Apr 2024 13:15:05 +0100 Subject: Disable SME2 Gemm kernel selection in case results needs to be accumulated SME2 kernels use a different accumulation buffer and destination tensor is not copied to this buffer as initial value, thus causing mismatches. This patch modifies the kernel selection algorithm such that it does not select SME2 kernels if accumulation is required. Resolves: COMPMID-6995 Change-Id: I82da3cba41729f938a046f26b41b63ff5716c02d Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11500 Reviewed-by: Jakub Sujak Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins Tested-by: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'src/core/NEON/kernels') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 290fe87230..af0d38ec37 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -124,14 +124,14 @@ GemmImplementation::with_estimate( { GemmMethod::GEMM_HYBRID, "sme2_gemv_fp32bf16fp32_dot_16VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed(args); } }, { GemmMethod::GEMM_HYBRID, "sme2_gemv_fp32_mla_16VL", - [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed(args); } }, @@ -139,7 +139,7 @@ GemmImplementation::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } @@ -148,7 +148,7 @@ GemmImplementation::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } @@ -157,7 +157,7 @@ GemmImplementation::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } @@ -166,7 +166,7 @@ GemmImplementation::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_4VLx1VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } @@ -175,7 +175,7 @@ GemmImplementation::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } }, @@ -183,7 +183,7 @@ GemmImplementation::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_2VLx2VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } }, -- cgit v1.2.1