diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 97 |
1 files changed, 90 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index cb3ff7aa29..75fb1cb306 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -24,6 +24,8 @@ #include "arm_gemm.hpp" +#include "kernel_weight_format.hpp" + #include <cstdint> #include <functional> @@ -37,15 +39,36 @@ template<typename Top, typename Tret, class OutputStage = Nothing> struct GemmImplementation { const GemmMethod method; const char * name; + const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {}; std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {}; std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { - if (is_supported != nullptr) { - return is_supported(args, os); + // Check supplied is_supported() function first. + if (is_supported != nullptr && !is_supported(args, os)) { + return false; + } + + // Check weight format is appropriate. + if (args._fixed_format == false) { + // Can't return a fixed format kernel if we weren't asked for one. + return (kernel_weight_format == KernelWeightFormat::NON_FIXED); } else { - return true; + // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it. + if (kernel_weight_format == KernelWeightFormat::NON_FIXED) { + return false; + } + + // If there's no config, or the config says ANY then this one is OK. + if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) { + return true; + } + + // If we get here it means there is a config and it specifies a format. Check it matches this kernel. + // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() + // was called above first. + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } @@ -84,6 +107,13 @@ struct GemmImplementation { method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : + method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; /* Slightly different version of above for straightforward GEMMs with no @@ -93,15 +123,36 @@ template<typename Top, typename Tret> struct GemmImplementation<Top, Tret, Nothing> { const GemmMethod method; const char * name; + const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; std::function<bool(const GemmArgs &)> is_supported = {}; std::function<uint64_t(const GemmArgs &)> cycle_estimate = {}; std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const Nothing &) const { - if (is_supported != nullptr) { - return is_supported(args); + // Check supplied is_supported() function first. + if (is_supported != nullptr && !is_supported(args)) { + return false; + } + + // Check weight format is appropriate. + if (args._fixed_format == false) { + // Can't return a fixed format kernel if we weren't asked for one. + return (kernel_weight_format == KernelWeightFormat::NON_FIXED); } else { - return true; + // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it. + if (kernel_weight_format == KernelWeightFormat::NON_FIXED) { + return false; + } + + // If there's no config, or the config says ANY then this one is OK. + if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) { + return true; + } + + // If we get here it means there is a config and it specifies a format. Check it matches this kernel. + // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() + // was called above first. + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } @@ -129,10 +180,22 @@ struct GemmImplementation<Top, Tret, Nothing> { return impl; } + static GemmImplementation with_estimate(GemmMethod m, const char *n, KernelWeightFormat f, + std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { + GemmImplementation impl(m,n,f); + + impl.is_supported=is_supported; + impl.cycle_estimate=cycle_estimate; + impl.instantiate=instantiate; + + return impl; + } + GemmImplementation(const GemmImplementation &) = default; GemmImplementation & operator= (const GemmImplementation &) = default; - GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {} + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat f=KernelWeightFormat::NON_FIXED) : method(m), name(n), kernel_weight_format(f) {} GemmImplementation(GemmMethod m, const char *n, std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, @@ -140,6 +203,13 @@ struct GemmImplementation<Top, Tret, Nothing> { method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, + std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) : + method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; /* "Main" function implemented for each valid combination of types. @@ -252,4 +322,17 @@ UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) { return UniqueGemmCommon<Top, Tret>(nullptr); } +template<typename Top, typename Tret, class OutputStage> +KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) { + const GemmImplementation<Top, Tret, OutputStage> *impl; + + if (find_implementation<Top, Tret>(args, os, impl)) { + return KernelDescription(impl->method, impl->name); + } + + /* This shouldn't happen - there should always be at least one valid implementation. */ + return KernelDescription(); +} + + } // namespace arm_gemm |