diff options
Diffstat (limited to 'src/runtime/CL/gemm_auto_heuristics')
-rw-r--r-- | src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp | 43 | ||||
-rw-r--r-- | src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h | 41 |
2 files changed, 58 insertions, 26 deletions
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp index c5618f2cce..ef160d1186 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp @@ -43,21 +43,31 @@ namespace cl_gemm { namespace auto_heuristics { -CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) +GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) { - // Select between mlgo and default heuristics - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + ARM_COMPUTE_UNUSED(reshape_b_only_on_first_run); + bool valid = false; + CLGEMMKernelType gemm_type{}; + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { - auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); - if(res.first) - { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str()); - return res.second; - } + std::tie(valid, gemm_type) = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); } - std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target); - ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get()); + if(valid) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm type: %s.", to_string(gemm_type).c_str()); + } + else + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); + } + return GEMMTypeResult(valid, gemm_type); +} + +GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) +{ + std::unique_ptr<ICLGEMMKernelSelection> default_heuristics = CLGEMMKernelSelectionFactory::create(query.gpu_target); + ARM_COMPUTE_ERROR_ON_NULLPTR(default_heuristics.get()); CLGEMMKernelSelectionParams params; params.m = query.m; @@ -67,9 +77,8 @@ CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_ params.is_rhs_constant = reshape_b_only_on_first_run; params.data_type = query.data_type; - const auto kernel_type = gemm_kernel->select_kernel(params); - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str()); - return kernel_type; + const auto kernel_type = default_heuristics->select_kernel(params); + return GEMMTypeResult(true, kernel_type); } GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query) @@ -88,7 +97,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigReshapedOnlyRHS config{}; - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); @@ -123,7 +132,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query) GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigReshaped config{}; - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); @@ -158,7 +167,7 @@ GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query) GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigNative config{}; - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h index 486c8bd6cb..020237b7f4 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h @@ -47,6 +47,22 @@ struct CommonQuery unsigned int b; /**< Batch size */ }; +/** Result of querying about GEMM type ( @ref CLGEMMKernelType) */ +struct GEMMTypeResult +{ + GEMMTypeResult(bool valid, CLGEMMKernelType gemm_type) + : valid{ valid }, gemm_type{ gemm_type } + { + } + /** Test if the result is valid */ + operator bool() const + { + return valid; + } + bool valid; /** If the result is valid */ + CLGEMMKernelType gemm_type; /** @ref CLGEMMKernelType */ +}; + /** Result of querying about GEMM config ( @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo) */ struct GEMMConfigResult { @@ -64,46 +80,53 @@ struct GEMMConfigResult GEMMRHSMatrixInfo rhs_info; /** @ref GEMMRHSMatrixInfo */ }; -/** Automatically select between mlgo and default heuristics to choose @ref CLGEMMKernelType +/** Select gemm type based on mlgo heuristics + * @param query Query + * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run + * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise + */ +GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run); + +/** Select gemm type based on default heuristics * @param query Query * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run - * @return CLGEMMKernelType + * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise */ -CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run); +GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run); /** Select gemm config based on mlgo heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query); /** Select gemm config based on default heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query); /** Select gemm config based on mlgo heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query); /** Select gemm config based on default heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query); /** Select gemm config based on mlgo heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query); /** Select gemm config based on default heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query); |