From 4ee8b1599dbaf7634d25607fa5ac96ba3dc6b0f2 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 16 Jul 2021 16:16:43 +0100 Subject: Update GEMM assembly kernels - Introduce Fp32 kernels with internal calculations in Bfloat16 when fast_mode is enabled - Improve kernel selection heuristics Signed-off-by: Georgios Pinitas Change-Id: I68a9e7e862b6fd2721b46e0d7cc791091c4ab279 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5965 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 118 ++++++++++++++++++--------- 1 file changed, 79 insertions(+), 39 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 1632e301ac..3cf84a614a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -31,17 +31,22 @@ #include "gemv_pretransposed.hpp" #include "kernels/a32_sgemm_8x6.hpp" -#include "kernels/a64_gemv_fp32_mla_32.hpp" +#include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp" +#include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp" +#include "kernels/a64_hybrid_fp32_mla_4x24.hpp" #include "kernels/a64_hybrid_fp32_mla_6x16.hpp" #include "kernels/a64_hybrid_fp32_mla_8x4.hpp" +#include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp" #include "kernels/a64_sgemm_8x12.hpp" #include "kernels/a64_sgemm_8x6.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp" -#include "kernels/sve_gemv_fp32_mla_8VL.hpp" +#include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp" +#include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp" #include "kernels/sve_hybrid_fp32_mla_8x1VL.hpp" +#include "kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp" #include "kernels/sve_interleaved_fp32_mla_8x3VL.hpp" #include "kernels/sve_interleaved_fp32_mmla_8x3VL.hpp" #include "kernels/sve_smallK_hybrid_fp32_mla_8x1VL.hpp" @@ -59,57 +64,94 @@ static const GemmImplementation gemm_fp32_methods[] = [](const GemmArgs &args) { return new GemvBatched(args); } }, #ifdef __aarch64__ +#ifdef ARM_COMPUTE_ENABLE_BF16 +// "fast mode" (BF16) kernels +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_interleaved_bf16fp32_mmla_8x12", + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmInterleaved(args); } +), +GemmImplementation::with_estimate( + GemmMethod::GEMM_HYBRID, + "a64_hybrid_fp32bf16fp32_mmla_6x16", + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirect::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirect(args); } +), +GemmImplementation::with_estimate( + GemmMethod::GEMM_HYBRID, + "a64_hybrid_fp32bf16fp32_mmla_4x24", + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirect::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirect(args); } +), +#endif // ARM_COMPUTE_ENABLE_BF16 #ifdef ARM_COMPUTE_ENABLE_SVE -{ +#ifdef ARM_COMPUTE_ENABLE_BF16 +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_interleaved_bf16fp32_mmla_8x3VL", + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, + [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmInterleaved(args); } +), +GemmImplementation::with_estimate( GemmMethod::GEMM_HYBRID, - "sve_gemv_fp32_mla_8VL", - [](const GemmArgs &args) { return args._ci->has_sve() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, - [](const GemmArgs &args) { return args._ci->get_cpu_model() != CPUModel::KLEIN; }, - [](const GemmArgs &args) { return new GemvPretransposed(args); } -}, -#endif -{ + "sve_hybrid_fp32bf16fp32_mmla_6x4VL", + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirect::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirect(args); } +), +GemmImplementation::with_estimate( GemmMethod::GEMM_HYBRID, - "a64_gemv_fp32_mla_32", - [](const GemmArgs &args) { return args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, - nullptr, - [](const GemmArgs &args) { return new GemvPretransposed(args); } -}, - -// MMLA next due to higher throughput (SVE only) -#if defined(ARM_COMPUTE_ENABLE_SVE) && defined(ARM_COMPUTE_ENABLE_SVEF32MM) + "sve_hybrid_fp32bf16fp32_mmla_4x6VL", + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirect::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirect(args); } +), +#endif // ARM_COMPUTE_ENABLE_BF16 +#ifdef ARM_COMPUTE_ENABLE_SVEF32MM +// MMLA next due to higher throughput (which is SVE only) +// Prefer this in all cases, except if fast mode is requested and BF16 is available. { GemmMethod::GEMM_INTERLEAVED, "sve_interleaved_fp32_mmla_8x3VL", [](const GemmArgs &args) { return args._ci->has_svef32mm() && (args._Ksize>4); }, - [](const GemmArgs &args) { return args._ci->get_cpu_model() != CPUModel::KLEIN; }, + [](const GemmArgs &args) { return !(args._fast_mode && args._ci->has_bf16()); }, [](const GemmArgs &args) { return new GemmInterleaved(args); } }, -#endif // ARM_COMPUTE_ENABLE_SVE && ARM_COMPUTE_ENABLE_SVEF32MM - -#ifdef ARM_COMPUTE_ENABLE_SVE -// SVE smallk / hybrid methods +#endif // ARM_COMPUTE_ENABLE_SVEF32MM +// SVE kernels { GemmMethod::GEMM_HYBRID, "sve_smallK_hybrid_fp32_mla_8x1VL", [](const GemmArgs &args) { return args._ci->has_sve() && args._Ksize <= 24 && !args._indirect_input; }, - [](const GemmArgs &args) { return args._ci->get_cpu_model() != CPUModel::KLEIN; }, + nullptr, [](const GemmArgs &args) { return new GemmHybrid(args); } }, { GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32_mla_8x1VL", [](const GemmArgs &args) { return args._ci->has_sve(); }, - [](const GemmArgs &args) { return args._ci->get_cpu_model() != CPUModel::KLEIN && (args._Nsize < 12); }, + [](const GemmArgs &args) { return (args._Nsize < 12); }, [](const GemmArgs &args) { return new GemmHybridIndirect(args); } }, -{ +GemmImplementation::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32_mla_6x4VL", [](const GemmArgs &args) { return args._ci->has_sve(); }, - [](const GemmArgs &args) { return args._ci->get_cpu_model() != CPUModel::KLEIN && (((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8))); }, + [](const GemmArgs &args) { return GemmHybridIndirect::estimate_cycles(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect(args); } -}, +), +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_interleaved_fp32_mla_8x3VL", + [](const GemmArgs &args) { return args._ci->has_sve(); }, + [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmInterleaved(args); } +), #endif // ARM_COMPUTE_ENABLE_SVE // Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases. { @@ -141,27 +183,25 @@ static const GemmImplementation gemm_fp32_methods[] = [](const GemmArgs &args) { return (args._Nsize < 12); }, [](const GemmArgs &args) { return new GemmHybridIndirect(args); } }, +GemmImplementation::with_estimate( + GemmMethod::GEMM_HYBRID, + "a64_hybrid_fp32_mla_4x24", + nullptr, + [](const GemmArgs &args) { return GemmHybridIndirect::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirect(args); } +), GemmImplementation::with_estimate( GemmMethod::GEMM_HYBRID, "a64_hybrid_fp32_mla_6x16", nullptr, - [](const GemmArgs &args) { return GemmHybridIndirect::estimate_cycles(args, cls_a64_hybrid_fp32_mla_6x16::get_performance_parameters(args._ci)); }, + [](const GemmArgs &args) { return GemmHybridIndirect::estimate_cycles(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect(args); } ), -#ifdef ARM_COMPUTE_ENABLE_SVE -{ - GemmMethod::GEMM_INTERLEAVED, - "sve_interleaved_fp32_mla_8x3VL", - [](const GemmArgs &args) { return args._ci->has_sve() && (args._Ksize>4); }, - [](const GemmArgs &args) { return args._ci->get_cpu_model() != CPUModel::KLEIN; }, - [](const GemmArgs &args) { return new GemmInterleaved(args); } -}, -#endif // ARM_COMPUTE_ENABLE_SVE GemmImplementation::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_sgemm_8x12", nullptr, - [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles(args, cls_a64_sgemm_8x12::get_performance_parameters(args._ci)); }, + [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles(args); }, [](const GemmArgs &args) { return new GemmInterleaved(args); } ), #endif // __aarch64__ -- cgit v1.2.1