From e5ef8c159a14872dda5e36e320f07b0963858d8c Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Fri, 26 Apr 2024 16:51:54 +0100 Subject: Disable SME2 Gemmlowp s8f32 kernel selection in case results needs to be accumulated Similar to https://review.mlplatform.org/c/ml/ComputeLibrary/+/11500, s8f32 kernels do not support accumulate mode. This patch modifies the kernel selection and also adds more tests to stress these test cases better. Partially Resolves: COMPMID-6995 Change-Id: I40e19446c012eb7334e4511e254cce0d635aa234 Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11503 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Radu Salavat Reviewed-by: Jakub Sujak Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp | 6 +++--- tests/validation/NEON/GEMMLowp.cpp | 12 ++++++++++-- tests/validation/fixtures/GEMMLowpFixture.h | 10 ++-------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp index 782399df8c..38d9b763f6 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp @@ -55,7 +55,7 @@ static const GemmImplementation gemm_s8fp32_meth { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp", - [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length(); return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized(args, dq); } @@ -63,7 +63,7 @@ static const GemmImplementation gemm_s8fp32_meth { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_s8qfp32_mopa_4Vx1VL.hpp", - [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized(args, dq); } @@ -71,7 +71,7 @@ static const GemmImplementation gemm_s8fp32_meth { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_s8qfp32_mopa_2Vx2VL.hpp", - [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized(args, dq); } }, diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp index 9b1da61ed7..d25f43a330 100644 --- a/tests/validation/NEON/GEMMLowp.cpp +++ b/tests/validation/NEON/GEMMLowp.cpp @@ -360,13 +360,21 @@ TEST_SUITE_END() // DynamicQuantization // Deqaunt tests involve returning F32 from the MatrixMultiplyCore kernels and is only implemented in aarch64 TEST_SUITE(Dequant) constexpr AbsoluteTolerance tolerance_dequantized(0.01f); -FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpDataset()) +FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::ALL, + combine( + datasets::SmallGEMMLowpDataset(), + make("accumulate", {true, false}) + )) { // Validate output validate(Accessor(_target), _reference, tolerance_dequantized); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::NIGHTLY, datasets::LargeGEMMLowpDataset()) +FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::NIGHTLY, + combine( + datasets::LargeGEMMLowpDataset(), + make("accumulate", {false}) + )) { // Validate output validate(Accessor(_target), _reference, tolerance_dequantized); diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h index 6b7cbba92e..aa4eedb75d 100644 --- a/tests/validation/fixtures/GEMMLowpFixture.h +++ b/tests/validation/fixtures/GEMMLowpFixture.h @@ -472,15 +472,9 @@ template