aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/CL
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-06-24 14:40:30 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-06-24 15:56:10 +0000
commit944170e1591ff23c9e6ede2201f0f6aba0f3439b (patch)
tree64d6b718c01458be04ca1b39c39704b78ce3b5d6 /arm_compute/core/CL
parent65383e21a5b82071229c6322bf65c47e3719b490 (diff)
downloadComputeLibrary-944170e1591ff23c9e6ede2201f0f6aba0f3439b.tar.gz
COMPMID-2172: Fuse bias addition with CLGEMMMatrixMultiplyNativeKernel
Change-Id: I714b92ec001fc71172719b67fb66d490538b6948 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1399 Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/CL')
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h13
1 files changed, 11 insertions, 2 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h
index c611dc4c1f..79689a2894 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h
@@ -48,8 +48,10 @@ public:
*
* @param[in] input0 Input tensor for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4.
* @param[in] input1 Input tensor for the RHS matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
+ * @param[in] input2 Input tensor containing the bias matrix. Data type supported: same as @p input0.
* @param[out] output Output tensor info. Data type supported: same as @p input0
* @param[in] alpha Weight of the matrix product
+ * @param[in] beta Weight of the matrix bias
* @param[in] lhs_info LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported:
* lhs_info.m0: 1,2,3,4,5,6,7,8
* lhs_info.k0: 2,3,4,8,16
@@ -58,14 +60,17 @@ public:
* rhs_info.k0: same of lhs_info.k0
* @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices
*/
- void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
+ void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+ const GEMMRHSMatrixInfo &rhs_info,
const GEMMReshapeInfo &gemm_info);
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyNativeKernel
*
* @param[in] input0 Input tensor info for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4.
* @param[in] input1 Input tensor info for the RHS matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
+ * @param[in] input2 Input tensor info containing the bias matrix. Data type supported: same as @p input0.
* @param[in] output Output tensor info. Data type supported: same as @p input0
* @param[in] alpha Weight of the matrix product
+ * @param[in] beta Weight of the matrix bias
* @param[in] lhs_info LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported:
* lhs_info.m0: 1,2,3,4,5,6,7,8
* lhs_info.k0: 2,3,4,8,16
@@ -76,7 +81,8 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
+ static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+ const GEMMRHSMatrixInfo &rhs_info,
const GEMMReshapeInfo &gemm_info);
// Inherited methods overridden:
@@ -85,11 +91,14 @@ public:
private:
const ICLTensor *_input0;
const ICLTensor *_input1;
+ const ICLTensor *_input2;
ICLTensor *_output;
bool _slide_matrix_b;
bool _reinterpret_input_as_3d;
bool _reinterpret_output_as_3d;
bool _use_dummy_work_items;
+ bool _add_bias;
+ bool _broadcast_bias;
};
} // namespace arm_compute
#endif /*__ARM_COMPUTE_CLGEMMMATRIXMULTIPLYNATIVEKERNEL_H__*/