diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-11-02 01:37:17 +0000 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-11-12 15:59:25 +0000 |
commit | c0b6f76561580414f08633a804fc548ccad65659 (patch) | |
tree | 4d46b7f479de04f799e29095392948aeb370c029 /src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | |
parent | 824061d9910ebb42cbe46b677c0b843db212c9a2 (diff) | |
download | ComputeLibrary-c0b6f76561580414f08633a804fc548ccad65659.tar.gz |
COMPMID-3776: Indirect GEMM
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index 261e7d2d9c..f6a0fc5d52 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -37,9 +37,9 @@ template<typename Top, typename Tret, class OutputStage = Nothing> struct GemmImplementation { const GemmMethod method; const char * name; - std::function<bool(const GemmArgs &, const OutputStage &)> is_supported; - std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate; - std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate; + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {}; + std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {}; + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { if (is_supported != nullptr) { @@ -57,13 +57,13 @@ struct GemmImplementation { } } - GemmImplementation(const GemmImplementation &) = default; - GemmImplementation &operator= (const GemmImplementation &) = default; - GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const { return instantiate(args, os); } + GemmImplementation(const GemmImplementation &) = default; + GemmImplementation & operator= (const GemmImplementation &) = default; + GemmImplementation(GemmMethod m, const char *n, std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : @@ -79,9 +79,9 @@ template<typename Top, typename Tret> struct GemmImplementation<Top, Tret, Nothing> { const GemmMethod method; const char * name; - std::function<bool(const GemmArgs &)> is_supported; - std::function<uint64_t(const GemmArgs &)> cycle_estimate; - std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate; + std::function<bool(const GemmArgs &)> is_supported = {}; + std::function<uint64_t(const GemmArgs &)> cycle_estimate = {}; + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const Nothing &) const { if (is_supported != nullptr) { @@ -103,7 +103,6 @@ struct GemmImplementation<Top, Tret, Nothing> { return instantiate(args); } - static GemmImplementation with_estimate(GemmMethod m, const char *n, std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { @@ -116,7 +115,10 @@ struct GemmImplementation<Top, Tret, Nothing> { return impl; } - GemmImplementation(GemmMethod m, const char * n) : method(m), name(n), is_supported(nullptr), cycle_estimate(nullptr), instantiate(nullptr) {} + GemmImplementation(const GemmImplementation &) = default; + GemmImplementation & operator= (const GemmImplementation &) = default; + + GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {} GemmImplementation(GemmMethod m, const char *n, std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, @@ -124,9 +126,6 @@ struct GemmImplementation<Top, Tret, Nothing> { method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } - - GemmImplementation(const GemmImplementation &) = default; - GemmImplementation &operator=(const GemmImplementation &) = default; }; /* "Master" function implemented for each valid combination of types. @@ -211,6 +210,7 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Check that this implementation supports the presented problem. */ + if (!i->do_is_supported(args, os)) { continue; } |