From 8c23ba1c5041467f314a10c1da9147e41d056139 Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Mon, 8 Feb 2021 14:19:23 +0000 Subject: Integrate MLGO into CLGEMM for reshaped kernel Resolves COMPMID-3845 Signed-off-by: SiCong Li Change-Id: I878ea6dc076177095816a75f9bc951326fd095b3 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5031 Comments-Addressed: Arm Jenkins Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- src/runtime/CL/functions/CLGEMM.cpp | 71 +++++++++++++++++----- .../gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp | 37 +++++++++++ .../CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h | 13 ++++ 3 files changed, 107 insertions(+), 14 deletions(-) diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index dcb9cb23ec..a0aaabf5fe 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -37,9 +37,6 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLScheduler.h" #include "arm_compute/runtime/ITensorAllocator.h" -#include "src/core/CL/ICLGEMMKernelConfiguration.h" -#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h" -#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h" #include "src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h" #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h" #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h" @@ -102,6 +99,7 @@ void CLGEMMReshapeRHSMatrixKernelManaged::configure(const CLCompileContext &comp namespace { +// Validate lhs_info and rhs_info for reshaped only rhs kernel inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info) { @@ -129,6 +127,7 @@ inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs return true; } +//Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs inline std::pair auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output) @@ -147,6 +146,55 @@ inline std::pair auto_select_gemm_config_r return { config.lhs_info, config.rhs_info }; } +// Validate lhs_info and rhs_info for reshaped kernel +inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, + const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d) +{ + // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel + TensorInfo tmp_a_info{}; + TensorInfo tmp_b_info{}; + + // Validate reshape LHS kernel + auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d))); + if(!bool(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d))) + { + return false; + } + + // Validate reshape RHS kernel + auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); + if(!bool(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info))) + { + return false; + } + // Validate mm kernel + gemm_kernel_info.lhs_info = lhs_info; + gemm_kernel_info.rhs_info = rhs_info; + if(!bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info))) + { + return false; + } + return true; +} + +//Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs +inline std::pair auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b, + const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d) +{ + auto config = auto_heuristics::select_mlgo_gemm_config_reshaped(query); + if(config) + { + if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d)) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); + return { config.lhs_info, config.rhs_info }; + } + } + config = select_default_gemm_config_reshaped(query); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); + return { config.lhs_info, config.rhs_info }; +} + } // namespace CLGEMM::CLGEMM(std::shared_ptr memory_manager, IWeightsManager *weights_manager) @@ -311,10 +359,8 @@ void CLGEMM::configure_reshaped_v2(const CLCompileContext &compile_context, cons GEMMRHSMatrixInfo rhs_info{}; // Pick up the GEMM configuration - std::unique_ptr gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target); - ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); - // Configure lhs_info and rhs_info - std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type); + std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a->info(), b->info(), + c == nullptr ? nullptr : c->info(), output->info(), gemm_info.reinterpret_input_as_3d()); _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d()); @@ -518,11 +564,10 @@ Status CLGEMM::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, con GEMMRHSMatrixInfo rhs_info; // Pick up the GEMM configuration - std::unique_ptr gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target); - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get()); - - // Configure lhs_info and rhs_info - std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type); + // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails + const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }); + lhs_info = gemm_config.lhs_info; + rhs_info = gemm_config.rhs_info; auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d()))); ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d())); @@ -567,8 +612,6 @@ Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf GEMMRHSMatrixInfo rhs_info; // Pick up the GEMM configuration - // Note there is no need to validate the configuration from mlgo heuristics as it is already validated in configure() and will fall back - // to default heuristics should it fail // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }); lhs_info = gemm_config.lhs_info; diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp index 8f0d5f4953..5790e077d4 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp @@ -29,6 +29,7 @@ #include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h" #include "src/core/CL/ICLGEMMKernelConfiguration.h" #include "src/core/CL/gemm/CLGEMMHelpers.cpp" +#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h" #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h" #include "src/runtime/CL/gemm/CLGEMMKernelSelection.h" #include "src/runtime/CL/mlgo/MLGOHeuristics.h" @@ -103,6 +104,42 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu config.export_cl_image); return GEMMConfigResult{ valid, lhs_info, rhs_info }; } + +GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query) +{ + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + std::unique_ptr gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(query.gpu_target); + ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); + std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type); + return GEMMConfigResult{ true, lhs_info, rhs_info }; +} + +GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query) +{ + bool valid = false; + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + mlgo::GEMMConfigReshaped config{}; + auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + if(mlgo_heuristics != nullptr) + { + std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); + } + if(valid) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str()); + } + else + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); + } + std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, config.v0, config.h0, config.interleave_lhs, config.interleave_rhs, !config.transpose_rhs, + config.transpose_rhs, + config.export_cl_image); + return GEMMConfigResult{ valid, lhs_info, rhs_info }; +} } // namespace auto_heuristics + } // namespace cl_gemm } // namespace arm_compute \ No newline at end of file diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h index e4fa1a6234..7cb9cab220 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h @@ -82,6 +82,19 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu * @return GEMMConfigResult */ GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query); + +/** Select gemm config based on mlgo heuristics + * @param query Query + * @return GEMMConfigResult + */ +GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query); + +/** Select gemm config based on default heuristics + * @param query Query + * @return GEMMConfigResult + */ +GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query); + } // namespace auto_heuristics } // namespace cl_gemm } // namespace arm_compute -- cgit v1.2.1