From f9225b4a8e3fd96edd07765420d37e0efac1c27f Mon Sep 17 00:00:00 2001 From: David Mansell Date: Tue, 7 Mar 2023 14:16:37 +0000 Subject: GEMM: SME: Allow threading for quantized GEMMs. The SME kernels for quantized int8/uint8 GEMMs erroneously required that maxthreads==1 before they could be selected. This resulted in them not being available on multi-thread runs. Remove that restriction. Resolves COMPMID-5962 Change-Id: Ia7933d0c66020b5e2981604ca97ff7ead95ec14e Signed-off-by: David Mansell Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9274 Reviewed-by: Viet-Hoa Do Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins Tested-by: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp | 6 +++--- src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'src/core/NEON/kernels') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index d168abcf6d..9e8907d60f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -80,7 +80,7 @@ static const GemmImplementation gemm_qint8_methods { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_s8q_mopa_1VLx4VL", - [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, + [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length(); return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline(args, qp); } @@ -88,7 +88,7 @@ static const GemmImplementation gemm_qint8_methods { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_s8q_mopa_4VLx1VL", - [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, + [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline(args, qp); } @@ -96,7 +96,7 @@ static const GemmImplementation gemm_qint8_methods { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_s8q_mopa_2VLx2VL", - [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, + [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, nullptr, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline(args, qp); } }, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index 01f5124a2c..f93f56b57d 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -76,7 +76,7 @@ static const GemmImplementation gemm_quint8_meth { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_u8q_mopa_1VLx4VL", - [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, + [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length(); return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline(args, qp); } @@ -84,7 +84,7 @@ static const GemmImplementation gemm_quint8_meth { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_u8q_mopa_4VLx1VL", - [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, + [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline(args, qp); } @@ -92,7 +92,7 @@ static const GemmImplementation gemm_quint8_meth { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_u8q_mopa_2VLx2VL", - [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, + [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, nullptr, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline(args, qp); } }, -- cgit v1.2.1