diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 118 |
1 files changed, 107 insertions, 11 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index f6a0fc5d52..5e77df7d4a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2020, 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,6 +24,8 @@ #include "arm_gemm.hpp" +#include "kernel_weight_format.hpp" + #include <cstdint> #include <functional> @@ -37,15 +39,36 @@ template<typename Top, typename Tret, class OutputStage = Nothing> struct GemmImplementation { const GemmMethod method; const char * name; + const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {}; std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {}; std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { - if (is_supported != nullptr) { - return is_supported(args, os); + // Check supplied is_supported() function first. + if (is_supported != nullptr && !is_supported(args, os)) { + return false; + } + + // Check weight format is appropriate. + if (args._fixed_format == false) { + // Can't return a fixed format kernel if we weren't asked for one. + return (kernel_weight_format == KernelWeightFormat::NON_FIXED); } else { - return true; + // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it. + if (kernel_weight_format == KernelWeightFormat::NON_FIXED) { + return false; + } + + // If there's no config, or the config says ANY then this one is OK. + if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) { + return true; + } + + // If we get here it means there is a config and it specifies a format. Check it matches this kernel. + // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() + // was called above first. + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } @@ -61,15 +84,36 @@ struct GemmImplementation { return instantiate(args, os); } + static GemmImplementation with_estimate(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) { + GemmImplementation impl(m,n); + + impl.is_supported=is_supported; + impl.cycle_estimate=cycle_estimate; + impl.instantiate=instantiate; + + return impl; + } + GemmImplementation(const GemmImplementation &) = default; GemmImplementation & operator= (const GemmImplementation &) = default; + GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {} + GemmImplementation(GemmMethod m, const char *n, std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : + method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; /* Slightly different version of above for straightforward GEMMs with no @@ -79,15 +123,36 @@ template<typename Top, typename Tret> struct GemmImplementation<Top, Tret, Nothing> { const GemmMethod method; const char * name; + const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; std::function<bool(const GemmArgs &)> is_supported = {}; std::function<uint64_t(const GemmArgs &)> cycle_estimate = {}; std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const Nothing &) const { - if (is_supported != nullptr) { - return is_supported(args); + // Check supplied is_supported() function first. + if (is_supported != nullptr && !is_supported(args)) { + return false; + } + + // Check weight format is appropriate. + if (args._fixed_format == false) { + // Can't return a fixed format kernel if we weren't asked for one. + return (kernel_weight_format == KernelWeightFormat::NON_FIXED); } else { - return true; + // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it. + if (kernel_weight_format == KernelWeightFormat::NON_FIXED) { + return false; + } + + // If there's no config, or the config says ANY then this one is OK. + if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) { + return true; + } + + // If we get here it means there is a config and it specifies a format. Check it matches this kernel. + // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() + // was called above first. + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } @@ -115,10 +180,22 @@ struct GemmImplementation<Top, Tret, Nothing> { return impl; } + static GemmImplementation with_estimate(GemmMethod m, const char *n, KernelWeightFormat f, + std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { + GemmImplementation impl(m,n,f); + + impl.is_supported=is_supported; + impl.cycle_estimate=cycle_estimate; + impl.instantiate=instantiate; + + return impl; + } + GemmImplementation(const GemmImplementation &) = default; GemmImplementation & operator= (const GemmImplementation &) = default; - GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {} + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat f=KernelWeightFormat::NON_FIXED) : method(m), name(n), kernel_weight_format(f) {} GemmImplementation(GemmMethod m, const char *n, std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, @@ -126,11 +203,20 @@ struct GemmImplementation<Top, Tret, Nothing> { method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, + std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) : + method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; -/* "Master" function implemented for each valid combination of types. - * Returns a list of GEMM implementation descriptors for processing by the - * other functions, terminated by an implementation with +/* Provides the list of implementation descriptors which is processed by the + * other functions. + * + * A specialised version is provided for each supported combination of types. + * The end of the list is indicated by a sentinel descriptor with * method==GemmMethod::DEFAULT. */ template<typename Top, typename Tret, class OutputStage = Nothing> const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list(); @@ -222,6 +308,15 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons } template<typename Top, typename Tret, class OutputStage> +bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) { + const GemmImplementation<Top, Tret, OutputStage> *impl; + const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl); + if (success) + wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format; + return success; +} + +template<typename Top, typename Tret, class OutputStage> UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) { const GemmImplementation<Top, Tret, OutputStage> *impl; @@ -244,4 +339,5 @@ KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) { return KernelDescription(); } + } // namespace arm_gemm |