From cfa2bba98169cb5ab1945462514be1b6badf7d98 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 27 Jun 2019 17:00:52 +0100 Subject: COMPMID-2178: Update GEMM assembly code. Perform offset reduction and requantization within the assembly wrapper. Change-Id: I5d5b3e1f6f9ef4c71805362c57f88ff199c027a3 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/1541 Comments-Addressed: Pablo Marquez Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- .../NEON/kernels/arm_gemm/gemm_implementation.hpp | 123 ++++++++++++++------- 1 file changed, 83 insertions(+), 40 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index d952140959..55d72f88cb 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -28,21 +28,77 @@ namespace arm_gemm { -template +/* Structure describing an implementation. For each supported combination + * of types, a static list of these structures is built up to describe the + * implementations available. + */ +template struct GemmImplementation { + const GemmMethod method; + const char * name; + std::function &, const OutputStage &)> is_supported; + std::function &, const OutputStage &)> is_recommended; + std::function *(const GemmArgs &, const OutputStage &)> instantiate; + + bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { + if (is_supported != nullptr) { + return is_supported(args, os); + } else { + return true; + } + } + + bool do_is_recommended(const GemmArgs &args, const OutputStage &os) const { + if (is_recommended != nullptr) { + return is_recommended(args, os); + } else { + return true; + } + } + + GemmCommon *do_instantiate(const GemmArgs &args, const OutputStage &os) const { + return instantiate(args, os); + } +}; + +/* Slightly different version of above for straightforward GEMMs with no + * output stage, so the std::functions there don't have to deal with the + * unnecessary second argument. */ +template +struct GemmImplementation { const GemmMethod method; const char * name; std::function &)> is_supported; std::function &)> is_recommended; std::function *(const GemmArgs &)> instantiate; + + bool do_is_supported(const GemmArgs &args, const Nothing &) const { + if (is_supported != nullptr) { + return is_supported(args); + } else { + return true; + } + } + + bool do_is_recommended(const GemmArgs &args, const Nothing &) const { + if (is_recommended != nullptr) { + return is_recommended(args); + } else { + return true; + } + } + + GemmCommon *do_instantiate(const GemmArgs &args, const Nothing &) const { + return instantiate(args); + } }; /* "Master" function implemented for each valid combination of types. * Returns a list of GEMM implementation descriptors for processing by the * other functions, terminated by an implementation with * method==GemmMethod::DEFAULT. */ -template -const GemmImplementation *gemm_implementation_list(); +template +const GemmImplementation *gemm_implementation_list(); /* * Select a GEMM implementation for the given arguments. @@ -59,16 +115,16 @@ const GemmImplementation *gemm_implementation_list(); * this function returns false and doesn't touch the provided pointer * reference. */ -template -bool find_implementation(const GemmArgs &args, const GemmImplementation * &impl) { - auto gemms = gemm_implementation_list(); +template +bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation * &impl) { + auto gemms = gemm_implementation_list(); const GemmConfig *cfg = args._cfg; - const GemmImplementation *saved_impl = nullptr; + const GemmImplementation *saved_impl = nullptr; - for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) { + for (const GemmImplementation *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Skip if this implementation doesn't support these args. */ - if (i->is_supported != nullptr && !i->is_supported(args)) { + if (!i->do_is_supported(args, os)) { continue; } @@ -91,7 +147,7 @@ bool find_implementation(const GemmArgs &args, const GemmImplementationis_recommended != nullptr && !i->is_recommended(args)) { + if (!i->do_is_recommended(args, os)) { continue; } @@ -111,19 +167,19 @@ bool find_implementation(const GemmArgs &args, const GemmImplementation -std::vector get_compatible_kernels(const GemmArgs &args) { +template +std::vector get_compatible_kernels(const GemmArgs &args, const OutputStage &os) { std::vector res; /* Find out what the default implementation in so we can set the flag accordingly later. */ - const GemmImplementation *default_impl; - find_implementation(args, default_impl); + const GemmImplementation *default_impl; + find_implementation(args, os, default_impl); - auto gemms = gemm_implementation_list(); + auto gemms = gemm_implementation_list(); - for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) { + for (const GemmImplementation *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Check that this implementation supports the presented problem. */ - if (i->is_supported != nullptr && !i->is_supported(args)) { + if (!i->do_is_supported(args, os)) { continue; } @@ -133,22 +189,22 @@ std::vector get_compatible_kernels(const GemmArgs &args return res; } -template -UniqueGemmCommon gemm(const GemmArgs &args) { - const GemmImplementation *impl; +template +UniqueGemmCommon gemm(const GemmArgs &args, const OutputStage &os) { + const GemmImplementation *impl; - if (find_implementation(args, impl)) { - return UniqueGemmCommon(impl->instantiate(args)); + if (find_implementation(args, os, impl)) { + return UniqueGemmCommon(impl->do_instantiate(args, os)); } return UniqueGemmCommon(nullptr); } -template -KernelDescription get_gemm_method(const GemmArgs &args) { - const GemmImplementation *impl; +template +KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) { + const GemmImplementation *impl; - if (find_implementation(args, impl)) { + if (find_implementation(args, os, impl)) { return KernelDescription(impl->method, impl->name); } @@ -156,17 +212,4 @@ KernelDescription get_gemm_method(const GemmArgs &args) { return KernelDescription(); } -template -bool method_is_compatible(GemmMethod method, const GemmArgs &args) { - /* Determine if the method is valid by attempting to obtain an implementation specifying this method. */ - GemmConfig cfg(method); - GemmArgs myargs = args; - - myargs._cfg = &cfg; - - const GemmImplementation *impl; - - return find_implementation(myargs, impl); -} - -} // namespace arm_gemm \ No newline at end of file +} // namespace arm_gemm -- cgit v1.2.1