aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp153
1 files changed, 95 insertions, 58 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 6734e3cce0..bf80784b79 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -22,56 +22,53 @@
* SOFTWARE.
*/
-#include "gemv_batched.hpp"
+#include <arm_gemm.hpp>
-namespace arm_gemm {
-
-template<typename Top, typename Tret>
-class GemmImplementation {
-public:
- /* Is this implementation compatible with the args as provided? */
- virtual bool is_supported(const GemmArgs<Tret> &args) { return true; }
- /* Is this implementation "recommended" for these args (heuristic)? */
- virtual bool is_recommended(const GemmArgs<Tret> &args) { return true; }
- /* Instantiate this method please. */
- virtual UniqueGemmCommon<Top, Tret> instantiate(const GemmArgs<Tret> &args) = 0;
+#include <functional>
- /* Indicate the "GemmMethod" for use as a selector */
- const GemmMethod method;
-
- virtual ~GemmImplementation() { }
-
- GemmImplementation(GemmMethod method) : method(method) { }
-};
+namespace arm_gemm {
-/* "gemv_batched" implementation is type-agnostic, so template it here. */
template<typename Top, typename Tret>
-class GemmImpl_gemv_batched : public GemmImplementation<Top, Tret> {
-public:
- bool is_supported(const GemmArgs<Tret> &args) override {
- return (args._Msize==1 && args._nbatches > 1);
- }
-
- UniqueGemmCommon<Top, Tret> instantiate(const GemmArgs<Tret> &args) override {
- return UniqueGemmCommon<Top, Tret> (new GemvBatched<Top, Tret>(args));
- }
-
- GemmImpl_gemv_batched() : GemmImplementation<Top, Tret>(GemmMethod::GEMV_BATCHED) { }
+struct GemmImplementation {
+ const GemmMethod method;
+ const char * name;
+ std::function<bool(const GemmArgs<Tret> &)> is_supported;
+ std::function<bool(const GemmArgs<Tret> &)> is_recommended;
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs<Tret> &)> instantiate;
};
/* "Master" function implemented for each valid combination of types.
* Returns a list of GEMM implementation descriptors for processing by the
- * other functions. */
+ * other functions, terminated by an implementation with
+ * method==GemmMethod::DEFAULT. */
template<typename Top, typename Tret>
-std::vector<GemmImplementation<Top, Tret> *> &gemm_implementation_list();
+const GemmImplementation<Top, Tret> *gemm_implementation_list();
+/*
+ * Select a GEMM implementation for the given arguments.
+ *
+ * The logic here returns the first method on the list which supports the
+ * requested problem parameters, matches the provided filters (method and/or
+ * name string match) and recommends itself.
+ *
+ * If there is no such method, it will return the first method which
+ * supports the requested parameters and passes the filters, regardless of
+ * recommendation.
+ *
+ * If no method supports the requested parameters and passes the filters,
+ * this function returns false and doesn't touch the provided pointer
+ * reference.
+ */
template<typename Top, typename Tret>
-GemmImplementation<Top, Tret> *find_implementation(GemmArgs<Tret> &args, GemmConfig *cfg) {
+bool find_implementation(const GemmArgs<Tret> &args, const GemmImplementation<Top, Tret> * &impl) {
auto gemms = gemm_implementation_list<Top, Tret>();
+ const GemmConfig *cfg = args._cfg;
- for(auto &&i : gemms) {
+ const GemmImplementation<Top, Tret> *saved_impl = nullptr;
+
+ for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) {
/* Skip if this implementation doesn't support these args. */
- if (!i->is_supported(args)) {
+ if (i->is_supported != nullptr && !i->is_supported(args)) {
continue;
}
@@ -80,52 +77,92 @@ GemmImplementation<Top, Tret> *find_implementation(GemmArgs<Tret> &args, GemmCon
continue;
}
- /* If no specific method is requested, check that this method recommends itself. */
- if ((!cfg || cfg->method == GemmMethod::DEFAULT) && !i->is_recommended(args)) {
+ /* Skip if a filter is to be applied and it doesn't match. */
+ if (cfg && cfg->filter != "" && !strstr(i->name, cfg->filter.c_str())) {
+ continue;
+ }
+
+ /* At this point, if we don't have a saved implementation, save this
+ * one. This is so that we always return something if a filter
+ * matches, even if it doesn't recommend itself.
+ */
+ if (saved_impl == nullptr) {
+ saved_impl=i;
+ }
+
+ /* Check that this method recommends itself. */
+ if (i->is_recommended != nullptr && !i->is_recommended(args)) {
+ continue;
+ }
+
+ impl=i;
+
+ return true;
+ }
+
+ /* We didn't find an option matching the filters that recommended
+ * itself. But if we found something earlier that matched the filters
+ * but wasn't recommended, return it here. */
+ if (saved_impl != nullptr) {
+ impl = saved_impl;
+ return true;
+ }
+
+ return false;
+}
+
+template<typename Top, typename Tret>
+std::vector<std::string> get_compatible_kernels(const GemmArgs<Tret> &args) {
+ std::vector<std::string> res;
+
+ auto gemms = gemm_implementation_list<Top, Tret>();
+
+ for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) {
+ /* Check that this implementation supports the presented problem. */
+ if (i->is_supported != nullptr && !i->is_supported(args)) {
continue;
}
- return i;
+ res.push_back(i->name);
}
- return nullptr;
+ return res;
}
template<typename Top, typename Tret>
-UniqueGemmCommon<Top, Tret> gemm(GemmArgs<Tret> &args, GemmConfig *cfg) {
- auto impl = find_implementation<Top, Tret>(args, cfg);
+UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args) {
+ const GemmImplementation<Top, Tret> *impl;
- if (impl) {
- return impl->instantiate(args);
+ if (find_implementation<Top, Tret>(args, impl)) {
+ return UniqueGemmCommon<Top, Tret>(impl->instantiate(args));
}
return UniqueGemmCommon<Top, Tret>(nullptr);
}
template<typename Top, typename Tret>
-GemmMethod get_gemm_method(GemmArgs<Tret> &args) {
- auto impl = find_implementation<Top, Tret>(args, nullptr);
+KernelDescription get_gemm_method(const GemmArgs<Tret> &args) {
+ const GemmImplementation<Top, Tret> *impl;
- if (impl) {
- return impl->method;
+ if (find_implementation<Top, Tret>(args, impl)) {
+ return KernelDescription(impl->method, impl->name);
}
/* This shouldn't happen - there should always be at least one valid implementation. */
- return GemmMethod::DEFAULT;
+ return KernelDescription();
}
template<typename Top, typename Tret>
-bool method_is_compatible(GemmMethod method, GemmArgs<Tret> &args) {
+bool method_is_compatible(GemmMethod method, const GemmArgs<Tret> &args) {
/* Determine if the method is valid by attempting to obtain an implementation specifying this method. */
- GemmConfig cfg(method);
+ GemmConfig cfg(method);
+ GemmArgs<Tret> myargs = args;
- auto impl = find_implementation<Top, Tret>(args, &cfg);
+ myargs._cfg = &cfg;
- if (impl) {
- return true;
- }
+ const GemmImplementation<Top, Tret> *impl;
- return false;
+ return find_implementation<Top, Tret>(myargs, impl);
}
-} // namespace arm_gemm
+} // namespace arm_gemm \ No newline at end of file