aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2021-02-10 16:57:33 +0000
committerGian Marco Iodice <gianmarco.iodice@arm.com>2021-02-11 09:57:54 +0000
commit1a28e73829385cabec7549f90b4cf468badf72fc (patch)
tree6b39fd02467212db59d33e04f772e64a7d40b99d
parent35981ca1cf8fab278831902664b7a103bcb216e3 (diff)
downloadComputeLibrary-1a28e73829385cabec7549f90b4cf468badf72fc.tar.gz
Validate mlgo gemm type selection and fall back to default heuristics
GEMM kernel type returned by mlgo heuristics in each of the CLGEMM and CLGEMMLowpMatrixMultiplyCore could also be invalid. Fix this by falling back to default heuristics, similar to how we deal with gemm configs for now. Resolves COMPMID-3847 Change-Id: Iae7c1dcd7def04969ad13a4c132873fda8c8a571 Signed-off-by: SiCong Li <sicong.li@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5044 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp33
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp31
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp43
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h41
4 files changed, 122 insertions, 26 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 3e4c604740..cf1a82bc5a 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -99,6 +99,39 @@ void CLGEMMReshapeRHSMatrixKernelManaged::configure(const CLCompileContext &comp
namespace
{
+inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
+{
+ switch(kernel_type)
+ {
+ case CLGEMMKernelType::NATIVE_V1:
+ case CLGEMMKernelType::RESHAPED_ONLY_RHS:
+ case CLGEMMKernelType::RESHAPED_V1:
+ case CLGEMMKernelType::RESHAPED:
+ {
+ return true;
+ }
+ default:
+ {
+ return false;
+ }
+ }
+}
+//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
+inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run)
+{
+ auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
+ if(bool(gemm_kernel))
+ {
+ if(validate_gemm_kernel(gemm_kernel.gemm_type))
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
+ return gemm_kernel.gemm_type;
+ }
+ }
+ gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
+ return gemm_kernel.gemm_type;
+}
// Validate lhs_info and rhs_info for reshaped only rhs kernel
inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 2c9bb3cb66..b189955c04 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -52,6 +52,37 @@ using namespace arm_compute::cl_gemm;
namespace
{
+inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
+{
+ switch(kernel_type)
+ {
+ case CLGEMMKernelType::NATIVE:
+ case CLGEMMKernelType::RESHAPED_ONLY_RHS:
+ {
+ return true;
+ }
+ default:
+ {
+ return false;
+ }
+ }
+}
+//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
+inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run)
+{
+ auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
+ if(bool(gemm_kernel))
+ {
+ if(validate_gemm_kernel(gemm_kernel.gemm_type))
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
+ return gemm_kernel.gemm_type;
+ }
+ }
+ gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
+ return gemm_kernel.gemm_type;
+}
// Validate lhs_info and rhs_info for native kernel
inline bool validate_lhs_rhs_info_native(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const GEMMReshapeInfo &reshape_info)
{
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
index c5618f2cce..ef160d1186 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
@@ -43,21 +43,31 @@ namespace cl_gemm
{
namespace auto_heuristics
{
-CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
{
- // Select between mlgo and default heuristics
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ ARM_COMPUTE_UNUSED(reshape_b_only_on_first_run);
+ bool valid = false;
+ CLGEMMKernelType gemm_type{};
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
- auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
- if(res.first)
- {
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str());
- return res.second;
- }
+ std::tie(valid, gemm_type) = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
}
- std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target);
- ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
+ if(valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm type: %s.", to_string(gemm_type).c_str());
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
+ }
+ return GEMMTypeResult(valid, gemm_type);
+}
+
+GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+{
+ std::unique_ptr<ICLGEMMKernelSelection> default_heuristics = CLGEMMKernelSelectionFactory::create(query.gpu_target);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(default_heuristics.get());
CLGEMMKernelSelectionParams params;
params.m = query.m;
@@ -67,9 +77,8 @@ CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_
params.is_rhs_constant = reshape_b_only_on_first_run;
params.data_type = query.data_type;
- const auto kernel_type = gemm_kernel->select_kernel(params);
- ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str());
- return kernel_type;
+ const auto kernel_type = default_heuristics->select_kernel(params);
+ return GEMMTypeResult(true, kernel_type);
}
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
@@ -88,7 +97,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigReshapedOnlyRHS config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
@@ -123,7 +132,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigReshaped config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
@@ -158,7 +167,7 @@ GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query)
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
mlgo::GEMMConfigNative config{};
- auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ const auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
if(mlgo_heuristics != nullptr)
{
std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
index 486c8bd6cb..020237b7f4 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
@@ -47,6 +47,22 @@ struct CommonQuery
unsigned int b; /**< Batch size */
};
+/** Result of querying about GEMM type ( @ref CLGEMMKernelType) */
+struct GEMMTypeResult
+{
+ GEMMTypeResult(bool valid, CLGEMMKernelType gemm_type)
+ : valid{ valid }, gemm_type{ gemm_type }
+ {
+ }
+ /** Test if the result is valid */
+ operator bool() const
+ {
+ return valid;
+ }
+ bool valid; /** If the result is valid */
+ CLGEMMKernelType gemm_type; /** @ref CLGEMMKernelType */
+};
+
/** Result of querying about GEMM config ( @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo) */
struct GEMMConfigResult
{
@@ -64,46 +80,53 @@ struct GEMMConfigResult
GEMMRHSMatrixInfo rhs_info; /** @ref GEMMRHSMatrixInfo */
};
-/** Automatically select between mlgo and default heuristics to choose @ref CLGEMMKernelType
+/** Select gemm type based on mlgo heuristics
+ * @param query Query
+ * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run
+ * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise
+ */
+GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run);
+
+/** Select gemm type based on default heuristics
* @param query Query
* @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run
- * @return CLGEMMKernelType
+ * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise
*/
-CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run);
+GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run);
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query);
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query);
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query);