aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_int8.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp34
1 files changed, 34 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 24507486ac..38a7c94ef0 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -39,16 +39,50 @@
#include "kernels/a64_smallK_hybrid_s8s32_dot_6x4.hpp"
#include "kernels/a64_smallK_hybrid_s8s32_dot_8x4.hpp"
+#ifdef ARM_COMPUTE_ENABLE_SVE
+#ifdef ARM_COMPUTE_ENABLE_SME2
+#include "kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL.hpp"
+#include "kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL.hpp"
+#include "kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL.hpp"
+#endif // ARM_COMPUTE_ENABLE_SME2
+
#include "kernels/sve_hybrid_s8s32_dot_6x4VL.hpp"
#include "kernels/sve_hybrid_s8s32_mmla_6x4VL.hpp"
#include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp"
#include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp"
#include "kernels/sve_smallK_hybrid_s8s32_dot_8x1VL.hpp"
+#endif // ARM_COMPUTE_ENABLE_SVE
namespace arm_gemm {
static const GemmImplementation<int8_t, int32_t> gemm_s8_methods[] = {
#ifdef ARM_COMPUTE_ENABLE_SVE
+#ifdef ARM_COMPUTE_ENABLE_SME2
+// SME kernels
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL",
+ [](const GemmArgs &args) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args) { const auto VL = sme::get_vector_length<int32_t>();
+ return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+ [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL, int8_t, int32_t>(args); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL",
+ [](const GemmArgs &args) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args) { const auto VL = sme::get_vector_length<int32_t>();
+ return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
+ [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL, int8_t, int32_t>(args); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL",
+ [](const GemmArgs &args) { return args._ci->has_sme2(); },
+ nullptr,
+ [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL, int8_t, int32_t>(args); }
+},
+#endif // ARM_COMPUTE_ENABLE_SME2
GemmImplementation<int8_t, int32_t>::with_estimate(
GemmMethod::GEMM_HYBRID,
"sve_hybrid_s8s32_mmla_6x4VL",