aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2023-03-07 14:16:37 +0000
committerDavid Mansell <David.Mansell@arm.com>2023-03-07 15:21:59 +0000
commitf9225b4a8e3fd96edd07765420d37e0efac1c27f (patch)
tree7ef56e092936a209afb25d08f74a2291f0a97c6b /src/core/NEON/kernels
parent3c7c1fa6c589f75fd8146a47fe71094c172fd56a (diff)
downloadComputeLibrary-f9225b4a8e3fd96edd07765420d37e0efac1c27f.tar.gz
GEMM: SME: Allow threading for quantized GEMMs.
The SME kernels for quantized int8/uint8 GEMMs erroneously required that maxthreads==1 before they could be selected. This resulted in them not being available on multi-thread runs. Remove that restriction. Resolves COMPMID-5962 Change-Id: Ia7933d0c66020b5e2981604ca97ff7ead95ec14e Signed-off-by: David Mansell <David.Mansell@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9274 Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp6
2 files changed, 6 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index d168abcf6d..9e8907d60f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -80,7 +80,7 @@ static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_s8q_mopa_1VLx4VL",
- [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
[](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>();
return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
[](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_1VLx4VL, int8_t, int8_t>(args, qp); }
@@ -88,7 +88,7 @@ static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_s8q_mopa_4VLx1VL",
- [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
[](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>();
return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
[](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_4VLx1VL, int8_t, int8_t>(args, qp); }
@@ -96,7 +96,7 @@ static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_s8q_mopa_2VLx2VL",
- [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
nullptr,
[](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_2VLx2VL, int8_t, int8_t>(args, qp); }
},
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index 01f5124a2c..f93f56b57d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -76,7 +76,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_u8q_mopa_1VLx4VL",
- [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
[](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<uint32_t>();
return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
[](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_1VLx4VL, uint8_t, uint8_t>(args, qp); }
@@ -84,7 +84,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_u8q_mopa_4VLx1VL",
- [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
[](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>();
return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
[](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_4VLx1VL, uint8_t, uint8_t>(args, qp); }
@@ -92,7 +92,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_u8q_mopa_2VLx2VL",
- [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
nullptr,
[](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_2VLx2VL, uint8_t, uint8_t>(args, qp); }
},