aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
diff options
context:
space:
mode:
authorFrancesco.Petrogalli@arm.com <francesco.petrogalli@arm.com>2022-04-05 10:31:08 +0000
committerFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-05-24 14:28:27 +0000
commit5fcf22dadf092efd7aafb359f9229aa270eb1129 (patch)
treef309426ed19bd6710329da3b530167db72d1c6b2 /src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
parenta8caa023f0d7b71b3a250a14ceee935052fcc74a (diff)
downloadComputeLibrary-5fcf22dadf092efd7aafb359f9229aa270eb1129.tar.gz
[arm_gemm] Import fixed-format kernels from gemm_linux.
This is a No Functional Change Intended (NFCI) patch. It imports the kernel in the code, but the interface to select them and expose the format of the weight tensors to the user will be provided in a subsequent patch. Kernels and kernel selection code in arm_gemm has been provided by David.Mansell <David.Mansell@arm.com>. The kernels are not compiled in the library by default, but need to be selected via the `scons` option `experimental_fixed_format_kernels=1`. Resolves: ONCPUML-829 Signed-off-by: Francesco.Petrogalli@arm.com <francesco.petrogalli@arm.com> Change-Id: If00ccb2b9b7221e01b214cf9783111226ccc8bf4 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7380 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp97
1 files changed, 90 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index cb3ff7aa29..75fb1cb306 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -24,6 +24,8 @@
#include "arm_gemm.hpp"
+#include "kernel_weight_format.hpp"
+
#include <cstdint>
#include <functional>
@@ -37,15 +39,36 @@ template<typename Top, typename Tret, class OutputStage = Nothing>
struct GemmImplementation {
const GemmMethod method;
const char * name;
+ const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED;
std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {};
std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {};
std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {};
bool do_is_supported(const GemmArgs &args, const OutputStage &os) const {
- if (is_supported != nullptr) {
- return is_supported(args, os);
+ // Check supplied is_supported() function first.
+ if (is_supported != nullptr && !is_supported(args, os)) {
+ return false;
+ }
+
+ // Check weight format is appropriate.
+ if (args._fixed_format == false) {
+ // Can't return a fixed format kernel if we weren't asked for one.
+ return (kernel_weight_format == KernelWeightFormat::NON_FIXED);
} else {
- return true;
+ // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it.
+ if (kernel_weight_format == KernelWeightFormat::NON_FIXED) {
+ return false;
+ }
+
+ // If there's no config, or the config says ANY then this one is OK.
+ if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) {
+ return true;
+ }
+
+ // If we get here it means there is a config and it specifies a format. Check it matches this kernel.
+ // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported()
+ // was called above first.
+ return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top)));
}
}
@@ -84,6 +107,13 @@ struct GemmImplementation {
method(m), name(n), is_supported(is_supported),
cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ),
instantiate(instantiate) { }
+
+ GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf,
+ std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended,
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) :
+ method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported),
+ cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ),
+ instantiate(instantiate) { }
};
/* Slightly different version of above for straightforward GEMMs with no
@@ -93,15 +123,36 @@ template<typename Top, typename Tret>
struct GemmImplementation<Top, Tret, Nothing> {
const GemmMethod method;
const char * name;
+ const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED;
std::function<bool(const GemmArgs &)> is_supported = {};
std::function<uint64_t(const GemmArgs &)> cycle_estimate = {};
std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {};
bool do_is_supported(const GemmArgs &args, const Nothing &) const {
- if (is_supported != nullptr) {
- return is_supported(args);
+ // Check supplied is_supported() function first.
+ if (is_supported != nullptr && !is_supported(args)) {
+ return false;
+ }
+
+ // Check weight format is appropriate.
+ if (args._fixed_format == false) {
+ // Can't return a fixed format kernel if we weren't asked for one.
+ return (kernel_weight_format == KernelWeightFormat::NON_FIXED);
} else {
- return true;
+ // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it.
+ if (kernel_weight_format == KernelWeightFormat::NON_FIXED) {
+ return false;
+ }
+
+ // If there's no config, or the config says ANY then this one is OK.
+ if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) {
+ return true;
+ }
+
+ // If we get here it means there is a config and it specifies a format. Check it matches this kernel.
+ // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported()
+ // was called above first.
+ return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top)));
}
}
@@ -129,10 +180,22 @@ struct GemmImplementation<Top, Tret, Nothing> {
return impl;
}
+ static GemmImplementation with_estimate(GemmMethod m, const char *n, KernelWeightFormat f,
+ std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate,
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) {
+ GemmImplementation impl(m,n,f);
+
+ impl.is_supported=is_supported;
+ impl.cycle_estimate=cycle_estimate;
+ impl.instantiate=instantiate;
+
+ return impl;
+ }
+
GemmImplementation(const GemmImplementation &) = default;
GemmImplementation & operator= (const GemmImplementation &) = default;
- GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {}
+ GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat f=KernelWeightFormat::NON_FIXED) : method(m), name(n), kernel_weight_format(f) {}
GemmImplementation(GemmMethod m, const char *n,
std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended,
@@ -140,6 +203,13 @@ struct GemmImplementation<Top, Tret, Nothing> {
method(m), name(n), is_supported(is_supported),
cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ),
instantiate(instantiate) { }
+
+ GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf,
+ std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended,
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) :
+ method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported),
+ cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ),
+ instantiate(instantiate) { }
};
/* "Main" function implemented for each valid combination of types.
@@ -252,4 +322,17 @@ UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) {
return UniqueGemmCommon<Top, Tret>(nullptr);
}
+template<typename Top, typename Tret, class OutputStage>
+KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) {
+ const GemmImplementation<Top, Tret, OutputStage> *impl;
+
+ if (find_implementation<Top, Tret>(args, os, impl)) {
+ return KernelDescription(impl->method, impl->name);
+ }
+
+ /* This shouldn't happen - there should always be at least one valid implementation. */
+ return KernelDescription();
+}
+
+
} // namespace arm_gemm