aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Android.bp1
-rw-r--r--SConscript1
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMM.h4
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp98
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp108
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h89
6 files changed, 267 insertions, 34 deletions
diff --git a/Android.bp b/Android.bp
index b531457673..31bc14b6b3 100644
--- a/Android.bp
+++ b/Android.bp
@@ -613,6 +613,7 @@ cc_library_static {
"src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.cpp",
"src/runtime/CL/gemm/CLGEMMDefaultTypeMidgard.cpp",
"src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp",
+ "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp",
"src/runtime/CL/mlgo/HeuristicTree.cpp",
"src/runtime/CL/mlgo/MLGOHeuristics.cpp",
"src/runtime/CL/mlgo/MLGOParser.cpp",
diff --git a/SConscript b/SConscript
index 307a0983d8..7d20ffe3f1 100644
--- a/SConscript
+++ b/SConscript
@@ -227,6 +227,7 @@ if env['opencl']:
runtime_files += Glob('src/runtime/gpu/cl/*.cpp')
runtime_files += Glob('src/runtime/gpu/cl/operators/*.cpp')
runtime_files += Glob('src/runtime/CL/mlgo/*.cpp')
+ runtime_files += Glob('src/runtime/CL/gemm_auto_heuristics/*.cpp')
graph_files += Glob('src/graph/backends/CL/*.cpp')
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index 3d645bdbff..8a210a2ba5 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2020 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -183,8 +183,6 @@ public:
void prepare() override;
private:
- static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run);
-
void configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
void configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
void configure_reshaped_v2(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 181ae2843b..dcb9cb23ec 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -29,6 +29,7 @@
#include "arm_compute/core/GPUTarget.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/KernelDescriptors.h"
+#include "arm_compute/core/Log.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
@@ -47,7 +48,9 @@
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/utils/helpers/float_ops.h"
#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
+#include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h"
#include "support/Cast.h"
+#include "utils/TypePrinter.h"
namespace arm_compute
{
@@ -97,6 +100,55 @@ void CLGEMMReshapeRHSMatrixKernelManaged::configure(const CLCompileContext &comp
}
} // namespace weights_transformations
+namespace
+{
+inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
+ const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
+{
+ // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
+ TensorInfo tmp_b_info{};
+ // Validate reshape RHS kernel
+ auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
+ if(!bool(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
+ {
+ return false;
+ }
+ // Validate mm kernel
+ gemm_kernel_info.lhs_info = lhs_info;
+ gemm_kernel_info.rhs_info = rhs_info;
+ gemm_kernel_info.has_pad_y = false;
+ if(!bool(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
+ {
+ return false;
+ }
+ gemm_kernel_info.has_pad_y = true;
+ if(!bool(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
+ {
+ return false;
+ }
+ return true;
+}
+
+inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c, const ITensorInfo *output)
+{
+ auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(query);
+ if(config)
+ {
+ if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
+ return { config.lhs_info, config.rhs_info };
+ }
+ }
+ config = select_default_gemm_config_reshaped_only_rhs(query);
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
+ return { config.lhs_info, config.rhs_info };
+}
+
+} // namespace
+
CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
: _memory_group(std::move(memory_manager)),
_weights_manager(weights_manager),
@@ -120,22 +172,6 @@ CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *
CLGEMM::~CLGEMM() = default;
-CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run)
-{
- std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target());
- ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
-
- CLGEMMKernelSelectionParams params;
- params.m = m;
- params.n = n;
- params.k = k;
- params.b = b;
- params.is_rhs_constant = reshape_b_only_on_first_run;
- params.data_type = data_type;
-
- return gemm_kernel->select_kernel(params);
-}
-
void CLGEMM::configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
const GEMMInfo &gemm_info)
{
@@ -277,7 +313,6 @@ void CLGEMM::configure_reshaped_v2(const CLCompileContext &compile_context, cons
// Pick up the GEMM configuration
std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
-
// Configure lhs_info and rhs_info
std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
@@ -343,11 +378,8 @@ void CLGEMM::configure_reshaped_only_rhs(const CLCompileContext &compile_context
GEMMRHSMatrixInfo rhs_info{};
// Pick up the GEMM configuration
- std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
- ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
-
- // Configure lhs_info and rhs_info
- std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
+ std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a->info(), b->info(),
+ c == nullptr ? nullptr : c->info(), output->info());
ICLTensor *reshaped_rhs = &_tmp_b;
if(_weights_manager && _weights_manager->are_weights_managed(b))
@@ -535,11 +567,12 @@ Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf
GEMMRHSMatrixInfo rhs_info;
// Pick up the GEMM configuration
- std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
-
- // Configure lhs_info and rhs_info
- std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
+ // Note there is no need to validate the configuration from mlgo heuristics as it is already validated in configure() and will fall back
+ // to default heuristics should it fail
+ // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
+ const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
+ lhs_info = gemm_config.lhs_info;
+ rhs_info = gemm_config.rhs_info;
auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
@@ -573,7 +606,6 @@ void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor
_lhs = a;
_dst = output;
- // Get the GPU target
bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
const unsigned int n = b->info()->dimension(0);
@@ -581,7 +613,7 @@ void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor
const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
// Select GEMMType
- _gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->info()->data_type(), _reshape_b_only_on_first_run);
+ _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->info()->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run);
const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
@@ -626,7 +658,11 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
// Select GEMMType
- CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->data_type(), gemm_info.reshape_b_only_on_first_run());
+ CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery
+ {
+ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
+ },
+ gemm_info.reshape_b_only_on_first_run());
const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
new file mode 100644
index 0000000000..8f0d5f4953
--- /dev/null
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h"
+
+#include "arm_compute/core/Log.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h"
+#include "src/core/CL/ICLGEMMKernelConfiguration.h"
+#include "src/core/CL/gemm/CLGEMMHelpers.cpp"
+#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
+#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
+#include "src/runtime/CL/mlgo/MLGOHeuristics.h"
+#include "src/runtime/CL/mlgo/Utils.h"
+#include "utils/TypePrinter.h"
+
+namespace arm_compute
+{
+namespace cl_gemm
+{
+namespace auto_heuristics
+{
+CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+{
+ // Select between mlgo and default heuristics
+ auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ if(mlgo_heuristics != nullptr)
+ {
+ auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
+ if(res.first)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str());
+ return res.second;
+ }
+ }
+ std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
+
+ CLGEMMKernelSelectionParams params;
+ params.m = query.m;
+ params.n = query.n;
+ params.k = query.k;
+ params.b = query.b;
+ params.is_rhs_constant = reshape_b_only_on_first_run;
+ params.data_type = query.data_type;
+
+ const auto kernel_type = gemm_kernel->select_kernel(params);
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str());
+ return kernel_type;
+}
+
+GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
+{
+ GEMMLHSMatrixInfo lhs_info;
+ GEMMRHSMatrixInfo rhs_info;
+ std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(query.gpu_target);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
+ std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type);
+ return GEMMConfigResult{ true, lhs_info, rhs_info };
+}
+
+GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query)
+{
+ bool valid = false;
+ GEMMLHSMatrixInfo lhs_info;
+ GEMMRHSMatrixInfo rhs_info;
+ mlgo::GEMMConfigReshapedOnlyRHS config{};
+ auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+ if(mlgo_heuristics != nullptr)
+ {
+ std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
+ }
+ if(valid)
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str());
+ }
+ else
+ {
+ ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
+ }
+ std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, config.h0, false, config.interleave_rhs, !config.transpose_rhs, config.transpose_rhs,
+ config.export_cl_image);
+ return GEMMConfigResult{ valid, lhs_info, rhs_info };
+}
+} // namespace auto_heuristics
+} // namespace cl_gemm
+} // namespace arm_compute \ No newline at end of file
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
new file mode 100644
index 0000000000..e4fa1a6234
--- /dev/null
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef SRC_RUNTIME_CL_GEMM_AUTO_HEURISTICS_CL_GEMM_AUTO_HEURISTICS_H
+#define SRC_RUNTIME_CL_GEMM_AUTO_HEURISTICS_CL_GEMM_AUTO_HEURISTICS_H
+
+#include "arm_compute/core/GPUTarget.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/CL/CLTypes.h"
+
+namespace arm_compute
+{
+namespace cl_gemm
+{
+namespace auto_heuristics
+{
+/** A collection of adaptor functions that enable the auto selection between mlgo-based heuristics and default heuristics */
+
+/** Common query */
+struct CommonQuery
+{
+ GPUTarget gpu_target; /**< Which @ref GPUTarget to query about */
+ DataType data_type; /**< Data type */
+ unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */
+ unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */
+ unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */
+ unsigned int b; /**< Batch size */
+};
+
+/** Result of querying about GEMM config ( @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo) */
+struct GEMMConfigResult
+{
+ GEMMConfigResult(bool valid, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info)
+ : valid{ valid }, lhs_info{ lhs_info }, rhs_info{ rhs_info }
+ {
+ }
+ /** Test if the result is valid */
+ operator bool() const
+ {
+ return valid;
+ }
+ bool valid; /** If the result is valid */
+ GEMMLHSMatrixInfo lhs_info; /** @ref GEMMLHSMatrixInfo */
+ GEMMRHSMatrixInfo rhs_info; /** @ref GEMMRHSMatrixInfo */
+};
+
+/** Automatically select between mlgo and default heuristics to choose @ref CLGEMMKernelType
+ * @param query Query
+ * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run
+ * @return CLGEMMKernelType
+ */
+CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run);
+
+/** Select gemm config based on mlgo heuristics
+ * @param query Query
+ * @return GEMMConfigResult
+ */
+GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query);
+
+/** Select gemm config based on default heuristics
+ * @param query Query
+ * @return GEMMConfigResult
+ */
+GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query);
+} // namespace auto_heuristics
+} // namespace cl_gemm
+} // namespace arm_compute
+
+#endif // SRC_RUNTIME_CL_GEMM_AUTO_HEURISTICS_CL_GEMM_AUTO_HEURISTICS_H \ No newline at end of file