diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 224 |
1 files changed, 176 insertions, 48 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index 569d1f44ca..5e77df7d4a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020, 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -22,8 +22,11 @@ * SOFTWARE. */ -#include <arm_gemm.hpp> +#include "arm_gemm.hpp" +#include "kernel_weight_format.hpp" + +#include <cstdint> #include <functional> namespace arm_gemm { @@ -36,29 +39,81 @@ template<typename Top, typename Tret, class OutputStage = Nothing> struct GemmImplementation { const GemmMethod method; const char * name; - std::function<bool(const GemmArgs &, const OutputStage &)> is_supported; - std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended; - std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate; + const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {}; + std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {}; + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { - if (is_supported != nullptr) { - return is_supported(args, os); + // Check supplied is_supported() function first. + if (is_supported != nullptr && !is_supported(args, os)) { + return false; + } + + // Check weight format is appropriate. + if (args._fixed_format == false) { + // Can't return a fixed format kernel if we weren't asked for one. + return (kernel_weight_format == KernelWeightFormat::NON_FIXED); } else { - return true; + // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it. + if (kernel_weight_format == KernelWeightFormat::NON_FIXED) { + return false; + } + + // If there's no config, or the config says ANY then this one is OK. + if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) { + return true; + } + + // If we get here it means there is a config and it specifies a format. Check it matches this kernel. + // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() + // was called above first. + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } - bool do_is_recommended(const GemmArgs &args, const OutputStage &os) const { - if (is_recommended != nullptr) { - return is_recommended(args, os); + uint64_t do_cycle_estimate(const GemmArgs &args, const OutputStage &os) const { + if (cycle_estimate != nullptr) { + return cycle_estimate(args, os); } else { - return true; + return 0; } } GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const { return instantiate(args, os); } + + static GemmImplementation with_estimate(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) { + GemmImplementation impl(m,n); + + impl.is_supported=is_supported; + impl.cycle_estimate=cycle_estimate; + impl.instantiate=instantiate; + + return impl; + } + + GemmImplementation(const GemmImplementation &) = default; + GemmImplementation & operator= (const GemmImplementation &) = default; + + GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {} + + GemmImplementation(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : + method(m), name(n), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : + method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; /* Slightly different version of above for straightforward GEMMs with no @@ -68,34 +123,100 @@ template<typename Top, typename Tret> struct GemmImplementation<Top, Tret, Nothing> { const GemmMethod method; const char * name; - std::function<bool(const GemmArgs &)> is_supported; - std::function<bool(const GemmArgs &)> is_recommended; - std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate; + const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; + std::function<bool(const GemmArgs &)> is_supported = {}; + std::function<uint64_t(const GemmArgs &)> cycle_estimate = {}; + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const Nothing &) const { - if (is_supported != nullptr) { - return is_supported(args); + // Check supplied is_supported() function first. + if (is_supported != nullptr && !is_supported(args)) { + return false; + } + + // Check weight format is appropriate. + if (args._fixed_format == false) { + // Can't return a fixed format kernel if we weren't asked for one. + return (kernel_weight_format == KernelWeightFormat::NON_FIXED); } else { - return true; + // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it. + if (kernel_weight_format == KernelWeightFormat::NON_FIXED) { + return false; + } + + // If there's no config, or the config says ANY then this one is OK. + if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) { + return true; + } + + // If we get here it means there is a config and it specifies a format. Check it matches this kernel. + // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() + // was called above first. + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } - bool do_is_recommended(const GemmArgs &args, const Nothing &) const { - if (is_recommended != nullptr) { - return is_recommended(args); + uint64_t do_cycle_estimate(const GemmArgs &args, const Nothing &) const { + if (cycle_estimate != nullptr) { + return cycle_estimate(args); } else { - return true; + return 0; } } GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const { return instantiate(args); } + + static GemmImplementation with_estimate(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { + GemmImplementation impl(m,n); + + impl.is_supported=is_supported; + impl.cycle_estimate=cycle_estimate; + impl.instantiate=instantiate; + + return impl; + } + + static GemmImplementation with_estimate(GemmMethod m, const char *n, KernelWeightFormat f, + std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { + GemmImplementation impl(m,n,f); + + impl.is_supported=is_supported; + impl.cycle_estimate=cycle_estimate; + impl.instantiate=instantiate; + + return impl; + } + + GemmImplementation(const GemmImplementation &) = default; + GemmImplementation & operator= (const GemmImplementation &) = default; + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat f=KernelWeightFormat::NON_FIXED) : method(m), name(n), kernel_weight_format(f) {} + + GemmImplementation(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) : + method(m), name(n), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, + std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) : + method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; -/* "Master" function implemented for each valid combination of types. - * Returns a list of GEMM implementation descriptors for processing by the - * other functions, terminated by an implementation with +/* Provides the list of implementation descriptors which is processed by the + * other functions. + * + * A specialised version is provided for each supported combination of types. + * The end of the list is indicated by a sentinel descriptor with * method==GemmMethod::DEFAULT. */ template<typename Top, typename Tret, class OutputStage = Nothing> const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list(); @@ -103,13 +224,11 @@ const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list(); /* * Select a GEMM implementation for the given arguments. * - * The logic here returns the first method on the list which supports the + * The logic here returns the method on the list which supports the * requested problem parameters, matches the provided filters (method and/or - * name string match) and recommends itself. - * - * If there is no such method, it will return the first method which - * supports the requested parameters and passes the filters, regardless of - * recommendation. + * name string match) and offers the lowest cycle estimate. A cycle + * estimate of '0' is treated as a special value, causing the corresponding + * method to be selected immediately. * * If no method supports the requested parameters and passes the filters, * this function returns false and doesn't touch the provided pointer @@ -121,6 +240,7 @@ bool find_implementation(const GemmArgs &args, const OutputStage &os, const Gemm const GemmConfig *cfg = args._cfg; const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr; + uint64_t best_estimate = 0; for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Skip if this implementation doesn't support these args. */ @@ -138,27 +258,24 @@ bool find_implementation(const GemmArgs &args, const OutputStage &os, const Gemm continue; } - /* At this point, if we don't have a saved implementation, save this - * one. This is so that we always return something if a filter - * matches, even if it doesn't recommend itself. - */ - if (saved_impl == nullptr) { - saved_impl=i; - } + /* Test the cycle estimate */ + uint64_t estimate = i->do_cycle_estimate(args, os); - /* Check that this method recommends itself. */ - if (!i->do_is_recommended(args, os)) { - continue; + /* Short circuit - if the estimate is zero, return this one immediately. */ + if (estimate==0) { + impl=i; + return true; } - impl=i; - - return true; + /* Otherwise, remember this is our best so far if we don't yet have + * a valid candidate, or we beat the estimate. */ + if ((saved_impl == nullptr) || (estimate < best_estimate)) { + saved_impl = i; + best_estimate = estimate; + } } - /* We didn't find an option matching the filters that recommended - * itself. But if we found something earlier that matched the filters - * but wasn't recommended, return it here. */ + /* Return whichever method gave the best estimate. */ if (saved_impl != nullptr) { impl = saved_impl; return true; @@ -179,17 +296,27 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Check that this implementation supports the presented problem. */ + if (!i->do_is_supported(args, os)) { continue; } - res.push_back(KernelDescription(i->method, i->name, i==default_impl)); + res.push_back(KernelDescription(i->method, i->name, i==default_impl, i->do_cycle_estimate(args, os))); } return res; } template<typename Top, typename Tret, class OutputStage> +bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) { + const GemmImplementation<Top, Tret, OutputStage> *impl; + const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl); + if (success) + wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format; + return success; +} + +template<typename Top, typename Tret, class OutputStage> UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) { const GemmImplementation<Top, Tret, OutputStage> *impl; @@ -212,4 +339,5 @@ KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) { return KernelDescription(); } + } // namespace arm_gemm |