From 926afe1c8ad6ba6a7bada62a4027fcb79d727104 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 19 Mar 2019 11:44:13 +0000 Subject: COMPMID-2097: Implement a heuristic to dispatch CLGEMMReshapedOnlyRHS kernel from CLGEMM Change-Id: I4170a80647b02501aa669e2c0347ddc39888ee76 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/928 Reviewed-by: Giuseppe Rossini Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- arm_compute/runtime/CL/functions/CLGEMM.h | 59 +++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 18 deletions(-) (limited to 'arm_compute/runtime/CL/functions') diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h index 0bad446551..8c462fa4cb 100644 --- a/arm_compute/runtime/CL/functions/CLGEMM.h +++ b/arm_compute/runtime/CL/functions/CLGEMM.h @@ -27,6 +27,7 @@ #include "arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h" +#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h" #include "arm_compute/runtime/CL/CLMemoryGroup.h" @@ -40,10 +41,11 @@ class ICLTensor; /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels: * - * -# @ref CLGEMMReshapeLHSMatrixKernel (only if the reshaped GEMM is selected by the heuristic model) - * -# @ref CLGEMMReshapeRHSMatrixKernel (only if the reshaped GEMM is selected by the heuristic model) - * -# @ref CLGEMMMatrixMultiplyKernel (if GPU target is NOT G76 or if the reshaped GEMM is NOT selected) - * -# @ref CLGEMMMatrixMultiplyReshapedKernel (only if the reshaped GEMM is selected by the heuristic model and the GPU target IS Mali-G76) + * -# @ref CLGEMMReshapeLHSMatrixKernel (only if the RESHAPED_V1 is selected by the heuristic model) + * -# @ref CLGEMMReshapeRHSMatrixKernel (only if either the RESHAPED_V1 or RESHAPED_ONLY_RHS is selected by the select_gemm_type method()) + * -# @ref CLGEMMMatrixMultiplyKernel (only if either the NATIVE or RESHAPED_V1 is selected by the select_gemm_type method()) + * -# @ref CLGEMMMatrixMultiplyReshapedKernel (only if RESHAPED_V1 is selected by the select_gemm_type method()) + * -# @ref CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_type method()) * -# @ref CLGEMMMatrixAdditionKernel (if c != nullptr and beta != 0.0) * */ @@ -102,20 +104,41 @@ public: void prepare() override; private: - CLMemoryGroup _memory_group; - CLGEMMMatrixMultiplyKernel _mm_kernel; - CLGEMMMatrixAdditionKernel _ma_kernel; - CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel; - CLGEMMReshapeRHSMatrixKernel _reshape_rhs_kernel; - CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel; - CLTensor _tmp_a; - CLTensor _tmp_b; - const ICLTensor *_original_b; - bool _is_interleaved_transposed; - bool _run_addition; - bool _reshape_b_only_on_first_run; - bool _is_prepared; - bool _is_new_gemm_reshaped; // Remove when COMPMID-1892 is completed + enum class GEMMType + { + NATIVE, + RESHAPED_V1, + RESHAPED_V2, + RESHAPED_ONLY_RHS + }; + + // TODO (COMPMID-2095) + static GEMMType select_gemm_type(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run, GPUTarget gpu_target); + + void configure_native(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); + void configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); + void configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); + void configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); + + static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + static Status validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + static Status validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + + CLMemoryGroup _memory_group; + CLGEMMMatrixMultiplyKernel _mm_kernel; + CLGEMMMatrixAdditionKernel _ma_kernel; + CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel; + CLGEMMReshapeRHSMatrixKernel _reshape_rhs_kernel; + CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel; + CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel; + CLTensor _tmp_a; + CLTensor _tmp_b; + const ICLTensor *_original_b; + bool _run_addition; + bool _reshape_b_only_on_first_run; + bool _is_prepared; + GEMMType _gemm_type; }; } // namespace arm_compute -- cgit v1.2.1