From db35345753e4ba81384c8a92ece6a8f598fd841a Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Mon, 8 Feb 2021 15:16:13 +0000 Subject: Integrate MLGO into CLGEMMLowpMatrixMultiplyCore for native kernel Resolves COMPMID-3846 Signed-off-by: SiCong Li Change-Id: Iad66f6dd7fa5b13ebace9f95fbc2fc4d677cf6a9 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5032 Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice Reviewed-by: Pablo Marquez Tello Comments-Addressed: Arm Jenkins --- src/runtime/CL/functions/CLGEMM.cpp | 4 +- .../CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp | 68 ++++++++++++++++++---- .../gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp | 36 ++++++++++++ .../CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h | 20 +++++-- 4 files changed, 112 insertions(+), 16 deletions(-) diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index a0aaabf5fe..3e4c604740 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -141,7 +141,7 @@ inline std::pair auto_select_gemm_config_r return { config.lhs_info, config.rhs_info }; } } - config = select_default_gemm_config_reshaped_only_rhs(query); + config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query); ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); return { config.lhs_info, config.rhs_info }; } @@ -190,7 +190,7 @@ inline std::pair auto_select_gemm_config_r return { config.lhs_info, config.rhs_info }; } } - config = select_default_gemm_config_reshaped(query); + config = auto_heuristics::select_default_gemm_config_reshaped(query); ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); return { config.lhs_info, config.rhs_info }; } diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp index 6c4d9ef54a..2c9bb3cb66 100644 --- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp @@ -34,8 +34,6 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" -#include "src/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h" -#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h" #include "src/core/CL/kernels/CLDepthConvertLayerKernel.h" #include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.h" #include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h" @@ -44,7 +42,6 @@ #include "src/core/CL/kernels/CLGEMMLowpReductionKernel.h" #include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h" #include "src/core/helpers/AutoConfiguration.h" -#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h" #include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h" #include "utils/TypePrinter.h" @@ -55,6 +52,43 @@ using namespace arm_compute::cl_gemm; namespace { +// Validate lhs_info and rhs_info for native kernel +inline bool validate_lhs_rhs_info_native(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const GEMMReshapeInfo &reshape_info) +{ + // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel + TensorInfo mm_result_s32_info{}; + // Output tensor auto initialization if not yet initialized + auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*a, *b, false, reshape_info)).set_data_type(DataType::S32)); + // Validate mm kernel + // NOTE: Ignore all other parameters (eg. output stage etc.) and only validate lhs and rhs info + // NOTE: This assumes: + // 1. lhs and rhs info's validity does not depend on these other parameters and vice versa(in CLGEMMLowpMatrixMultiplyNativeKernel.cpp validate_arguments). + // 2. lhs and rhs info does not cause window and padding issues through side effects (in CLGEMMLowpMatrixMultiplyNativeKernel.cpp validate_and_configure_window). + if(!bool(CLGEMMLowpMatrixMultiplyNativeKernel::validate(a, b, &mm_result_s32_info, lhs_info, rhs_info, reshape_info))) + { + return false; + } + return true; +} + +// Automatically select between mlgo (prioritized) and default heuristics for native kernel configs +std::pair auto_select_gemm_config_native(auto_heuristics::CommonQuery query, const ITensorInfo *a, const ITensorInfo *b, const GEMMReshapeInfo &reshape_info) +{ + auto config = auto_heuristics::select_mlgo_gemm_config_native(query); + if(config) + { + if(validate_lhs_rhs_info_native(config.lhs_info, config.rhs_info, a, b, reshape_info)) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use native config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); + return { config.lhs_info, config.rhs_info }; + } + } + config = auto_heuristics::select_default_gemm_config_native(query); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use native config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); + return { config.lhs_info, config.rhs_info }; +} + +// Validate lhs_info and rhs_info for reshaped only rhs kernel inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *output, unsigned int m, unsigned int n, unsigned int k, bool reinterpret_input_as_3d, int depth_output_gemm3d) { @@ -89,6 +123,7 @@ inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs return true; } +// Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs std::pair auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, bool reinterpret_input_as_3d, int depth_output_gemm3d, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *output) @@ -102,7 +137,7 @@ std::pair auto_select_gemm_config_reshaped return { config.lhs_info, config.rhs_info }; } } - config = select_default_gemm_config_reshaped_only_rhs(query); + config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query); ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); return { config.lhs_info, config.rhs_info }; } @@ -195,6 +230,8 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); + const auto reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d); + // Check if we need to reshape the matrix A and matrix B _is_gemm_reshaped = is_gemm_reshaped(auto_select_gemm_kernel(auto_heuristics::CommonQuery{ gpu_target, a->info()->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run)); @@ -298,10 +335,12 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con else { // Pick up the GEMM configuration - std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); + // It doesn't matter whether Datatype is DataType::QASYMM8 or DataType::QASYMM8_SIGNED, since it only affect the shape configuration + std::tie(lhs_info, rhs_info) = auto_select_gemm_config_native(auto_heuristics::CommonQuery{ gpu_target, DataType::QASYMM8, m, n, k, batch_size }, + _matrix_a->info(), _convert_to_qasymm8 ? _qasymm8_weights.info() : matrix_b->info(), reshape_info); // Configure matrix multiply kernel - _mm_native_kernel->configure(compile_context, _matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d)); + _mm_native_kernel->configure(compile_context, _matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, reshape_info); _offset_contribution_output_stage_kernel->configure(compile_context, &_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c, output, a->info()->dimension(0), @@ -331,10 +370,12 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con else { // Pick up the GEMM configuration - std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); + // It doesn't matter whether Datatype is DataType::QASYMM8 or DataType::QASYMM8_SIGNED, since it only affect the shape configuration + std::tie(lhs_info, rhs_info) = auto_select_gemm_config_native(auto_heuristics::CommonQuery{ gpu_target, DataType::QASYMM8, m, n, k, batch_size }, + a->info(), _convert_to_qasymm8 ? _qasymm8_weights.info() : b->info(), reshape_info); // Configure matrix multiply kernel - _mm_native_kernel->configure(compile_context, _matrix_a, matrix_b, output, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d)); + _mm_native_kernel->configure(compile_context, _matrix_a, matrix_b, output, lhs_info, rhs_info, reshape_info); } // Configure offset contribution kernel @@ -490,7 +531,11 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, false, reshape_info)).set_data_type(DataType::S32)); // Pick up the GEMM configuration - std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); + // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails + // It doesn't matter whether Datatype is DataType::QASYMM8 or DataType::QASYMM8_SIGNED, since it only affect the shape configuration + const auto res = select_default_gemm_config_native(auto_heuristics::CommonQuery{ gpu_target, DataType::QASYMM8, m, n, k, batch_size }); + lhs_info = res.lhs_info; + rhs_info = res.rhs_info; // Validate matrix multiply ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info)); @@ -518,7 +563,10 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso else { // Pick up the GEMM configuration - std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); + // It doesn't matter whether Datatype is DataType::QASYMM8 or DataType::QASYMM8_SIGNED, since it only affect the shape configuration + const auto res = select_default_gemm_config_native(auto_heuristics::CommonQuery{ gpu_target, DataType::QASYMM8, m, n, k, batch_size }); + lhs_info = res.lhs_info; + rhs_info = res.rhs_info; // Validate matrix multiply ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, output, lhs_info, rhs_info, reshape_info)); diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp index 5790e077d4..c5618f2cce 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp @@ -29,6 +29,7 @@ #include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h" #include "src/core/CL/ICLGEMMKernelConfiguration.h" #include "src/core/CL/gemm/CLGEMMHelpers.cpp" +#include "src/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h" #include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h" #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h" #include "src/runtime/CL/gemm/CLGEMMKernelSelection.h" @@ -100,6 +101,7 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &qu { ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); } + // Setting irrelevant unsigned int parameters to 1 and bool parameters to false as they do no matter std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, config.h0, false, config.interleave_rhs, !config.transpose_rhs, config.transpose_rhs, config.export_cl_image); return GEMMConfigResult{ valid, lhs_info, rhs_info }; @@ -139,6 +141,40 @@ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query) config.export_cl_image); return GEMMConfigResult{ valid, lhs_info, rhs_info }; } + +GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query) +{ + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + std::unique_ptr gemm_config = CLGEMMNativeKernelConfigurationFactory::create(query.gpu_target); + ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); + std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type); + return GEMMConfigResult{ true, lhs_info, rhs_info }; +} + +GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query) +{ + bool valid = false; + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + mlgo::GEMMConfigNative config{}; + auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + if(mlgo_heuristics != nullptr) + { + std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); + } + if(valid) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str()); + } + else + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); + } + // Setting irrelevant unsigned int parameters to 1 and bool parameters to false as they do no matter + std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, 1, false, false, false, false, false); + return GEMMConfigResult{ valid, lhs_info, rhs_info }; +} } // namespace auto_heuristics } // namespace cl_gemm diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h index 7cb9cab220..486c8bd6cb 100644 --- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h @@ -73,28 +73,40 @@ CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_ /** Select gemm config based on mlgo heuristics * @param query Query - * @return GEMMConfigResult + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise */ GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query); /** Select gemm config based on default heuristics * @param query Query - * @return GEMMConfigResult + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise */ GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query); /** Select gemm config based on mlgo heuristics * @param query Query - * @return GEMMConfigResult + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise */ GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query); /** Select gemm config based on default heuristics * @param query Query - * @return GEMMConfigResult + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise */ GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query); +/** Select gemm config based on mlgo heuristics + * @param query Query + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + */ +GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query); + +/** Select gemm config based on default heuristics + * @param query Query + * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise + */ +GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query); + } // namespace auto_heuristics } // namespace cl_gemm } // namespace arm_compute -- cgit v1.2.1