aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp7
1 files changed, 4 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index f93f56b57d..6254ec668d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -76,7 +76,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_u8q_mopa_1VLx4VL",
- [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
[](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<uint32_t>();
return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
[](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_1VLx4VL, uint8_t, uint8_t>(args, qp); }
@@ -84,7 +84,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_u8q_mopa_4VLx1VL",
- [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
[](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>();
return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
[](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_4VLx1VL, uint8_t, uint8_t>(args, qp); }
@@ -92,7 +92,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_u8q_mopa_2VLx2VL",
- [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
nullptr,
[](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_2VLx2VL, uint8_t, uint8_t>(args, qp); }
},
@@ -233,6 +233,7 @@ const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_li
template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os);
+template KernelDescription get_gemm_method<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
} // namespace arm_gemm