diff options
author | Francesco.Petrogalli@arm.com <francesco.petrogalli@arm.com> | 2022-04-05 10:31:08 +0000 |
---|---|---|
committer | Francesco Petrogalli <francesco.petrogalli@arm.com> | 2022-05-24 14:28:27 +0000 |
commit | 5fcf22dadf092efd7aafb359f9229aa270eb1129 (patch) | |
tree | f309426ed19bd6710329da3b530167db72d1c6b2 /src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | |
parent | a8caa023f0d7b71b3a250a14ceee935052fcc74a (diff) | |
download | ComputeLibrary-5fcf22dadf092efd7aafb359f9229aa270eb1129.tar.gz |
[arm_gemm] Import fixed-format kernels from gemm_linux.
This is a No Functional Change Intended (NFCI) patch. It imports the
kernel in the code, but the interface to select them and expose the
format of the weight tensors to the user will be provided in a
subsequent patch.
Kernels and kernel selection code in arm_gemm has been provided
by David.Mansell <David.Mansell@arm.com>.
The kernels are not compiled in the library by default, but need to be
selected via the `scons` option `experimental_fixed_format_kernels=1`.
Resolves: ONCPUML-829
Signed-off-by: Francesco.Petrogalli@arm.com <francesco.petrogalli@arm.com>
Change-Id: If00ccb2b9b7221e01b214cf9783111226ccc8bf4
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7380
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 97 |
1 files changed, 90 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index cb3ff7aa29..75fb1cb306 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -24,6 +24,8 @@ #include "arm_gemm.hpp" +#include "kernel_weight_format.hpp" + #include <cstdint> #include <functional> @@ -37,15 +39,36 @@ template<typename Top, typename Tret, class OutputStage = Nothing> struct GemmImplementation { const GemmMethod method; const char * name; + const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {}; std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {}; std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { - if (is_supported != nullptr) { - return is_supported(args, os); + // Check supplied is_supported() function first. + if (is_supported != nullptr && !is_supported(args, os)) { + return false; + } + + // Check weight format is appropriate. + if (args._fixed_format == false) { + // Can't return a fixed format kernel if we weren't asked for one. + return (kernel_weight_format == KernelWeightFormat::NON_FIXED); } else { - return true; + // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it. + if (kernel_weight_format == KernelWeightFormat::NON_FIXED) { + return false; + } + + // If there's no config, or the config says ANY then this one is OK. + if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) { + return true; + } + + // If we get here it means there is a config and it specifies a format. Check it matches this kernel. + // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() + // was called above first. + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } @@ -84,6 +107,13 @@ struct GemmImplementation { method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : + method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; /* Slightly different version of above for straightforward GEMMs with no @@ -93,15 +123,36 @@ template<typename Top, typename Tret> struct GemmImplementation<Top, Tret, Nothing> { const GemmMethod method; const char * name; + const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; std::function<bool(const GemmArgs &)> is_supported = {}; std::function<uint64_t(const GemmArgs &)> cycle_estimate = {}; std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const Nothing &) const { - if (is_supported != nullptr) { - return is_supported(args); + // Check supplied is_supported() function first. + if (is_supported != nullptr && !is_supported(args)) { + return false; + } + + // Check weight format is appropriate. + if (args._fixed_format == false) { + // Can't return a fixed format kernel if we weren't asked for one. + return (kernel_weight_format == KernelWeightFormat::NON_FIXED); } else { - return true; + // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it. + if (kernel_weight_format == KernelWeightFormat::NON_FIXED) { + return false; + } + + // If there's no config, or the config says ANY then this one is OK. + if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) { + return true; + } + + // If we get here it means there is a config and it specifies a format. Check it matches this kernel. + // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() + // was called above first. + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } @@ -129,10 +180,22 @@ struct GemmImplementation<Top, Tret, Nothing> { return impl; } + static GemmImplementation with_estimate(GemmMethod m, const char *n, KernelWeightFormat f, + std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { + GemmImplementation impl(m,n,f); + + impl.is_supported=is_supported; + impl.cycle_estimate=cycle_estimate; + impl.instantiate=instantiate; + + return impl; + } + GemmImplementation(const GemmImplementation &) = default; GemmImplementation & operator= (const GemmImplementation &) = default; - GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {} + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat f=KernelWeightFormat::NON_FIXED) : method(m), name(n), kernel_weight_format(f) {} GemmImplementation(GemmMethod m, const char *n, std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, @@ -140,6 +203,13 @@ struct GemmImplementation<Top, Tret, Nothing> { method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, + std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) : + method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; /* "Main" function implemented for each valid combination of types. @@ -252,4 +322,17 @@ UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) { return UniqueGemmCommon<Top, Tret>(nullptr); } +template<typename Top, typename Tret, class OutputStage> +KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) { + const GemmImplementation<Top, Tret, OutputStage> *impl; + + if (find_implementation<Top, Tret>(args, os, impl)) { + return KernelDescription(impl->method, impl->name); + } + + /* This shouldn't happen - there should always be at least one valid implementation. */ + return KernelDescription(); +} + + } // namespace arm_gemm |