diff options
Diffstat (limited to 'src/gpu/cl/kernels/gemm/ClGemmHelpers.h')
-rw-r--r-- | src/gpu/cl/kernels/gemm/ClGemmHelpers.h | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h index bf1e8fce82..6689b10e69 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022 Arm Limited. + * Copyright (c) 2019-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -35,6 +35,8 @@ namespace kernels { namespace gemm { +using GeMMConfigsMatrix = std::vector<std::vector<int32_t>>; + /** Configure @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo * * @param[in] m Number of rows (M) in the LHS matrix not reshaped @@ -103,6 +105,18 @@ Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, */ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b, const DataType data_type, unsigned int &best_m0, unsigned int &best_n0); + +/** Find the preferred configurations for the LHS and RHS tensor using the GeMMConfigsMatrix provided by the user + * + * @param[in] configs List of best configurations for a limited number of GeMM shapes + * @param[in] m Number of rows of the LHS matrix + * @param[in] n Number of columns of the RHS matrix + * @param[in] k Number of columns of the LHS matrix, rows of the RHS matrix + * @param[in] b Batch size + * + * @return @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo + */ +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b); } // namespace gemm } // namespace kernels } // namespace opencl |