diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index be7a4ee570..ba9649c0e7 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -36,6 +36,14 @@ #include "kernels/a64_smallK_hybrid_u8u32_dot_6x4.hpp" #include "kernels/a64_smallK_hybrid_u8u32_dot_8x4.hpp" +#ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME2 +#include "kernels/sme2_gemv_u8qa_dot_16VL.hpp" +#include "kernels/sme2_interleaved_nomerge_u8q_mopa_1VLx4VL.hpp" +#include "kernels/sme2_interleaved_nomerge_u8q_mopa_2VLx2VL.hpp" +#include "kernels/sme2_interleaved_nomerge_u8q_mopa_4VLx1VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SME2 + #include "kernels/sve_hybrid_u8qa_dot_4x4VL.hpp" #include "kernels/sve_hybrid_u8qa_mmla_4x4VL.hpp" #include "kernels/sve_hybrid_u8u32_dot_6x4VL.hpp" @@ -43,11 +51,13 @@ #include "kernels/sve_interleaved_u8u32_dot_8x3VL.hpp" #include "kernels/sve_interleaved_u8u32_mmla_8x3VL.hpp" #include "kernels/sve_smallK_hybrid_u8u32_dot_8x1VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SVE #include "gemm_hybrid_indirect.hpp" #include "gemm_hybrid_quantized.hpp" #include "gemm_hybrid_quantized_inline.hpp" #include "gemm_interleaved.hpp" +#include "gemv_pretransposed.hpp" #include "quantize_wrapper.hpp" namespace arm_gemm { @@ -55,6 +65,39 @@ namespace arm_gemm { static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_methods[] = { #ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME2 +// SME kernels +{ + GemmMethod::GEMM_HYBRID, + "sme2_gemv_u8qa_dot_16VL", + [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && quant_hybrid_asymmetric(qp) && args._Msize == 1 && !args._indirect_input && args._nbatches == 1; }, + nullptr, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemvPretransposed<cls_sme2_gemv_u8qa_dot_16VL, uint8_t, uint8_t, Requantize32>(args, qp); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_u8q_mopa_1VLx4VL", + [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, + [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<uint32_t>(); + return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_1VLx4VL, uint8_t, uint8_t>(args, qp); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_u8q_mopa_4VLx1VL", + [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, + [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>(); + return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_4VLx1VL, uint8_t, uint8_t>(args, qp); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_u8q_mopa_2VLx2VL", + [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, + nullptr, + [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_2VLx2VL, uint8_t, uint8_t>(args, qp); } +}, +#endif // ARM_COMPUTE_ENABLE_SME2 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_u8qa_mmla_4x4VL", |