diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index 261e7d2d9c..f6a0fc5d52 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -37,9 +37,9 @@ template<typename Top, typename Tret, class OutputStage = Nothing> struct GemmImplementation { const GemmMethod method; const char * name; - std::function<bool(const GemmArgs &, const OutputStage &)> is_supported; - std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate; - std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate; + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {}; + std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {}; + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { if (is_supported != nullptr) { @@ -57,13 +57,13 @@ struct GemmImplementation { } } - GemmImplementation(const GemmImplementation &) = default; - GemmImplementation &operator= (const GemmImplementation &) = default; - GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const { return instantiate(args, os); } + GemmImplementation(const GemmImplementation &) = default; + GemmImplementation & operator= (const GemmImplementation &) = default; + GemmImplementation(GemmMethod m, const char *n, std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : @@ -79,9 +79,9 @@ template<typename Top, typename Tret> struct GemmImplementation<Top, Tret, Nothing> { const GemmMethod method; const char * name; - std::function<bool(const GemmArgs &)> is_supported; - std::function<uint64_t(const GemmArgs &)> cycle_estimate; - std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate; + std::function<bool(const GemmArgs &)> is_supported = {}; + std::function<uint64_t(const GemmArgs &)> cycle_estimate = {}; + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const Nothing &) const { if (is_supported != nullptr) { @@ -103,7 +103,6 @@ struct GemmImplementation<Top, Tret, Nothing> { return instantiate(args); } - static GemmImplementation with_estimate(GemmMethod m, const char *n, std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { @@ -116,7 +115,10 @@ struct GemmImplementation<Top, Tret, Nothing> { return impl; } - GemmImplementation(GemmMethod m, const char * n) : method(m), name(n), is_supported(nullptr), cycle_estimate(nullptr), instantiate(nullptr) {} + GemmImplementation(const GemmImplementation &) = default; + GemmImplementation & operator= (const GemmImplementation &) = default; + + GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {} GemmImplementation(GemmMethod m, const char *n, std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, @@ -124,9 +126,6 @@ struct GemmImplementation<Top, Tret, Nothing> { method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } - - GemmImplementation(const GemmImplementation &) = default; - GemmImplementation &operator=(const GemmImplementation &) = default; }; /* "Master" function implemented for each valid combination of types. @@ -211,6 +210,7 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Check that this implementation supports the presented problem. */ + if (!i->do_is_supported(args, os)) { continue; } |