aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp211
1 files changed, 120 insertions, 91 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 7d14971b70..8bc33ccb69 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_hybrid.hpp"
#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "gemm_native.hpp"
@@ -30,112 +31,140 @@
#include "gemv_native_transposed.hpp"
#include "gemv_pretransposed.hpp"
-#include "kernels/a64_sgemm_12x8.hpp"
#include "kernels/a32_sgemm_8x6.hpp"
-#include "kernels/a64_sgemv_trans.hpp"
-#include "kernels/a64_sgemv_pretransposed.hpp"
+#include "kernels/a64_sgemm_12x8.hpp"
#include "kernels/a64_sgemm_native_16x4.hpp"
+#include "kernels/a64_sgemm_nativeA_pretransposeB_16x4.hpp"
+#include "kernels/a64_sgemv_pretransposed.hpp"
+#include "kernels/a64_sgemv_trans.hpp"
+#include "kernels/sve_hybrid_fp32_mla_4VLx4.hpp"
#include "kernels/sve_interleaved_fp32_mla_3VLx8.hpp"
+#include "kernels/sve_native_fp32_mla_4VLx4.hpp"
+#include "kernels/sve_smallK_fp32_mla_1VLx4.hpp"
+#include "kernels/sve_smallK_hybrid_fp32_mla_1VLx4.hpp"
namespace arm_gemm {
-#if defined(__aarch64__) && !defined(__ARM_FEATURE_SVE)
-// SGEMM implementations for AArch64 without SVE
-
-// Pretransposed GEMV
-class GemmImpl_sgemm_gemv_pretransposed : public GemmImplementation<float, float> {
-public:
- bool is_supported(const GemmArgs<float> &args) override {
- return (args._Msize==1 && args._alpha==1.0f && args._pretransposed_hint && args._nbatches==1);
- }
+static const GemmImplementation<float, float> gemm_fp32_methods[] =
+{
+{
+ GemmMethod::GEMV_BATCHED,
+ "gemv_batched",
+ [](const GemmArgs<float> &args) { return (args._Msize==1) && (args._nbatches>1); },
+ nullptr,
+ [](const GemmArgs<float> &args) { return new GemvBatched<float, float>(args); }
+},
+#ifdef __aarch64__
+{
+ GemmMethod::GEMV_PRETRANSPOSED,
+ "sgemv_pretransposed",
+ [](const GemmArgs<float> &args) { return (args._Msize==1 && args._alpha==1.0f && args._pretransposed_hint && args._nbatches==1); },
+ nullptr,
+ [](const GemmArgs<float> &args) { return new GemvPretransposed<sgemv_pretransposed, float, float>(args); }
+},
+{
+ GemmMethod::GEMV_NATIVE_TRANSPOSED,
+ "sgemv_trans",
+ [](const GemmArgs<float> &args) { return (args._Msize==1 && args._alpha==1.0f && !args._trA && !args._trB && args._nbatches==1); },
+ nullptr,
+ [](const GemmArgs<float> &args) { return new GemvNativeTransposed<sgemv_trans, float, float>(args); }
+},
- UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
- return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._trB, args._beta));
- }
-
- GemmImpl_sgemm_gemv_pretransposed() : GemmImplementation<float, float>(GemmMethod::GEMV_PRETRANSPOSED) { }
-};
-
-// Native GEMV
-class GemmImpl_sgemm_gemv_native_transposed : public GemmImplementation<float, float> {
-public:
- bool is_supported(const GemmArgs<float> &args) override {
- return (args._Msize==1 && args._alpha==1.0f && !args._trA && !args._trB && args._nbatches==1);
- }
-
- UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
- return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._beta));
- }
-
- GemmImpl_sgemm_gemv_native_transposed() : GemmImplementation<float, float>(GemmMethod::GEMV_NATIVE_TRANSPOSED) { }
-};
-
-// Native GEMM
-class GemmImpl_sgemm_gemm_native : public GemmImplementation<float, float> {
-public:
- bool is_supported(const GemmArgs<float> &args) override {
- return (args._Ksize>4 && (args._Nsize % 16)==0 && args._alpha==1.0f && !args._trA && !args._trB);
- }
-
- bool is_recommended(const GemmArgs<float> &args) override {
- return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8));
- }
-
- UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
- return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._beta));
- }
-
- GemmImpl_sgemm_gemm_native() : GemmImplementation<float, float>(GemmMethod::GEMM_NATIVE) { }
-};
-#endif // __aarch64__
-
-// Interleaved GEMM
-class GemmImpl_sgemm_gemm_interleaved : public GemmImplementation<float, float> {
-public:
- UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
#ifdef __ARM_FEATURE_SVE
- return UniqueGemmCommon<float, float> (new GemmInterleaved<interleaved_fp32_mla_3VLx8, float, float>(args));
-#elif defined(__aarch64__)
- return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(args));
-#elif defined(__arm__)
- return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(args));
-#else
-# error Unknown Architecture.
-#endif
- }
-
- GemmImpl_sgemm_gemm_interleaved() : GemmImplementation<float, float>(GemmMethod::GEMM_INTERLEAVED) { }
-};
+ // SVE smallk / native / hybrid methods
+{
+ GemmMethod::GEMM_HYBRID,
+ "smallK_hybrid_fp32_mla_1VLx4",
+ [](const GemmArgs<float> &args) { return (args._Ksize <= 24) && !args._trA && args._alpha==1.0f && args._pretransposed_hint; },
+ nullptr,
+ [](const GemmArgs<float> &args) { return new GemmHybrid<smallK_hybrid_fp32_mla_1VLx4, float, float>(args); }
+},
+{
+ GemmMethod::GEMM_HYBRID,
+ "hybrid_fp32_mla_4VLx4",
+ [](const GemmArgs<float> &args) { return (args._Ksize >= 4) && (args._alpha == 1.0f) && !args._trA && args._pretransposed_hint; },
+ [](const GemmArgs<float> &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
+ [](const GemmArgs<float> &args) { return new GemmHybrid<hybrid_fp32_mla_4VLx4, float, float>(args); }
+},
+{
+ GemmMethod::GEMM_NATIVE,
+ "smallK_fp32_mla_1VLx4",
+ [](const GemmArgs<float> &args) { return (args._Ksize <= 24) && !args._trA && !args._trB && args._alpha==1.0f; },
+ nullptr,
+ [](const GemmArgs<float> &args) { return new GemmNative<smallK_fp32_mla_1VLx4, float, float>(args); }
+},
+{
+ GemmMethod::GEMM_NATIVE,
+ "native_fp32_mla_4VLx4",
+ [](const GemmArgs<float> &args) { return (args._Ksize>4 && args._alpha==1.0f && !args._trA && !args._trB); },
+ [](const GemmArgs<float> &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
+ [](const GemmArgs<float> &args) { return new GemmNative<native_fp32_mla_4VLx4, float, float>(args); }
+},
+#endif // __ARM_FEATURE_SVE
+
+// NEON native / hybrid methods
+{
+ GemmMethod::GEMM_HYBRID,
+ "sgemm_nativeA_pretransposeB_16x4",
+ [](const GemmArgs<float> &args) { return (args._Ksize >= 4) && (args._alpha == 1.0f) && !args._trA && args._pretransposed_hint; },
+ [](const GemmArgs<float> &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
+ [](const GemmArgs<float> &args) { return new GemmHybrid<sgemm_nativeA_pretransposeB_16x4, float, float>(args); }
+},
+{
+ GemmMethod::GEMM_NATIVE,
+ "sgemm_native_16x4",
+ [](const GemmArgs<float> &args) { return (args._Ksize>4 && (args._Nsize % 16)==0 && args._alpha==1.0f && !args._trA && !args._trB); },
+ [](const GemmArgs<float> &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
+ [](const GemmArgs<float> &args) { return new GemmNative<sgemm_native_16x4, float, float>(args); }
+},
-static GemmImpl_gemv_batched<float, float> gemv_batched_impl{};
-#if defined(__aarch64__) && !defined(__ARM_FEATURE_SVE)
-static GemmImpl_sgemm_gemv_pretransposed sgemm_gemv_pretransposed_impl{};
-static GemmImpl_sgemm_gemv_native_transposed sgemm_gemv_native_transposed_impl{};
-static GemmImpl_sgemm_gemm_native sgemm_gemm_native_impl{};
-#endif
-static GemmImpl_sgemm_gemm_interleaved sgemm_gemm_interleaved_impl{};
+#ifdef __ARM_FEATURE_SVE
+ {
+ GemmMethod::GEMM_INTERLEAVED,
+ "interleaved_fp32_mla_3VLx8",
+ [](const GemmArgs<float> &args) { return (args._Ksize>4); },
+ nullptr,
+ [](const GemmArgs<float> &args) { return new GemmInterleaved<interleaved_fp32_mla_3VLx8, float, float>(args); }
+},
+#endif // __ARM_FEATURE_SVE
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sgemm_12x8",
+ nullptr,
+ nullptr,
+ [](const GemmArgs<float> &args) { return new GemmInterleaved<sgemm_12x8, float, float>(args); }
+},
+#endif // __aarch64__
-/* List of implementations (order matters) */
-static std::vector<GemmImplementation<float, float> *> SGemmMethods = {
- &gemv_batched_impl,
-#if defined(__aarch64__) && !defined(__ARM_FEATURE_SVE)
- &sgemm_gemv_pretransposed_impl,
- &sgemm_gemv_native_transposed_impl,
- &sgemm_gemm_native_impl,
-#endif
- &sgemm_gemm_interleaved_impl
+#ifdef __arm__
+ {
+ GemmMethod::GEMM_INTERLEAVED,
+ "sgemm_8x6",
+ nullptr,
+ nullptr,
+ [](const GemmArgs<float> &args) { return new GemmInterleaved<sgemm_8x6, float, float>(args); }
+},
+#endif // __arm__
+{
+ GemmMethod::DEFAULT,
+ "",
+ nullptr,
+ nullptr,
+ nullptr
+}
};
/* Templated function to return this list. */
template<>
-std::vector<GemmImplementation<float, float> *> &gemm_implementation_list<float, float>() {
- return SGemmMethods;
+const GemmImplementation<float, float> *gemm_implementation_list<float, float>() {
+ return gemm_fp32_methods;
}
/* Explicitly instantiate the external functions for these types. */
-template UniqueGemmCommon<float, float> gemm<float, float>(GemmArgs<float> &args, GemmConfig *cfg);
-template GemmMethod get_gemm_method<float, float>(GemmArgs<float> &args);
-template bool method_is_compatible<float, float>(GemmMethod method, GemmArgs<float> &args);
+template UniqueGemmCommon<float, float> gemm<float, float>(const GemmArgs<float> &args);
+template KernelDescription get_gemm_method<float, float>(const GemmArgs<float> &args);
+template bool method_is_compatible<float, float>(GemmMethod method, const GemmArgs<float> &args);
+template std::vector<std::string> get_compatible_kernels<float, float> (const GemmArgs<float> &args);
-} // namespace arm_gemm
+} // namespace arm_gemm \ No newline at end of file