From b0f342ec315397e4b87d3a9cc3d12f3645c153bc Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 21 May 2019 13:32:43 +0100 Subject: COMPMID-2171: Fuse bias addition with CLGEMMMatrixMultiplyReshapedOnlyRHSKernel Change-Id: I1d1e1f28fe7022309d72900893e8368820ca0f89 Signed-off-by: giuros01 Reviewed-on: https://review.mlplatform.org/c/1259 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- .../CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h | 19 +- arm_compute/core/Types.h | 36 +++- src/core/CL/cl_kernels/gemm.cl | 219 ++++++++++++++------- src/core/CL/cl_kernels/gemm_helpers.h | 174 ++++++++++++++-- .../CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp | 96 +++++++-- src/runtime/CL/functions/CLGEMM.cpp | 47 ++--- .../CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp | 74 ++++--- tests/validation/fixtures/GEMMFixture.h | 88 +++++++-- 8 files changed, 571 insertions(+), 182 deletions(-) diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h index 26a1378d27..e3b3880a37 100644 --- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h @@ -51,8 +51,10 @@ public: * * @param[in] input0 Input tensor containing the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4. * @param[in] input1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3. + * @param[in] input2 Input tensor containing the bias matrix. Data type supported: same as @p input0. * @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0 * @param[in] alpha Weight of the matrix product + * @param[in] beta Weight of the matrix bias * @param[in] lhs_info LHS matrix information used to retrieve the number of rows to be processed by each thread. Only the following values are supported: * lhs_info.m0: 1,2,3,4,5,6,7,8 * @param[in] rhs_info RHS matrix information used for reshaping the input1 tensor. Only the following values are supported: @@ -61,14 +63,17 @@ public: * rhs_info.transpose: true,false * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices */ - void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, - const GEMMReshapeInfo &gemm_info); + void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMReshapeInfo &gemm_info); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyReshapedOnlyRHSKernel * * @param[in] input0 Input tensor info for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4. * @param[in] input1 Input tensor info for the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3. + * @param[in] input2 Input tensor info containing the bias matrix. Data type supported: same as @p input0. * @param[in] output Output tensor info. Data type supported: same as @p input0 * @param[in] alpha Weight of the matrix product + * @param[in] beta Weight of the matrix bias * @param[in] lhs_info LHS matrix information used to retrieve the number of rows to be processed by each thread. Only the following values are supported: * lhs_info.m0: 1,2,3,4,5,6,7,8 * @param[in] rhs_info RHS matrix information used for reshaping the input1 tensor. Only the following values are supported: @@ -79,8 +84,9 @@ public: * * @return a status */ - static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, - const GEMMReshapeInfo &gemm_info); + static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMReshapeInfo &gemm_info); // Inherited methods overridden: void run(const Window &window, cl::CommandQueue &queue) override; @@ -88,11 +94,14 @@ public: private: const ICLTensor *_input0; const ICLTensor *_input1; + const ICLTensor *_input2; ICLTensor *_output; bool _slide_matrix_b; bool _reinterpret_input_as_3d; bool _reinterpret_output_as_3d; bool _use_dummy_work_items; + bool _add_bias; + bool _broadcast_bias; }; } // namespace arm_compute -#endif /*__ARM_COMPUTE_CLGEMMMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__*/ \ No newline at end of file +#endif /*__ARM_COMPUTE_CLGEMMMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__*/ diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 1787e68130..d49315d591 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1602,7 +1602,7 @@ class GEMMReshapeInfo final public: /** Default constructor */ GEMMReshapeInfo() - : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false) + : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _broadcast_bias(false) { } /** Constructor @@ -1615,11 +1615,12 @@ public: * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel. * If 0 the output will not be reinterpreted as 3D. Default 0 * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used - * to perform 1x1 convolutions with the NHWC data layout) + * to perform 1x1 convolutions with the NHWC data layout) + * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix. */ - GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false) + GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool broadcast_bias = false) : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d), - _reinterpret_input_as_3d(reinterpret_input_as_3d) + _reinterpret_input_as_3d(reinterpret_input_as_3d), _broadcast_bias(broadcast_bias) { } /** Number of matrix A rows @@ -1681,6 +1682,14 @@ public: { return _reinterpret_input_as_3d; }; + /** Flag which specifies whether to broadcast the shape of the bias tensor. + * + * @return True if the shape of the bias tensor is to be broadcasted. + */ + bool broadcast_bias() const + { + return _broadcast_bias; + }; private: const int _m; @@ -1690,6 +1699,7 @@ private: const int _mult_interleave4x4_height; const int _depth_output_gemm3d; const bool _reinterpret_input_as_3d; + const bool _broadcast_bias; }; struct DepthwiseConvolutionReshapeInfo @@ -1749,7 +1759,7 @@ public: /** Default constructor */ GEMMInfo() : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(true), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _retain_internal_weights(false), _gemmlowp_output_stage(), - _fp_mixed_precision(false) + _fp_mixed_precision(false), _broadcast_bias(false) { } /** Constructor @@ -1764,12 +1774,13 @@ public: * @param[in] retain_internal_weights (Optional) Retain the weights tensor from previous run * @param[in] gemmlowp_output_stage (Optional) GEMMLowp Output stage info * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. - * + * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix. */ GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false, - GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false) + GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false) : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d), - _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision) + _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision), + _broadcast_bias(broadcast_bias) { } /** Flag which specifies if the matrix A has been reshaped @@ -1838,6 +1849,14 @@ public: { return _fp_mixed_precision; }; + /** Flag which specifies whether to broadcast the shape of the bias tensor. + * + * @return True if the shape of the bias tensor is to be broadcasted. + */ + bool broadcast_bias() const + { + return _broadcast_bias; + }; private: const bool _is_a_reshaped; @@ -1848,6 +1867,7 @@ private: const bool _retain_internal_weights; const GEMMLowpOutputStageInfo _gemmlowp_output_stage; const bool _fp_mixed_precision; + const bool _broadcast_bias; }; /** Winograd information */ diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index 41e5c338b3..2ac2eb7c32 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -731,29 +731,29 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), // 3x4 -> 4x3 // 3x8 -> 8x3 // 3x16 -> 16x3 - res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0); - res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1); + res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0); + res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1); #if N0 > 2 - res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2); + res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2); #endif // N0 > 2 #if N0 > 3 - res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3); + res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3); #endif // N0 > 3 #if N0 > 4 - res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4); - res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5); - res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6); - res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7); + res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4); + res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5); + res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6); + res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7); #endif // N0 > 4 #if N0 > 8 - res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8); - res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9); - resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA); - resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB); - resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC); - resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD); - resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE); - resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF); + res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8); + res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9); + resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA); + resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB); + resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC); + resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD); + resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE); + resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF); #endif // N0 > 8 #elif K0 == 4 // K0 == 4 @@ -1029,35 +1029,48 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix * - * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32 - * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes) - * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes) - * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix - * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr - * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes) - * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes) - * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix - * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr - * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) - * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) - * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix - * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes) - * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes) - * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D) - * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32 + * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes) + * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes) + * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix + * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr + * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes) + * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes) + * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix + * @param[in] bias_ptr (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr + * @param[in] bias_stride_x (Optional)Stride of the bias reshaped matrix in X dimension (in bytes) + * @param[in] bias_step_x (Optional)bias_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] bias_stride_y (Optional)Stride of the bias reshaped matrix in Y dimension (in bytes) + * @param[in] bias_step_y (Optional)bias_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] bias_offset_first_element_in_bytes (Optional)The offset of the first element in the bias reshaped matrix + * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr + * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) + * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) + * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix + * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes) + * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes) + * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D) + * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), IMAGE_DECLARATION(rhs), +#if defined(BETA) + IMAGE_DECLARATION(bias), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint lhs_stride_z, uint rhs_stride_z, +#if defined(BETA) + uint bias_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_INPUT_AS_3D) , @@ -1108,7 +1121,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), #endif // defined(MATRIX_B_DEPTH) REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0; - REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0); + REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); #if defined(REINTERPRET_INPUT_AS_3D) // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D @@ -1144,7 +1157,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); // Load values from RHS matrix - LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs); + LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero); // Accumulate ARM_DOT_K0XN0(K0, a0, b, c0); @@ -1181,7 +1194,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); // Load values from RHS matrix - LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs); + LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero); // Accumulate ARM_DOT_K0XN0(1, a0, b, c0); @@ -1236,6 +1249,36 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA); #endif // defined(ALPHA) + // Add beta*bias +#if defined(BETA) +#if defined(BROADCAST_BIAS) + __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)); + + LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, DATA_TYPE, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + ADD_BLOCK_BROADCAST(M0, c, bias0); + +#else // defined(BROADCAST_BIAS) + __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id( + 2) * bias_stride_z; + + LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(M0, DATA_TYPE, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias + ADD_BLOCK(M0, c, bias); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + // Store output block STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout); @@ -1360,35 +1403,48 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix * - * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32 - * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes) - * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes) - * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix - * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr - * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes) - * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes) - * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix - * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr - * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) - * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) - * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix - * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes) - * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes) - * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D) - * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) + * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32 + * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes) + * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes) + * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix + * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr + * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes) + * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes) + * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix + * @param[in] bias_ptr (Optional) Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr + * @param[in] bias_stride_x (Optional) Stride of the bias reshaped matrix in X dimension (in bytes) + * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] bias_stride_y (Optional) Stride of the bias reshaped matrix in Y dimension (in bytes) + * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix + * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr + * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) + * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) + * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix + * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes) + * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes) + * @param[in] bias_stride_z (Optional)Stride of the bias reshaped matrix in Z dimension (in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D) + * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), IMAGE_DECLARATION(rhs), +#if defined(BETA) + IMAGE_DECLARATION(bias), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint lhs_stride_z, uint rhs_stride_z, +#if defined(BETA) + uint bias_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_INPUT_AS_3D) , @@ -1438,7 +1494,8 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), rhs_offset += z * rhs_stride_z; #endif // defined(MATRIX_B_DEPTH) - REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0; + REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0; + REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0; #if defined(REINTERPRET_INPUT_AS_3D) @@ -1568,6 +1625,36 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA); #endif // defined(ALPHA) + // Add beta*bias +#if defined(BETA) +#if defined(BROADCAST_BIAS) + __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)); + + LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, DATA_TYPE, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + ADD_BLOCK_BROADCAST(M0, c, bias0); + +#else // defined(BROADCAST_BIAS) + __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id( + 2) * bias_stride_z; + + LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(M0, DATA_TYPE, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias + ADD_BLOCK(M0, c, bias); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + // Store output block STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout); diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h index 2c76992b31..cd2d39b433 100644 --- a/src/core/CL/cl_kernels/gemm_helpers.h +++ b/src/core/CL/cl_kernels/gemm_helpers.h @@ -360,69 +360,69 @@ #define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) #define SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##0 = BASENAME##0 * (DATA_TYPE)SCALE; + BASENAME##0 *= (DATA_TYPE)SCALE; #define SCALE_ROW_2(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##1 = BASENAME##1 * (DATA_TYPE)SCALE; + BASENAME##1 *= (DATA_TYPE)SCALE; #define SCALE_ROW_3(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_2(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##2 = BASENAME##2 * (DATA_TYPE)SCALE; + BASENAME##2 *= (DATA_TYPE)SCALE; #define SCALE_ROW_4(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_3(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##3 = BASENAME##3 * (DATA_TYPE)SCALE; + BASENAME##3 *= (DATA_TYPE)SCALE; #define SCALE_ROW_5(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_4(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##4 = BASENAME##4 * (DATA_TYPE)SCALE; + BASENAME##4 *= (DATA_TYPE)SCALE; #define SCALE_ROW_6(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_5(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##5 = BASENAME##5 * (DATA_TYPE)SCALE; + BASENAME##5 *= (DATA_TYPE)SCALE; #define SCALE_ROW_7(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_6(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##6 = BASENAME##6 * (DATA_TYPE)SCALE; + BASENAME##6 *= (DATA_TYPE)SCALE; #define SCALE_ROW_8(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_7(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##7 = BASENAME##7 * (DATA_TYPE)SCALE; + BASENAME##7 *= (DATA_TYPE)SCALE; #define SCALE_ROW_9(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_8(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##8 = BASENAME##8 * (DATA_TYPE)SCALE; + BASENAME##8 *= (DATA_TYPE)SCALE; #define SCALE_ROW_10(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_9(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##9 = BASENAME##9 * (DATA_TYPE)SCALE; + BASENAME##9 *= (DATA_TYPE)SCALE; #define SCALE_ROW_11(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_10(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##A = BASENAME##A * (DATA_TYPE)SCALE; + BASENAME##A *= (DATA_TYPE)SCALE; #define SCALE_ROW_12(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_11(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##B = BASENAME##B * (DATA_TYPE)SCALE; + BASENAME##B *= (DATA_TYPE)SCALE; #define SCALE_ROW_13(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_12(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##C = BASENAME##C * (DATA_TYPE)SCALE; + BASENAME##C *= (DATA_TYPE)SCALE; #define SCALE_ROW_14(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_13(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##D = BASENAME##D * (DATA_TYPE)SCALE; + BASENAME##D *= (DATA_TYPE)SCALE; #define SCALE_ROW_15(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_14(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##E = BASENAME##E * (DATA_TYPE)SCALE; + BASENAME##E *= (DATA_TYPE)SCALE; #define SCALE_ROW_16(DATA_TYPE, BASENAME, SCALE) \ SCALE_ROW_15(DATA_TYPE, BASENAME, SCALE) \ - BASENAME##F = BASENAME##F * (DATA_TYPE)SCALE; + BASENAME##F *= (DATA_TYPE)SCALE; -// SCALE_ROW_n scales the variables BASENAME##0 to BASENAME##(n-1) by SCALE +// SCALE_BLOCK_n scales the variables BASENAME##0 to BASENAME##(n-1) by SCALE #define SCALE_BLOCK_STR(N, DATA_TYPE, BASENAME, SCALE) SCALE_ROW_##N(DATA_TYPE, BASENAME, SCALE) /** Scale elements stored in variables BASENAME##0 to BASENAME##(N-1) by SCALE * Supported cases N=1,2,3..16, for variables BASENAME[0..N] @@ -479,3 +479,143 @@ #define TRANSPOSE_K0XN0(K0, N0, BASENAME, B) \ CONCAT(TRANSPOSE_K0X, N0) \ (K0, BASENAME, B); + +#define ADD_ROW_1(BASENAME, BIAS) \ + BASENAME##0 += BIAS##0; + +#define ADD_ROW_2(BASENAME, BIAS) \ + ADD_ROW_1(BASENAME, BIAS) \ + BASENAME##1 += BIAS##1; + +#define ADD_ROW_3(BASENAME, BIAS) \ + ADD_ROW_2(BASENAME, BIAS) \ + BASENAME##2 += BIAS##2; + +#define ADD_ROW_4(BASENAME, BIAS) \ + ADD_ROW_3(BASENAME, BIAS) \ + BASENAME##3 += BIAS##3; + +#define ADD_ROW_5(BASENAME, BIAS) \ + ADD_ROW_4(BASENAME, BIAS) \ + BASENAME##4 += BIAS##4; + +#define ADD_ROW_6(BASENAME, BIAS) \ + ADD_ROW_5(BASENAME, BIAS) \ + BASENAME##5 += BIAS##5; + +#define ADD_ROW_7(BASENAME, BIAS) \ + ADD_ROW_6(BASENAME, BIAS) \ + BASENAME##6 += BIAS##6; + +#define ADD_ROW_8(BASENAME, BIAS) \ + ADD_ROW_7(BASENAME, BIAS) \ + BASENAME##7 += BIAS##7; + +#define ADD_ROW_9(BASENAME, BIAS) \ + ADD_ROW_8(BASENAME, BIAS) \ + BASENAME##8 += BIAS##8; + +#define ADD_ROW_10(BASENAME, BIAS) \ + ADD_ROW_9(BASENAME, BIAS) \ + BASENAME##9 += BIAS##9; + +#define ADD_ROW_11(BASENAME, BIAS) \ + ADD_ROW_10(BASENAME, BIAS) \ + BASENAME##A += BIAS##A; + +#define ADD_ROW_12(BASENAME, BIAS) \ + ADD_ROW_11(BASENAME, BIAS) \ + BASENAME##B += BIAS##B; + +#define ADD_ROW_13(BASENAME, BIAS) \ + ADD_ROW_12(BASENAME, BIAS) \ + BASENAME##C += BIAS##C; + +#define ADD_ROW_14(BASENAME, BIAS) \ + ADD_ROW_13(BASENAME, BIAS) \ + BASENAME##D += BIAS##D; + +#define ADD_ROW_15(BASENAME, BIAS) \ + ADD_ROW_14(BASENAME, BIAS) \ + BASENAME##E += BIAS##E; + +#define ADD_ROW_16(BASENAME, BIAS) \ + ADD_ROW_15(BASENAME, BIAS) \ + BASENAME##F += BIAS##F; + +// ADD_ROW_n add the variables BIAS##0... BIAS##(n-1) to BASENAME##0 to BASENAME##(n-1) +#define ADD_BLOCK_STR(N, BASENAME, BIAS) ADD_ROW_##N(BASENAME, BIAS) +/** Add BIAS to BASENAME##0 ... BASENAME##(N-1) + * Supported cases N=1,2,3..16, for variables BASENAME[0..N] + */ +#define ADD_BLOCK(N, BASENAME, BIAS) ADD_BLOCK_STR(N, BASENAME, BIAS) + +#define ADD_ROW_BROADCAST_1(BASENAME, BIAS) \ + BASENAME##0 += BIAS; + +#define ADD_ROW_BROADCAST_2(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_1(BASENAME, BIAS) \ + BASENAME##1 += BIAS; + +#define ADD_ROW_BROADCAST_3(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_2(BASENAME, BIAS) \ + BASENAME##2 += BIAS; + +#define ADD_ROW_BROADCAST_4(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_3(BASENAME, BIAS) \ + BASENAME##3 += BIAS; + +#define ADD_ROW_BROADCAST_5(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_4(BASENAME, BIAS) \ + BASENAME##4 += BIAS; + +#define ADD_ROW_BROADCAST_6(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_5(BASENAME, BIAS) \ + BASENAME##5 += BIAS; + +#define ADD_ROW_BROADCAST_7(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_6(BASENAME, BIAS) \ + BASENAME##6 += BIAS; + +#define ADD_ROW_BROADCAST_8(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_7(BASENAME, BIAS) \ + BASENAME##7 += BIAS; + +#define ADD_ROW_BROADCAST_9(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_8(BASENAME, BIAS) \ + BASENAME##8 += BIAS; + +#define ADD_ROW_BROADCAST_10(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_9(BASENAME, BIAS) \ + BASENAME##9 += BIAS; + +#define ADD_ROW_BROADCAST_11(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_10(BASENAME, BIAS) \ + BASENAME##A += BIAS; + +#define ADD_ROW_BROADCAST_12(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_11(BASENAME, BIAS) \ + BASENAME##B += BIAS; + +#define ADD_ROW_BROADCAST_13(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_12(BASENAME, BIAS) \ + BASENAME##C += BIAS; + +#define ADD_ROW_BROADCAST_14(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_13(BASENAME, BIAS) \ + BASENAME##D += BIAS; + +#define ADD_ROW_BROADCAST_15(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_14(BASENAME, BIAS) \ + BASENAME##E += BIAS; + +#define ADD_ROW_BROADCAST_16(BASENAME, BIAS) \ + ADD_ROW_BROADCAST_15(BASENAME, BIAS) \ + BASENAME##F += BIAS; + +// ADD_ROW_n add the variables BIAS to BASENAME##0 to BASENAME##(n-1) +#define ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS) ADD_ROW_BROADCAST_##N(BASENAME, BIAS) +/** Add elements stored in variables BIAS##0 ... BIAS##(N-1) to BASENAME##0 ... BASENAME##(N-1) + * Supported cases N=1,2,3..16, for variables BASENAME[0..N] + */ +#define ADD_BLOCK_BROADCAST(N, BASENAME, BIAS) ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS) diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp index 24372657f5..58c4cdd2f6 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -50,8 +50,9 @@ namespace { using ElementsProcessed = Steps; -Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, - const GEMMReshapeInfo &gemm_info) +Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMReshapeInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output); @@ -72,6 +73,22 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, tensor_shape1.set(0, n); tensor_shape1.set(1, k); + if(input2 != nullptr && std::abs(0.0f - beta) > 0.00001f) + { + const int input2_dim0 = static_cast(input2->dimension(0)); + const int input2_dim1 = static_cast(input2->dimension(1)); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input2, input1); + if(gemm_info.broadcast_bias()) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((input2_dim1 != 1 || input2_dim0 != n), "Incorrect dimension of bias matrix which is to be broadcasted"); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((input2_dim0 != n || input2_dim1 != m), "Incorrect dimension of bias matrix"); + } + } + const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1); const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_rhs_reshaped_shape(tensor_info1, rhs_info)); @@ -97,7 +114,8 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, return Status{}; } -std::pair validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, +std::pair validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info, ElementsProcessed &num_elements_processed) { unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0]; @@ -152,8 +170,24 @@ std::pair validate_and_configure_window(ITensorInfo *input0, ITe ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x), output->dimension(1) + bottom_pad); - window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop - update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor + if(input2 != nullptr) + { + const int bias_processed_per_iteration_x = num_elems_processed_per_iteration_x; + + const int bias_processed_per_iteration_y = gemm_info.broadcast_bias() ? 1 : num_elems_processed_per_iteration_y; + + AccessWindowStatic input2_access(input2, 0, 0, + ceil_to_multiple(input2->dimension(0), bias_processed_per_iteration_x), + ceil_to_multiple(input2->dimension(1), bias_processed_per_iteration_y)); + + window_changed = update_window_and_padding(win, input0_access, input1_access, input2_access) || // window used by the execute_window_loop + update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor + } + else + { + window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop + update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor + } output_access.set_valid_region(win_out, ValidRegion(Coordinates(), output->tensor_shape())); @@ -169,23 +203,28 @@ std::pair validate_and_configure_window(ITensorInfo *input0, ITe } // namespace CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::CLGEMMMatrixMultiplyReshapedOnlyRHSKernel() - : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false), _use_dummy_work_items(false) + : _input0(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false), _use_dummy_work_items(false), + _add_bias(false), _broadcast_bias(false) { } -void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, +void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, + const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), alpha, lhs_info, rhs_info, gemm_info)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), (input2 != nullptr ? input2->info() : nullptr), output->info(), alpha, beta, lhs_info, rhs_info, gemm_info)); _input0 = input0; _input1 = input1; + _input2 = std::abs(0.0f - beta) > 0.00001f ? input2 : nullptr; _output = output; _reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); _reinterpret_output_as_3d = (gemm_info.depth_output_gemm3d() != 0); _use_dummy_work_items = preferred_dummy_work_items_support(CLKernelLibrary::get().get_device()); + _add_bias = _input2 != nullptr; + _broadcast_bias = gemm_info.broadcast_bias(); // In case both input and output have to be reinterpreted as 3D tensors, // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. @@ -202,7 +241,7 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input ElementsProcessed num_elements_processed{}; // Configure kernel window - auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), lhs_info, rhs_info, gemm_info, num_elements_processed); + auto win_config = validate_and_configure_window(input0->info(), input1->info(), input2 != nullptr ? input2->info() : nullptr, output->info(), lhs_info, rhs_info, gemm_info, num_elements_processed); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); ICLKernel::configure_internal(win_config.second); @@ -210,8 +249,11 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type())); build_opts.add_option_if(std::abs(1.0f - alpha) > 0.00001f, "-DALPHA=" + float_to_string_with_full_precision(alpha)); + build_opts.add_option_if(std::abs(0.0f - beta) > 0.00001f && _input2 != nullptr, "-DBETA=" + float_to_string_with_full_precision(beta)); + build_opts.add_option_if(std::abs(1.0f - beta) < 0.00001f, "-DUNIT_BETA"); build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D"); build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D"); + build_opts.add_option_if(gemm_info.broadcast_bias(), "-DBROADCAST_BIAS"); build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1))); build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2))); build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2))); @@ -257,13 +299,15 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *input _config_id += support::cpp11::to_string(rhs_info.interleave); } -Status CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, const GEMMLHSMatrixInfo &lhs_info, +Status CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, + const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info) { ElementsProcessed num_elements_processed{}; - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, alpha, lhs_info, rhs_info, gemm_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, input2, output, alpha, beta, lhs_info, rhs_info, gemm_info)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), + input2 != nullptr ? input2->clone().get() : nullptr, output->clone().get(), lhs_info, rhs_info, @@ -294,7 +338,15 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::run(const Window &window, cl::Co if(_reinterpret_input_as_3d) { // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor - const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3; + unsigned int idx0; + if(_add_bias) + { + idx0 = 4 * num_arguments_per_2D_tensor() + 4; + } + else + { + idx0 = 3 * num_arguments_per_2D_tensor() + 3; + } const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom; _kernel.setArg(idx0, static_cast(total_cross_plane_pad)); } @@ -302,7 +354,15 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::run(const Window &window, cl::Co if(_reinterpret_output_as_3d) { // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor - const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0); + unsigned int idx0; + if(_add_bias) + { + idx0 = 4 * num_arguments_per_2D_tensor() + 4 + (_reinterpret_input_as_3d ? 1 : 0); + } + else + { + idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0); + } const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom; _kernel.setArg(idx0, static_cast(total_cross_plane_pad)); } @@ -320,12 +380,20 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::run(const Window &window, cl::Co unsigned int idx = 0; add_2D_tensor_argument(idx, _input0, slice); add_2D_tensor_argument(idx, _input1, slice_b); + if(_add_bias) + { + add_2D_tensor_argument(idx, _input2, slice); + } add_2D_tensor_argument(idx, _output, slice); _kernel.setArg(idx++, static_cast(_input0->info()->strides_in_bytes()[2])); _kernel.setArg(idx++, static_cast(_input1->info()->strides_in_bytes()[2])); + if(_add_bias) + { + _kernel.setArg(idx++, static_cast(_input2->info()->strides_in_bytes()[2])); + } _kernel.setArg(idx++, static_cast(_output->info()->strides_in_bytes()[2])); enqueue(queue, *this, slice, lws_hint(), _use_dummy_work_items); } while(window.slide_window_slice_3D(slice)); } -} // namespace arm_compute \ No newline at end of file +} // namespace arm_compute diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 492709f0d0..21a9fce233 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -242,10 +242,6 @@ void CLGEMM::configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info) { - ARM_COMPUTE_ERROR_ON(c != nullptr); - ARM_COMPUTE_UNUSED(beta); - ARM_COMPUTE_UNUSED(c); - DataType data_type = a->info()->data_type(); bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1); @@ -254,11 +250,12 @@ void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); const GPUTarget gpu_target = CLScheduler::get().target(); + bool broadcast_bias = gemm_info.broadcast_bias(); // Set the target for the kernels _mm_kernel.set_target(gpu_target); - GEMMReshapeInfo reshape_info(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d); + GEMMReshapeInfo reshape_info(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, broadcast_bias); // Manage intermediate buffers if(!_reshape_b_only_on_first_run) @@ -279,7 +276,7 @@ void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); // Configure and tune matrix multiply kernel - _mm_reshaped_only_rhs_kernel.configure(a, &_tmp_b, output, alpha, lhs_info, rhs_info, reshape_info); + _mm_reshaped_only_rhs_kernel.configure(a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, reshape_info); if(!_reshape_b_only_on_first_run) { @@ -426,7 +423,6 @@ Status CLGEMM::validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, // Validate matrix addition kernel ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta)); } - return Status{}; } @@ -438,17 +434,16 @@ Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf TensorInfo tmp_b_info{}; // Get the GPU target - const GPUTarget gpu_target = CLScheduler::get().target(); - const DataType data_type = a->data_type(); - bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); - const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); - const unsigned int n = b->dimension(0); - const unsigned int k = a->dimension(0); - const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); - const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); - const bool add_c = (beta != 0.f && c != nullptr); - - const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d); + const GPUTarget gpu_target = CLScheduler::get().target(); + const DataType data_type = a->data_type(); + bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); + const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); + const unsigned int n = b->dimension(0); + const unsigned int k = a->dimension(0); + const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); + const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); + const bool broadcast_bias = gemm_info.broadcast_bias(); + const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, broadcast_bias); GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; @@ -464,13 +459,7 @@ Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)); // Validate matrix multiply - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, output, alpha, lhs_info, rhs_info, reshape_info)); - - if(add_c) - { - // Validate matrix addition kernel - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta)); - } + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, reshape_info)); return Status{}; } @@ -497,10 +486,10 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * // Select GEMMType _gemm_type = select_gemm_type(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target); - const bool is_gemm_v2 = (_gemm_type == GEMMType::RESHAPED_V2) || (_gemm_type == GEMMType::RESHAPED_ONLY_RHS); - const bool add_c = (beta != 0.f && c != nullptr); - const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f; - const bool fuse_add = is_beta_one && (c != nullptr && c->info()->num_dimensions() == 1) && !is_gemm_v2; + const bool is_gemm_reshaped_only_rhs = _gemm_type == GEMMType::RESHAPED_ONLY_RHS; + const bool add_c = (beta != 0.f && c != nullptr); + const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f; + const bool fuse_add = (is_beta_one && (c != nullptr && c->info()->num_dimensions() == 1)) || is_gemm_reshaped_only_rhs; switch(_gemm_type) { diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp index 83051d2efe..23ae004912 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp @@ -72,6 +72,9 @@ constexpr float tolerance_num_f16 = 0.02f; /** Alpha values to test - Precommit */ const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} ); +/** Beta values to test - Precommit */ +const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} ); + /** M values to test */ const auto m_values = framework::dataset::make("M", 37); @@ -120,8 +123,11 @@ const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, fal /** Transpose values to test with RHS matrix */ const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true, false }); +/**Broadcast bias from vector to matrix */ +const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", {false, true} ); + /** Configuration test */ -void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int h0_value, bool i_value_rhs, bool t_value_rhs, DataType data_type) +void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int h0_value, bool i_value_rhs, bool t_value_rhs, bool broadcast_bias, DataType data_type) { const unsigned int M = m_value; const unsigned int N = n_value; @@ -138,7 +144,7 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned rhs_info.interleave = i_value_rhs; rhs_info.transpose = t_value_rhs; - GEMMReshapeInfo gemm_info(M, N, K); + GEMMReshapeInfo gemm_info(M, N, K, false, false, 0, false, broadcast_bias); const TensorShape lhs_shape(K, M, b_value); const TensorShape rhs_shape(N, K, b_value); @@ -154,13 +160,22 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned CLTensor rhs_reshaped = create_tensor(rhs_shape_reshaped, data_type); CLTensor dst = create_tensor(dst_shape, data_type); + TensorShape bias_shape = dst_shape; + if (broadcast_bias) + { + bias_shape[1] = 1; + bias_shape[2] = 1; + } + CLTensor bias = create_tensor(bias_shape, data_type); + ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); // Create and configure function CLGEMMMatrixMultiplyReshapedOnlyRHS gemm; - gemm.configure(&lhs, &rhs_reshaped, &dst, 1.0f, lhs_info, rhs_info, gemm_info); + gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, 1.0f, 1.0f, lhs_info, rhs_info, gemm_info); } } // namespace @@ -168,7 +183,7 @@ TEST_SUITE(CL) TEST_SUITE(GEMMMatrixMultiplyReshapedOnlyRHS) TEST_SUITE(Float) TEST_SUITE(FP32) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine( +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -179,13 +194,14 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi h0_values_precommit), i_values_rhs), t_values_rhs), -m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs, t_value_rhs) + broadcast_bias_values), +m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs, t_value_rhs, broadcast_bias) { - validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs, t_value_rhs, DataType::F32); + validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs, t_value_rhs, broadcast_bias, DataType::F32); } FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -197,14 +213,16 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -216,14 +234,16 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), n_values), @@ -236,14 +256,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< i_values_rhs), t_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values)) + a_values), + b_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); } FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), n_values), @@ -256,7 +277,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< i_values_rhs), t_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values)) + a_values), + b_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); @@ -265,7 +287,7 @@ TEST_SUITE_END() // FP32 TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -277,14 +299,16 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), @@ -296,14 +320,16 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), n_values), @@ -316,14 +342,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< i_values_rhs), t_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values)) + a_values), + b_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); } FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture, framework::DatasetMode::NIGHTLY, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), n_values), @@ -336,7 +363,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< i_values_rhs), t_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values)) + a_values), + b_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); @@ -347,4 +375,4 @@ TEST_SUITE_END() // GEMMMatrixMulipltyReshapedOnlyRHS TEST_SUITE_END() // CL } // namespace validation } // namespace test -} // namespace arm_compute \ No newline at end of file +} // namespace arm_compute diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index b7976104aa..34f9bd848c 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -390,7 +390,7 @@ class GEMMMatrixMultiplyReshapedOnlyRHSValidationFixture : public framework::Fix public: template void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0, - bool interleave_rhs, bool transpose_rhs, DataType data_type, float alpha) + bool interleave_rhs, bool transpose_rhs, DataType data_type, float alpha, float beta, bool broadcast_bias) { GEMMLHSMatrixInfo lhs_info; lhs_info.m0 = m0; @@ -407,8 +407,18 @@ public: const TensorShape lhs_shape(k, m, batch_size); const TensorShape rhs_shape(n, k, batch_size); - _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type, alpha); - _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha); + TensorShape bias_shape; + if(broadcast_bias) + { + bias_shape = TensorShape(n, 1, 1); + } + else + { + bias_shape = TensorShape(n, m, batch_size); + } + + _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias); + _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, broadcast_bias); } protected: @@ -423,11 +433,13 @@ protected: library->fill_borders_with_garbage(tensor, distribution_inf, i); } - TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, DataType data_type, float alpha) + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + DataType data_type, float alpha, float beta, bool broadcast_bias) { // Create tensors - TensorType lhs = create_tensor(lhs_shape, data_type, 1); - TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType lhs = create_tensor(lhs_shape, data_type, 1); + TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType bias = create_tensor(bias_shape, data_type, 1); TensorType rhs_reshaped; TensorType dst; @@ -441,7 +453,7 @@ protected: ReshapeRHSFunctionType reshape_rhs; GEMMFunctionType gemm; reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info); - gemm.configure(&lhs, &rhs_reshaped, &dst, alpha, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K)); + gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, 0, false, broadcast_bias)); ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -450,6 +462,7 @@ protected: lhs.allocator()->allocate(); rhs.allocator()->allocate(); rhs_reshaped.allocator()->allocate(); + bias.allocator()->allocate(); dst.allocator()->allocate(); ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -460,6 +473,7 @@ protected: // Fill tensors fill(AccessorType(lhs), 0); fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); // Compute GEMM reshape_rhs.run(); @@ -468,7 +482,7 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha) + SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, bool broadcast_bias) { TensorShape dst_shape = lhs_shape; dst_shape[0] = rhs_shape[0]; @@ -477,13 +491,31 @@ protected: // Create reference SimpleTensor lhs{ lhs_shape, data_type, 1 }; SimpleTensor rhs{ rhs_shape, data_type, 1 }; - SimpleTensor c{ dst_shape, data_type, 1 }; + SimpleTensor bias{ dst_shape, data_type, 1 }; + + const int n = rhs_shape[0]; + const int m = lhs_shape[1]; + const int batch_size = lhs_shape[2]; // Fill reference fill(lhs, 0); fill(rhs, 1); - return reference::gemm(lhs, rhs, c, alpha, 0.0f); + if(broadcast_bias) + { + SimpleTensor tmp{ bias_shape, data_type, 1 }; + fill(tmp, 2); + for(int i = 0; i < m * batch_size; i++) + { + memcpy(bias.data() + i * n, tmp.data(), n * sizeof(T)); + } + } + else + { + fill(bias, 2); + } + + return (reference::gemm(lhs, rhs, bias, alpha, beta)); } TensorType _target{}; @@ -590,7 +622,7 @@ class GEMMMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framework::F public: template void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0, - bool interleave_rhs, bool transpose_rhs, DataType data_type, float alpha) + bool interleave_rhs, bool transpose_rhs, DataType data_type, float alpha, float beta) { GEMMLHSMatrixInfo lhs_info; lhs_info.m0 = m0; @@ -609,9 +641,10 @@ public: // Set the tensor shapes for LHS and RHS matrices const TensorShape lhs_shape(k, m, batch_size); const TensorShape rhs_shape(n, k, batch_size); + const TensorShape bias_shape(n, 1, 1); - _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type, alpha, m_h); - _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, m_h); + _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, m_h); + _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, m_h); } protected: @@ -622,12 +655,14 @@ protected: library->fill(tensor, distribution, i); } - TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, DataType data_type, float alpha, + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + DataType data_type, float alpha, float beta, unsigned int m_h) { // Create tensors - TensorType lhs = create_tensor(lhs_shape, data_type, 1); - TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType lhs = create_tensor(lhs_shape, data_type, 1); + TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType bias = create_tensor(bias_shape, data_type, 1); TensorType rhs_reshaped; TensorType dst; @@ -641,7 +676,7 @@ protected: ReshapeRHSFunctionType reshape_rhs; GEMMFunctionType gemm; reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info); - gemm.configure(&lhs, &rhs_reshaped, &dst, alpha, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h)); + gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h, false, true)); ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -650,6 +685,7 @@ protected: lhs.allocator()->allocate(); rhs.allocator()->allocate(); rhs_reshaped.allocator()->allocate(); + bias.allocator()->allocate(); dst.allocator()->allocate(); ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -660,6 +696,7 @@ protected: // Fill tensors fill(AccessorType(lhs), 0); fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); // Compute GEMM reshape_rhs.run(); @@ -668,7 +705,7 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha, unsigned int m_h) + SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, unsigned int m_h) { TensorShape dst_shape = lhs_shape; dst_shape.set(0, rhs_shape[0]); @@ -679,13 +716,24 @@ protected: // Create reference SimpleTensor lhs{ lhs_shape, data_type, 1 }; SimpleTensor rhs{ rhs_shape, data_type, 1 }; - SimpleTensor c{ dst_shape, data_type, 1 }; + SimpleTensor bias{ dst_shape, data_type, 1 }; + + const int n = rhs_shape[0]; + const int m = lhs_shape[1]; + const int batch_size = lhs_shape[2]; // Fill reference fill(lhs, 0); fill(rhs, 1); - return reference::gemm(lhs, rhs, c, alpha, 0.0f); + SimpleTensor tmp{ bias_shape, data_type, 1 }; + fill(tmp, 2); + for(int i = 0; i < m * batch_size; i++) + { + memcpy(bias.data() + i * n, tmp.data(), n * sizeof(T)); + } + + return reference::gemm(lhs, rhs, bias, alpha, beta); } TensorType _target{}; -- cgit v1.2.1