aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp')
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp36
1 files changed, 36 insertions, 0 deletions
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
index 5790e077d4..c5618f2cce 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
@@ -29,6 +29,7 @@
#include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h"
#include "src/core/CL/ICLGEMMKernelConfiguration.h"
#include "src/core/CL/gemm/CLGEMMHelpers.cpp"
+#include "src/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h"
#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
@@ -100,6 +101,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu
{
ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
}
+ // Setting irrelevant unsigned int parameters to 1 and bool parameters to false as they do no matter
std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, config.h0, false, config.interleave_rhs, !config.transpose_rhs, config.transpose_rhs,
config.export_cl_image);
return GEMMConfigResult{ valid, lhs_info, rhs_info };
@@ -139,6 +141,40 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
config.export_cl_image);
return GEMMConfigResult{ valid, lhs_info, rhs_info };
}
+
+GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query)
+{
+ GEMMLHSMatrixInfo lhs_info;
+ GEMMRHSMatrixInfo rhs_info;
+ std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMNativeKernelConfigurationFactory::create(query.gpu_target);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
+ std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type);
+ return GEMMConfigResult{ true, lhs_info, rhs_info };
+}
+
+GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query)
+{
+ bool valid = false;
+ GEMMLHSMatrixInfo lhs_info;
+ GEMMRHSMatrixInfo rhs_info;
+ mlgo::GEMMConfigNative config{};
+ auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ if(mlgo_heuristics != nullptr)
+ {
+ std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
+ }
+ if(valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str());
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
+ }
+ // Setting irrelevant unsigned int parameters to 1 and bool parameters to false as they do no matter
+ std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, 1, false, false, false, false, false);
+ return GEMMConfigResult{ valid, lhs_info, rhs_info };
+}
} // namespace auto_heuristics
} // namespace cl_gemm