diff options
Diffstat (limited to 'src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp')
-rw-r--r-- | src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp | 43 |
1 files changed, 26 insertions, 17 deletions
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp index c5618f2cce..ef160d1186 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp @@ -43,21 +43,31 @@ namespace cl_gemm { namespace auto_heuristics { -CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) +GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) { - // Select between mlgo and default heuristics - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + ARM_COMPUTE_UNUSED(reshape_b_only_on_first_run); + bool valid = false; + CLGEMMKernelType gemm_type{}; + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { - auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); - if(res.first) - { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str()); - return res.second; - } + std::tie(valid, gemm_type) = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); } - std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target); - ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get()); + if(valid) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm type: %s.", to_string(gemm_type).c_str()); + } + else + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); + } + return GEMMTypeResult(valid, gemm_type); +} + +GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) +{ + std::unique_ptr<ICLGEMMKernelSelection> default_heuristics = CLGEMMKernelSelectionFactory::create(query.gpu_target); + ARM_COMPUTE_ERROR_ON_NULLPTR(default_heuristics.get()); CLGEMMKernelSelectionParams params; params.m = query.m; @@ -67,9 +77,8 @@ CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_ params.is_rhs_constant = reshape_b_only_on_first_run; params.data_type = query.data_type; - const auto kernel_type = gemm_kernel->select_kernel(params); - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str()); - return kernel_type; + const auto kernel_type = default_heuristics->select_kernel(params); + return GEMMTypeResult(true, kernel_type); } GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query) @@ -88,7 +97,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigReshapedOnlyRHS config{}; - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); @@ -123,7 +132,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query) GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigReshaped config{}; - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); @@ -158,7 +167,7 @@ GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query) GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigNative config{}; - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); |