aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp42
1 files changed, 42 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index 1d7b9c5b73..ac49536643 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -38,6 +38,14 @@
#include "kernels/a64_smallK_hybrid_s8s32_dot_6x4.hpp"
#include "kernels/a64_smallK_hybrid_s8s32_dot_8x4.hpp"
+#ifdef ARM_COMPUTE_ENABLE_SVE
+#ifdef ARM_COMPUTE_ENABLE_SME2
+#include "kernels/sme2_gemv_s8qa_dot_16VL.hpp"
+#include "kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL.hpp"
+#include "kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL.hpp"
+#include "kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL.hpp"
+#endif // ARM_COMPUTE_ENABLE_SME2
+
#include "kernels/sve_hybrid_s8qa_dot_4x4VL.hpp"
#include "kernels/sve_hybrid_s8qa_mmla_4x4VL.hpp"
#include "kernels/sve_hybrid_s8qs_dot_6x4VL.hpp"
@@ -47,11 +55,13 @@
#include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp"
#include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp"
#include "kernels/sve_smallK_hybrid_s8s32_dot_8x1VL.hpp"
+#endif // ARM_COMPUTE_ENABLE_SVE
#include "gemm_hybrid_indirect.hpp"
#include "gemm_hybrid_quantized.hpp"
#include "gemm_hybrid_quantized_inline.hpp"
#include "gemm_interleaved.hpp"
+#include "gemv_pretransposed.hpp"
#include "quantize_wrapper.hpp"
#include "utils.hpp"
@@ -60,6 +70,38 @@ namespace arm_gemm {
static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods[] =
{
#ifdef ARM_COMPUTE_ENABLE_SVE
+#ifdef ARM_COMPUTE_ENABLE_SME2
+{
+ GemmMethod::GEMM_HYBRID,
+ "sme2_gemv_s8qa_dot_16VL",
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && quant_hybrid_asymmetric(qp) && args._Msize == 1 && !args._indirect_input && args._nbatches == 1; },
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemvPretransposed<cls_sme2_gemv_s8qa_dot_16VL, int8_t, int8_t, Requantize32>(args, qp); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8q_mopa_1VLx4VL",
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>();
+ return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_1VLx4VL, int8_t, int8_t>(args, qp); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8q_mopa_4VLx1VL",
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>();
+ return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_4VLx1VL, int8_t, int8_t>(args, qp); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8q_mopa_2VLx2VL",
+ [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
+ nullptr,
+ [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_2VLx2VL, int8_t, int8_t>(args, qp); }
+},
+#endif // ARM_COMPUTE_ENABLE_SME2
GemmImplementation<int8_t, int8_t, Requantize32>::with_estimate(
GemmMethod::GEMM_HYBRID,
"sve_hybrid_s8qa_mmla_4x4VL",