aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp98
1 files changed, 65 insertions, 33 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index c726d7b0aa..261e7d2d9c 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -24,6 +24,7 @@
#include "arm_gemm.hpp"
+#include <cstdint>
#include <functional>
namespace arm_gemm {
@@ -37,7 +38,7 @@ struct GemmImplementation {
const GemmMethod method;
const char * name;
std::function<bool(const GemmArgs &, const OutputStage &)> is_supported;
- std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended;
+ std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate;
std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate;
bool do_is_supported(const GemmArgs &args, const OutputStage &os) const {
@@ -48,17 +49,27 @@ struct GemmImplementation {
}
}
- bool do_is_recommended(const GemmArgs &args, const OutputStage &os) const {
- if (is_recommended != nullptr) {
- return is_recommended(args, os);
+ uint64_t do_cycle_estimate(const GemmArgs &args, const OutputStage &os) const {
+ if (cycle_estimate != nullptr) {
+ return cycle_estimate(args, os);
} else {
- return true;
+ return 0;
}
}
+ GemmImplementation(const GemmImplementation &) = default;
+ GemmImplementation &operator= (const GemmImplementation &) = default;
+
GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const {
return instantiate(args, os);
}
+
+ GemmImplementation(GemmMethod m, const char *n,
+ std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended,
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) :
+ method(m), name(n), is_supported(is_supported),
+ cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ),
+ instantiate(instantiate) { }
};
/* Slightly different version of above for straightforward GEMMs with no
@@ -69,7 +80,7 @@ struct GemmImplementation<Top, Tret, Nothing> {
const GemmMethod method;
const char * name;
std::function<bool(const GemmArgs &)> is_supported;
- std::function<bool(const GemmArgs &)> is_recommended;
+ std::function<uint64_t(const GemmArgs &)> cycle_estimate;
std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate;
bool do_is_supported(const GemmArgs &args, const Nothing &) const {
@@ -80,17 +91,42 @@ struct GemmImplementation<Top, Tret, Nothing> {
}
}
- bool do_is_recommended(const GemmArgs &args, const Nothing &) const {
- if (is_recommended != nullptr) {
- return is_recommended(args);
+ uint64_t do_cycle_estimate(const GemmArgs &args, const Nothing &) const {
+ if (cycle_estimate != nullptr) {
+ return cycle_estimate(args);
} else {
- return true;
+ return 0;
}
}
GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const {
return instantiate(args);
}
+
+
+ static GemmImplementation with_estimate(GemmMethod m, const char *n,
+ std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate,
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) {
+ GemmImplementation impl(m,n);
+
+ impl.is_supported=is_supported;
+ impl.cycle_estimate=cycle_estimate;
+ impl.instantiate=instantiate;
+
+ return impl;
+ }
+
+ GemmImplementation(GemmMethod m, const char * n) : method(m), name(n), is_supported(nullptr), cycle_estimate(nullptr), instantiate(nullptr) {}
+
+ GemmImplementation(GemmMethod m, const char *n,
+ std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended,
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) :
+ method(m), name(n), is_supported(is_supported),
+ cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ),
+ instantiate(instantiate) { }
+
+ GemmImplementation(const GemmImplementation &) = default;
+ GemmImplementation &operator=(const GemmImplementation &) = default;
};
/* "Master" function implemented for each valid combination of types.
@@ -103,13 +139,11 @@ const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list();
/*
* Select a GEMM implementation for the given arguments.
*
- * The logic here returns the first method on the list which supports the
+ * The logic here returns the method on the list which supports the
* requested problem parameters, matches the provided filters (method and/or
- * name string match) and recommends itself.
- *
- * If there is no such method, it will return the first method which
- * supports the requested parameters and passes the filters, regardless of
- * recommendation.
+ * name string match) and offers the lowest cycle estimate. A cycle
+ * estimate of '0' is treated as a special value, causing the corresponding
+ * method to be selected immediately.
*
* If no method supports the requested parameters and passes the filters,
* this function returns false and doesn't touch the provided pointer
@@ -121,6 +155,7 @@ bool find_implementation(const GemmArgs &args, const OutputStage &os, const Gemm
const GemmConfig *cfg = args._cfg;
const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr;
+ uint64_t best_estimate = 0;
for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
/* Skip if this implementation doesn't support these args. */
@@ -138,27 +173,24 @@ bool find_implementation(const GemmArgs &args, const OutputStage &os, const Gemm
continue;
}
- /* At this point, if we don't have a saved implementation, save this
- * one. This is so that we always return something if a filter
- * matches, even if it doesn't recommend itself.
- */
- if (saved_impl == nullptr) {
- saved_impl=i;
- }
+ /* Test the cycle estimate */
+ uint64_t estimate = i->do_cycle_estimate(args, os);
- /* Check that this method recommends itself. */
- if (!i->do_is_recommended(args, os)) {
- continue;
+ /* Short circuit - if the estimate is zero, return this one immediately. */
+ if (estimate==0) {
+ impl=i;
+ return true;
}
- impl=i;
-
- return true;
+ /* Otherwise, remember this is our best so far if we don't yet have
+ * a valid candidate, or we beat the estimate. */
+ if ((saved_impl == nullptr) || (estimate < best_estimate)) {
+ saved_impl = i;
+ best_estimate = estimate;
+ }
}
- /* We didn't find an option matching the filters that recommended
- * itself. But if we found something earlier that matched the filters
- * but wasn't recommended, return it here. */
+ /* Return whichever method gave the best estimate. */
if (saved_impl != nullptr) {
impl = saved_impl;
return true;
@@ -183,7 +215,7 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons
continue;
}
- res.push_back(KernelDescription(i->method, i->name, i==default_impl));
+ res.push_back(KernelDescription(i->method, i->name, i==default_impl, i->do_cycle_estimate(args, os)));
}
return res;