aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp92
1 files changed, 34 insertions, 58 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 9194bdd4d4..1a90e96140 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,75 +38,51 @@
namespace arm_gemm {
-#ifdef __ARM_FEATURE_SVE
-class GemmImpl_gemm_fp16_interleaved_fp16 : public GemmImplementation<__fp16, __fp16> {
-public:
-
- UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<interleaved_fp16_mla_3VLx8, __fp16, __fp16>(args));
- }
-
- GemmImpl_gemm_fp16_interleaved_fp16() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED_FP16) { }
-};
-
-#elif defined(__aarch64__)
-
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS)
-class GemmImpl_gemm_fp16_interleaved_fp16 : public GemmImplementation<__fp16, __fp16> {
-public:
+static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = {
+#if defined(__ARM_FEATURE_SVE)
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "interleaved_fp16_mla_3VLx8",
+ [](const GemmArgs<__fp16> &args) { return (args._Ksize > 4); },
+ [](const GemmArgs<__fp16> &args) { return true; },
+ [](const GemmArgs<__fp16> &args) { return new GemmInterleaved<interleaved_fp16_mla_3VLx8, __fp16, __fp16>(args); }
+},
+#endif
+#if defined(__aarch64__) && (defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS))
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "hgemm_24x8",
+ [](const GemmArgs<__fp16> &args) {
#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- bool is_supported(const GemmArgs<__fp16> &args) override {
return args._ci->has_fp16();
- }
-#endif
-
- UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(args));
- }
-
- GemmImpl_gemm_fp16_interleaved_fp16() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED_FP16) { }
-};
-#endif
-
-#endif // __aarch64__
-
-class GemmImpl_gemm_fp16_interleaved : public GemmImplementation<__fp16, __fp16> {
-public:
- UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
-#ifdef __aarch64__
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(args));
-#elif defined(__arm__)
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(args));
#else
-# error Unknown Architecture
+ return true;
#endif
- }
-
- GemmImpl_gemm_fp16_interleaved() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED) { }
-};
-
-#if defined(__aarch64__) && (defined(__ARM_FEATURE_VECTOR_ARITHMETIC) || defined(FP16_KERNELS) || defined(__ARM_FEATURE_SVE))
-static GemmImpl_gemm_fp16_interleaved_fp16 gemm_fp16_interleaved_fp16_impl{};
-#endif
-static GemmImpl_gemm_fp16_interleaved gemm_fp16_interleaved_impl{};
-
-static std::vector<GemmImplementation<__fp16, __fp16> *> gemm_fp16_methods = {
-#if defined(__aarch64__) && (defined(__ARM_FEATURE_VECTOR_ARITHMETIC) || defined(FP16_KERNELS) || defined(__ARM_FEATURE_SVE))
- &gemm_fp16_interleaved_fp16_impl,
+ },
+ [](const GemmArgs<__fp16> &args) { return true; },
+ [](const GemmArgs<__fp16> &args) { return new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(args); }
+},
#endif
- &gemm_fp16_interleaved_impl
+{
+ GemmMethod::DEFAULT,
+ "",
+ nullptr,
+ nullptr,
+ nullptr,
+}
};
template<>
-std::vector<GemmImplementation<__fp16, __fp16> *> &gemm_implementation_list<__fp16, __fp16>() {
+const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp16>() {
return gemm_fp16_methods;
}
/* Explicitly instantiate the external functions for these types. */
-template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16>(GemmArgs<__fp16> &args, GemmConfig *cfg);
-template GemmMethod get_gemm_method<__fp16, __fp16>(GemmArgs<__fp16> &args);
-template bool method_is_compatible<__fp16, __fp16>(GemmMethod method, GemmArgs<__fp16> &args);
+template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16>(const GemmArgs<__fp16> &args);
+template KernelDescription get_gemm_method<__fp16, __fp16>(const GemmArgs<__fp16> &args);
+template bool method_is_compatible<__fp16, __fp16>(GemmMethod method, const GemmArgs<__fp16> &args);
+template std::vector<std::string> get_compatible_kernels<__fp16, __fp16> (const GemmArgs<__fp16> &args);
} // namespace arm_gemm
-#endif // __ARM_FP16_ARGS
+#endif // __ARM_FP16_ARGS \ No newline at end of file