aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp28
1 files changed, 14 insertions, 14 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 261e7d2d9c..f6a0fc5d52 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -37,9 +37,9 @@ template<typename Top, typename Tret, class OutputStage = Nothing>
struct GemmImplementation {
const GemmMethod method;
const char * name;
- std::function<bool(const GemmArgs &, const OutputStage &)> is_supported;
- std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate;
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate;
+ std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {};
+ std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {};
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {};
bool do_is_supported(const GemmArgs &args, const OutputStage &os) const {
if (is_supported != nullptr) {
@@ -57,13 +57,13 @@ struct GemmImplementation {
}
}
- GemmImplementation(const GemmImplementation &) = default;
- GemmImplementation &operator= (const GemmImplementation &) = default;
-
GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const {
return instantiate(args, os);
}
+ GemmImplementation(const GemmImplementation &) = default;
+ GemmImplementation & operator= (const GemmImplementation &) = default;
+
GemmImplementation(GemmMethod m, const char *n,
std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended,
std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) :
@@ -79,9 +79,9 @@ template<typename Top, typename Tret>
struct GemmImplementation<Top, Tret, Nothing> {
const GemmMethod method;
const char * name;
- std::function<bool(const GemmArgs &)> is_supported;
- std::function<uint64_t(const GemmArgs &)> cycle_estimate;
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate;
+ std::function<bool(const GemmArgs &)> is_supported = {};
+ std::function<uint64_t(const GemmArgs &)> cycle_estimate = {};
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {};
bool do_is_supported(const GemmArgs &args, const Nothing &) const {
if (is_supported != nullptr) {
@@ -103,7 +103,6 @@ struct GemmImplementation<Top, Tret, Nothing> {
return instantiate(args);
}
-
static GemmImplementation with_estimate(GemmMethod m, const char *n,
std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate,
std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) {
@@ -116,7 +115,10 @@ struct GemmImplementation<Top, Tret, Nothing> {
return impl;
}
- GemmImplementation(GemmMethod m, const char * n) : method(m), name(n), is_supported(nullptr), cycle_estimate(nullptr), instantiate(nullptr) {}
+ GemmImplementation(const GemmImplementation &) = default;
+ GemmImplementation & operator= (const GemmImplementation &) = default;
+
+ GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {}
GemmImplementation(GemmMethod m, const char *n,
std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended,
@@ -124,9 +126,6 @@ struct GemmImplementation<Top, Tret, Nothing> {
method(m), name(n), is_supported(is_supported),
cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ),
instantiate(instantiate) { }
-
- GemmImplementation(const GemmImplementation &) = default;
- GemmImplementation &operator=(const GemmImplementation &) = default;
};
/* "Master" function implemented for each valid combination of types.
@@ -211,6 +210,7 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons
for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
/* Check that this implementation supports the presented problem. */
+
if (!i->do_is_supported(args, os)) {
continue;
}