From 944170e1591ff23c9e6ede2201f0f6aba0f3439b Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 24 Jun 2019 14:40:30 +0100 Subject: COMPMID-2172: Fuse bias addition with CLGEMMMatrixMultiplyNativeKernel Change-Id: I714b92ec001fc71172719b67fb66d490538b6948 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1399 Reviewed-by: Giuseppe Rossini Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- .../core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'arm_compute/core/CL') diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h index c611dc4c1f..79689a2894 100644 --- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h @@ -48,8 +48,10 @@ public: * * @param[in] input0 Input tensor for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4. * @param[in] input1 Input tensor for the RHS matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3. + * @param[in] input2 Input tensor containing the bias matrix. Data type supported: same as @p input0. * @param[out] output Output tensor info. Data type supported: same as @p input0 * @param[in] alpha Weight of the matrix product + * @param[in] beta Weight of the matrix bias * @param[in] lhs_info LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported: * lhs_info.m0: 1,2,3,4,5,6,7,8 * lhs_info.k0: 2,3,4,8,16 @@ -58,14 +60,17 @@ public: * rhs_info.k0: same of lhs_info.k0 * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices */ - void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyNativeKernel * * @param[in] input0 Input tensor info for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4. * @param[in] input1 Input tensor info for the RHS matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3. + * @param[in] input2 Input tensor info containing the bias matrix. Data type supported: same as @p input0. * @param[in] output Output tensor info. Data type supported: same as @p input0 * @param[in] alpha Weight of the matrix product + * @param[in] beta Weight of the matrix bias * @param[in] lhs_info LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported: * lhs_info.m0: 1,2,3,4,5,6,7,8 * lhs_info.k0: 2,3,4,8,16 @@ -76,7 +81,8 @@ public: * * @return a status */ - static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info); // Inherited methods overridden: @@ -85,11 +91,14 @@ public: private: const ICLTensor *_input0; const ICLTensor *_input1; + const ICLTensor *_input2; ICLTensor *_output; bool _slide_matrix_b; bool _reinterpret_input_as_3d; bool _reinterpret_output_as_3d; bool _use_dummy_work_items; + bool _add_bias; + bool _broadcast_bias; }; } // namespace arm_compute #endif /*__ARM_COMPUTE_CLGEMMMATRIXMULTIPLYNATIVEKERNEL_H__*/ -- cgit v1.2.1