/* * Copyright (c) 2021 Arm Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h" #include "arm_compute/core/Log.h" #include "arm_compute/core/Validate.h" #include "arm_compute/runtime/CL/CLScheduler.h" #include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h" #include "src/core/CL/ICLGEMMKernelConfiguration.h" #include "src/core/CL/gemm/CLGEMMHelpers.cpp" #include "src/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h" #include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h" #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h" #include "src/runtime/CL/gemm/CLGEMMKernelSelection.h" #include "src/runtime/CL/mlgo/MLGOHeuristics.h" #include "src/runtime/CL/mlgo/Utils.h" #include "utils/TypePrinter.h" namespace arm_compute { namespace cl_gemm { namespace auto_heuristics { CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) { // Select between mlgo and default heuristics auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); if(res.first) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str()); return res.second; } } std::unique_ptr gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target); ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get()); CLGEMMKernelSelectionParams params; params.m = query.m; params.n = query.n; params.k = query.k; params.b = query.b; params.is_rhs_constant = reshape_b_only_on_first_run; params.data_type = query.data_type; const auto kernel_type = gemm_kernel->select_kernel(params); ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str()); return kernel_type; } GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query) { GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; std::unique_ptr gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(query.gpu_target); ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type); return GEMMConfigResult{ true, lhs_info, rhs_info }; } GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query) { bool valid = false; GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigReshapedOnlyRHS config{}; auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); } if(valid) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str()); } else { ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); } // Setting irrelevant unsigned int parameters to 1 and bool parameters to false as they do no matter std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, config.h0, false, config.interleave_rhs, !config.transpose_rhs, config.transpose_rhs, config.export_cl_image); return GEMMConfigResult{ valid, lhs_info, rhs_info }; } GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query) { GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; std::unique_ptr gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(query.gpu_target); ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type); return GEMMConfigResult{ true, lhs_info, rhs_info }; } GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query) { bool valid = false; GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigReshaped config{}; auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); } if(valid) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str()); } else { ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); } std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, config.v0, config.h0, config.interleave_lhs, config.interleave_rhs, !config.transpose_rhs, config.transpose_rhs, config.export_cl_image); return GEMMConfigResult{ valid, lhs_info, rhs_info }; } GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query) { GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; std::unique_ptr gemm_config = CLGEMMNativeKernelConfigurationFactory::create(query.gpu_target); ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type); return GEMMConfigResult{ true, lhs_info, rhs_info }; } GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query) { bool valid = false; GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; mlgo::GEMMConfigNative config{}; auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); if(mlgo_heuristics != nullptr) { std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); } if(valid) { ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str()); } else { ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); } // Setting irrelevant unsigned int parameters to 1 and bool parameters to false as they do no matter std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, 1, false, false, false, false, false); return GEMMConfigResult{ valid, lhs_info, rhs_info }; } } // namespace auto_heuristics } // namespace cl_gemm } // namespace arm_compute