diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 98 |
1 files changed, 65 insertions, 33 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index c726d7b0aa..261e7d2d9c 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -24,6 +24,7 @@ #include "arm_gemm.hpp" +#include <cstdint> #include <functional> namespace arm_gemm { @@ -37,7 +38,7 @@ struct GemmImplementation { const GemmMethod method; const char * name; std::function<bool(const GemmArgs &, const OutputStage &)> is_supported; - std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended; + std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate; std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { @@ -48,17 +49,27 @@ struct GemmImplementation { } } - bool do_is_recommended(const GemmArgs &args, const OutputStage &os) const { - if (is_recommended != nullptr) { - return is_recommended(args, os); + uint64_t do_cycle_estimate(const GemmArgs &args, const OutputStage &os) const { + if (cycle_estimate != nullptr) { + return cycle_estimate(args, os); } else { - return true; + return 0; } } + GemmImplementation(const GemmImplementation &) = default; + GemmImplementation &operator= (const GemmImplementation &) = default; + GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const { return instantiate(args, os); } + + GemmImplementation(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : + method(m), name(n), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; /* Slightly different version of above for straightforward GEMMs with no @@ -69,7 +80,7 @@ struct GemmImplementation<Top, Tret, Nothing> { const GemmMethod method; const char * name; std::function<bool(const GemmArgs &)> is_supported; - std::function<bool(const GemmArgs &)> is_recommended; + std::function<uint64_t(const GemmArgs &)> cycle_estimate; std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate; bool do_is_supported(const GemmArgs &args, const Nothing &) const { @@ -80,17 +91,42 @@ struct GemmImplementation<Top, Tret, Nothing> { } } - bool do_is_recommended(const GemmArgs &args, const Nothing &) const { - if (is_recommended != nullptr) { - return is_recommended(args); + uint64_t do_cycle_estimate(const GemmArgs &args, const Nothing &) const { + if (cycle_estimate != nullptr) { + return cycle_estimate(args); } else { - return true; + return 0; } } GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const { return instantiate(args); } + + + static GemmImplementation with_estimate(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { + GemmImplementation impl(m,n); + + impl.is_supported=is_supported; + impl.cycle_estimate=cycle_estimate; + impl.instantiate=instantiate; + + return impl; + } + + GemmImplementation(GemmMethod m, const char * n) : method(m), name(n), is_supported(nullptr), cycle_estimate(nullptr), instantiate(nullptr) {} + + GemmImplementation(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) : + method(m), name(n), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } + + GemmImplementation(const GemmImplementation &) = default; + GemmImplementation &operator=(const GemmImplementation &) = default; }; /* "Master" function implemented for each valid combination of types. @@ -103,13 +139,11 @@ const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list(); /* * Select a GEMM implementation for the given arguments. * - * The logic here returns the first method on the list which supports the + * The logic here returns the method on the list which supports the * requested problem parameters, matches the provided filters (method and/or - * name string match) and recommends itself. - * - * If there is no such method, it will return the first method which - * supports the requested parameters and passes the filters, regardless of - * recommendation. + * name string match) and offers the lowest cycle estimate. A cycle + * estimate of '0' is treated as a special value, causing the corresponding + * method to be selected immediately. * * If no method supports the requested parameters and passes the filters, * this function returns false and doesn't touch the provided pointer @@ -121,6 +155,7 @@ bool find_implementation(const GemmArgs &args, const OutputStage &os, const Gemm const GemmConfig *cfg = args._cfg; const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr; + uint64_t best_estimate = 0; for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Skip if this implementation doesn't support these args. */ @@ -138,27 +173,24 @@ bool find_implementation(const GemmArgs &args, const OutputStage &os, const Gemm continue; } - /* At this point, if we don't have a saved implementation, save this - * one. This is so that we always return something if a filter - * matches, even if it doesn't recommend itself. - */ - if (saved_impl == nullptr) { - saved_impl=i; - } + /* Test the cycle estimate */ + uint64_t estimate = i->do_cycle_estimate(args, os); - /* Check that this method recommends itself. */ - if (!i->do_is_recommended(args, os)) { - continue; + /* Short circuit - if the estimate is zero, return this one immediately. */ + if (estimate==0) { + impl=i; + return true; } - impl=i; - - return true; + /* Otherwise, remember this is our best so far if we don't yet have + * a valid candidate, or we beat the estimate. */ + if ((saved_impl == nullptr) || (estimate < best_estimate)) { + saved_impl = i; + best_estimate = estimate; + } } - /* We didn't find an option matching the filters that recommended - * itself. But if we found something earlier that matched the filters - * but wasn't recommended, return it here. */ + /* Return whichever method gave the best estimate. */ if (saved_impl != nullptr) { impl = saved_impl; return true; @@ -183,7 +215,7 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons continue; } - res.push_back(KernelDescription(i->method, i->name, i==default_impl)); + res.push_back(KernelDescription(i->method, i->name, i==default_impl, i->do_cycle_estimate(args, os))); } return res; |