diff options
4 files changed, 122 insertions, 26 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 3e4c604740..cf1a82bc5a 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -99,6 +99,39 @@ void CLGEMMReshapeRHSMatrixKernelManaged::configure(const CLCompileContext &comp namespace { +inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type) +{ + switch(kernel_type) + { + case CLGEMMKernelType::NATIVE_V1: + case CLGEMMKernelType::RESHAPED_ONLY_RHS: + case CLGEMMKernelType::RESHAPED_V1: + case CLGEMMKernelType::RESHAPED: + { + return true; + } + default: + { + return false; + } + } +} +//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type +inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run) +{ + auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run); + if(bool(gemm_kernel)) + { + if(validate_gemm_kernel(gemm_kernel.gemm_type)) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str()); + return gemm_kernel.gemm_type; + } + } + gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str()); + return gemm_kernel.gemm_type; +} // Validate lhs_info and rhs_info for reshaped only rhs kernel inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info) diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp index 2c9bb3cb66..b189955c04 100644 --- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp @@ -52,6 +52,37 @@ using namespace arm_compute::cl_gemm; namespace { +inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type) +{ + switch(kernel_type) + { + case CLGEMMKernelType::NATIVE: + case CLGEMMKernelType::RESHAPED_ONLY_RHS: + { + return true; + } + default: + { + return false; + } + } +} +//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type +inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run) +{ + auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run); + if(bool(gemm_kernel)) + { + if(validate_gemm_kernel(gemm_kernel.gemm_type)) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str()); + return gemm_kernel.gemm_type; + } + } + gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str()); + return gemm_kernel.gemm_type; +} // Validate lhs_info and rhs_info for native kernel inline bool validate_lhs_rhs_info_native(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const GEMMReshapeInfo &reshape_info) { diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp index c5618f2cce..ef160d1186 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp @@ -43,21 +43,31 @@ namespace cl_gemm { namespace auto_heuristics { -CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) +GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) { - // Select between mlgo and default heuristics - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + ARM_COMPUTE_UNUSED(reshape_b_only_on_first_run); + bool valid = false; + CLGEMMKernelType gemm_type{}; + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { - auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); - if(res.first) - { - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str()); - return res.second; - } + std::tie(valid, gemm_type) = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); } - std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target); - ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get()); + if(valid) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm type: %s.", to_string(gemm_type).c_str()); + } + else + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); + } + return GEMMTypeResult(valid, gemm_type); +} + +GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) +{ + std::unique_ptr<ICLGEMMKernelSelection> default_heuristics = CLGEMMKernelSelectionFactory::create(query.gpu_target); + ARM_COMPUTE_ERROR_ON_NULLPTR(default_heuristics.get()); CLGEMMKernelSelectionParams params; params.m = query.m; @@ -67,9 +77,8 @@ CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_ params.is_rhs_constant = reshape_b_only_on_first_run; params.data_type = query.data_type; - const auto kernel_type = gemm_kernel->select_kernel(params); - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str()); - return kernel_type; + const auto kernel_type = default_heuristics->select_kernel(params); + return GEMMTypeResult(true, kernel_type); } GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query) @@ -88,7 +97,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigReshapedOnlyRHS config{}; - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); @@ -123,7 +132,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query) GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigReshaped config{}; - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); @@ -158,7 +167,7 @@ GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query) GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigNative config{}; - auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h index 486c8bd6cb..020237b7f4 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h @@ -47,6 +47,22 @@ struct CommonQuery unsigned int b; /**< Batch size */ }; +/** Result of querying about GEMM type ( @ref CLGEMMKernelType) */ +struct GEMMTypeResult +{ + GEMMTypeResult(bool valid, CLGEMMKernelType gemm_type) + : valid{ valid }, gemm_type{ gemm_type } + { + } + /** Test if the result is valid */ + operator bool() const + { + return valid; + } + bool valid; /** If the result is valid */ + CLGEMMKernelType gemm_type; /** @ref CLGEMMKernelType */ +}; + /** Result of querying about GEMM config ( @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo) */ struct GEMMConfigResult { @@ -64,46 +80,53 @@ struct GEMMConfigResult GEMMRHSMatrixInfo rhs_info; /** @ref GEMMRHSMatrixInfo */ }; -/** Automatically select between mlgo and default heuristics to choose @ref CLGEMMKernelType +/** Select gemm type based on mlgo heuristics + * @param query Query + * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run + * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise + */ +GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run); + +/** Select gemm type based on default heuristics * @param query Query * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run - * @return CLGEMMKernelType + * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise */ -CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run); +GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run); /** Select gemm config based on mlgo heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query); /** Select gemm config based on default heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query); /** Select gemm config based on mlgo heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query); /** Select gemm config based on default heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query); /** Select gemm config based on mlgo heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query); /** Select gemm config based on default heuristics * @param query Query - * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise */ GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query); |