aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp45
1 files changed, 45 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 42f4528066..2796b0d204 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -34,9 +34,17 @@
#include "gemm_interleaved.hpp"
#include "kernels/a32_sgemm_8x6.hpp"
+#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp"
+#include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp"
+#endif // ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_hgemm_8x24.hpp"
#include "kernels/a64_hybrid_fp16_mla_6x32.hpp"
#include "kernels/a64_sgemm_8x12.hpp"
+#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp"
+#include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp"
+#endif // ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp"
#include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp"
@@ -58,6 +66,24 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
),
+#ifdef ENABLE_FIXED_FORMAT_KERNELS
+GemmImplementation<__fp16, __fp16>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "sve_ffinterleaved_fp16_mla_8x3VL",
+ KernelWeightFormat::VL1VL_BL16,
+ [](const GemmArgs &args) { return args._ci->has_sve(); },
+ [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
+ [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
+),
+GemmImplementation<__fp16, __fp16>::with_estimate(
+ GemmMethod::GEMM_HYBRID,
+ "sve_ffhybrid_fp16_mla_6x4VL",
+ KernelWeightFormat::VL1VL_BL16,
+ [](const GemmArgs &args) { return args._ci->has_sve(); },
+ [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
+ [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
+),
+#endif // ENABLE_FIXED_FORMAT_KERNELS
#endif // ARM_COMPUTE_ENABLE_SVE
#if defined(__aarch64__)
GemmImplementation<__fp16, __fp16>::with_estimate(
@@ -74,6 +100,24 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); }
),
+#ifdef ENABLE_FIXED_FORMAT_KERNELS
+GemmImplementation<__fp16, __fp16>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_ffinterleaved_fp16_mla_8x24",
+ KernelWeightFormat::VL128_BL16,
+ [](const GemmArgs &args) { return args._ci->has_fp16(); },
+ [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
+ [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>(args); }
+),
+GemmImplementation<__fp16, __fp16>::with_estimate(
+ GemmMethod::GEMM_HYBRID,
+ "a64_ffhybrid_fp16_mla_6x32",
+ KernelWeightFormat::VL128_BL16,
+ [](const GemmArgs &args) { return args._ci->has_fp16(); },
+ [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
+ [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
+),
+#endif // ENABLE_FIXED_FORMAT_KERNELS
{
GemmMethod::GEMM_INTERLEAVED,
"a64_sgemm_8x12",
@@ -109,6 +153,7 @@ const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp1
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
+template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
} // namespace arm_gemm