aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/CL
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/CL')
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h13
1 files changed, 11 insertions, 2 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h
index c611dc4c1f..79689a2894 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h
@@ -48,8 +48,10 @@ public:
*
* @param[in] input0 Input tensor for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4.
* @param[in] input1 Input tensor for the RHS matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
+ * @param[in] input2 Input tensor containing the bias matrix. Data type supported: same as @p input0.
* @param[out] output Output tensor info. Data type supported: same as @p input0
* @param[in] alpha Weight of the matrix product
+ * @param[in] beta Weight of the matrix bias
* @param[in] lhs_info LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported:
* lhs_info.m0: 1,2,3,4,5,6,7,8
* lhs_info.k0: 2,3,4,8,16
@@ -58,14 +60,17 @@ public:
* rhs_info.k0: same of lhs_info.k0
* @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices
*/
- void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
+ void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+ const GEMMRHSMatrixInfo &rhs_info,
const GEMMReshapeInfo &gemm_info);
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyNativeKernel
*
* @param[in] input0 Input tensor info for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4.
* @param[in] input1 Input tensor info for the RHS matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
+ * @param[in] input2 Input tensor info containing the bias matrix. Data type supported: same as @p input0.
* @param[in] output Output tensor info. Data type supported: same as @p input0
* @param[in] alpha Weight of the matrix product
+ * @param[in] beta Weight of the matrix bias
* @param[in] lhs_info LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported:
* lhs_info.m0: 1,2,3,4,5,6,7,8
* lhs_info.k0: 2,3,4,8,16
@@ -76,7 +81,8 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
+ static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+ const GEMMRHSMatrixInfo &rhs_info,
const GEMMReshapeInfo &gemm_info);
// Inherited methods overridden:
@@ -85,11 +91,14 @@ public:
private:
const ICLTensor *_input0;
const ICLTensor *_input1;
+ const ICLTensor *_input2;
ICLTensor *_output;
bool _slide_matrix_b;
bool _reinterpret_input_as_3d;
bool _reinterpret_output_as_3d;
bool _use_dummy_work_items;
+ bool _add_bias;
+ bool _broadcast_bias;
};
} // namespace arm_compute
#endif /*__ARM_COMPUTE_CLGEMMMATRIXMULTIPLYNATIVEKERNEL_H__*/