diff options
Diffstat (limited to 'src/gpu/cl/kernels/gemm/ClGemmHelpers.h')
-rw-r--r-- | src/gpu/cl/kernels/gemm/ClGemmHelpers.h | 36 |
1 files changed, 28 insertions, 8 deletions
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h index 6689b10e69..84776fb207 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h @@ -54,8 +54,18 @@ using GeMMConfigsMatrix = std::vector<std::vector<int32_t>>; * * @return @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo */ -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, - bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image = false); +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_lhs_rhs_info(unsigned int m, + unsigned int n, + unsigned int m0, + unsigned int n0, + unsigned int k0, + unsigned int v0, + unsigned int h0, + bool lhs_interleave, + bool rhs_interleave, + bool lhs_transpose, + bool rhs_transpose, + bool export_to_cl_image = false); /** Select @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo * @@ -72,9 +82,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_lhs_rhs_info(unsigned * * @return @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo */ -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> select_lhs_rhs_info(std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> info_img, - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> info_buf, - unsigned int n, unsigned int k, unsigned int b, DataType data_type); +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> +select_lhs_rhs_info(std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> info_img, + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> info_buf, + unsigned int n, + unsigned int k, + unsigned int b, + DataType data_type); /** Update padding required to export the OpenCL buffer to OpenCL image2d * @@ -103,8 +117,13 @@ Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, * * @return true if MMUL kernel is preferred over kernels w/o MMUL, false otherwise */ -bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b, - const DataType data_type, unsigned int &best_m0, unsigned int &best_n0); +bool is_mmul_kernel_preferred(const unsigned int m, + const unsigned int n, + const unsigned int k, + const unsigned int b, + const DataType data_type, + unsigned int &best_m0, + unsigned int &best_n0); /** Find the preferred configurations for the LHS and RHS tensor using the GeMMConfigsMatrix provided by the user * @@ -116,7 +135,8 @@ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const * * @return @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo */ -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b); +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> +find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b); } // namespace gemm } // namespace kernels } // namespace opencl |