aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/gemm/ClGemmHelpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/gemm/ClGemmHelpers.h')
-rw-r--r--src/gpu/cl/kernels/gemm/ClGemmHelpers.h16
1 files changed, 15 insertions, 1 deletions
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h
index bf1e8fce82..6689b10e69 100644
--- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h
+++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,6 +35,8 @@ namespace kernels
{
namespace gemm
{
+using GeMMConfigsMatrix = std::vector<std::vector<int32_t>>;
+
/** Configure @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo
*
* @param[in] m Number of rows (M) in the LHS matrix not reshaped
@@ -103,6 +105,18 @@ Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info,
*/
bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b,
const DataType data_type, unsigned int &best_m0, unsigned int &best_n0);
+
+/** Find the preferred configurations for the LHS and RHS tensor using the GeMMConfigsMatrix provided by the user
+ *
+ * @param[in] configs List of best configurations for a limited number of GeMM shapes
+ * @param[in] m Number of rows of the LHS matrix
+ * @param[in] n Number of columns of the RHS matrix
+ * @param[in] k Number of columns of the LHS matrix, rows of the RHS matrix
+ * @param[in] b Batch size
+ *
+ * @return @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo
+ */
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b);
} // namespace gemm
} // namespace kernels
} // namespace opencl