aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp86
1 files changed, 43 insertions, 43 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 5e77df7d4a..19d5e3e23d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2018-2020, 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,17 +35,17 @@ namespace arm_gemm {
* of types, a static list of these structures is built up to describe the
* implementations available.
*/
-template<typename Top, typename Tret, class OutputStage = Nothing>
+template<typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
struct GemmImplementation {
const GemmMethod method;
const char * name;
const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED;
std::function<bool(const GemmArgs &, const OutputStage &)> is_supported = {};
std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate = {};
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {};
+ std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &, const OutputStage &)> instantiate = {};
bool do_is_supported(const GemmArgs &args, const OutputStage &os) const {
- // Check supplied is_supported() function first.
+ // Check supplied is_supported() function first.
if (is_supported != nullptr && !is_supported(args, os)) {
return false;
}
@@ -68,7 +68,7 @@ struct GemmImplementation {
// If we get here it means there is a config and it specifies a format. Check it matches this kernel.
// NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported()
// was called above first.
- return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top)));
+ return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Tlop)));
}
}
@@ -80,13 +80,13 @@ struct GemmImplementation {
}
}
- GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const {
+ GemmCommon<Tlop, Trop, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const {
return instantiate(args, os);
}
static GemmImplementation with_estimate(GemmMethod m, const char *n,
std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate,
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) {
+ std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) {
GemmImplementation impl(m,n);
impl.is_supported=is_supported;
@@ -103,14 +103,14 @@ struct GemmImplementation {
GemmImplementation(GemmMethod m, const char *n,
std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended,
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) :
+ std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) :
method(m), name(n), is_supported(is_supported),
cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ),
instantiate(instantiate) { }
GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf,
std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended,
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) :
+ std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) :
method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported),
cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ),
instantiate(instantiate) { }
@@ -119,17 +119,17 @@ struct GemmImplementation {
/* Slightly different version of above for straightforward GEMMs with no
* output stage, so the std::functions there don't have to deal with the
* unnecessary second argument. */
-template<typename Top, typename Tret>
-struct GemmImplementation<Top, Tret, Nothing> {
+template<typename Tlop, typename Trop, typename Tret>
+struct GemmImplementation<Tlop, Trop, Tret, Nothing> {
const GemmMethod method;
const char * name;
const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED;
std::function<bool(const GemmArgs &)> is_supported = {};
std::function<uint64_t(const GemmArgs &)> cycle_estimate = {};
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate = {};
+ std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &)> instantiate = {};
bool do_is_supported(const GemmArgs &args, const Nothing &) const {
- // Check supplied is_supported() function first.
+ // Check supplied is_supported() function first.
if (is_supported != nullptr && !is_supported(args)) {
return false;
}
@@ -152,7 +152,7 @@ struct GemmImplementation<Top, Tret, Nothing> {
// If we get here it means there is a config and it specifies a format. Check it matches this kernel.
// NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported()
// was called above first.
- return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top)));
+ return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Tlop)));
}
}
@@ -164,13 +164,13 @@ struct GemmImplementation<Top, Tret, Nothing> {
}
}
- GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const {
+ GemmCommon<Tlop, Trop, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const {
return instantiate(args);
}
static GemmImplementation with_estimate(GemmMethod m, const char *n,
std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate,
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) {
+ std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &)> instantiate) {
GemmImplementation impl(m,n);
impl.is_supported=is_supported;
@@ -182,7 +182,7 @@ struct GemmImplementation<Top, Tret, Nothing> {
static GemmImplementation with_estimate(GemmMethod m, const char *n, KernelWeightFormat f,
std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate,
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) {
+ std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &)> instantiate) {
GemmImplementation impl(m,n,f);
impl.is_supported=is_supported;
@@ -199,14 +199,14 @@ struct GemmImplementation<Top, Tret, Nothing> {
GemmImplementation(GemmMethod m, const char *n,
std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended,
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) :
+ std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &)> instantiate) :
method(m), name(n), is_supported(is_supported),
cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ),
instantiate(instantiate) { }
GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf,
std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended,
- std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) :
+ std::function<GemmCommon<Tlop, Trop, Tret> *(const GemmArgs &)> instantiate) :
method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported),
cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ),
instantiate(instantiate) { }
@@ -218,8 +218,8 @@ struct GemmImplementation<Top, Tret, Nothing> {
* A specialised version is provided for each supported combination of types.
* The end of the list is indicated by a sentinel descriptor with
* method==GemmMethod::DEFAULT. */
-template<typename Top, typename Tret, class OutputStage = Nothing>
-const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list();
+template<typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
+const GemmImplementation<Tlop, Trop, Tret, OutputStage> *gemm_implementation_list();
/*
* Select a GEMM implementation for the given arguments.
@@ -234,15 +234,15 @@ const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list();
* this function returns false and doesn't touch the provided pointer
* reference.
*/
-template<typename Top, typename Tret, class OutputStage>
-bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation<Top, Tret, OutputStage> * &impl) {
- auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
+template<typename Tlop, typename Trop, typename Tret, class OutputStage>
+bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation<Tlop, Trop, Tret, OutputStage> * &impl) {
+ auto gemms = gemm_implementation_list<Tlop, Trop, Tret, OutputStage>();
const GemmConfig *cfg = args._cfg;
- const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr;
+ const GemmImplementation<Tlop, Trop, Tret, OutputStage> *saved_impl = nullptr;
uint64_t best_estimate = 0;
- for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
+ for (const GemmImplementation<Tlop, Trop, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
/* Skip if this implementation doesn't support these args. */
if (!i->do_is_supported(args, os)) {
continue;
@@ -284,17 +284,17 @@ bool find_implementation(const GemmArgs &args, const OutputStage &os, const Gemm
return false;
}
-template<typename Top, typename Tret, class OutputStage>
+template<typename Tlop, typename Trop, typename Tret, class OutputStage>
std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage &os) {
std::vector<KernelDescription> res;
/* Find out what the default implementation in so we can set the flag accordingly later. */
- const GemmImplementation<Top, Tret, OutputStage> *default_impl;
+ const GemmImplementation<Tlop, Trop, Tret, OutputStage> *default_impl;
find_implementation(args, os, default_impl);
- auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
+ auto gemms = gemm_implementation_list<Tlop, Trop, Tret, OutputStage>();
- for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
+ for (const GemmImplementation<Tlop, Trop, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) {
/* Check that this implementation supports the presented problem. */
if (!i->do_is_supported(args, os)) {
@@ -307,31 +307,31 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons
return res;
}
-template<typename Top, typename Tret, class OutputStage>
+template<typename Tlop, typename Trop, typename Tret, class OutputStage>
bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) {
- const GemmImplementation<Top, Tret, OutputStage> *impl;
- const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl);
+ const GemmImplementation<Tlop, Trop, Tret, OutputStage> *impl;
+ const bool success = find_implementation<Tlop, Trop, Tret, OutputStage>(args, os, impl);
if (success)
- wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format;
+ wf = UniqueGemmCommon<Tlop, Trop, Tret>(impl->do_instantiate(args, os))->get_config().weight_format;
return success;
}
-template<typename Top, typename Tret, class OutputStage>
-UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) {
- const GemmImplementation<Top, Tret, OutputStage> *impl;
+template<typename Tlop, typename Trop, typename Tret, class OutputStage>
+UniqueGemmCommon<Tlop, Trop, Tret> gemm(const GemmArgs &args, const OutputStage &os) {
+ const GemmImplementation<Tlop, Trop, Tret, OutputStage> *impl;
- if (find_implementation<Top, Tret, OutputStage>(args, os, impl)) {
- return UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os));
+ if (find_implementation<Tlop, Trop, Tret, OutputStage>(args, os, impl)) {
+ return UniqueGemmCommon<Tlop, Trop, Tret>(impl->do_instantiate(args, os));
}
- return UniqueGemmCommon<Top, Tret>(nullptr);
+ return UniqueGemmCommon<Tlop, Trop, Tret>(nullptr);
}
-template<typename Top, typename Tret, class OutputStage>
+template<typename Tlop, typename Trop, typename Tret, class OutputStage>
KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) {
- const GemmImplementation<Top, Tret, OutputStage> *impl;
+ const GemmImplementation<Tlop, Trop, Tret, OutputStage> *impl;
- if (find_implementation<Top, Tret>(args, os, impl)) {
+ if (find_implementation<Tlop, Trop, Tret>(args, os, impl)) {
return KernelDescription(impl->method, impl->name);
}