aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp40
1 files changed, 20 insertions, 20 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 55d72f88cb..569d1f44ca 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -34,13 +34,13 @@ namespace arm_gemm {
*/
template<typename Top, typename Tret, class OutputStage = Nothing>
struct GemmImplementation {
- const GemmMethod method;
- const char * name;
- std::function<bool(const GemmArgs<Tret> &, const OutputStage &)> is_supported;
- std::function<bool(const GemmArgs<Tret> &, const OutputStage &)> is_recommended;
- std::function<GemmCommon<Top, Tret> *(const GemmArgs<Tret> &, const OutputStage &)> instantiate;
+ const GemmMethod method;
+ const char * name;
+ std::function<bool(const GemmArgs &, const OutputStage &)> is_supported;
+ std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended;
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate;
- bool do_is_supported(const GemmArgs<Tret> &args, const OutputStage &os) const {
+ bool do_is_supported(const GemmArgs &args, const OutputStage &os) const {
if (is_supported != nullptr) {
return is_supported(args, os);
} else {
@@ -48,7 +48,7 @@ struct GemmImplementation {
}
}
- bool do_is_recommended(const GemmArgs<Tret> &args, const OutputStage &os) const {
+ bool do_is_recommended(const GemmArgs &args, const OutputStage &os) const {
if (is_recommended != nullptr) {
return is_recommended(args, os);
} else {
@@ -56,7 +56,7 @@ struct GemmImplementation {
}
}
- GemmCommon<Top, Tret> *do_instantiate(const GemmArgs<Tret> &args, const OutputStage &os) const {
+ GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const {
return instantiate(args, os);
}
};
@@ -66,13 +66,13 @@ struct GemmImplementation {
* unnecessary second argument. */
template<typename Top, typename Tret>
struct GemmImplementation<Top, Tret, Nothing> {
- const GemmMethod method;
- const char * name;
- std::function<bool(const GemmArgs<Tret> &)> is_supported;
- std::function<bool(const GemmArgs<Tret> &)> is_recommended;
- std::function<GemmCommon<Top, Tret> *(const GemmArgs<Tret> &)> instantiate;
+ const GemmMethod method;
+ const char * name;
+ std::function<bool(const GemmArgs &)> is_supported;
+ std::function<bool(const GemmArgs &)> is_recommended;
+ std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate;
- bool do_is_supported(const GemmArgs<Tret> &args, const Nothing &) const {
+ bool do_is_supported(const GemmArgs &args, const Nothing &) const {
if (is_supported != nullptr) {
return is_supported(args);
} else {
@@ -80,7 +80,7 @@ struct GemmImplementation<Top, Tret, Nothing> {
}
}
- bool do_is_recommended(const GemmArgs<Tret> &args, const Nothing &) const {
+ bool do_is_recommended(const GemmArgs &args, const Nothing &) const {
if (is_recommended != nullptr) {
return is_recommended(args);
} else {
@@ -88,7 +88,7 @@ struct GemmImplementation<Top, Tret, Nothing> {
}
}
- GemmCommon<Top, Tret> *do_instantiate(const GemmArgs<Tret> &args, const Nothing &) const {
+ GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const {
return instantiate(args);
}
};
@@ -116,7 +116,7 @@ const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list();
* reference.
*/
template<typename Top, typename Tret, class OutputStage>
-bool find_implementation(const GemmArgs<Tret> &args, const OutputStage &os, const GemmImplementation<Top, Tret, OutputStage> * &impl) {
+bool find_implementation(const GemmArgs &args, const OutputStage &os, const GemmImplementation<Top, Tret, OutputStage> * &impl) {
auto gemms = gemm_implementation_list<Top, Tret, OutputStage>();
const GemmConfig *cfg = args._cfg;
@@ -168,7 +168,7 @@ bool find_implementation(const GemmArgs<Tret> &args, const OutputStage &os, cons
}
template<typename Top, typename Tret, class OutputStage>
-std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args, const OutputStage &os) {
+std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage &os) {
std::vector<KernelDescription> res;
/* Find out what the default implementation in so we can set the flag accordingly later. */
@@ -190,7 +190,7 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs<Tret> &args
}
template<typename Top, typename Tret, class OutputStage>
-UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args, const OutputStage &os) {
+UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) {
const GemmImplementation<Top, Tret, OutputStage> *impl;
if (find_implementation<Top, Tret, OutputStage>(args, os, impl)) {
@@ -201,7 +201,7 @@ UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args, const OutputStage &
}
template<typename Top, typename Tret, class OutputStage>
-KernelDescription get_gemm_method(const GemmArgs<Tret> &args, const OutputStage &os) {
+KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) {
const GemmImplementation<Top, Tret, OutputStage> *impl;
if (find_implementation<Top, Tret>(args, os, impl)) {