aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp71
1 files changed, 57 insertions, 14 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index dcb9cb23ec..a0aaabf5fe 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -37,9 +37,6 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "arm_compute/runtime/ITensorAllocator.h"
-#include "src/core/CL/ICLGEMMKernelConfiguration.h"
-#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
-#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
#include "src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
@@ -102,6 +99,7 @@ void CLGEMMReshapeRHSMatrixKernelManaged::configure(const CLCompileContext &comp
namespace
{
+// Validate lhs_info and rhs_info for reshaped only rhs kernel
inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
{
@@ -129,6 +127,7 @@ inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs
return true;
}
+//Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
const ITensorInfo *b,
const ITensorInfo *c, const ITensorInfo *output)
@@ -147,6 +146,55 @@ inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_r
return { config.lhs_info, config.rhs_info };
}
+// Validate lhs_info and rhs_info for reshaped kernel
+inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
+ const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
+{
+ // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
+ TensorInfo tmp_a_info{};
+ TensorInfo tmp_b_info{};
+
+ // Validate reshape LHS kernel
+ auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
+ if(!bool(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
+ {
+ return false;
+ }
+
+ // Validate reshape RHS kernel
+ auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
+ if(!bool(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
+ {
+ return false;
+ }
+ // Validate mm kernel
+ gemm_kernel_info.lhs_info = lhs_info;
+ gemm_kernel_info.rhs_info = rhs_info;
+ if(!bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
+ {
+ return false;
+ }
+ return true;
+}
+
+//Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
+inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
+ const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
+{
+ auto config = auto_heuristics::select_mlgo_gemm_config_reshaped(query);
+ if(config)
+ {
+ if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
+ return { config.lhs_info, config.rhs_info };
+ }
+ }
+ config = select_default_gemm_config_reshaped(query);
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
+ return { config.lhs_info, config.rhs_info };
+}
+
} // namespace
CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
@@ -311,10 +359,8 @@ void CLGEMM::configure_reshaped_v2(const CLCompileContext &compile_context, cons
GEMMRHSMatrixInfo rhs_info{};
// Pick up the GEMM configuration
- std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
- ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
- // Configure lhs_info and rhs_info
- std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
+ std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a->info(), b->info(),
+ c == nullptr ? nullptr : c->info(), output->info(), gemm_info.reinterpret_input_as_3d());
_reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
@@ -518,11 +564,10 @@ Status CLGEMM::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, con
GEMMRHSMatrixInfo rhs_info;
// Pick up the GEMM configuration
- std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
-
- // Configure lhs_info and rhs_info
- std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
+ // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
+ const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
+ lhs_info = gemm_config.lhs_info;
+ rhs_info = gemm_config.rhs_info;
auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
@@ -567,8 +612,6 @@ Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf
GEMMRHSMatrixInfo rhs_info;
// Pick up the GEMM configuration
- // Note there is no need to validate the configuration from mlgo heuristics as it is already validated in configure() and will fall back
- // to default heuristics should it fail
// NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
lhs_info = gemm_config.lhs_info;