diff options
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h | 90 |
1 files changed, 43 insertions, 47 deletions
diff --git a/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h index 4cc8899690..1b8e5dcc1d 100644 --- a/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h +++ b/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,24 +24,19 @@ #ifndef ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYCORE_H #define ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYCORE_H +#include "arm_compute/function_info/GEMMInfo.h" #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/MemoryGroup.h" +#include <memory> + namespace arm_compute { class CLCompileContext; class IMemoryManager; class ICLTensor; class ITensorInfo; -class CLDepthConvertLayerKernel; -class CLGEMMLowpMatrixMultiplyNativeKernel; -class CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel; -class CLGEMMLowpOffsetContributionKernel; -class CLGEMMLowpOffsetContributionOutputStageKernel; -class CLGEMMLowpMatrixAReductionKernel; -class CLGEMMLowpMatrixBReductionKernel; -class CLGEMMReshapeRHSMatrixKernel; /** Basic function to execute GEMMLowpMatrixMultiplyCore on OpenCL. */ class CLGEMMLowpMatrixMultiplyCore : public IFunction @@ -61,6 +56,26 @@ public: ~CLGEMMLowpMatrixMultiplyCore(); /** Initialise the kernel's inputs, output * + * Valid data layouts: + * - NHWC + * - NCHW + * + * Valid data type configurations: + * |src0 |src1 |src2 |dst | + * |:--------------|:------------------|:--------|:--------------| + * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | + * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 | + * |QASYMM8 |QSYMM8 |S32 |QASYMM8 | + * |QASYMM8 |QASYMM8 |S32 |S32 | + * |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |S32 | + * |QASYMM8 |QSYMM8 |S32 |S32 | + * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | + * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED | + * |QASYMM8_SIGNED |QSYMM8 |S32 |QASYMM8_SIGNED | + * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |S32 | + * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |S32 | + * |QASYMM8_SIGNED |QSYMM8 |S32 |S32 | + * * @note GEMMLowp: low precision GEMM kernel. [A * B + C] * This kernel performs the following computations: * @@ -76,7 +91,11 @@ public: * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and * if the reshape of matrix B should be executed only for the first run */ - void configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, const GEMMInfo &gemm_info = GEMMInfo()); + void configure(const ICLTensor *a, + const ICLTensor *b, + const ICLTensor *c, + ICLTensor *output, + const GEMMInfo &gemm_info = GEMMInfo()); /** Initialise the kernel's inputs, output * * @note GEMMLowp: low precision GEMM kernel. [A * B + C] @@ -95,7 +114,12 @@ public: * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and * if the reshape of matrix B should be executed only for the first run */ - void configure(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, const GEMMInfo &gemm_info = GEMMInfo()); + void configure(const CLCompileContext &compile_context, + const ICLTensor *a, + const ICLTensor *b, + const ICLTensor *c, + ICLTensor *output, + const GEMMInfo &gemm_info = GEMMInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixMultiplyCore * * @param[in] a First input tensor info (Matrix A). Data type supported: QASYMM8. @@ -107,47 +131,19 @@ public: * * @return a status */ - static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info = GEMMInfo()); + static Status validate(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + const ITensorInfo *output, + const GEMMInfo &gemm_info = GEMMInfo()); // Inherited methods overridden: void run() override; void prepare() override; private: - MemoryGroup _memory_group; - - // Kernels used - std::unique_ptr<CLDepthConvertLayerKernel> _weights_to_qasymm8; - std::unique_ptr<CLGEMMLowpMatrixMultiplyNativeKernel> _mm_native_kernel; - std::unique_ptr<CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel> _mm_reshaped_only_rhs_kernel; - std::unique_ptr<CLGEMMReshapeRHSMatrixKernel> _mtx_b_reshape_kernel; - std::unique_ptr<CLGEMMLowpMatrixAReductionKernel> _mtx_a_reduction_kernel; - std::unique_ptr<CLGEMMLowpMatrixBReductionKernel> _mtx_b_reduction_kernel; - std::unique_ptr<CLGEMMLowpOffsetContributionKernel> _offset_contribution_kernel; - std::unique_ptr<CLGEMMLowpOffsetContributionOutputStageKernel> _offset_contribution_output_stage_kernel; - - // Temporary tensors - CLTensor _qasymm8_weights; - CLTensor _vector_sum_col; - CLTensor _vector_sum_row; - CLTensor _tmp_b; - CLTensor _mm_result_s32; - CLTensor _gemm_output_stage_multipliers; - CLTensor _gemm_output_stage_shifts; - - // Tensor pointers - const ICLTensor *_matrix_a; - const ICLTensor *_original_b; - const ICLTensor *_output; - - int32_t _a_offset; - int32_t _b_offset; - bool _is_gemm_reshaped; - bool _reshape_b_only_on_first_run; - bool _is_prepared; - bool _run_output_stage; - bool _convert_to_qasymm8; - bool _run_offset_contribution; + struct Impl; + std::unique_ptr<Impl> _impl; }; } // namespace arm_compute -#endif /*ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYCORE_H */
\ No newline at end of file +#endif /*ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYCORE_H */ |