diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2019-06-24 14:40:30 +0100 |
---|---|---|
committer | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2019-06-24 15:56:10 +0000 |
commit | 944170e1591ff23c9e6ede2201f0f6aba0f3439b (patch) | |
tree | 64d6b718c01458be04ca1b39c39704b78ce3b5d6 /arm_compute/core/CL/kernels | |
parent | 65383e21a5b82071229c6322bf65c47e3719b490 (diff) | |
download | ComputeLibrary-944170e1591ff23c9e6ede2201f0f6aba0f3439b.tar.gz |
COMPMID-2172: Fuse bias addition with CLGEMMMatrixMultiplyNativeKernel
Change-Id: I714b92ec001fc71172719b67fb66d490538b6948
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1399
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/CL/kernels')
-rw-r--r-- | arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h index c611dc4c1f..79689a2894 100644 --- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h @@ -48,8 +48,10 @@ public: * * @param[in] input0 Input tensor for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4. * @param[in] input1 Input tensor for the RHS matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3. + * @param[in] input2 Input tensor containing the bias matrix. Data type supported: same as @p input0. * @param[out] output Output tensor info. Data type supported: same as @p input0 * @param[in] alpha Weight of the matrix product + * @param[in] beta Weight of the matrix bias * @param[in] lhs_info LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported: * lhs_info.m0: 1,2,3,4,5,6,7,8 * lhs_info.k0: 2,3,4,8,16 @@ -58,14 +60,17 @@ public: * rhs_info.k0: same of lhs_info.k0 * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices */ - void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyNativeKernel * * @param[in] input0 Input tensor info for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4. * @param[in] input1 Input tensor info for the RHS matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3. + * @param[in] input2 Input tensor info containing the bias matrix. Data type supported: same as @p input0. * @param[in] output Output tensor info. Data type supported: same as @p input0 * @param[in] alpha Weight of the matrix product + * @param[in] beta Weight of the matrix bias * @param[in] lhs_info LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported: * lhs_info.m0: 1,2,3,4,5,6,7,8 * lhs_info.k0: 2,3,4,8,16 @@ -76,7 +81,8 @@ public: * * @return a status */ - static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info); // Inherited methods overridden: @@ -85,11 +91,14 @@ public: private: const ICLTensor *_input0; const ICLTensor *_input1; + const ICLTensor *_input2; ICLTensor *_output; bool _slide_matrix_b; bool _reinterpret_input_as_3d; bool _reinterpret_output_as_3d; bool _use_dummy_work_items; + bool _add_bias; + bool _broadcast_bias; }; } // namespace arm_compute #endif /*__ARM_COMPUTE_CLGEMMMATRIXMULTIPLYNATIVEKERNEL_H__*/ |