diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 84 |
1 files changed, 42 insertions, 42 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index 19d5e3e23d..db5155f500 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -35,17 +35,17 @@ namespace arm_gemm { * of types, a static list of these structures is built up to describe the * implementations available. */ -template<typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing> +template<typename Top, typename Tret, class OutputStage = Nothing> struct GemmImplementation { const GemmMethod method; const char * name; const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {}; std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {}; - std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {}; + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { - // Check supplied is_supported() function first. + // Check supplied is_supported() function first. if (is_supported != nullptr && !is_supported(args, os)) { return false; } @@ -68,7 +68,7 @@ struct GemmImplementation { // If we get here it means there is a config and it specifies a format. Check it matches this kernel. // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() // was called above first. - return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Tlop))); + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } @@ -80,13 +80,13 @@ struct GemmImplementation { } } - GemmCommon<Tlop, Trop, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const { + GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const { return instantiate(args, os); } static GemmImplementation with_estimate(GemmMethod m, const char *n, std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate, - std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) { + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) { GemmImplementation impl(m,n); impl.is_supported=is_supported; @@ -103,14 +103,14 @@ struct GemmImplementation { GemmImplementation(GemmMethod m, const char *n, std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, - std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, - std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } @@ -119,17 +119,17 @@ struct GemmImplementation { /* Slightly different version of above for straightforward GEMMs with no * output stage, so the std::functions there don't have to deal with the * unnecessary second argument. */ -template<typename Tlop, typename Trop, typename Tret> -struct GemmImplementation<Tlop, Trop, Tret, Nothing> { +template<typename Top, typename Tret> +struct GemmImplementation<Top, Tret, Nothing> { const GemmMethod method; const char * name; const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; std::function<bool(const GemmArgs &)> is_supported = {}; std::function<uint64_t(const GemmArgs &)> cycle_estimate = {}; - std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &)> instantiate = {}; + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const Nothing &) const { - // Check supplied is_supported() function first. + // Check supplied is_supported() function first. if (is_supported != nullptr && !is_supported(args)) { return false; } @@ -152,7 +152,7 @@ struct GemmImplementation<Tlop, Trop, Tret, Nothing> { // If we get here it means there is a config and it specifies a format. Check it matches this kernel. // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() // was called above first. - return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Tlop))); + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } @@ -164,13 +164,13 @@ struct GemmImplementation<Tlop, Trop, Tret, Nothing> { } } - GemmCommon<Tlop, Trop, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const { + GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const { return instantiate(args); } static GemmImplementation with_estimate(GemmMethod m, const char *n, std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, - std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &)> instantiate) { + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { GemmImplementation impl(m,n); impl.is_supported=is_supported; @@ -182,7 +182,7 @@ struct GemmImplementation<Tlop, Trop, Tret, Nothing> { static GemmImplementation with_estimate(GemmMethod m, const char *n, KernelWeightFormat f, std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, - std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &)> instantiate) { + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { GemmImplementation impl(m,n,f); impl.is_supported=is_supported; @@ -199,14 +199,14 @@ struct GemmImplementation<Tlop, Trop, Tret, Nothing> { GemmImplementation(GemmMethod m, const char *n, std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, - std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &)> instantiate) : + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) : method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, - std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &)> instantiate) : + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) : method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } @@ -218,8 +218,8 @@ struct GemmImplementation<Tlop, Trop, Tret, Nothing> { * A specialised version is provided for each supported combination of types. * The end of the list is indicated by a sentinel descriptor with * method==GemmMethod::DEFAULT. */ -template<typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing> -const GemmImplementation<Tlop, Trop, Tret, OutputStage> *gemm_implementation_list(); +template<typename Top, typename Tret, class OutputStage = Nothing> +const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list(); /* * Select a GEMM implementation for the given arguments. @@ -234,15 +234,15 @@ const GemmImplementation<Tlop, Trop, Tret, OutputStage> *gemm_implementation_lis * this function returns false and doesn't touch the provided pointer * reference. */ -template<typename Tlop, typename Trop, typename Tret, class OutputStage> -bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation<Tlop, Trop, Tret, OutputStage> * &impl) { - auto gemms = gemm_implementation_list<Tlop, Trop, Tret, OutputStage>(); +template<typename Top, typename Tret, class OutputStage> +bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation<Top, Tret, OutputStage> * &impl) { + auto gemms = gemm_implementation_list<Top, Tret, OutputStage>(); const GemmConfig *cfg = args._cfg; - const GemmImplementation<Tlop, Trop, Tret, OutputStage> *saved_impl = nullptr; + const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr; uint64_t best_estimate = 0; - for (const GemmImplementation<Tlop, Trop, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { + for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Skip if this implementation doesn't support these args. */ if (!i->do_is_supported(args, os)) { continue; @@ -284,17 +284,17 @@ bool find_implementation(const GemmArgs &args, const OutputStage &os, const Gemm return false; } -template<typename Tlop, typename Trop, typename Tret, class OutputStage> +template<typename Top, typename Tret, class OutputStage> std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage &os) { std::vector<KernelDescription> res; /* Find out what the default implementation in so we can set the flag accordingly later. */ - const GemmImplementation<Tlop, Trop, Tret, OutputStage> *default_impl; + const GemmImplementation<Top, Tret, OutputStage> *default_impl; find_implementation(args, os, default_impl); - auto gemms = gemm_implementation_list<Tlop, Trop, Tret, OutputStage>(); + auto gemms = gemm_implementation_list<Top, Tret, OutputStage>(); - for (const GemmImplementation<Tlop, Trop, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { + for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Check that this implementation supports the presented problem. */ if (!i->do_is_supported(args, os)) { @@ -307,31 +307,31 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons return res; } -template<typename Tlop, typename Trop, typename Tret, class OutputStage> +template<typename Top, typename Tret, class OutputStage> bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) { - const GemmImplementation<Tlop, Trop, Tret, OutputStage> *impl; - const bool success = find_implementation<Tlop, Trop, Tret, OutputStage>(args, os, impl); + const GemmImplementation<Top, Tret, OutputStage> *impl; + const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl); if (success) - wf = UniqueGemmCommon<Tlop, Trop, Tret>(impl->do_instantiate(args, os))->get_config().weight_format; + wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format; return success; } -template<typename Tlop, typename Trop, typename Tret, class OutputStage> -UniqueGemmCommon<Tlop, Trop, Tret> gemm(const GemmArgs &args, const OutputStage &os) { - const GemmImplementation<Tlop, Trop, Tret, OutputStage> *impl; +template<typename Top, typename Tret, class OutputStage> +UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) { + const GemmImplementation<Top, Tret, OutputStage> *impl; - if (find_implementation<Tlop, Trop, Tret, OutputStage>(args, os, impl)) { - return UniqueGemmCommon<Tlop, Trop, Tret>(impl->do_instantiate(args, os)); + if (find_implementation<Top, Tret, OutputStage>(args, os, impl)) { + return UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os)); } - return UniqueGemmCommon<Tlop, Trop, Tret>(nullptr); + return UniqueGemmCommon<Top, Tret>(nullptr); } -template<typename Tlop, typename Trop, typename Tret, class OutputStage> +template<typename Top, typename Tret, class OutputStage> KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) { - const GemmImplementation<Tlop, Trop, Tret, OutputStage> *impl; + const GemmImplementation<Top, Tret, OutputStage> *impl; - if (find_implementation<Tlop, Trop, Tret>(args, os, impl)) { + if (find_implementation<Top, Tret>(args, os, impl)) { return KernelDescription(impl->method, impl->name); } |