diff options
author | SiCong Li <sicong.li@arm.com> | 2021-02-08 14:19:23 +0000 |
---|---|---|
committer | SiCong Li <sicong.li@arm.com> | 2021-02-09 18:39:43 +0000 |
commit | 8c23ba1c5041467f314a10c1da9147e41d056139 (patch) | |
tree | ee359b5e3715a7208add5b8c99e5a81845ec123b /src/runtime/CL/gemm_auto_heuristics | |
parent | 373b407558f99eb4bba632c170d03d807941dd2a (diff) | |
download | ComputeLibrary-8c23ba1c5041467f314a10c1da9147e41d056139.tar.gz |
Integrate MLGO into CLGEMM for reshaped kernel
Resolves COMPMID-3845
Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: I878ea6dc076177095816a75f9bc951326fd095b3
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5031
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/gemm_auto_heuristics')
-rw-r--r-- | src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp | 37 | ||||
-rw-r--r-- | src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h | 13 |
2 files changed, 50 insertions, 0 deletions
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp index 8f0d5f4953..5790e077d4 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp @@ -29,6 +29,7 @@ #include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h" #include "src/core/CL/ICLGEMMKernelConfiguration.h" #include "src/core/CL/gemm/CLGEMMHelpers.cpp" +#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h" #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h" #include "src/runtime/CL/gemm/CLGEMMKernelSelection.h" #include "src/runtime/CL/mlgo/MLGOHeuristics.h" @@ -103,6 +104,42 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu config.export_cl_image); return GEMMConfigResult{ valid, lhs_info, rhs_info }; } + +GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query) +{ + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(query.gpu_target); + ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); + std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type); + return GEMMConfigResult{ true, lhs_info, rhs_info }; +} + +GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query) +{ + bool valid = false; + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + mlgo::GEMMConfigReshaped config{}; + auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + if(mlgo_heuristics != nullptr) + { + std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); + } + if(valid) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str()); + } + else + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); + } + std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, config.v0, config.h0, config.interleave_lhs, config.interleave_rhs, !config.transpose_rhs, + config.transpose_rhs, + config.export_cl_image); + return GEMMConfigResult{ valid, lhs_info, rhs_info }; +} } // namespace auto_heuristics + } // namespace cl_gemm } // namespace arm_compute
\ No newline at end of file diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h index e4fa1a6234..7cb9cab220 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h @@ -82,6 +82,19 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu * @return GEMMConfigResult */ GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query); + +/** Select gemm config based on mlgo heuristics + * @param query Query + * @return GEMMConfigResult + */ +GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query); + +/** Select gemm config based on default heuristics + * @param query Query + * @return GEMMConfigResult + */ +GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query); + } // namespace auto_heuristics } // namespace cl_gemm } // namespace arm_compute |