aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp104
1 files changed, 75 insertions, 29 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 2fd040efbe..6e47adbaa4 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "gemm_native.hpp"
#include "gemv_batched.hpp"
@@ -37,47 +38,92 @@
namespace arm_gemm {
-template<>
-UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const float alpha, const float beta,
- const int maxthreads, const bool pretransposed_hint) {
- /* Handle "batched GEMV" */
- if (M==1 && nbatches>1) {
- return UniqueGemmCommon<float, float> (new GemvBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+#ifdef __aarch64__
+// SGEMM implementations for AArch64
+
+// Pretransposed GEMV
+class GemmImpl_sgemm_gemv_pretransposed : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Msize==1 && args._alpha==1.0f && args._pretransposed_hint && args._nbatches==1);
}
-#ifdef __aarch64__
- /* Cases in priority order */
- /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set. nbatches must be 1 or we would have returned above so don't test. */
- if (M==1 && alpha==1.0f && pretransposed_hint) {
- return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, nmulti, trB, beta));
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._trB, args._beta));
}
- /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */
- if (M==1 && alpha==1.0f && !trA && !trB) {
- return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
+ GemmImpl_sgemm_gemv_pretransposed() : GemmImplementation<float, float>(GemmMethod::GEMV_PRETRANSPOSED) { }
+};
+
+// Native GEMV
+class GemmImpl_sgemm_gemv_native_transposed : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Msize==1 && args._alpha==1.0f && !args._trA && !args._trB && args._nbatches==1);
}
- /* Native GEMM: requires K at least 4, N a multiple of 16, doesn't
- * handle alpha or transpose. Use for small N/K, or if the blocked GEMM
- * won't thread properly. */
- if ((K >= 4) && ((N % 16) == 0) && alpha==1.0f && !trA && !trB &&
- ((K <= 128 && N <= 128) || (nmulti > 1 && (M/maxthreads) < 8))) {
- return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._beta));
}
- /* Blocked GEMM, handles all cases. */
- return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ GemmImpl_sgemm_gemv_native_transposed() : GemmImplementation<float, float>(GemmMethod::GEMV_NATIVE_TRANSPOSED) { }
+};
+
+// Native GEMM
+class GemmImpl_sgemm_gemm_native : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Ksize>4 && (args._Nsize % 16)==0 && args._alpha==1.0f && !args._trA && !args._trB);
+ }
+
+ bool is_recommended(const GemmArgs<float> &args) override {
+ return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8));
+ }
+
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._beta));
+ }
+
+ GemmImpl_sgemm_gemm_native() : GemmImplementation<float, float>(GemmMethod::GEMM_NATIVE) { }
+};
+#endif // __aarch64__
+
+// Interleaved GEMM
+class GemmImpl_sgemm_gemm_interleaved : public GemmImplementation<float, float> {
+public:
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+#ifdef __aarch64__
+ return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(args));
+#elif defined(__arm__)
+ return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(args));
#else
- return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+# error Unknown Architecture.
#endif
-}
+ }
-// Instantiate static class variables.
+ GemmImpl_sgemm_gemm_interleaved() : GemmImplementation<float, float>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+/* List of implementations (order matters) */
+static std::vector<GemmImplementation<float, float> *> SGemmMethods = {
+ new GemmImpl_gemv_batched<float, float>(),
#ifdef __aarch64__
-const int sgemm_native_16x4::out_width;
-const int sgemm_native_16x4::out_height;
+ new GemmImpl_sgemm_gemv_pretransposed(),
+ new GemmImpl_sgemm_gemv_native_transposed(),
+ new GemmImpl_sgemm_gemm_native(),
#endif
+ new GemmImpl_sgemm_gemm_interleaved()
+};
+
+/* Templated function to return this list. */
+template<>
+std::vector<GemmImplementation<float, float> *> &gemm_implementation_list<float, float>() {
+ return SGemmMethods;
+}
+
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<float, float> gemm<float, float>(GemmArgs<float> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<float, float>(GemmArgs<float> &args);
+template bool method_is_compatible<float, float>(GemmMethod method, GemmArgs<float> &args);
} // namespace arm_gemm