aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp123
1 files changed, 83 insertions, 40 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index d952140959..55d72f88cb 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -28,21 +28,77 @@
namespace arm_gemm {
-template<typename Top, typename Tret>
+/* Structure describing an implementation. For each supported combination
+ * of types, a static list of these structures is built up to describe the
+ * implementations available.
+ */
+template<typename Top, typename Tret, class OutputStage = Nothing>
struct GemmImplementation {
+ const GemmMethod method;
+ const char * name;
+ std::function<bool(const GemmArgs<Tret> &, const OutputStage &)> is_supported;
+ std::function<bool(const GemmArgs<Tret> &, const OutputStage &)> is_recommended;
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs<Tret> &, const OutputStage &)> instantiate;
+
+ bool do_is_supported(const GemmArgs<Tret> &args, const OutputStage &os) const {
+ if (is_supported != nullptr) {
+ return is_supported(args, os);
+ } else {
+ return true;
+ }
+ }
+
+ bool do_is_recommended(const GemmArgs<Tret> &args, const OutputStage &os) const {
+ if (is_recommended != nullptr) {
+ return is_recommended(args, os);
+ } else {
+ return true;
+ }
+ }
+
+ GemmCommon<Top, Tret> *do_instantiate(const GemmArgs<Tret> &args, const OutputStage &os) const {
+ return instantiate(args, os);
+ }
+};
+
+/* Slightly different version of above for straightforward GEMMs with no
+ * output stage, so the std::functions there don't have to deal with the
+ * unnecessary second argument. */
+template<typename Top, typename Tret>
+struct GemmImplementation<Top, Tret, Nothing> {
const GemmMethod method;
const char * name;
std::function<bool(const GemmArgs<Tret> &)> is_supported;
std::function<bool(const GemmArgs<Tret> &)> is_recommended;
std::function<GemmCommon<Top, Tret> *(const GemmArgs<Tret> &)> instantiate;
+
+ bool do_is_supported(const GemmArgs<Tret> &args, const Nothing &) const {
+ if (is_supported != nullptr) {
+ return is_supported(args);
+ } else {
+ return true;
+ }
+ }
+
+ bool do_is_recommended(const GemmArgs<Tret> &args, const Nothing &) const {
+ if (is_recommended != nullptr) {
+ return is_recommended(args);
+ } else {
+ return true;
+ }
+ }
+
+ GemmCommon<Top, Tret> *do_instantiate(const GemmArgs<Tret> &args, const Nothing &) const {
+ return instantiate(args);
+ }
};
/* "Master" function implemented for each valid combination of types.
* Returns a list of GEMM implementation descriptors for processing by the
* other functions, terminated by an implementation with
* method==GemmMethod::DEFAULT. */
-template<typename Top, typename Tret>
-const GemmImplementation<Top, Tret> *gemm_implementation_list();
+template<typename Top, typename Tret, class OutputStage = Nothing>
+const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list();
/*
* Select a GEMM implementation for the given arguments.
@@ -59,16 +115,16 @@ const GemmImplementation<Top, Tret> *gemm_implementation_list();
* this function returns false and doesn't touch the provided pointer
* reference.
*/
-template<typename Top, typename Tret>
-bool find_implementation(const GemmArgs<Tret> &args, const GemmImplementation<Top, Tret> * &impl) {
- auto gemms = gemm_implementation_list<Top, Tret>();
+template<typename Top, typename Tret, class OutputStage>
+bool find_implementation(const GemmArgs<Tret> &args, const OutputStage &os, const GemmImplementation<Top, Tret, OutputStage> * &impl) {
+ auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
const GemmConfig *cfg = args._cfg;
- const GemmImplementation<Top, Tret> *saved_impl = nullptr;
+ const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr;
- for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) {
+ for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
/* Skip if this implementation doesn't support these args. */
- if (i->is_supported != nullptr && !i->is_supported(args)) {
+ if (!i->do_is_supported(args, os)) {
continue;
}
@@ -91,7 +147,7 @@ bool find_implementation(const GemmArgs<Tret> &args, const GemmImplementation<To
}
/* Check that this method recommends itself. */
- if (i->is_recommended != nullptr && !i->is_recommended(args)) {
+ if (!i->do_is_recommended(args, os)) {
continue;
}
@@ -111,19 +167,19 @@ bool find_implementation(const GemmArgs<Tret> &args, const GemmImplementation<To
return false;
}
-template<typename Top, typename Tret>
-std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args) {
+template<typename Top, typename Tret, class OutputStage>
+std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args, const OutputStage &os) {
std::vector<KernelDescription> res;
/* Find out what the default implementation in so we can set the flag accordingly later. */
- const GemmImplementation<Top, Tret> *default_impl;
- find_implementation(args, default_impl);
+ const GemmImplementation<Top, Tret, OutputStage> *default_impl;
+ find_implementation(args, os, default_impl);
- auto gemms = gemm_implementation_list<Top, Tret>();
+ auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
- for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) {
+ for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
/* Check that this implementation supports the presented problem. */
- if (i->is_supported != nullptr && !i->is_supported(args)) {
+ if (!i->do_is_supported(args, os)) {
continue;
}
@@ -133,22 +189,22 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args
return res;
}
-template<typename Top, typename Tret>
-UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args) {
- const GemmImplementation<Top, Tret> *impl;
+template<typename Top, typename Tret, class OutputStage>
+UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args, const OutputStage &os) {
+ const GemmImplementation<Top, Tret, OutputStage> *impl;
- if (find_implementation<Top, Tret>(args, impl)) {
- return UniqueGemmCommon<Top, Tret>(impl->instantiate(args));
+ if (find_implementation<Top, Tret, OutputStage>(args, os, impl)) {
+ return UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os));
}
return UniqueGemmCommon<Top, Tret>(nullptr);
}
-template<typename Top, typename Tret>
-KernelDescription get_gemm_method(const GemmArgs<Tret> &args) {
- const GemmImplementation<Top, Tret> *impl;
+template<typename Top, typename Tret, class OutputStage>
+KernelDescription get_gemm_method(const GemmArgs<Tret> &args, const OutputStage &os) {
+ const GemmImplementation<Top, Tret, OutputStage> *impl;
- if (find_implementation<Top, Tret>(args, impl)) {
+ if (find_implementation<Top, Tret>(args, os, impl)) {
return KernelDescription(impl->method, impl->name);
}
@@ -156,17 +212,4 @@ KernelDescription get_gemm_method(const GemmArgs<Tret> &args) {
return KernelDescription();
}
-template<typename Top, typename Tret>
-bool method_is_compatible(GemmMethod method, const GemmArgs<Tret> &args) {
- /* Determine if the method is valid by attempting to obtain an implementation specifying this method. */
- GemmConfig cfg(method);
- GemmArgs<Tret> myargs = args;
-
- myargs._cfg = &cfg;
-
- const GemmImplementation<Top, Tret> *impl;
-
- return find_implementation<Top, Tret>(myargs, impl);
-}
-
-} // namespace arm_gemm \ No newline at end of file
+} // namespace arm_gemm