diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 123 |
1 files changed, 83 insertions, 40 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index d952140959..55d72f88cb 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -28,21 +28,77 @@ namespace arm_gemm { -template<typename Top, typename Tret> +/* Structure describing an implementation. For each supported combination + * of types, a static list of these structures is built up to describe the + * implementations available. + */ +template<typename Top, typename Tret, class OutputStage = Nothing> struct GemmImplementation { + const GemmMethod method; + const char * name; + std::function<bool(const GemmArgs<Tret> &, const OutputStage &)> is_supported; + std::function<bool(const GemmArgs<Tret> &, const OutputStage &)> is_recommended; + std::function<GemmCommon<Top, Tret> *(const GemmArgs<Tret> &, const OutputStage &)> instantiate; + + bool do_is_supported(const GemmArgs<Tret> &args, const OutputStage &os) const { + if (is_supported != nullptr) { + return is_supported(args, os); + } else { + return true; + } + } + + bool do_is_recommended(const GemmArgs<Tret> &args, const OutputStage &os) const { + if (is_recommended != nullptr) { + return is_recommended(args, os); + } else { + return true; + } + } + + GemmCommon<Top, Tret> *do_instantiate(const GemmArgs<Tret> &args, const OutputStage &os) const { + return instantiate(args, os); + } +}; + +/* Slightly different version of above for straightforward GEMMs with no + * output stage, so the std::functions there don't have to deal with the + * unnecessary second argument. */ +template<typename Top, typename Tret> +struct GemmImplementation<Top, Tret, Nothing> { const GemmMethod method; const char * name; std::function<bool(const GemmArgs<Tret> &)> is_supported; std::function<bool(const GemmArgs<Tret> &)> is_recommended; std::function<GemmCommon<Top, Tret> *(const GemmArgs<Tret> &)> instantiate; + + bool do_is_supported(const GemmArgs<Tret> &args, const Nothing &) const { + if (is_supported != nullptr) { + return is_supported(args); + } else { + return true; + } + } + + bool do_is_recommended(const GemmArgs<Tret> &args, const Nothing &) const { + if (is_recommended != nullptr) { + return is_recommended(args); + } else { + return true; + } + } + + GemmCommon<Top, Tret> *do_instantiate(const GemmArgs<Tret> &args, const Nothing &) const { + return instantiate(args); + } }; /* "Master" function implemented for each valid combination of types. * Returns a list of GEMM implementation descriptors for processing by the * other functions, terminated by an implementation with * method==GemmMethod::DEFAULT. */ -template<typename Top, typename Tret> -const GemmImplementation<Top, Tret> *gemm_implementation_list(); +template<typename Top, typename Tret, class OutputStage = Nothing> +const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list(); /* * Select a GEMM implementation for the given arguments. @@ -59,16 +115,16 @@ const GemmImplementation<Top, Tret> *gemm_implementation_list(); * this function returns false and doesn't touch the provided pointer * reference. */ -template<typename Top, typename Tret> -bool find_implementation(const GemmArgs<Tret> &args, const GemmImplementation<Top, Tret> * &impl) { - auto gemms = gemm_implementation_list<Top, Tret>(); +template<typename Top, typename Tret, class OutputStage> +bool find_implementation(const GemmArgs<Tret> &args, const OutputStage &os, const GemmImplementation<Top, Tret, OutputStage> * &impl) { + auto gemms = gemm_implementation_list<Top, Tret, OutputStage>(); const GemmConfig *cfg = args._cfg; - const GemmImplementation<Top, Tret> *saved_impl = nullptr; + const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr; - for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) { + for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Skip if this implementation doesn't support these args. */ - if (i->is_supported != nullptr && !i->is_supported(args)) { + if (!i->do_is_supported(args, os)) { continue; } @@ -91,7 +147,7 @@ bool find_implementation(const GemmArgs<Tret> &args, const GemmImplementation<To } /* Check that this method recommends itself. */ - if (i->is_recommended != nullptr && !i->is_recommended(args)) { + if (!i->do_is_recommended(args, os)) { continue; } @@ -111,19 +167,19 @@ bool find_implementation(const GemmArgs<Tret> &args, const GemmImplementation<To return false; } -template<typename Top, typename Tret> -std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args) { +template<typename Top, typename Tret, class OutputStage> +std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args, const OutputStage &os) { std::vector<KernelDescription> res; /* Find out what the default implementation in so we can set the flag accordingly later. */ - const GemmImplementation<Top, Tret> *default_impl; - find_implementation(args, default_impl); + const GemmImplementation<Top, Tret, OutputStage> *default_impl; + find_implementation(args, os, default_impl); - auto gemms = gemm_implementation_list<Top, Tret>(); + auto gemms = gemm_implementation_list<Top, Tret, OutputStage>(); - for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) { + for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Check that this implementation supports the presented problem. */ - if (i->is_supported != nullptr && !i->is_supported(args)) { + if (!i->do_is_supported(args, os)) { continue; } @@ -133,22 +189,22 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args return res; } -template<typename Top, typename Tret> -UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args) { - const GemmImplementation<Top, Tret> *impl; +template<typename Top, typename Tret, class OutputStage> +UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args, const OutputStage &os) { + const GemmImplementation<Top, Tret, OutputStage> *impl; - if (find_implementation<Top, Tret>(args, impl)) { - return UniqueGemmCommon<Top, Tret>(impl->instantiate(args)); + if (find_implementation<Top, Tret, OutputStage>(args, os, impl)) { + return UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os)); } return UniqueGemmCommon<Top, Tret>(nullptr); } -template<typename Top, typename Tret> -KernelDescription get_gemm_method(const GemmArgs<Tret> &args) { - const GemmImplementation<Top, Tret> *impl; +template<typename Top, typename Tret, class OutputStage> +KernelDescription get_gemm_method(const GemmArgs<Tret> &args, const OutputStage &os) { + const GemmImplementation<Top, Tret, OutputStage> *impl; - if (find_implementation<Top, Tret>(args, impl)) { + if (find_implementation<Top, Tret>(args, os, impl)) { return KernelDescription(impl->method, impl->name); } @@ -156,17 +212,4 @@ KernelDescription get_gemm_method(const GemmArgs<Tret> &args) { return KernelDescription(); } -template<typename Top, typename Tret> -bool method_is_compatible(GemmMethod method, const GemmArgs<Tret> &args) { - /* Determine if the method is valid by attempting to obtain an implementation specifying this method. */ - GemmConfig cfg(method); - GemmArgs<Tret> myargs = args; - - myargs._cfg = &cfg; - - const GemmImplementation<Top, Tret> *impl; - - return find_implementation<Top, Tret>(myargs, impl); -} - -} // namespace arm_gemm
\ No newline at end of file +} // namespace arm_gemm |