aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp')
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp43
1 files changed, 26 insertions, 17 deletions
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
index c5618f2cce..ef160d1186 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
@@ -43,21 +43,31 @@ namespace cl_gemm
{
namespace auto_heuristics
{
-CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
{
- // Select between mlgo and default heuristics
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ ARM_COMPUTE_UNUSED(reshape_b_only_on_first_run);
+ bool valid = false;
+ CLGEMMKernelType gemm_type{};
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
- auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
- if(res.first)
- {
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str());
- return res.second;
- }
+ std::tie(valid, gemm_type) = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
}
- std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target);
- ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
+ if(valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm type: %s.", to_string(gemm_type).c_str());
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
+ }
+ return GEMMTypeResult(valid, gemm_type);
+}
+
+GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+{
+ std::unique_ptr<ICLGEMMKernelSelection> default_heuristics = CLGEMMKernelSelectionFactory::create(query.gpu_target);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(default_heuristics.get());
CLGEMMKernelSelectionParams params;
params.m = query.m;
@@ -67,9 +77,8 @@ CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_
params.is_rhs_constant = reshape_b_only_on_first_run;
params.data_type = query.data_type;
- const auto kernel_type = gemm_kernel->select_kernel(params);
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str());
- return kernel_type;
+ const auto kernel_type = default_heuristics->select_kernel(params);
+ return GEMMTypeResult(true, kernel_type);
}
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
@@ -88,7 +97,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigReshapedOnlyRHS config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
@@ -123,7 +132,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigReshaped config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
@@ -158,7 +167,7 @@ GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query)
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigNative config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });