diff options
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h | 54 |
1 files changed, 31 insertions, 23 deletions
diff --git a/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h index f404ccdf4c..82f307a773 100644 --- a/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h +++ b/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h @@ -27,6 +27,7 @@ #include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h" +#include "arm_compute/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMTranspose1xWKernel.h" #include "arm_compute/runtime/CL/CLMemoryGroup.h" @@ -45,7 +46,8 @@ class ICLTensor; * -# @ref CLGEMMLowpMatrixMultiplyKernel * -# @ref CLGEMMLowpMatrixAReductionKernel (if the offset of matrix B is not 0) * -# @ref CLGEMMLowpMatrixBReductionKernel (if the offset of matrix A is not 0) - * -# @ref CLGEMMLowpOffsetContributionKernel + * -# @ref CLGEMMLowpOffsetContributionKernel (if gemm_info.gemmlowp_output_stage == NONE) + * -# @ref CLGEMMLowpOffsetContributionOutputStageKernel (if gemm_info.gemmlowp_output_stage != NONE) * */ class CLGEMMLowpMatrixMultiplyCore : public IFunction @@ -63,54 +65,60 @@ public: CLGEMMLowpMatrixMultiplyCore &operator=(CLGEMMLowpMatrixMultiplyCore &&) = default; /** Initialise the kernel's inputs, output * - * @note GEMM_LOWP: low precision GEMM kernel + * @note GEMMLowp: low precision GEMM kernel. [A * B + C] * This kernel performs the following computations: * * -# Convert a values from QASYMM8 to int32 and add a_offset to each of them. * -# Convert b values from QASYMM8 to int32 add b_offset to each of them. * -# Compute the matrix product of the resulting a * b in int32. + * -# Quantize to uint8 if gemm_info.gemmlowp_output_stage != NONE * * @param[in] a First input tensor (Matrix A). Data type supported: QASYMM8. * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a - * @param[out] output Output tensor. Data type supported: Data type supported: S32 + * @param[in] c Third input tensor (Matrix C). It can be a nullptr. Data type supported: S32 + * @param[out] output Output tensor. Data type supported: S32 or QASYMM8 if gemm_info.gemmlowp_output_stage != NONE * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and * if the reshape of matrix B should be executed only for the first run */ - void configure(const ICLTensor *a, const ICLTensor *b, ICLTensor *output, const GEMMInfo &gemm_info = GEMMInfo()); + void configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, const GEMMInfo &gemm_info = GEMMInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixMultiplyCore * * @param[in] a First input tensor (Matrix A). Data type supported: QASYMM8. * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a - * @param[in] output Output tensor. Data type supported: Data type supported: S32 + * @param[in] c Third input tensor (Matrix C). It can be a nullptr. Data type supported: S32 + * @param[in] output Output tensor. Data type supported: S32 or QASYMM8 if gemm_info.gemmlowp_output_stage != NONE * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and * if the reshape of matrix B should be executed only for the first run * * @return a status */ - static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *output, const GEMMInfo &gemm_info = GEMMInfo()); + static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info = GEMMInfo()); // Inherited methods overridden: void run() override; void prepare() override; private: - CLMemoryGroup _memory_group; - CLGEMMLowpMatrixMultiplyKernel _mm_kernel; - CLGEMMInterleave4x4Kernel _mtx_a_reshape_kernel; - CLGEMMTranspose1xWKernel _mtx_b_reshape_kernel; - CLGEMMLowpMatrixAReductionKernel _mtx_a_reduction_kernel; - CLGEMMLowpMatrixBReductionKernel _mtx_b_reduction_kernel; - CLGEMMLowpOffsetContributionKernel _offset_contribution_kernel; - CLTensor _vector_sum_col; - CLTensor _vector_sum_row; - CLTensor _tmp_a; - CLTensor _tmp_b; - const ICLTensor *_original_b; - int32_t _a_offset; - int32_t _b_offset; - bool _is_interleaved_transposed; - bool _reshape_b_only_on_first_run; - bool _is_prepared; + CLMemoryGroup _memory_group; + CLGEMMLowpMatrixMultiplyKernel _mm_kernel; + CLGEMMInterleave4x4Kernel _mtx_a_reshape_kernel; + CLGEMMTranspose1xWKernel _mtx_b_reshape_kernel; + CLGEMMLowpMatrixAReductionKernel _mtx_a_reduction_kernel; + CLGEMMLowpMatrixBReductionKernel _mtx_b_reduction_kernel; + CLGEMMLowpOffsetContributionKernel _offset_contribution_kernel; + CLGEMMLowpOffsetContributionOutputStageKernel _offset_contribution_output_stage_kernel; + CLTensor _vector_sum_col; + CLTensor _vector_sum_row; + CLTensor _tmp_a; + CLTensor _tmp_b; + CLTensor _mm_result_s32; + const ICLTensor *_original_b; + int32_t _a_offset; + int32_t _b_offset; + bool _is_interleaved_transposed; + bool _reshape_b_only_on_first_run; + bool _is_prepared; + bool _fuse_output_stage; }; } #endif /*__ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYCORE_H__ */ |