diff options
Diffstat (limited to 'src/runtime/CL')
-rw-r--r-- | src/runtime/CL/functions/CLGEMM.cpp | 98 | ||||
-rw-r--r-- | src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp | 108 | ||||
-rw-r--r-- | src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h | 89 |
3 files changed, 264 insertions, 31 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 181ae2843b..dcb9cb23ec 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,6 +29,7 @@ #include "arm_compute/core/GPUTarget.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/KernelDescriptors.h" +#include "arm_compute/core/Log.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" @@ -47,7 +48,9 @@ #include "src/core/helpers/AutoConfiguration.h" #include "src/core/utils/helpers/float_ops.h" #include "src/runtime/CL/gemm/CLGEMMKernelSelection.h" +#include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h" #include "support/Cast.h" +#include "utils/TypePrinter.h" namespace arm_compute { @@ -97,6 +100,55 @@ void CLGEMMReshapeRHSMatrixKernelManaged::configure(const CLCompileContext &comp } } // namespace weights_transformations +namespace +{ +inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, + const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info) +{ + // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel + TensorInfo tmp_b_info{}; + // Validate reshape RHS kernel + auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); + if(!bool(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info))) + { + return false; + } + // Validate mm kernel + gemm_kernel_info.lhs_info = lhs_info; + gemm_kernel_info.rhs_info = rhs_info; + gemm_kernel_info.has_pad_y = false; + if(!bool(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info))) + { + return false; + } + gemm_kernel_info.has_pad_y = true; + if(!bool(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info))) + { + return false; + } + return true; +} + +inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, const ITensorInfo *output) +{ + auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(query); + if(config) + { + if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info)) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); + return { config.lhs_info, config.rhs_info }; + } + } + config = select_default_gemm_config_reshaped_only_rhs(query); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str()); + return { config.lhs_info, config.rhs_info }; +} + +} // namespace + CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager) : _memory_group(std::move(memory_manager)), _weights_manager(weights_manager), @@ -120,22 +172,6 @@ CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager * CLGEMM::~CLGEMM() = default; -CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run) -{ - std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target()); - ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get()); - - CLGEMMKernelSelectionParams params; - params.m = m; - params.n = n; - params.k = k; - params.b = b; - params.is_rhs_constant = reshape_b_only_on_first_run; - params.data_type = data_type; - - return gemm_kernel->select_kernel(params); -} - void CLGEMM::configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info) { @@ -277,7 +313,6 @@ void CLGEMM::configure_reshaped_v2(const CLCompileContext &compile_context, cons // Pick up the GEMM configuration std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target); ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); - // Configure lhs_info and rhs_info std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type); @@ -343,11 +378,8 @@ void CLGEMM::configure_reshaped_only_rhs(const CLCompileContext &compile_context GEMMRHSMatrixInfo rhs_info{}; // Pick up the GEMM configuration - std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target); - ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); - - // Configure lhs_info and rhs_info - std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type); + std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a->info(), b->info(), + c == nullptr ? nullptr : c->info(), output->info()); ICLTensor *reshaped_rhs = &_tmp_b; if(_weights_manager && _weights_manager->are_weights_managed(b)) @@ -535,11 +567,12 @@ Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf GEMMRHSMatrixInfo rhs_info; // Pick up the GEMM configuration - std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target); - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get()); - - // Configure lhs_info and rhs_info - std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type); + // Note there is no need to validate the configuration from mlgo heuristics as it is already validated in configure() and will fall back + // to default heuristics should it fail + // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails + const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }); + lhs_info = gemm_config.lhs_info; + rhs_info = gemm_config.rhs_info; auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)); @@ -573,7 +606,6 @@ void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor _lhs = a; _dst = output; - // Get the GPU target bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1); const unsigned int n = b->info()->dimension(0); @@ -581,7 +613,7 @@ void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2); // Select GEMMType - _gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->info()->data_type(), _reshape_b_only_on_first_run); + _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->info()->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); @@ -626,7 +658,11 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); // Select GEMMType - CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->data_type(), gemm_info.reshape_b_only_on_first_run()); + CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery + { + CLScheduler::get().target(), a->data_type(), m, n, k, batch_size, + }, + gemm_info.reshape_b_only_on_first_run()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp new file mode 100644 index 0000000000..8f0d5f4953 --- /dev/null +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h" + +#include "arm_compute/core/Log.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/CL/CLScheduler.h" +#include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h" +#include "src/core/CL/ICLGEMMKernelConfiguration.h" +#include "src/core/CL/gemm/CLGEMMHelpers.cpp" +#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h" +#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h" +#include "src/runtime/CL/mlgo/MLGOHeuristics.h" +#include "src/runtime/CL/mlgo/Utils.h" +#include "utils/TypePrinter.h" + +namespace arm_compute +{ +namespace cl_gemm +{ +namespace auto_heuristics +{ +CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run) +{ + // Select between mlgo and default heuristics + auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + if(mlgo_heuristics != nullptr) + { + auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); + if(res.first) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str()); + return res.second; + } + } + std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target); + ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get()); + + CLGEMMKernelSelectionParams params; + params.m = query.m; + params.n = query.n; + params.k = query.k; + params.b = query.b; + params.is_rhs_constant = reshape_b_only_on_first_run; + params.data_type = query.data_type; + + const auto kernel_type = gemm_kernel->select_kernel(params); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str()); + return kernel_type; +} + +GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query) +{ + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(query.gpu_target); + ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); + std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type); + return GEMMConfigResult{ true, lhs_info, rhs_info }; +} + +GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query) +{ + bool valid = false; + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + mlgo::GEMMConfigReshapedOnlyRHS config{}; + auto mlgo_heuristics = CLScheduler::get().gemm_heuristics(); + if(mlgo_heuristics != nullptr) + { + std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b }); + } + if(valid) + { + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str()); + } + else + { + ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed"); + } + std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, config.h0, false, config.interleave_rhs, !config.transpose_rhs, config.transpose_rhs, + config.export_cl_image); + return GEMMConfigResult{ valid, lhs_info, rhs_info }; +} +} // namespace auto_heuristics +} // namespace cl_gemm +} // namespace arm_compute
\ No newline at end of file diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h new file mode 100644 index 0000000000..e4fa1a6234 --- /dev/null +++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_RUNTIME_CL_GEMM_AUTO_HEURISTICS_CL_GEMM_AUTO_HEURISTICS_H +#define SRC_RUNTIME_CL_GEMM_AUTO_HEURISTICS_CL_GEMM_AUTO_HEURISTICS_H + +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/runtime/CL/CLTypes.h" + +namespace arm_compute +{ +namespace cl_gemm +{ +namespace auto_heuristics +{ +/** A collection of adaptor functions that enable the auto selection between mlgo-based heuristics and default heuristics */ + +/** Common query */ +struct CommonQuery +{ + GPUTarget gpu_target; /**< Which @ref GPUTarget to query about */ + DataType data_type; /**< Data type */ + unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */ + unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */ + unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */ + unsigned int b; /**< Batch size */ +}; + +/** Result of querying about GEMM config ( @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo) */ +struct GEMMConfigResult +{ + GEMMConfigResult(bool valid, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info) + : valid{ valid }, lhs_info{ lhs_info }, rhs_info{ rhs_info } + { + } + /** Test if the result is valid */ + operator bool() const + { + return valid; + } + bool valid; /** If the result is valid */ + GEMMLHSMatrixInfo lhs_info; /** @ref GEMMLHSMatrixInfo */ + GEMMRHSMatrixInfo rhs_info; /** @ref GEMMRHSMatrixInfo */ +}; + +/** Automatically select between mlgo and default heuristics to choose @ref CLGEMMKernelType + * @param query Query + * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run + * @return CLGEMMKernelType + */ +CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run); + +/** Select gemm config based on mlgo heuristics + * @param query Query + * @return GEMMConfigResult + */ +GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query); + +/** Select gemm config based on default heuristics + * @param query Query + * @return GEMMConfigResult + */ +GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query); +} // namespace auto_heuristics +} // namespace cl_gemm +} // namespace arm_compute + +#endif // SRC_RUNTIME_CL_GEMM_AUTO_HEURISTICS_CL_GEMM_AUTO_HEURISTICS_H
\ No newline at end of file |