aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2018-07-06 17:53:35 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:10 +0000
commite39334c15c7fd141bb8173d5017ea5ca157fca2c (patch)
treefffa2f7b136525037c4d99586bc194374e5bd3dc /src/core/NEON/kernels/arm_gemm
parente8bd2c729546e59aa0adc241976ea91fc6f25b52 (diff)
downloadComputeLibrary-e39334c15c7fd141bb8173d5017ea5ca157fca2c.tar.gz
COMPMID-1271: New system for GEMM heuristics
This patch implements a system for separating the "validity" from "preferred" aspect of the current heuristics in gemm_*.cpp. Now, each gemm_*.cpp defines a list of candidate implementations, each of which supplies an is_valid() function (to check for validity), an is_preferred() function (the "heuristic" part), and an instantiate() function which actually produces the GemmCommon object pointer. The actual gemm() function is now templated and uses this list to select an implementation. This patch also implements a mechanism to identify the preferred implementation, and override it via the GemmConfig structure. Change-Id: Id49ab7af8bf2e3e9fd951a9698883ade234d40e1 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139120 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp61
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp104
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp131
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int16.cpp26
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp50
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp27
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_native.hpp62
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp26
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp52
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_batched.hpp9
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp14
11 files changed, 425 insertions, 137 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 65f43f302b..829ae325a9 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -28,6 +28,7 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "kernels/a64_hgemm_24x8.hpp"
@@ -36,37 +37,59 @@
namespace arm_gemm {
-template<>
-UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const __fp16 alpha, const __fp16 beta,
- const int maxthreads, const bool pretransposed_hint) {
#ifdef __aarch64__
-// Only consider the native FP16 kernel if it will get built.
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS)
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- // If the compiler is configured to enable this feature always, then assume it is available at runtime too.
- const bool use_fp16=true;
-#else
- // Otherwise, detect at runtime via CPUInfo.
- const bool use_fp16=ci.has_fp16();
+class GemmImpl_gemm_fp16_interleaved_fp16 : public GemmImplementation<__fp16, __fp16> {
+public:
+#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ bool is_supported(const GemmArgs<__fp16> &args) override {
+ return args._ci->has_fp16();
+ }
#endif
- // If FP16 is supported, use it.
- if (use_fp16) {
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(args));
}
+
+ GemmImpl_gemm_fp16_interleaved_fp16() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED_FP16) { }
+};
#endif
- // Fallback to using the blocked SGEMM kernel.
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+#endif // __aarch64__
+
+class GemmImpl_gemm_fp16_interleaved : public GemmImplementation<__fp16, __fp16> {
+public:
+ UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
+#ifdef __aarch64__
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(args));
+#elif defined(__arm__)
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(args));
#else
- // For AArch32, only support the SGEMM route for now.
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+# error Unknown Architecture
#endif
+ }
+
+ GemmImpl_gemm_fp16_interleaved() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<__fp16, __fp16> *> gemm_fp16_methods = {
+#if defined(__aarch64__) && (defined(__ARM_FEATURE_VECTOR_ARITHMETIC) || defined(FP16_KERNELS))
+ new GemmImpl_gemm_fp16_interleaved_fp16(),
+#endif
+ new GemmImpl_gemm_fp16_interleaved()
+};
+
+template<>
+std::vector<GemmImplementation<__fp16, __fp16> *> &gemm_implementation_list<__fp16, __fp16>() {
+ return gemm_fp16_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16>(GemmArgs<__fp16> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<__fp16, __fp16>(GemmArgs<__fp16> &args);
+template bool method_is_compatible<__fp16, __fp16>(GemmMethod method, GemmArgs<__fp16> &args);
+
} // namespace arm_gemm
#endif // __ARM_FP16_ARGS
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 2fd040efbe..6e47adbaa4 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "gemm_native.hpp"
#include "gemv_batched.hpp"
@@ -37,47 +38,92 @@
namespace arm_gemm {
-template<>
-UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const float alpha, const float beta,
- const int maxthreads, const bool pretransposed_hint) {
- /* Handle "batched GEMV" */
- if (M==1 && nbatches>1) {
- return UniqueGemmCommon<float, float> (new GemvBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+#ifdef __aarch64__
+// SGEMM implementations for AArch64
+
+// Pretransposed GEMV
+class GemmImpl_sgemm_gemv_pretransposed : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Msize==1 && args._alpha==1.0f && args._pretransposed_hint && args._nbatches==1);
}
-#ifdef __aarch64__
- /* Cases in priority order */
- /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set. nbatches must be 1 or we would have returned above so don't test. */
- if (M==1 && alpha==1.0f && pretransposed_hint) {
- return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, nmulti, trB, beta));
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._trB, args._beta));
}
- /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */
- if (M==1 && alpha==1.0f && !trA && !trB) {
- return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
+ GemmImpl_sgemm_gemv_pretransposed() : GemmImplementation<float, float>(GemmMethod::GEMV_PRETRANSPOSED) { }
+};
+
+// Native GEMV
+class GemmImpl_sgemm_gemv_native_transposed : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Msize==1 && args._alpha==1.0f && !args._trA && !args._trB && args._nbatches==1);
}
- /* Native GEMM: requires K at least 4, N a multiple of 16, doesn't
- * handle alpha or transpose. Use for small N/K, or if the blocked GEMM
- * won't thread properly. */
- if ((K >= 4) && ((N % 16) == 0) && alpha==1.0f && !trA && !trB &&
- ((K <= 128 && N <= 128) || (nmulti > 1 && (M/maxthreads) < 8))) {
- return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._beta));
}
- /* Blocked GEMM, handles all cases. */
- return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ GemmImpl_sgemm_gemv_native_transposed() : GemmImplementation<float, float>(GemmMethod::GEMV_NATIVE_TRANSPOSED) { }
+};
+
+// Native GEMM
+class GemmImpl_sgemm_gemm_native : public GemmImplementation<float, float> {
+public:
+ bool is_supported(const GemmArgs<float> &args) override {
+ return (args._Ksize>4 && (args._Nsize % 16)==0 && args._alpha==1.0f && !args._trA && !args._trB);
+ }
+
+ bool is_recommended(const GemmArgs<float> &args) override {
+ return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8));
+ }
+
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+ return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._beta));
+ }
+
+ GemmImpl_sgemm_gemm_native() : GemmImplementation<float, float>(GemmMethod::GEMM_NATIVE) { }
+};
+#endif // __aarch64__
+
+// Interleaved GEMM
+class GemmImpl_sgemm_gemm_interleaved : public GemmImplementation<float, float> {
+public:
+ UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+#ifdef __aarch64__
+ return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(args));
+#elif defined(__arm__)
+ return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(args));
#else
- return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+# error Unknown Architecture.
#endif
-}
+ }
-// Instantiate static class variables.
+ GemmImpl_sgemm_gemm_interleaved() : GemmImplementation<float, float>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+/* List of implementations (order matters) */
+static std::vector<GemmImplementation<float, float> *> SGemmMethods = {
+ new GemmImpl_gemv_batched<float, float>(),
#ifdef __aarch64__
-const int sgemm_native_16x4::out_width;
-const int sgemm_native_16x4::out_height;
+ new GemmImpl_sgemm_gemv_pretransposed(),
+ new GemmImpl_sgemm_gemv_native_transposed(),
+ new GemmImpl_sgemm_gemm_native(),
#endif
+ new GemmImpl_sgemm_gemm_interleaved()
+};
+
+/* Templated function to return this list. */
+template<>
+std::vector<GemmImplementation<float, float> *> &gemm_implementation_list<float, float>() {
+ return SGemmMethods;
+}
+
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<float, float> gemm<float, float>(GemmArgs<float> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<float, float>(GemmArgs<float> &args);
+template bool method_is_compatible<float, float>(GemmMethod method, GemmArgs<float> &args);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
new file mode 100644
index 0000000000..6734e3cce0
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "gemv_batched.hpp"
+
+namespace arm_gemm {
+
+template<typename Top, typename Tret>
+class GemmImplementation {
+public:
+ /* Is this implementation compatible with the args as provided? */
+ virtual bool is_supported(const GemmArgs<Tret> &args) { return true; }
+ /* Is this implementation "recommended" for these args (heuristic)? */
+ virtual bool is_recommended(const GemmArgs<Tret> &args) { return true; }
+ /* Instantiate this method please. */
+ virtual UniqueGemmCommon<Top, Tret> instantiate(const GemmArgs<Tret> &args) = 0;
+
+ /* Indicate the "GemmMethod" for use as a selector */
+ const GemmMethod method;
+
+ virtual ~GemmImplementation() { }
+
+ GemmImplementation(GemmMethod method) : method(method) { }
+};
+
+/* "gemv_batched" implementation is type-agnostic, so template it here. */
+template<typename Top, typename Tret>
+class GemmImpl_gemv_batched : public GemmImplementation<Top, Tret> {
+public:
+ bool is_supported(const GemmArgs<Tret> &args) override {
+ return (args._Msize==1 && args._nbatches > 1);
+ }
+
+ UniqueGemmCommon<Top, Tret> instantiate(const GemmArgs<Tret> &args) override {
+ return UniqueGemmCommon<Top, Tret> (new GemvBatched<Top, Tret>(args));
+ }
+
+ GemmImpl_gemv_batched() : GemmImplementation<Top, Tret>(GemmMethod::GEMV_BATCHED) { }
+};
+
+/* "Master" function implemented for each valid combination of types.
+ * Returns a list of GEMM implementation descriptors for processing by the
+ * other functions. */
+template<typename Top, typename Tret>
+std::vector<GemmImplementation<Top, Tret> *> &gemm_implementation_list();
+
+template<typename Top, typename Tret>
+GemmImplementation<Top, Tret> *find_implementation(GemmArgs<Tret> &args, GemmConfig *cfg) {
+ auto gemms = gemm_implementation_list<Top, Tret>();
+
+ for(auto &&i : gemms) {
+ /* Skip if this implementation doesn't support these args. */
+ if (!i->is_supported(args)) {
+ continue;
+ }
+
+ /* Skip if a specific method is requested and this is a different one. */
+ if (cfg && cfg->method != GemmMethod::DEFAULT && i->method != cfg->method) {
+ continue;
+ }
+
+ /* If no specific method is requested, check that this method recommends itself. */
+ if ((!cfg || cfg->method == GemmMethod::DEFAULT) && !i->is_recommended(args)) {
+ continue;
+ }
+
+ return i;
+ }
+
+ return nullptr;
+}
+
+template<typename Top, typename Tret>
+UniqueGemmCommon<Top, Tret> gemm(GemmArgs<Tret> &args, GemmConfig *cfg) {
+ auto impl = find_implementation<Top, Tret>(args, cfg);
+
+ if (impl) {
+ return impl->instantiate(args);
+ }
+
+ return UniqueGemmCommon<Top, Tret>(nullptr);
+}
+
+template<typename Top, typename Tret>
+GemmMethod get_gemm_method(GemmArgs<Tret> &args) {
+ auto impl = find_implementation<Top, Tret>(args, nullptr);
+
+ if (impl) {
+ return impl->method;
+ }
+
+ /* This shouldn't happen - there should always be at least one valid implementation. */
+ return GemmMethod::DEFAULT;
+}
+
+template<typename Top, typename Tret>
+bool method_is_compatible(GemmMethod method, GemmArgs<Tret> &args) {
+ /* Determine if the method is valid by attempting to obtain an implementation specifying this method. */
+ GemmConfig cfg(method);
+
+ auto impl = find_implementation<Top, Tret>(args, &cfg);
+
+ if (impl) {
+ return true;
+ }
+
+ return false;
+}
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index 57cd15f698..f61cc1358f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -25,20 +25,36 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "kernels/a64_gemm_s16_12x8.hpp"
namespace arm_gemm {
+class GemmImpl_gemm_s16_interleaved : public GemmImplementation<int16_t, int32_t> {
+public:
+ UniqueGemmCommon<int16_t, int32_t> instantiate(const GemmArgs<int32_t> &args) override {
+ return UniqueGemmCommon<int16_t, int32_t>(new GemmInterleaved<gemm_s16_12x8, int16_t, int32_t>(args));
+ }
+
+ GemmImpl_gemm_s16_interleaved() : GemmImplementation<int16_t, int32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<int16_t, int32_t> *> gemm_s16_methods = {
+ new GemmImpl_gemm_s16_interleaved()
+};
+
template<>
-UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const int32_t alpha, const int32_t beta,
- const int maxthreads, const bool pretransposed_hint) {
- return UniqueGemmCommon<int16_t, int32_t>(new GemmInterleaved<gemm_s16_12x8, int16_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+std::vector<GemmImplementation<int16_t, int32_t> *> &gemm_implementation_list<int16_t, int32_t>() {
+ return gemm_s16_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t>(GemmArgs<int32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<int16_t, int32_t>(GemmArgs<int32_t> &args);
+template bool method_is_compatible<int16_t, int32_t>(GemmMethod method, GemmArgs<int32_t> &args);
+
} // namespace arm_gemm
#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 04803eb81a..f50b399de5 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -25,32 +25,52 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
-#include "kernels/a64_gemm_s8_4x4.hpp"
#include "kernels/a64_gemm_s16_12x8.hpp"
#include "kernels/a64_gemm_s8_12x8.hpp"
+#include "kernels/a64_gemm_s8_4x4.hpp"
namespace arm_gemm {
-template<>
-UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const int32_t alpha, const int32_t beta,
- const int maxthreads, const bool pretransposed_hint) {
- if (ci.has_dotprod()) {
- // Dot product supporting CPUs. This family has a special version for A55r1.
- return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_12x8, int8_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+class GemmImpl_gemm_s8_interleaved_dot : public GemmImplementation<int8_t, int32_t> {
+public:
+ bool is_supported(const GemmArgs<int32_t> &args) override {
+ return args._ci->has_dotprod();
}
- return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_4x4, int8_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<int8_t, int32_t> instantiate(const GemmArgs<int32_t> &args) override {
+ return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_12x8, int8_t, int32_t>(args));
+ }
+
+ GemmImpl_gemm_s8_interleaved_dot() : GemmImplementation<int8_t, int32_t>(GemmMethod::GEMM_INTERLEAVED_DOT) { }
+};
- // TODO: There's a better approach for A53, but it doesn't work
- // well on heterogeneous systems as the required data formats
- // are different. Figure out how to enable this:
- // gemm = new GemmInterleaved<gemm_s16_12x8, int8_t, int32_t>(ci, M, N, K, trA, trB);
+class GemmImpl_gemm_s8_interleaved : public GemmImplementation<int8_t, int32_t> {
+public:
+ UniqueGemmCommon<int8_t, int32_t> instantiate(const GemmArgs<int32_t> &args) override {
+ return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_4x4, int8_t, int32_t>(args));
+ }
+
+ GemmImpl_gemm_s8_interleaved() : GemmImplementation<int8_t, int32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<int8_t, int32_t> *> gemm_s8_methods = {
+ new GemmImpl_gemm_s8_interleaved_dot(),
+ new GemmImpl_gemm_s8_interleaved()
+};
+
+template<>
+std::vector<GemmImplementation<int8_t, int32_t> *> &gemm_implementation_list<int8_t, int32_t>() {
+ return gemm_s8_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t>(GemmArgs<int32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<int8_t, int32_t>(GemmArgs<int32_t> &args);
+template bool method_is_compatible<int8_t, int32_t>(GemmMethod method, GemmArgs<int32_t> &args);
+
} // namespace arm_gemm
-#endif // aarch64
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index c5a43e6519..0e58a4d01f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -317,16 +317,15 @@ public:
GemmInterleaved & operator= (GemmInterleaved &) = delete;
/* Constructor */
- GemmInterleaved(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
- const Tr alpha, const Tr beta, const int maxthreads, const bool pretransposed) :
- _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti),
- _trA(trA), _trB(trB), _alpha(alpha), _beta(beta),
- _maxthreads(maxthreads), _nthreads(maxthreads), _pretransposed(pretransposed) {
- const unsigned int L1_size = ci->get_L1_cache_size();
- const unsigned int L2_size = ci->get_L2_cache_size();
+ GemmInterleaved(const GemmArgs<Tr> &args)
+ : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize),
+ _nbatches(args._nbatches), _nmulti(args._nmulti), _trA(args._trA), _trB(args._trB),
+ _alpha(args._alpha), _beta(args._beta), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _pretransposed(args._pretransposed_hint) {
+ const unsigned int L1_size = _ci->get_L1_cache_size();
+ const unsigned int L2_size = _ci->get_L2_cache_size();
- assert(maxthreads > 0);
+ assert(_maxthreads > 0);
// Work out blocking parameters
@@ -339,10 +338,10 @@ public:
_k_block = std::max(_k_block, 1U) * strategy::k_unroll();
// Now tune to presented problem size; this is how many blocks we need.
- int num_k_blocks = iceildiv(K, _k_block);
+ int num_k_blocks = iceildiv(_Ksize, _k_block);
// So divide the space equally into that many blocks.
- _k_block = iceildiv(K, num_k_blocks);
+ _k_block = iceildiv(_Ksize, num_k_blocks);
// And round UP to the K unroll level required.
_k_block = iceildiv(_k_block, strategy::k_unroll());
@@ -358,14 +357,14 @@ public:
_x_block = std::max(_x_block, 1U) * strategy::out_width();
// And tune to the presented problem size.
- int num_x_blocks = iceildiv(N, _x_block);
- _x_block = iceildiv(N, num_x_blocks);
+ int num_x_blocks = iceildiv(_Nsize, _x_block);
+ _x_block = iceildiv(_Nsize, num_x_blocks);
_x_block = iceildiv(_x_block, strategy::out_width());
_x_block *= strategy::out_width();
// Work out the rounded size of M - needed for some buffers.
- _Mround = iceildiv(M, strategy::out_height());
+ _Mround = iceildiv(_Msize, strategy::out_height());
_Mround *= strategy::out_height();
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
index 6fed645d82..baa1316745 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
@@ -62,6 +62,14 @@ class GemmNative : public GemmCommon<To, Tr> {
unsigned int k_block=0;
unsigned int n_block=0;
+ unsigned int window_per_batch() const {
+ return iceildiv(_Msize, strategy::out_height());
+ }
+
+ unsigned int window_per_multi() const {
+ return window_per_batch() * _nbatches;
+ }
+
public:
GemmNative(GemmNative &) = delete;
GemmNative & operator= (GemmNative &) = delete;
@@ -73,9 +81,9 @@ public:
n_block = N;
}
- // Window is number of out_height blocks
+ // Window is amount per multi multiplied by total number of multis.
unsigned int get_window_size() const override {
- return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis;
+ return window_per_multi() * _nmultis;
}
// Actually execute the GEMM.
@@ -85,39 +93,39 @@ public:
#endif
strategy strat(_ci);
- const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height);
- const unsigned int window_per_multi = window_per_batch * _nbatches;
-
- const unsigned int first_multi = start / window_per_multi;
- const unsigned int last_multi = end / window_per_multi;
-
- const unsigned int first_batch = (start - (first_multi * window_per_multi)) / window_per_batch;
- const unsigned int last_batch = (end - (last_multi * window_per_multi)) / window_per_batch;
-
- const unsigned int first_row = ((start - (first_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
- const unsigned int last_row = ((end - (last_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
-
static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same.");
static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
- for (unsigned int multi=first_multi; multi<=last_multi; multi++) {
- const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0;
- const unsigned int batch_max = (multi == last_multi) ? last_batch : (_nbatches-1);
+ /* Compute starting point based on 'start' */
+ unsigned int multi = start / window_per_multi();
+ unsigned int multi_pos = start % window_per_multi();
+
+ unsigned int batch = multi_pos / window_per_batch();
+ unsigned int batch_pos = multi_pos % window_per_batch();
- for (unsigned int batch=batch_0; batch <= batch_max; batch++) {
- const unsigned int m_start = ((multi == first_multi) && (batch==first_batch)) ? first_row : 0;
- const unsigned int m_end = ((multi == last_multi) && (batch==last_batch)) ? last_row : _Msize;
+ unsigned int y0 = batch_pos * strategy::out_height();
- for (unsigned int y0=m_start; y0<m_end; y0+=strategy::out_height) {
- const unsigned int ymax = std::min(y0 + strategy::out_height, m_end);
+ for (unsigned int pos=start; pos<end; pos++) {
+ const unsigned int ymax = std::min(y0 + strategy::out_height(), _Msize);
#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize);
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax-y0) * _Nsize * _Ksize);
#endif
- strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
- this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
- this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
- _beta, (ymax-y0), _Nsize, _Ksize);
+ strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
+ this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
+ this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
+ _beta, (ymax-y0), _Nsize, _Ksize);
+
+ /* Advance to next item */
+ y0 += strategy::out_height();
+
+ /* Check for batch/multi overflow */
+ if (y0 >= _Msize) {
+ y0=0;
+ batch++;
+ if (batch == _nbatches) {
+ batch=0;
+ multi++;
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index 6db55c02d0..f4b712c453 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -25,20 +25,36 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
#include "kernels/a64_gemm_u16_12x8.hpp"
namespace arm_gemm {
+class GemmImpl_gemm_u16_interleaved : public GemmImplementation<uint16_t, uint32_t> {
+public:
+ UniqueGemmCommon<uint16_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint16_t, uint32_t>(new GemmInterleaved<gemm_u16_12x8, uint16_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u16_interleaved() : GemmImplementation<uint16_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<uint16_t, uint32_t> *> gemm_u16_methods = {
+ new GemmImpl_gemm_u16_interleaved()
+};
+
template<>
-UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, uint32_t alpha, uint32_t beta,
- const int maxthreads, const bool pretransposed_hint) {
- return UniqueGemmCommon<uint16_t, uint32_t>(new GemmInterleaved<gemm_u16_12x8, uint16_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+std::vector<GemmImplementation<uint16_t, uint32_t> *> &gemm_implementation_list<uint16_t, uint32_t>() {
+ return gemm_u16_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t>(GemmArgs<uint32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<uint16_t, uint32_t>(GemmArgs<uint32_t> &args);
+template bool method_is_compatible<uint16_t, uint32_t>(GemmMethod method, GemmArgs<uint32_t> &args);
+
} // namespace arm_gemm
#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index 1ca92f9d4e..d97dd5c3de 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -25,32 +25,52 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
-#include "kernels/a64_gemm_u8_4x4.hpp"
+#include "kernels/a64_gemm_u16_12x8.hpp"
#include "kernels/a64_gemm_u8_12x8.hpp"
+#include "kernels/a64_gemm_u8_4x4.hpp"
namespace arm_gemm {
-template<>
-UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const uint32_t alpha, const uint32_t beta,
- const int maxthreads, const bool pretransposed_hint) {
- if (ci.has_dotprod()) {
- // Dot product supporting CPUs. This family has a special version for A55r1.
- return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_12x8, uint8_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+class GemmImpl_gemm_u8_interleaved_dot : public GemmImplementation<uint8_t, uint32_t> {
+public:
+ bool is_supported(const GemmArgs<uint32_t> &args) override {
+ return args._ci->has_dotprod();
}
- // Non dot-product code.
- return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<uint8_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_12x8, uint8_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u8_interleaved_dot() : GemmImplementation<uint8_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED_DOT) { }
+};
- // TODO: There's a better approach for A53, but it doesn't work
- // well on heterogeneous systems as the required data formats
- // are different. Figure out how to enable this:
- // gemm = new GemmInterleaved<gemm_s16_12x8, int8_t, int32_t>(ci, M, N, K, trA, trB);
+class GemmImpl_gemm_u8_interleaved : public GemmImplementation<uint8_t, uint32_t> {
+public:
+ UniqueGemmCommon<uint8_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u8_interleaved() : GemmImplementation<uint8_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<uint8_t, uint32_t> *> gemm_u8_methods = {
+ new GemmImpl_gemm_u8_interleaved_dot(),
+ new GemmImpl_gemm_u8_interleaved()
+};
+
+template<>
+std::vector<GemmImplementation<uint8_t, uint32_t> *> &gemm_implementation_list<uint8_t, uint32_t>() {
+ return gemm_u8_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t>(GemmArgs<uint32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<uint8_t, uint32_t>(GemmArgs<uint32_t> &args);
+template bool method_is_compatible<uint8_t, uint32_t>(GemmMethod method, GemmArgs<uint32_t> &args);
+
} // namespace arm_gemm
-#endif // aarch64
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
index d91b44b9a8..d65971e47d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_batched.hpp
@@ -36,11 +36,12 @@ private:
UniqueGemmCommon<To, Tr> _subgemm = nullptr;
public:
- GemvBatched(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
- const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint) {
+ GemvBatched(const GemmArgs<Tr> &args) {
/* Just create a subgemm with batches->M */
- _subgemm = gemm<To,Tr>(ci, nbatches, N, K, 1, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint);
+ GemmArgs<Tr> newargs = args;
+ newargs._Msize = args._nbatches;
+ newargs._nbatches = 1;
+ _subgemm = gemm<To,Tr>(newargs, nullptr);
}
void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp
index 11a589d75c..1a3596511b 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_native_16x4.hpp
@@ -46,9 +46,17 @@ public:
typedef void (*kern_type)(const float *, int, const float *, int, float *, int, float, int, int, int);
/* Kernel blocking parameters */
- static const int out_width = 16;
- static const int out_height = 4;
- static const int k_unroll = 1;
+ static int out_width() {
+ return 16;
+ }
+
+ static int out_height() {
+ return 4;
+ }
+
+ static int k_unroll() {
+ return 1;
+ }
// Default to the generic kernel
kern_type kernel=a64_sgemm_native_16x4;