aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h')
-rw-r--r--src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h41
1 files changed, 32 insertions, 9 deletions
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
index 486c8bd6cb..020237b7f4 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
@@ -47,6 +47,22 @@ struct CommonQuery
unsigned int b; /**< Batch size */
};
+/** Result of querying about GEMM type ( @ref CLGEMMKernelType) */
+struct GEMMTypeResult
+{
+ GEMMTypeResult(bool valid, CLGEMMKernelType gemm_type)
+ : valid{ valid }, gemm_type{ gemm_type }
+ {
+ }
+ /** Test if the result is valid */
+ operator bool() const
+ {
+ return valid;
+ }
+ bool valid; /** If the result is valid */
+ CLGEMMKernelType gemm_type; /** @ref CLGEMMKernelType */
+};
+
/** Result of querying about GEMM config ( @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo) */
struct GEMMConfigResult
{
@@ -64,46 +80,53 @@ struct GEMMConfigResult
GEMMRHSMatrixInfo rhs_info; /** @ref GEMMRHSMatrixInfo */
};
-/** Automatically select between mlgo and default heuristics to choose @ref CLGEMMKernelType
+/** Select gemm type based on mlgo heuristics
+ * @param query Query
+ * @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run
+ * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise
+ */
+GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run);
+
+/** Select gemm type based on default heuristics
* @param query Query
* @param reshape_b_only_on_first_run Additional query parameter if reshape b only on first run
- * @return CLGEMMKernelType
+ * @return GEMMTypeResult. Result is valid if bool(GEMMTypeResult) == true and invalid otherwise
*/
-CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run);
+GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run);
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query);
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query);
/** Select gemm config based on mlgo heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_mlgo_gemm_config_native(const CommonQuery &query);
/** Select gemm config based on default heuristics
* @param query Query
- * @return GEMMConfigResult. Result is valid if bool(GEMMCOnfigResult) == true and invalid otherwise
+ * @return GEMMConfigResult. Result is valid if bool(GEMMConfigResult) == true and invalid otherwise
*/
GEMMConfigResult select_default_gemm_config_native(const CommonQuery &query);