diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 62 |
1 files changed, 30 insertions, 32 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index c7adf8e4ac..0d9f53b84d 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020, 2022-2023 Arm Limited. + * Copyright (c) 2017-2020, 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ // This can only be built if the target/compiler supports FP16 arguments. -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(ENABLE_FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) #include "arm_gemm.hpp" @@ -42,12 +42,10 @@ #include "kernels/a64_hgemm_8x24.hpp" #include "kernels/a64_hybrid_fp16_mla_6x32.hpp" #include "kernels/a64_sgemm_8x12.hpp" -#ifdef ARM_COMPUTE_ENABLE_SME2 #include "kernels/sme2_gemv_fp16fp32fp16_dot_16VL.hpp" #include "kernels/sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL.hpp" #include "kernels/sme2_interleaved_nomerge_fp16fp32fp16_mopa_2VLx2VL.hpp" #include "kernels/sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL.hpp" -#endif // ARM_COMPUTE_ENABLE_SME2 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp" #include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp" @@ -57,7 +55,7 @@ namespace arm_gemm { -static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = { +static const GemmImplementation<__fp16, __fp16, __fp16> gemm_fp16_methods[] = { #ifdef ARM_COMPUTE_ENABLE_SVE #ifdef ARM_COMPUTE_ENABLE_SME2 { @@ -73,7 +71,7 @@ static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = { [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, { GemmMethod::GEMM_INTERLEAVED, @@ -81,32 +79,32 @@ static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = { [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp16fp32fp16_mopa_2VLx2VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, nullptr, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_2VLx2VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_2VLx2VL, __fp16, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, #endif // ARM_COMPUTE_ENABLE_SME2 -GemmImplementation<__fp16, __fp16>::with_estimate( +GemmImplementation<__fp16, __fp16, __fp16>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp16_mla_6x4VL", [](const GemmArgs &args) { return args._ci->has_sve(); }, - [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, - [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); } + [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16, __fp16>(args); } ), -GemmImplementation<__fp16, __fp16>::with_estimate( +GemmImplementation<__fp16, __fp16, __fp16>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "sve_interleaved_fp16_mla_8x3VL", [](const GemmArgs &args) { return args._ci->has_sve() && (args._Ksize > 4); }, - [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); } + [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, + [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16, __fp16>(args); } ), #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS -GemmImplementation<__fp16, __fp16>::with_estimate( +GemmImplementation<__fp16, __fp16, __fp16>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "sve_ffinterleaved_fp16_mla_8x3VL", KernelWeightFormat::VL1VL_BL16, @@ -114,7 +112,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); } ), -GemmImplementation<__fp16, __fp16>::with_estimate( +GemmImplementation<__fp16, __fp16, __fp16>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_ffhybrid_fp16_mla_6x4VL", KernelWeightFormat::VL1VL_BL16, @@ -125,22 +123,22 @@ GemmImplementation<__fp16, __fp16>::with_estimate( #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE #if defined(__aarch64__) -GemmImplementation<__fp16, __fp16>::with_estimate( +GemmImplementation<__fp16, __fp16, __fp16>::with_estimate( GemmMethod::GEMM_HYBRID, "a64_hybrid_fp16_mla_6x32", [](const GemmArgs &args) { return args._ci->has_fp16(); }, - [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, - [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16>(args); } + [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16, __fp16>(args); } ), -GemmImplementation<__fp16, __fp16>::with_estimate( +GemmImplementation<__fp16, __fp16, __fp16>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_hgemm_8x24", [](const GemmArgs &args) { return args._ci->has_fp16(); }, - [](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); } + [](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, + [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16, __fp16>(args); } ), #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS -GemmImplementation<__fp16, __fp16>::with_estimate( +GemmImplementation<__fp16, __fp16, __fp16>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_ffinterleaved_fp16_mla_8x24", KernelWeightFormat::VL128_BL16, @@ -148,7 +146,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>(args); } ), -GemmImplementation<__fp16, __fp16>::with_estimate( +GemmImplementation<__fp16, __fp16, __fp16>::with_estimate( GemmMethod::GEMM_HYBRID, "a64_ffhybrid_fp16_mla_6x32", KernelWeightFormat::VL128_BL16, @@ -162,7 +160,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( "a64_sgemm_8x12", nullptr, [](const GemmArgs &args) { return !args._ci->has_fp16(); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, __fp16, __fp16>(args); } + [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, __fp16, __fp16, __fp16>(args); } }, #elif defined(__arm__) { @@ -170,7 +168,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( "sgemm_8x6", nullptr, nullptr, - [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(args); } + [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, __fp16, __fp16, __fp16>(args); } }, #else // not AArch64 or AArch32 # error Unknown Architecture @@ -185,16 +183,16 @@ GemmImplementation<__fp16, __fp16>::with_estimate( }; template<> -const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp16>() { +const GemmImplementation<__fp16, __fp16, __fp16> *gemm_implementation_list<__fp16, __fp16, __fp16>() { return gemm_fp16_methods; } /* Explicitly instantiate the external functions for these types. */ -template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); -template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); -template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); +template UniqueGemmCommon<__fp16, __fp16, __fp16> gemm<__fp16, __fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<__fp16, __fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); +template KernelDescription get_gemm_method<__fp16, __fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); +template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); } // namespace arm_gemm -#endif // defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#endif // defined(__aarch64__) && (defined(ENABLE_FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) |