aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp33
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp31
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp43
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h41
4 files changed, 122 insertions, 26 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 3e4c604740..cf1a82bc5a 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -99,6 +99,39 @@ void CLGEMMReshapeRHSMatrixKernelManaged::configure(const CLCompileContext &comp
namespace
{
+inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
+{
+ switch(kernel_type)
+ {
+ case CLGEMMKernelType::NATIVE_V1:
+ case CLGEMMKernelType::RESHAPED_ONLY_RHS:
+ case CLGEMMKernelType::RESHAPED_V1:
+ case CLGEMMKernelType::RESHAPED:
+ {
+ return true;
+ }
+ default:
+ {
+ return false;
+ }
+ }
+}
+//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
+inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run)
+{
+ auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
+ if(bool(gemm_kernel))
+ {
+ if(validate_gemm_kernel(gemm_kernel.gemm_type))
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
+ return gemm_kernel.gemm_type;
+ }
+ }
+ gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
+ return gemm_kernel.gemm_type;
+}
// Validate lhs_info and rhs_info for reshaped only rhs kernel
inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 2c9bb3cb66..b189955c04 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -52,6 +52,37 @@ using namespace arm_compute::cl_gemm;
namespace
{
+inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
+{
+ switch(kernel_type)
+ {
+ case CLGEMMKernelType::NATIVE:
+ case CLGEMMKernelType::RESHAPED_ONLY_RHS:
+ {
+ return true;
+ }
+ default:
+ {
+ return false;
+ }
+ }
+}
+//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
+inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run)
+{
+ auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
+ if(bool(gemm_kernel))
+ {
+ if(validate_gemm_kernel(gemm_kernel.gemm_type))
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
+ return gemm_kernel.gemm_type;
+ }
+ }
+ gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
+ return gemm_kernel.gemm_type;
+}
// Validate lhs_info and rhs_info for native kernel
inline bool validate_lhs_rhs_info_native(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const GEMMReshapeInfo &reshape_info)
{
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
index c5618f2cce..ef160d1186 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
@@ -43,21 +43,31 @@ namespace cl_gemm
{
namespace auto_heuristics
{
-CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
{
- // Select between mlgo and default heuristics
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ ARM_COMPUTE_UNUSED(reshape_b_only_on_first_run);
+ bool valid = false;
+ CLGEMMKernelType gemm_type{};
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
- auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
- if(res.first)
- {
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str());
- return res.second;
- }
+ std::tie(valid, gemm_type) = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
}
- std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target);
- ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
+ if(valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm type: %s.", to_string(gemm_type).c_str());
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
+ }
+ return GEMMTypeResult(valid, gemm_type);
+}
+
+GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+{
+ std::unique_ptr<ICLGEMMKernelSelection> default_heuristics = CLGEMMKernelSelectionFactory::create(query.gpu_target);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(default_heuristics.get());
CLGEMMKernelSelectionParams params;
params.m = query.m;
@@ -67,9 +77,8 @@ CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_
params.is_rhs_constant = reshape_b_only_on_first_run;
params.data_type = query.data_type;
- const auto kernel_type = gemm_kernel->select_kernel(params);
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str());
- return kernel_type;
+ const auto kernel_type = default_heuristics->select_kernel(params);
+ return GEMMTypeResult(true, kernel_type);
}
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
@@ -88,7 +97,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigReshapedOnlyRHS config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
@@ -123,7 +132,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigReshaped config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
@@ -158,7 +167,7 @@ GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query)
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigNative config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
index 486c8bd6cb..020237b7f4 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
@@ -47,6 +47,22 @@ struct CommonQuery
unsigned int b; /**< Batch size */
};
+/** Result of querying about GEMM type ( @ref CLGEMMKernelType) */
+struct GEMMTypeResult
+{
+ GEMMTypeResult(bool valid, CLGEMMKernelType gemm_type)
+ : valid{ valid }, gemm_type{ gemm_type }
+ {
+ }
+ /** Test if the result is valid */
+ operator bool() const
+ {
+ return valid;
+ }
+ bool valid; /** If the result is valid */
+ CLGEMMKernelType gemm_type; /** @ref CLGEMMKernelType */
+};
+
/** Result of querying about GEMM config ( @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo) */
struct GEMMConfigResult
{
@@ -64,46 +80,53 @@ struct GEMMConfigResult
GEMMRHSMatrixInfo rhs_info; /** @ref GEMMRHSMatrixInfo */
};
-/** Automatically select between mlgo and default heuristics to choose @ref CLGEMMKernelType
+/** Select gemm type based on mlgo heuristics
+ * @param query Query
+ * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run
+ * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise
+ */
+GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run);
+
+/** Select gemm type based on default heuristics
* @param query Query
* @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run
- * @return CLGEMMKernelType
+ * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise
*/
-CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run);
+GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run);
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query);
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query);
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query);