diff options
author | Francesco.Petrogalli@arm.com <francesco.petrogalli@arm.com> | 2022-04-05 10:31:08 +0000 |
---|---|---|
committer | Francesco Petrogalli <francesco.petrogalli@arm.com> | 2022-05-24 14:28:27 +0000 |
commit | 5fcf22dadf092efd7aafb359f9229aa270eb1129 (patch) | |
tree | f309426ed19bd6710329da3b530167db72d1c6b2 /src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | |
parent | a8caa023f0d7b71b3a250a14ceee935052fcc74a (diff) | |
download | ComputeLibrary-5fcf22dadf092efd7aafb359f9229aa270eb1129.tar.gz |
[arm_gemm] Import fixed-format kernels from gemm_linux.
This is a No Functional Change Intended (NFCI) patch. It imports the
kernel in the code, but the interface to select them and expose the
format of the weight tensors to the user will be provided in a
subsequent patch.
Kernels and kernel selection code in arm_gemm has been provided
by David.Mansell <David.Mansell@arm.com>.
The kernels are not compiled in the library by default, but need to be
selected via the `scons` option `experimental_fixed_format_kernels=1`.
Resolves: ONCPUML-829
Signed-off-by: Francesco.Petrogalli@arm.com <francesco.petrogalli@arm.com>
Change-Id: If00ccb2b9b7221e01b214cf9783111226ccc8bf4
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7380
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 69a2803903..4f7e191fb3 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -31,6 +31,12 @@ #include "gemv_pretransposed.hpp" #include "kernels/a32_sgemm_8x6.hpp" +#ifdef ENABLE_FIXED_FORMAT_KERNELS +#include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp" +#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp" +#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp" +#include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp" +#endif // ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp" #include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp" #include "kernels/a64_hybrid_fp32_mla_4x24.hpp" @@ -42,6 +48,12 @@ #include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp" +#ifdef ENABLE_FIXED_FORMAT_KERNELS +#include "kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp" +#include "kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp" +#include "kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp" +#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp" +#endif // ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp" #include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp" @@ -73,6 +85,7 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, float, float>(args); } ), + GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "a64_hybrid_fp32bf16fp32_mmla_6x16", @@ -152,6 +165,42 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>(args); } ), + #ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_BF16 +GemmImplementation<float, float>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_ffinterleaved_bf16fp32_mmla_8x3VL", + KernelWeightFormat::VL2VL_BL64_BF16, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, float, float>::estimate_cycles<float>(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, float, float>(args); } +), +GemmImplementation<float, float>::with_estimate( + GemmMethod::GEMM_HYBRID, + "sve_ffhybrid_fp32bf16fp32_mmla_4x6VL", + KernelWeightFormat::VL2VL_BL64_BF16, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); } +), +#endif +GemmImplementation<float, float>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_ffinterleaved_fp32_mla_8x3VL", + KernelWeightFormat::VL1VL_BL32, + [](const GemmArgs &args) { return args._ci->has_sve(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp32_mla_8x3VL, float, float>(args); } +), +GemmImplementation<float, float>::with_estimate( + GemmMethod::GEMM_HYBRID, + "sve_ffhybrid_fp32_mla_6x4VL", + KernelWeightFormat::VL1VL_BL32, + [](const GemmArgs &args) { return args._ci->has_sve(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>(args); } +), +#endif // ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE // Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases. { @@ -204,6 +253,43 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, float, float>(args); } ), +#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_BF16 +// "fast mode" (BF16) kernels +GemmImplementation<float, float>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_ffinterleaved_bf16fp32_mmla_8x12", + KernelWeightFormat::VL256_BL64_BF16, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, float, float>::estimate_cycles<float>(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, float, float>(args); } +), +GemmImplementation<float, float>::with_estimate( + GemmMethod::GEMM_HYBRID, + "a64_ffhybrid_fp32bf16fp32_mmla_4x24", + KernelWeightFormat::VL256_BL64_BF16, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>(args); } +), +#endif // BF16 +GemmImplementation<float, float>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_ffinterleaved_fp32_mla_8x12", + KernelWeightFormat::VL128_BL32, + nullptr, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp32_mla_8x12, float, float>::estimate_cycles<float>(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp32_mla_8x12, float, float>(args); } +), +GemmImplementation<float, float>::with_estimate( + GemmMethod::GEMM_HYBRID, + "a64_ffhybrid_fp32_mla_6x16", + KernelWeightFormat::VL128_BL32, + nullptr, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>(args); } +), +#endif // ENABLE_FIXED_FORMAT_KERNELS #endif // __aarch64__ #ifdef __arm__ @@ -233,6 +319,7 @@ const GemmImplementation<float, float> *gemm_implementation_list<float, float>() /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &); template bool has_opt_gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &); +template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm |