diff options
Diffstat (limited to 'src/gpu/cl/operators/ClGemm.h')
-rw-r--r-- | src/gpu/cl/operators/ClGemm.h | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/src/gpu/cl/operators/ClGemm.h b/src/gpu/cl/operators/ClGemm.h index 3c0cad3ca4..aac463f0b8 100644 --- a/src/gpu/cl/operators/ClGemm.h +++ b/src/gpu/cl/operators/ClGemm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -34,6 +34,7 @@ #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h" #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h" #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.h" +#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h" #include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h" #include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" @@ -50,6 +51,7 @@ namespace opencl * -# @ref kernels::ClGemmMatrixMultiplyNativeKernel (only if NATIVE is selected by the select_gemm_kernel method()) * -# @ref kernels::ClGemmMatrixMultiplyReshapedKernel (only if RESHAPED is selected by the select_gemm_kernel method()) * -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method()) + * -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel (only if RESHAPED_ONLY_RHS_MMUL is selected by the select_gemm_kernel method()) */ class ClGemm : public IClOperator { @@ -102,10 +104,12 @@ private: void configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); void configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); void configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); private: enum AuxTensorIdx @@ -116,17 +120,18 @@ private: }; private: - std::unique_ptr<kernels::ClGemmReshapeLhsMatrixKernel> _reshape_lhs_kernel; - std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel> _reshape_rhs_kernel; - std::unique_ptr<kernels::ClGemmMatrixMultiplyNativeKernel> _mm_native_kernel; - std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedKernel> _mm_reshaped_kernel; - std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel> _mm_reshaped_only_rhs_kernel; - TensorInfo _tmp_a; - TensorInfo _tmp_b; - bool _reshape_b_only_on_first_run; - CLGEMMKernelType _gemm_kernel_type; - bool _is_prepared; - experimental::MemoryRequirements _aux_mem{}; + std::unique_ptr<kernels::ClGemmReshapeLhsMatrixKernel> _reshape_lhs_kernel; + std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel> _reshape_rhs_kernel; + std::unique_ptr<kernels::ClGemmMatrixMultiplyNativeKernel> _mm_native_kernel; + std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedKernel> _mm_reshaped_kernel; + std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel> _mm_reshaped_only_rhs_kernel; + std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel> _mm_reshaped_only_rhs_mmul_kernel; + TensorInfo _tmp_a; + TensorInfo _tmp_b; + bool _reshape_b_only_on_first_run; + CLGEMMKernelType _gemm_kernel_type; + bool _is_prepared; + experimental::MemoryRequirements _aux_mem{}; }; } // namespace opencl } // namespace arm_compute |