diff options
author | SiCong Li <sicong.li@arm.com> | 2021-02-08 15:16:13 +0000 |
---|---|---|
committer | SiCong Li <sicong.li@arm.com> | 2021-02-10 15:25:40 +0000 |
commit | db35345753e4ba81384c8a92ece6a8f598fd841a (patch) | |
tree | 100d05c21fb794fc1a1f5fdeb32bf21572719e3d /src/runtime/CL/gemm_auto_heuristics | |
parent | 79144a642b33ff1ac40a44aaa1881261d12e6376 (diff) | |
download | ComputeLibrary-db35345753e4ba81384c8a92ece6a8f598fd841a.tar.gz |
Integrate MLGO into CLGEMMLowpMatrixMultiplyCore for native kernel
Resolves COMPMID-3846
Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: Iad66f6dd7fa5b13ebace9f95fbc2fc4d677cf6a9
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5032
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/gemm_auto_heuristics')
-rw-r--r-- | src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp | 36 | ||||
-rw-r--r-- | src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h | 20 |
2 files changed, 52 insertions, 4 deletions
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp index 5790e077d4..c5618f2cce 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp @@ -29,6 +29,7 @@ #include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h" #include "src/core/CL/ICLGEMMKernelConfiguration.h" #include "src/core/CL/gemm/CLGEMMHelpers.cpp" +#include "src/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h" #include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h" #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h" #include "src/runtime/CL/gemm/CLGEMMKernelSelection.h" @@ -100,6 +101,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu { ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); } + // Setting irrelevant unsigned int parameters to 1 and bool parameters to false as they do no matter std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, config.h0, false, config.interleave_rhs, !config.transpose_rhs, config.transpose_rhs, config.export_cl_image); return GEMMConfigResult{ valid, lhs_info, rhs_info }; @@ -139,6 +141,40 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query) config.export_cl_image); return GEMMConfigResult{ valid, lhs_info, rhs_info }; } + +GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query) +{ + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMNativeKernelConfigurationFactory::create(query.gpu_target); + ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); + std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type); + return GEMMConfigResult{ true, lhs_info, rhs_info }; +} + +GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query) +{ + bool valid = false; + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + mlgo::GEMMConfigNative config{}; + auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + if(mlgo_heuristics != nullptr) + { + std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); + } + if(valid) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str()); + } + else + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); + } + // Setting irrelevant unsigned int parameters to 1 and bool parameters to false as they do no matter + std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, 1, false, false, false, false, false); + return GEMMConfigResult{ valid, lhs_info, rhs_info }; +} } // namespace auto_heuristics } // namespace cl_gemm diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h index 7cb9cab220..486c8bd6cb 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h @@ -73,28 +73,40 @@ CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_ /** Select gemm config based on mlgo heuristics * @param query Query - * @return GEMMConfigResult + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise */ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query); /** Select gemm config based on default heuristics * @param query Query - * @return GEMMConfigResult + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise */ GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query); /** Select gemm config based on mlgo heuristics * @param query Query - * @return GEMMConfigResult + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise */ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query); /** Select gemm config based on default heuristics * @param query Query - * @return GEMMConfigResult + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise */ GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query); +/** Select gemm config based on mlgo heuristics + * @param query Query + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + */ +GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query); + +/** Select gemm config based on default heuristics + * @param query Query + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + */ +GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query); + } // namespace auto_heuristics } // namespace cl_gemm } // namespace arm_compute |