diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_int8.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_int8.cpp | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index 24507486ac..38a7c94ef0 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -39,16 +39,50 @@ #include "kernels/a64_smallK_hybrid_s8s32_dot_6x4.hpp" #include "kernels/a64_smallK_hybrid_s8s32_dot_8x4.hpp" +#ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME2 +#include "kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL.hpp" +#include "kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL.hpp" +#include "kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SME2 + #include "kernels/sve_hybrid_s8s32_dot_6x4VL.hpp" #include "kernels/sve_hybrid_s8s32_mmla_6x4VL.hpp" #include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp" #include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp" #include "kernels/sve_smallK_hybrid_s8s32_dot_8x1VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SVE namespace arm_gemm { static const GemmImplementation<int8_t, int32_t> gemm_s8_methods[] = { #ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME2 +// SME kernels +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL", + [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { const auto VL = sme::get_vector_length<int32_t>(); + return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL, int8_t, int32_t>(args); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL", + [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { const auto VL = sme::get_vector_length<int32_t>(); + return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL, int8_t, int32_t>(args); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL", + [](const GemmArgs &args) { return args._ci->has_sme2(); }, + nullptr, + [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL, int8_t, int32_t>(args); } +}, +#endif // ARM_COMPUTE_ENABLE_SME2 GemmImplementation<int8_t, int32_t>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_s8s32_mmla_6x4VL", |