aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp31
1 files changed, 31 insertions, 0 deletions
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 2c9bb3cb66..b189955c04 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -52,6 +52,37 @@ using namespace arm_compute::cl_gemm;
namespace
{
+inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
+{
+ switch(kernel_type)
+ {
+ case CLGEMMKernelType::NATIVE:
+ case CLGEMMKernelType::RESHAPED_ONLY_RHS:
+ {
+ return true;
+ }
+ default:
+ {
+ return false;
+ }
+ }
+}
+//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
+inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run)
+{
+ auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
+ if(bool(gemm_kernel))
+ {
+ if(validate_gemm_kernel(gemm_kernel.gemm_type))
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
+ return gemm_kernel.gemm_type;
+ }
+ }
+ gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
+ return gemm_kernel.gemm_type;
+}
// Validate lhs_info and rhs_info for native kernel
inline bool validate_lhs_rhs_info_native(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const GEMMReshapeInfo &reshape_info)
{