From 03b2971ac69a86f10a1566938d1a25afee15746c Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Wed, 1 Jun 2022 11:47:14 +0100 Subject: Integrate SME2 kernels * Add SME/SME2 detection. * Integrate SME2 implementation for: - Normal convolution - Winograd - Depthwise convolution - Pooling Resolves: COMPMID-5700 Signed-off-by: Viet-Hoa Do Change-Id: I2f1ca1d05f8cfeee9309ed1c0a36096a4a6aad5c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8692 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp | 43 ++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index 58e4861bc0..515d55c73b 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -48,10 +48,20 @@ #include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp" #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS + +#ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME2 +#include "kernels/sme2_gemv_bf16fp32_dot_16VL.hpp" +#include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL.hpp" +#include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL.hpp" +#include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SME2 + #include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp" #include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp" #include "kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SVE namespace arm_gemm { @@ -60,6 +70,39 @@ static const GemmImplementation gemm_bf16_methods[] = #ifdef __aarch64__ #ifdef ARM_COMPUTE_ENABLE_BF16 #ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME2 +// SME kernels +{ + GemmMethod::GEMM_HYBRID, + "sme2_gemv_bf16fp32_dot_16VL", + [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + nullptr, + [](const GemmArgs &args) { return new GemvPretransposed(args); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", + [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); + return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL", + [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); + return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL", + [](const GemmArgs &args) { return args._ci->has_sme2(); }, + nullptr, + [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } +}, +#endif // ARM_COMPUTE_ENABLE_SME2 // gemm_bf16_interleaved GemmImplementation::with_estimate( GemmMethod::GEMM_INTERLEAVED, -- cgit v1.2.1