aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/gemm_auto_heuristics
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/gemm_auto_heuristics')
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp36
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h20
2 files changed, 52 insertions, 4 deletions
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
index 5790e077d4..c5618f2cce 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
@@ -29,6 +29,7 @@
#include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h"
#include "src/core/CL/ICLGEMMKernelConfiguration.h"
#include "src/core/CL/gemm/CLGEMMHelpers.cpp"
+#include "src/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h"
#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
@@ -100,6 +101,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
}
+ // Setting irrelevant unsigned int parameters to 1 and bool parameters to false as they do no matter
std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, config.h0, false, config.interleave_rhs, !config.transpose_rhs, config.transpose_rhs,
config.export_cl_image);
return GEMMConfigResult{ valid, lhs_info, rhs_info };
@@ -139,6 +141,40 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
config.export_cl_image);
return GEMMConfigResult{ valid, lhs_info, rhs_info };
}
+
+GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query)
+{
+ GEMMLHSMatrixInfo lhs_info;
+ GEMMRHSMatrixInfo rhs_info;
+ std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMNativeKernelConfigurationFactory::create(query.gpu_target);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
+ std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type);
+ return GEMMConfigResult{ true, lhs_info, rhs_info };
+}
+
+GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query)
+{
+ bool valid = false;
+ GEMMLHSMatrixInfo lhs_info;
+ GEMMRHSMatrixInfo rhs_info;
+ mlgo::GEMMConfigNative config{};
+ auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ if(mlgo_heuristics != nullptr)
+ {
+ std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
+ }
+ if(valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str());
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
+ }
+ // Setting irrelevant unsigned int parameters to 1 and bool parameters to false as they do no matter
+ std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, 1, false, false, false, false, false);
+ return GEMMConfigResult{ valid, lhs_info, rhs_info };
+}
} // namespace auto_heuristics
} // namespace cl_gemm
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
index 7cb9cab220..486c8bd6cb 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
@@ -73,28 +73,40 @@ CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult
+ * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult
+ * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query);
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult
+ * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult
+ * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query);
+/** Select gemm config based on mlgo heuristics
+ * @param query Query
+ * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ */
+GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query);
+
+/** Select gemm config based on default heuristics
+ * @param query Query
+ * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ */
+GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query);
+
} // namespace auto_heuristics
} // namespace cl_gemm
} // namespace arm_compute