/* * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #include #include namespace arm_gemm { /* Structure describing an implementation. For each supported combination * of types, a static list of these structures is built up to describe the * implementations available. */ template struct GemmImplementation { const GemmMethod method; const char * name; std::function &, const OutputStage &)> is_supported; std::function &, const OutputStage &)> is_recommended; std::function *(const GemmArgs &, const OutputStage &)> instantiate; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { if (is_supported != nullptr) { return is_supported(args, os); } else { return true; } } bool do_is_recommended(const GemmArgs &args, const OutputStage &os) const { if (is_recommended != nullptr) { return is_recommended(args, os); } else { return true; } } GemmCommon *do_instantiate(const GemmArgs &args, const OutputStage &os) const { return instantiate(args, os); } }; /* Slightly different version of above for straightforward GEMMs with no * output stage, so the std::functions there don't have to deal with the * unnecessary second argument. */ template struct GemmImplementation { const GemmMethod method; const char * name; std::function &)> is_supported; std::function &)> is_recommended; std::function *(const GemmArgs &)> instantiate; bool do_is_supported(const GemmArgs &args, const Nothing &) const { if (is_supported != nullptr) { return is_supported(args); } else { return true; } } bool do_is_recommended(const GemmArgs &args, const Nothing &) const { if (is_recommended != nullptr) { return is_recommended(args); } else { return true; } } GemmCommon *do_instantiate(const GemmArgs &args, const Nothing &) const { return instantiate(args); } }; /* "Master" function implemented for each valid combination of types. * Returns a list of GEMM implementation descriptors for processing by the * other functions, terminated by an implementation with * method==GemmMethod::DEFAULT. */ template const GemmImplementation *gemm_implementation_list(); /* * Select a GEMM implementation for the given arguments. * * The logic here returns the first method on the list which supports the * requested problem parameters, matches the provided filters (method and/or * name string match) and recommends itself. * * If there is no such method, it will return the first method which * supports the requested parameters and passes the filters, regardless of * recommendation. * * If no method supports the requested parameters and passes the filters, * this function returns false and doesn't touch the provided pointer * reference. */ template bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation * &impl) { auto gemms = gemm_implementation_list(); const GemmConfig *cfg = args._cfg; const GemmImplementation *saved_impl = nullptr; for (const GemmImplementation *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Skip if this implementation doesn't support these args. */ if (!i->do_is_supported(args, os)) { continue; } /* Skip if a specific method is requested and this is a different one. */ if (cfg && cfg->method != GemmMethod::DEFAULT && i->method != cfg->method) { continue; } /* Skip if a filter is to be applied and it doesn't match. */ if (cfg && cfg->filter != "" && !strstr(i->name, cfg->filter.c_str())) { continue; } /* At this point, if we don't have a saved implementation, save this * one. This is so that we always return something if a filter * matches, even if it doesn't recommend itself. */ if (saved_impl == nullptr) { saved_impl=i; } /* Check that this method recommends itself. */ if (!i->do_is_recommended(args, os)) { continue; } impl=i; return true; } /* We didn't find an option matching the filters that recommended * itself. But if we found something earlier that matched the filters * but wasn't recommended, return it here. */ if (saved_impl != nullptr) { impl = saved_impl; return true; } return false; } template std::vector get_compatible_kernels(const GemmArgs &args, const OutputStage &os) { std::vector res; /* Find out what the default implementation in so we can set the flag accordingly later. */ const GemmImplementation *default_impl; find_implementation(args, os, default_impl); auto gemms = gemm_implementation_list(); for (const GemmImplementation *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Check that this implementation supports the presented problem. */ if (!i->do_is_supported(args, os)) { continue; } res.push_back(KernelDescription(i->method, i->name, i==default_impl)); } return res; } template UniqueGemmCommon gemm(const GemmArgs &args, const OutputStage &os) { const GemmImplementation *impl; if (find_implementation(args, os, impl)) { return UniqueGemmCommon(impl->do_instantiate(args, os)); } return UniqueGemmCommon(nullptr); } template KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) { const GemmImplementation *impl; if (find_implementation(args, os, impl)) { return KernelDescription(impl->method, impl->name); } /* This shouldn't happen - there should always be at least one valid implementation. */ return KernelDescription(); } } // namespace arm_gemm