aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/operators/ClGemm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/operators/ClGemm.h')
-rw-r--r--src/gpu/cl/operators/ClGemm.h29
1 files changed, 17 insertions, 12 deletions
diff --git a/src/gpu/cl/operators/ClGemm.h b/src/gpu/cl/operators/ClGemm.h
index 3c0cad3ca4..aac463f0b8 100644
--- a/src/gpu/cl/operators/ClGemm.h
+++ b/src/gpu/cl/operators/ClGemm.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,6 +34,7 @@
#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h"
#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h"
#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.h"
+#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h"
#include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h"
#include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h"
@@ -50,6 +51,7 @@ namespace opencl
* -# @ref kernels::ClGemmMatrixMultiplyNativeKernel (only if NATIVE is selected by the select_gemm_kernel method())
* -# @ref kernels::ClGemmMatrixMultiplyReshapedKernel (only if RESHAPED is selected by the select_gemm_kernel method())
* -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method())
+ * -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel (only if RESHAPED_ONLY_RHS_MMUL is selected by the select_gemm_kernel method())
*/
class ClGemm : public IClOperator
{
@@ -102,10 +104,12 @@ private:
void configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
void configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
void configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+ static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
private:
enum AuxTensorIdx
@@ -116,17 +120,18 @@ private:
};
private:
- std::unique_ptr<kernels::ClGemmReshapeLhsMatrixKernel> _reshape_lhs_kernel;
- std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel> _reshape_rhs_kernel;
- std::unique_ptr<kernels::ClGemmMatrixMultiplyNativeKernel> _mm_native_kernel;
- std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedKernel> _mm_reshaped_kernel;
- std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel> _mm_reshaped_only_rhs_kernel;
- TensorInfo _tmp_a;
- TensorInfo _tmp_b;
- bool _reshape_b_only_on_first_run;
- CLGEMMKernelType _gemm_kernel_type;
- bool _is_prepared;
- experimental::MemoryRequirements _aux_mem{};
+ std::unique_ptr<kernels::ClGemmReshapeLhsMatrixKernel> _reshape_lhs_kernel;
+ std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel> _reshape_rhs_kernel;
+ std::unique_ptr<kernels::ClGemmMatrixMultiplyNativeKernel> _mm_native_kernel;
+ std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedKernel> _mm_reshaped_kernel;
+ std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel> _mm_reshaped_only_rhs_kernel;
+ std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel> _mm_reshaped_only_rhs_mmul_kernel;
+ TensorInfo _tmp_a;
+ TensorInfo _tmp_b;
+ bool _reshape_b_only_on_first_run;
+ CLGEMMKernelType _gemm_kernel_type;
+ bool _is_prepared;
+ experimental::MemoryRequirements _aux_mem{};
};
} // namespace opencl
} // namespace arm_compute