aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-05-21 13:32:43 +0100
committerGiuseppe Rossini <giuseppe.rossini@arm.com>2019-06-04 15:58:08 +0000
commitb0f342ec315397e4b87d3a9cc3d12f3645c153bc (patch)
tree3bfd95d4196f6c45feb368b0a020f3bb304e79cd /arm_compute/core
parentbbac660f1959ed2ab58b31a8d5db524883da1754 (diff)
downloadComputeLibrary-b0f342ec315397e4b87d3a9cc3d12f3645c153bc.tar.gz
COMPMID-2171: Fuse bias addition with CLGEMMMatrixMultiplyReshapedOnlyRHSKernel
Change-Id: I1d1e1f28fe7022309d72900893e8368820ca0f89 Signed-off-by: giuros01 <giuseppe.rossini@arm.com> Reviewed-on: https://review.mlplatform.org/c/1259 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core')
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h19
-rw-r--r--arm_compute/core/Types.h36
2 files changed, 42 insertions, 13 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h
index 26a1378d27..e3b3880a37 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h
@@ -51,8 +51,10 @@ public:
*
* @param[in] input0 Input tensor containing the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4.
* @param[in] input1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
+ * @param[in] input2 Input tensor containing the bias matrix. Data type supported: same as @p input0.
* @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
* @param[in] alpha Weight of the matrix product
+ * @param[in] beta Weight of the matrix bias
* @param[in] lhs_info LHS matrix information used to retrieve the number of rows to be processed by each thread. Only the following values are supported:
* lhs_info.m0: 1,2,3,4,5,6,7,8
* @param[in] rhs_info RHS matrix information used for reshaping the input1 tensor. Only the following values are supported:
@@ -61,14 +63,17 @@ public:
* rhs_info.transpose: true,false
* @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices
*/
- void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
- const GEMMReshapeInfo &gemm_info);
+ void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+ const GEMMRHSMatrixInfo &rhs_info,
+ const GEMMReshapeInfo &gemm_info);
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyReshapedOnlyRHSKernel
*
* @param[in] input0 Input tensor info for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4.
* @param[in] input1 Input tensor info for the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3.
+ * @param[in] input2 Input tensor info containing the bias matrix. Data type supported: same as @p input0.
* @param[in] output Output tensor info. Data type supported: same as @p input0
* @param[in] alpha Weight of the matrix product
+ * @param[in] beta Weight of the matrix bias
* @param[in] lhs_info LHS matrix information used to retrieve the number of rows to be processed by each thread. Only the following values are supported:
* lhs_info.m0: 1,2,3,4,5,6,7,8
* @param[in] rhs_info RHS matrix information used for reshaping the input1 tensor. Only the following values are supported:
@@ -79,8 +84,9 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
- const GEMMReshapeInfo &gemm_info);
+ static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+ const GEMMRHSMatrixInfo &rhs_info,
+ const GEMMReshapeInfo &gemm_info);
// Inherited methods overridden:
void run(const Window &window, cl::CommandQueue &queue) override;
@@ -88,11 +94,14 @@ public:
private:
const ICLTensor *_input0;
const ICLTensor *_input1;
+ const ICLTensor *_input2;
ICLTensor *_output;
bool _slide_matrix_b;
bool _reinterpret_input_as_3d;
bool _reinterpret_output_as_3d;
bool _use_dummy_work_items;
+ bool _add_bias;
+ bool _broadcast_bias;
};
} // namespace arm_compute
-#endif /*__ARM_COMPUTE_CLGEMMMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__*/ \ No newline at end of file
+#endif /*__ARM_COMPUTE_CLGEMMMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__*/
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 1787e68130..d49315d591 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1602,7 +1602,7 @@ class GEMMReshapeInfo final
public:
/** Default constructor */
GEMMReshapeInfo()
- : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false)
+ : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _broadcast_bias(false)
{
}
/** Constructor
@@ -1615,11 +1615,12 @@ public:
* @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel.
* If 0 the output will not be reinterpreted as 3D. Default 0
* @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used
- * to perform 1x1 convolutions with the NHWC data layout)
+ * to perform 1x1 convolutions with the NHWC data layout)
+ * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
*/
- GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false)
+ GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool broadcast_bias = false)
: _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d),
- _reinterpret_input_as_3d(reinterpret_input_as_3d)
+ _reinterpret_input_as_3d(reinterpret_input_as_3d), _broadcast_bias(broadcast_bias)
{
}
/** Number of matrix A rows
@@ -1681,6 +1682,14 @@ public:
{
return _reinterpret_input_as_3d;
};
+ /** Flag which specifies whether to broadcast the shape of the bias tensor.
+ *
+ * @return True if the shape of the bias tensor is to be broadcasted.
+ */
+ bool broadcast_bias() const
+ {
+ return _broadcast_bias;
+ };
private:
const int _m;
@@ -1690,6 +1699,7 @@ private:
const int _mult_interleave4x4_height;
const int _depth_output_gemm3d;
const bool _reinterpret_input_as_3d;
+ const bool _broadcast_bias;
};
struct DepthwiseConvolutionReshapeInfo
@@ -1749,7 +1759,7 @@ public:
/** Default constructor */
GEMMInfo()
: _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(true), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _retain_internal_weights(false), _gemmlowp_output_stage(),
- _fp_mixed_precision(false)
+ _fp_mixed_precision(false), _broadcast_bias(false)
{
}
/** Constructor
@@ -1764,12 +1774,13 @@ public:
* @param[in] retain_internal_weights (Optional) Retain the weights tensor from previous run
* @param[in] gemmlowp_output_stage (Optional) GEMMLowp Output stage info
* @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
- *
+ * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
*/
GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
- GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false)
+ GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false)
: _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d),
- _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision)
+ _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision),
+ _broadcast_bias(broadcast_bias)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -1838,6 +1849,14 @@ public:
{
return _fp_mixed_precision;
};
+ /** Flag which specifies whether to broadcast the shape of the bias tensor.
+ *
+ * @return True if the shape of the bias tensor is to be broadcasted.
+ */
+ bool broadcast_bias() const
+ {
+ return _broadcast_bias;
+ };
private:
const bool _is_a_reshaped;
@@ -1848,6 +1867,7 @@ private:
const bool _retain_internal_weights;
const GEMMLowpOutputStageInfo _gemmlowp_output_stage;
const bool _fp_mixed_precision;
+ const bool _broadcast_bias;
};
/** Winograd information */