diff options
Diffstat (limited to 'src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp')
-rw-r--r-- | src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp index 8f0d5f4953..5790e077d4 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp @@ -29,6 +29,7 @@ #include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h" #include "src/core/CL/ICLGEMMKernelConfiguration.h" #include "src/core/CL/gemm/CLGEMMHelpers.cpp" +#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h" #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h" #include "src/runtime/CL/gemm/CLGEMMKernelSelection.h" #include "src/runtime/CL/mlgo/MLGOHeuristics.h" @@ -103,6 +104,42 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu config.export_cl_image); return GEMMConfigResult{ valid, lhs_info, rhs_info }; } + +GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query) +{ + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(query.gpu_target); + ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); + std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type); + return GEMMConfigResult{ true, lhs_info, rhs_info }; +} + +GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query) +{ + bool valid = false; + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + mlgo::GEMMConfigReshaped config{}; + auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + if(mlgo_heuristics != nullptr) + { + std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); + } + if(valid) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str()); + } + else + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); + } + std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, config.v0, config.h0, config.interleave_lhs, config.interleave_rhs, !config.transpose_rhs, + config.transpose_rhs, + config.export_cl_image); + return GEMMConfigResult{ valid, lhs_info, rhs_info }; +} } // namespace auto_heuristics + } // namespace cl_gemm } // namespace arm_compute
\ No newline at end of file |