aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2021-02-10 16:57:33 +0000
committerGian Marco Iodice <gianmarco.iodice@arm.com>2021-02-11 09:57:54 +0000
commit1a28e73829385cabec7549f90b4cf468badf72fc (patch)
tree6b39fd02467212db59d33e04f772e64a7d40b99d /src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
parent35981ca1cf8fab278831902664b7a103bcb216e3 (diff)
downloadComputeLibrary-1a28e73829385cabec7549f90b4cf468badf72fc.tar.gz
Validate mlgo gemm type selection and fall back to default heuristics
GEMM kernel type returned by mlgo heuristics in each of the CLGEMM and CLGEMMLowpMatrixMultiplyCore could also be invalid. Fix this by falling back to default heuristics, similar to how we deal with gemm configs for now. Resolves COMPMID-3847 Change-Id: Iae7c1dcd7def04969ad13a4c132873fda8c8a571 Signed-off-by: SiCong Li <sicong.li@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5044 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp')
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp43
1 files changed, 26 insertions, 17 deletions
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
index c5618f2cce..ef160d1186 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
@@ -43,21 +43,31 @@ namespace cl_gemm
{
namespace auto_heuristics
{
-CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
{
- // Select between mlgo and default heuristics
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ ARM_COMPUTE_UNUSED(reshape_b_only_on_first_run);
+ bool valid = false;
+ CLGEMMKernelType gemm_type{};
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
- auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
- if(res.first)
- {
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str());
- return res.second;
- }
+ std::tie(valid, gemm_type) = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
}
- std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target);
- ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
+ if(valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm type: %s.", to_string(gemm_type).c_str());
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
+ }
+ return GEMMTypeResult(valid, gemm_type);
+}
+
+GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+{
+ std::unique_ptr<ICLGEMMKernelSelection> default_heuristics = CLGEMMKernelSelectionFactory::create(query.gpu_target);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(default_heuristics.get());
CLGEMMKernelSelectionParams params;
params.m = query.m;
@@ -67,9 +77,8 @@ CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_
params.is_rhs_constant = reshape_b_only_on_first_run;
params.data_type = query.data_type;
- const auto kernel_type = gemm_kernel->select_kernel(params);
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str());
- return kernel_type;
+ const auto kernel_type = default_heuristics->select_kernel(params);
+ return GEMMTypeResult(true, kernel_type);
}
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
@@ -88,7 +97,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigReshapedOnlyRHS config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
@@ -123,7 +132,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigReshaped config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
@@ -158,7 +167,7 @@ GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query)
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigNative config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });