aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp61
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp104
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp131
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int16.cpp26
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp50
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp27
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_native.hpp62
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp26
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp52
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_batched.hpp9
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp14
11 files changed, 425 insertions, 137 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 65f43f302b..829ae325a9 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -28,6 +28,7 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "kernels/a64_hgemm_24x8.hpp"
@@ -36,37 +37,59 @@
namespace arm_gemm {
-template<>
-UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const __fp16 alpha, const __fp16 beta,
- const int maxthreads, const bool pretransposed_hint) {
#ifdef __aarch64__
-// Only consider the native FP16 kernel if it will get built.
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS)
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- // If the compiler is configured to enable this feature always, then assume it is available at runtime too.
- const bool use_fp16=true;
-#else
- // Otherwise, detect at runtime via CPUInfo.
- const bool use_fp16=ci.has_fp16();
+class GemmImpl_gemm_fp16_interleaved_fp16 : public GemmImplementation<__fp16, __fp16> {
+public:
+#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ bool is_supported(const GemmArgs<__fp16> &args) override {
+ return args._ci->has_fp16();
+ }
#endif
- // If FP16 is supported, use it.
- if (use_fp16) {
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(args));
}
+
+ GemmImpl_gemm_fp16_interleaved_fp16() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED_FP16) { }
+};
#endif
- // Fallback to using the blocked SGEMM kernel.
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+#endif // __aarch64__
+
+class GemmImpl_gemm_fp16_interleaved : public GemmImplementation<__fp16, __fp16> {
+public:
+ UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
+#ifdef __aarch64__
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(args));
+#elif defined(__arm__)
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(args));
#else
- // For AArch32, only support the SGEMM route for now.
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+# error Unknown Architecture
#endif
+ }
+
+ GemmImpl_gemm_fp16_interleaved() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<__fp16, __fp16> *> gemm_fp16_methods = {
+#if defined(__aarch64__) && (defined(__ARM_FEATURE_VECTOR_ARITHMETIC) || defined(FP16_KERNELS))
+ new GemmImpl_gemm_fp16_interleaved_fp16(),
+#endif
+ new GemmImpl_gemm_fp16_interleaved()
+};
+
+template<>
+std::vector<GemmImplementation<__fp16, __fp16> *> &gemm_implementation_list<__fp16, __fp16>() {
+ return gemm_fp16_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16>(GemmArgs<__fp16> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<__fp16, __fp16>(GemmArgs<__fp16> &args);
+template bool method_is_compatible<__fp16, __fp16>(GemmMethod method, GemmArgs<__fp16> &args);
+
} // namespace arm_gemm
#endif // __ARM_FP16_ARGS
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 2fd040efbe..6e47adbaa4 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "gemm_native.hpp"
#include "gemv_batched.hpp"
@@ -37,47 +38,92 @@
namespace arm_gemm {
-template<>
-UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const float alpha, const float beta,
- const int maxthreads, const bool pretransposed_hint) {
- /* Handle "batched GEMV" */
- if (M==1 && nbatches>1) {
- return UniqueGemmCommon<float, float> (new GemvBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+#ifdef __aarch64__
+// SGEMM implementations for AArch64
+
+// Pretransposed GEMV
+class GemmImpl_sgemm_gemv_pretransposed : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Msize==1 && args._alpha==1.0f && args._pretransposed_hint && args._nbatches==1);
}
-#ifdef __aarch64__
- /* Cases in priority order */
- /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set. nbatches must be 1 or we would have returned above so don't test. */
- if (M==1 && alpha==1.0f && pretransposed_hint) {
- return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, nmulti, trB, beta));
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._trB, args._beta));
}
- /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */
- if (M==1 && alpha==1.0f && !trA && !trB) {
- return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
+ GemmImpl_sgemm_gemv_pretransposed() : GemmImplementation<float, float>(GemmMethod::GEMV_PRETRANSPOSED) { }
+};
+
+// Native GEMV
+class GemmImpl_sgemm_gemv_native_transposed : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Msize==1 && args._alpha==1.0f && !args._trA && !args._trB && args._nbatches==1);
}
- /* Native GEMM: requires K at least 4, N a multiple of 16, doesn't
- * handle alpha or transpose. Use for small N/K, or if the blocked GEMM
- * won't thread properly. */
- if ((K >= 4) && ((N % 16) == 0) && alpha==1.0f && !trA && !trB &&
- ((K <= 128 && N <= 128) || (nmulti > 1 && (M/maxthreads) < 8))) {
- return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._beta));
}
- /* Blocked GEMM, handles all cases. */
- return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ GemmImpl_sgemm_gemv_native_transposed() : GemmImplementation<float, float>(GemmMethod::GEMV_NATIVE_TRANSPOSED) { }
+};
+
+// Native GEMM
+class GemmImpl_sgemm_gemm_native : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Ksize>4 && (args._Nsize % 16)==0 && args._alpha==1.0f && !args._trA && !args._trB);
+ }
+
+ bool is_recommended(const GemmArgs<float> &args) override {
+ return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8));
+ }
+
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._beta));
+ }
+
+ GemmImpl_sgemm_gemm_native() : GemmImplementation<float, float>(GemmMethod::GEMM_NATIVE) { }
+};
+#endif // __aarch64__
+
+// Interleaved GEMM
+class GemmImpl_sgemm_gemm_interleaved : public GemmImplementation<float, float> {
+public:
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+#ifdef __aarch64__
+ return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(args));
+#elif defined(__arm__)
+ return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(args));
#else
- return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+# error Unknown Architecture.
#endif
-}
+ }
-// Instantiate static class variables.
+ GemmImpl_sgemm_gemm_interleaved() : GemmImplementation<float, float>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+/* List of implementations (order matters) */
+static std::vector<GemmImplementation<float, float> *> SGemmMethods = {
+ new GemmImpl_gemv_batched<float, float>(),
#ifdef __aarch64__
-const int sgemm_native_16x4::out_width;
-const int sgemm_native_16x4::out_height;
+ new GemmImpl_sgemm_gemv_pretransposed(),
+ new GemmImpl_sgemm_gemv_native_transposed(),
+ new GemmImpl_sgemm_gemm_native(),
#endif
+ new GemmImpl_sgemm_gemm_interleaved()
+};
+
+/* Templated function to return this list. */
+template<>
+std::vector<GemmImplementation<float, float> *> &gemm_implementation_list<float, float>() {
+ return SGemmMethods;
+}
+
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<float, float> gemm<float, float>(GemmArgs<float> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<float, float>(GemmArgs<float> &args);
+template bool method_is_compatible<float, float>(GemmMethod method, GemmArgs<float> &args);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
new file mode 100644
index 0000000000..6734e3cce0
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "gemv_batched.hpp"
+
+namespace arm_gemm {
+
+template<typename Top, typename Tret>
+class GemmImplementation {
+public:
+ /* Is this implementation compatible with the args as provided? */
+ virtual bool is_supported(const GemmArgs<Tret> &args) { return true; }
+ /* Is this implementation "recommended" for these args (heuristic)? */
+ virtual bool is_recommended(const GemmArgs<Tret> &args) { return true; }
+ /* Instantiate this method please. */
+ virtual UniqueGemmCommon<Top, Tret> instantiate(const GemmArgs<Tret> &args) = 0;
+
+ /* Indicate the "GemmMethod" for use as a selector */
+ const GemmMethod method;
+
+ virtual ~GemmImplementation() { }
+
+ GemmImplementation(GemmMethod method) : method(method) { }
+};
+
+/* "gemv_batched" implementation is type-agnostic, so template it here. */
+template<typename Top, typename Tret>
+class GemmImpl_gemv_batched : public GemmImplementation<Top, Tret> {
+public:
+ bool is_supported(const GemmArgs<Tret> &args) override {
+ return (args._Msize==1 && args._nbatches > 1);
+ }
+
+ UniqueGemmCommon<Top, Tret> instantiate(const GemmArgs<Tret> &args) override {
+ return UniqueGemmCommon<Top, Tret> (new GemvBatched<Top, Tret>(args));
+ }
+
+ GemmImpl_gemv_batched() : GemmImplementation<Top, Tret>(GemmMethod::GEMV_BATCHED) { }
+};
+
+/* "Master" function implemented for each valid combination of types.
+ * Returns a list of GEMM implementation descriptors for processing by the
+ * other functions. */
+template<typename Top, typename Tret>
+std::vector<GemmImplementation<Top, Tret> *> &gemm_implementation_list();
+
+template<typename Top, typename Tret>
+GemmImplementation<Top, Tret> *find_implementation(GemmArgs<Tret> &args, GemmConfig *cfg) {
+ auto gemms = gemm_implementation_list<Top, Tret>();
+
+ for(auto &&i : gemms) {
+ /* Skip if this implementation doesn't support these args. */
+ if (!i->is_supported(args)) {
+ continue;
+ }
+
+ /* Skip if a specific method is requested and this is a different one. */
+ if (cfg && cfg->method != GemmMethod::DEFAULT && i->method != cfg->method) {
+ continue;
+ }
+
+ /* If no specific method is requested, check that this method recommends itself. */
+ if ((!cfg || cfg->method == GemmMethod::DEFAULT) && !i->is_recommended(args)) {
+ continue;
+ }
+
+ return i;
+ }
+
+ return nullptr;
+}
+
+template<typename Top, typename Tret>
+UniqueGemmCommon<Top, Tret> gemm(GemmArgs<Tret> &args, GemmConfig *cfg) {
+ auto impl = find_implementation<Top, Tret>(args, cfg);
+
+ if (impl) {
+ return impl->instantiate(args);
+ }
+
+ return UniqueGemmCommon<Top, Tret>(nullptr);
+}
+
+template<typename Top, typename Tret>
+GemmMethod get_gemm_method(GemmArgs<Tret> &args) {
+ auto impl = find_implementation<Top, Tret>(args, nullptr);
+
+ if (impl) {
+ return impl->method;
+ }
+
+ /* This shouldn't happen - there should always be at least one valid implementation. */
+ return GemmMethod::DEFAULT;
+}
+
+template<typename Top, typename Tret>
+bool method_is_compatible(GemmMethod method, GemmArgs<Tret> &args) {
+ /* Determine if the method is valid by attempting to obtain an implementation specifying this method. */
+ GemmConfig cfg(method);
+
+ auto impl = find_implementation<Top, Tret>(args, &cfg);
+
+ if (impl) {
+ return true;
+ }
+
+ return false;
+}
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index 57cd15f698..f61cc1358f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -25,20 +25,36 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "kernels/a64_gemm_s16_12x8.hpp"
namespace arm_gemm {
+class GemmImpl_gemm_s16_interleaved : public GemmImplementation<int16_t, int32_t> {
+public:
+ UniqueGemmCommon<int16_t, int32_t> instantiate(const GemmArgs<int32_t> &args) override {
+ return UniqueGemmCommon<int16_t, int32_t>(new GemmInterleaved<gemm_s16_12x8, int16_t, int32_t>(args));
+ }
+
+ GemmImpl_gemm_s16_interleaved() : GemmImplementation<int16_t, int32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<int16_t, int32_t> *> gemm_s16_methods = {
+ new GemmImpl_gemm_s16_interleaved()
+};
+
template<>
-UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const int32_t alpha, const int32_t beta,
- const int maxthreads, const bool pretransposed_hint) {
- return UniqueGemmCommon<int16_t, int32_t>(new GemmInterleaved<gemm_s16_12x8, int16_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+std::vector<GemmImplementation<int16_t, int32_t> *> &gemm_implementation_list<int16_t, int32_t>() {
+ return gemm_s16_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t>(GemmArgs<int32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<int16_t, int32_t>(GemmArgs<int32_t> &args);
+template bool method_is_compatible<int16_t, int32_t>(GemmMethod method, GemmArgs<int32_t> &args);
+
} // namespace arm_gemm
#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 04803eb81a..f50b399de5 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -25,32 +25,52 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
-#include "kernels/a64_gemm_s8_4x4.hpp"
#include "kernels/a64_gemm_s16_12x8.hpp"
#include "kernels/a64_gemm_s8_12x8.hpp"
+#include "kernels/a64_gemm_s8_4x4.hpp"
namespace arm_gemm {
-template<>
-UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const int32_t alpha, const int32_t beta,
- const int maxthreads, const bool pretransposed_hint) {
- if (ci.has_dotprod()) {
- // Dot product supporting CPUs. This family has a special version for A55r1.
- return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_12x8, int8_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+class GemmImpl_gemm_s8_interleaved_dot : public GemmImplementation<int8_t, int32_t> {
+public:
+ bool is_supported(const GemmArgs<int32_t> &args) override {
+ return args._ci->has_dotprod();
}
- return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_4x4, int8_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<int8_t, int32_t> instantiate(const GemmArgs<int32_t> &args) override {
+ return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_12x8, int8_t, int32_t>(args));
+ }
+
+ GemmImpl_gemm_s8_interleaved_dot() : GemmImplementation<int8_t, int32_t>(GemmMethod::GEMM_INTERLEAVED_DOT) { }
+};
- // TODO: There's a better approach for A53, but it doesn't work
- // well on heterogeneous systems as the required data formats
- // are different. Figure out how to enable this:
- // gemm = new GemmInterleaved<gemm_s16_12x8, int8_t, int32_t>(ci, M, N, K, trA, trB);
+class GemmImpl_gemm_s8_interleaved : public GemmImplementation<int8_t, int32_t> {
+public:
+ UniqueGemmCommon<int8_t, int32_t> instantiate(const GemmArgs<int32_t> &args) override {
+ return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_4x4, int8_t, int32_t>(args));
+ }
+
+ GemmImpl_gemm_s8_interleaved() : GemmImplementation<int8_t, int32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<int8_t, int32_t> *> gemm_s8_methods = {
+ new GemmImpl_gemm_s8_interleaved_dot(),
+ new GemmImpl_gemm_s8_interleaved()
+};
+
+template<>
+std::vector<GemmImplementation<int8_t, int32_t> *> &gemm_implementation_list<int8_t, int32_t>() {
+ return gemm_s8_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t>(GemmArgs<int32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<int8_t, int32_t>(GemmArgs<int32_t> &args);
+template bool method_is_compatible<int8_t, int32_t>(GemmMethod method, GemmArgs<int32_t> &args);
+
} // namespace arm_gemm
-#endif // aarch64
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index c5a43e6519..0e58a4d01f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -317,16 +317,15 @@ public:
GemmInterleaved & operator= (GemmInterleaved &) = delete;
/* Constructor */
- GemmInterleaved(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
- const Tr alpha, const Tr beta, const int maxthreads, const bool pretransposed) :
- _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti),
- _trA(trA), _trB(trB), _alpha(alpha), _beta(beta),
- _maxthreads(maxthreads), _nthreads(maxthreads), _pretransposed(pretransposed) {
- const unsigned int L1_size = ci->get_L1_cache_size();
- const unsigned int L2_size = ci->get_L2_cache_size();
+ GemmInterleaved(const GemmArgs<Tr> &args)
+ : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
+ _nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB),
+ _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _pretransposed(args._pretransposed_hint) {
+ const unsigned int L1_size = _ci->get_L1_cache_size();
+ const unsigned int L2_size = _ci->get_L2_cache_size();
- assert(maxthreads > 0);
+ assert(_maxthreads > 0);
// Work out blocking parameters
@@ -339,10 +338,10 @@ public:
_k_block = std::max(_k_block, 1U) * strategy::k_unroll();
// Now tune to presented problem size; this is how many blocks we need.
- int num_k_blocks = iceildiv(K, _k_block);
+ int num_k_blocks = iceildiv(_Ksize, _k_block);
// So divide the space equally into that many blocks.
- _k_block = iceildiv(K, num_k_blocks);
+ _k_block = iceildiv(_Ksize, num_k_blocks);
// And round UP to the K unroll level required.
_k_block = iceildiv(_k_block, strategy::k_unroll());
@@ -358,14 +357,14 @@ public:
_x_block = std::max(_x_block, 1U) * strategy::out_width();
// And tune to the presented problem size.
- int num_x_blocks = iceildiv(N, _x_block);
- _x_block = iceildiv(N, num_x_blocks);
+ int num_x_blocks = iceildiv(_Nsize, _x_block);
+ _x_block = iceildiv(_Nsize, num_x_blocks);
_x_block = iceildiv(_x_block, strategy::out_width());
_x_block *= strategy::out_width();
// Work out the rounded size of M - needed for some buffers.
- _Mround = iceildiv(M, strategy::out_height());
+ _Mround = iceildiv(_Msize, strategy::out_height());
_Mround *= strategy::out_height();
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
index 6fed645d82..baa1316745 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
@@ -62,6 +62,14 @@ class GemmNative : public GemmCommon<To, Tr> {
unsigned int k_block=0;
unsigned int n_block=0;
+ unsigned int window_per_batch() const {
+ return iceildiv(_Msize, strategy::out_height());
+ }
+
+ unsigned int window_per_multi() const {
+ return window_per_batch() * _nbatches;
+ }
+
public:
GemmNative(GemmNative &) = delete;
GemmNative & operator= (GemmNative &) = delete;
@@ -73,9 +81,9 @@ public:
n_block = N;
}
- // Window is number of out_height blocks
+ // Window is amount per multi multiplied by total number of multis.
unsigned int get_window_size() const override {
- return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis;
+ return window_per_multi() * _nmultis;
}
// Actually execute the GEMM.
@@ -85,39 +93,39 @@ public:
#endif
strategy strat(_ci);
- const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height);
- const unsigned int window_per_multi = window_per_batch * _nbatches;
-
- const unsigned int first_multi = start / window_per_multi;
- const unsigned int last_multi = end / window_per_multi;
-
- const unsigned int first_batch = (start - (first_multi * window_per_multi)) / window_per_batch;
- const unsigned int last_batch = (end - (last_multi * window_per_multi)) / window_per_batch;
-
- const unsigned int first_row = ((start - (first_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
- const unsigned int last_row = ((end - (last_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
-
static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same.");
static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
- for (unsigned int multi=first_multi; multi<=last_multi; multi++) {
- const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0;
- const unsigned int batch_max = (multi == last_multi) ? last_batch : (_nbatches-1);
+ /* Compute starting point based on 'start' */
+ unsigned int multi = start / window_per_multi();
+ unsigned int multi_pos = start % window_per_multi();
+
+ unsigned int batch = multi_pos / window_per_batch();
+ unsigned int batch_pos = multi_pos % window_per_batch();
- for (unsigned int batch=batch_0; batch <= batch_max; batch++) {
- const unsigned int m_start = ((multi == first_multi) && (batch==first_batch)) ? first_row : 0;
- const unsigned int m_end = ((multi == last_multi) && (batch==last_batch)) ? last_row : _Msize;
+ unsigned int y0 = batch_pos * strategy::out_height();
- for (unsigned int y0=m_start; y0<m_end; y0+=strategy::out_height) {
- const unsigned int ymax = std::min(y0 + strategy::out_height, m_end);
+ for (unsigned int pos=start; pos<end; pos++) {
+ const unsigned int ymax = std::min(y0 + strategy::out_height(), _Msize);
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize);
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize);
#endif
- strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
- this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
- this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
- _beta, (ymax-y0), _Nsize, _Ksize);
+ strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
+ this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
+ this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
+ _beta, (ymax-y0), _Nsize, _Ksize);
+
+ /* Advance to next item */
+ y0 += strategy::out_height();
+
+ /* Check for batch/multi overflow */
+ if (y0 >= _Msize) {
+ y0=0;
+ batch++;
+ if (batch == _nbatches) {
+ batch=0;
+ multi++;
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index 6db55c02d0..f4b712c453 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -25,20 +25,36 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "kernels/a64_gemm_u16_12x8.hpp"
namespace arm_gemm {
+class GemmImpl_gemm_u16_interleaved : public GemmImplementation<uint16_t, uint32_t> {
+public:
+ UniqueGemmCommon<uint16_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint16_t, uint32_t>(new GemmInterleaved<gemm_u16_12x8, uint16_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u16_interleaved() : GemmImplementation<uint16_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<uint16_t, uint32_t> *> gemm_u16_methods = {
+ new GemmImpl_gemm_u16_interleaved()
+};
+
template<>
-UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, uint32_t alpha, uint32_t beta,
- const int maxthreads, const bool pretransposed_hint) {
- return UniqueGemmCommon<uint16_t, uint32_t>(new GemmInterleaved<gemm_u16_12x8, uint16_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+std::vector<GemmImplementation<uint16_t, uint32_t> *> &gemm_implementation_list<uint16_t, uint32_t>() {
+ return gemm_u16_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t>(GemmArgs<uint32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<uint16_t, uint32_t>(GemmArgs<uint32_t> &args);
+template bool method_is_compatible<uint16_t, uint32_t>(GemmMethod method, GemmArgs<uint32_t> &args);
+
} // namespace arm_gemm
#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index 1ca92f9d4e..d97dd5c3de 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -25,32 +25,52 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
-#include "kernels/a64_gemm_u8_4x4.hpp"
+#include "kernels/a64_gemm_u16_12x8.hpp"
#include "kernels/a64_gemm_u8_12x8.hpp"
+#include "kernels/a64_gemm_u8_4x4.hpp"
namespace arm_gemm {
-template<>
-UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const uint32_t alpha, const uint32_t beta,
- const int maxthreads, const bool pretransposed_hint) {
- if (ci.has_dotprod()) {
- // Dot product supporting CPUs. This family has a special version for A55r1.
- return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_12x8, uint8_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+class GemmImpl_gemm_u8_interleaved_dot : public GemmImplementation<uint8_t, uint32_t> {
+public:
+ bool is_supported(const GemmArgs<uint32_t> &args) override {
+ return args._ci->has_dotprod();
}
- // Non dot-product code.
- return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<uint8_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_12x8, uint8_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u8_interleaved_dot() : GemmImplementation<uint8_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED_DOT) { }
+};
- // TODO: There's a better approach for A53, but it doesn't work
- // well on heterogeneous systems as the required data formats
- // are different. Figure out how to enable this:
- // gemm = new GemmInterleaved<gemm_s16_12x8, int8_t, int32_t>(ci, M, N, K, trA, trB);
+class GemmImpl_gemm_u8_interleaved : public GemmImplementation<uint8_t, uint32_t> {
+public:
+ UniqueGemmCommon<uint8_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u8_interleaved() : GemmImplementation<uint8_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<uint8_t, uint32_t> *> gemm_u8_methods = {
+ new GemmImpl_gemm_u8_interleaved_dot(),
+ new GemmImpl_gemm_u8_interleaved()
+};
+
+template<>
+std::vector<GemmImplementation<uint8_t, uint32_t> *> &gemm_implementation_list<uint8_t, uint32_t>() {
+ return gemm_u8_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t>(GemmArgs<uint32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<uint8_t, uint32_t>(GemmArgs<uint32_t> &args);
+template bool method_is_compatible<uint8_t, uint32_t>(GemmMethod method, GemmArgs<uint32_t> &args);
+
} // namespace arm_gemm
-#endif // aarch64
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
index d91b44b9a8..d65971e47d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
@@ -36,11 +36,12 @@ private:
UniqueGemmCommon<To, Tr> _subgemm = nullptr;
public:
- GemvBatched(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
- const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint) {
+ GemvBatched(const GemmArgs<Tr> &args) {
/* Just create a subgemm with batches->M */
- _subgemm = gemm<To,Tr>(ci, nbatches, N, K, 1, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint);
+ GemmArgs<Tr> newargs = args;
+ newargs._Msize = args._nbatches;
+ newargs._nbatches = 1;
+ _subgemm = gemm<To,Tr>(newargs, nullptr);
}
void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp
index 11a589d75c..1a3596511b 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp
@@ -46,9 +46,17 @@ public:
typedef void (*kern_type)(const float *, int, const float *, int, float *, int, float, int, int, int);
/* Kernel blocking parameters */
- static const int out_width = 16;
- static const int out_height = 4;
- static const int k_unroll = 1;
+ static int out_width() {
+ return 16;
+ }
+
+ static int out_height() {
+ return 4;
+ }
+
+ static int k_unroll() {
+ return 1;
+ }
// Default to the generic kernel
kern_type kernel=a64_sgemm_native_16x4;