aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLGEMM.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLGEMM.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMM.h59
1 files changed, 41 insertions, 18 deletions
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index 0bad446551..8c462fa4cb 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -27,6 +27,7 @@
#include "arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
+#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
#include "arm_compute/runtime/CL/CLMemoryGroup.h"
@@ -40,10 +41,11 @@ class ICLTensor;
/** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels:
*
- * -# @ref CLGEMMReshapeLHSMatrixKernel (only if the reshaped GEMM is selected by the heuristic model)
- * -# @ref CLGEMMReshapeRHSMatrixKernel (only if the reshaped GEMM is selected by the heuristic model)
- * -# @ref CLGEMMMatrixMultiplyKernel (if GPU target is NOT G76 or if the reshaped GEMM is NOT selected)
- * -# @ref CLGEMMMatrixMultiplyReshapedKernel (only if the reshaped GEMM is selected by the heuristic model and the GPU target IS Mali-G76)
+ * -# @ref CLGEMMReshapeLHSMatrixKernel (only if the RESHAPED_V1 is selected by the heuristic model)
+ * -# @ref CLGEMMReshapeRHSMatrixKernel (only if either the RESHAPED_V1 or RESHAPED_ONLY_RHS is selected by the select_gemm_type method())
+ * -# @ref CLGEMMMatrixMultiplyKernel (only if either the NATIVE or RESHAPED_V1 is selected by the select_gemm_type method())
+ * -# @ref CLGEMMMatrixMultiplyReshapedKernel (only if RESHAPED_V1 is selected by the select_gemm_type method())
+ * -# @ref CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_type method())
* -# @ref CLGEMMMatrixAdditionKernel (if c != nullptr and beta != 0.0)
*
*/
@@ -102,20 +104,41 @@ public:
void prepare() override;
private:
- CLMemoryGroup _memory_group;
- CLGEMMMatrixMultiplyKernel _mm_kernel;
- CLGEMMMatrixAdditionKernel _ma_kernel;
- CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel;
- CLGEMMReshapeRHSMatrixKernel _reshape_rhs_kernel;
- CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel;
- CLTensor _tmp_a;
- CLTensor _tmp_b;
- const ICLTensor *_original_b;
- bool _is_interleaved_transposed;
- bool _run_addition;
- bool _reshape_b_only_on_first_run;
- bool _is_prepared;
- bool _is_new_gemm_reshaped; // Remove when COMPMID-1892 is completed
+ enum class GEMMType
+ {
+ NATIVE,
+ RESHAPED_V1,
+ RESHAPED_V2,
+ RESHAPED_ONLY_RHS
+ };
+
+ // TODO (COMPMID-2095)
+ static GEMMType select_gemm_type(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run, GPUTarget gpu_target);
+
+ void configure_native(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ void configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ void configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ void configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
+
+ static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ static Status validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ static Status validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+
+ CLMemoryGroup _memory_group;
+ CLGEMMMatrixMultiplyKernel _mm_kernel;
+ CLGEMMMatrixAdditionKernel _ma_kernel;
+ CLGEMMReshapeLHSMatrixKernel _reshape_lhs_kernel;
+ CLGEMMReshapeRHSMatrixKernel _reshape_rhs_kernel;
+ CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel;
+ CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel;
+ CLTensor _tmp_a;
+ CLTensor _tmp_b;
+ const ICLTensor *_original_b;
+ bool _run_addition;
+ bool _reshape_b_only_on_first_run;
+ bool _is_prepared;
+ GEMMType _gemm_type;
};
} // namespace arm_compute