From d1f54767fc9d6398a5eea38e639dd0ce3df8e5d8 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 19 Jul 2019 09:54:47 +0100 Subject: COMPMID-1979: Fuse Activation Function in CLGEMM - part 3 Fused beta*bias in in the old cl gemm kernels Fused activation function in the old cl gemm kernels Change-Id: I695fb9189e6d4792010abd256784624982d17d79 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1587 Reviewed-by: Giuseppe Rossini Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- .../core/CL/kernels/CLGEMMMatrixMultiplyKernel.h | 18 +- src/core/CL/cl_kernels/gemm.cl | 2181 ++++++++++---------- src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp | 133 +- tests/CL/Helper.h | 13 + tests/validation/CL/GEMMMatrixMultiply.cpp | 344 +++ .../CL/GEMMMatrixMultiplyInterleavedTransposed.cpp | 404 ++++ tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 2 +- tests/validation/fixtures/GEMMFixture.h | 500 +++++ 8 files changed, 2470 insertions(+), 1125 deletions(-) create mode 100644 tests/validation/CL/GEMMMatrixMultiply.cpp create mode 100644 tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h index 724a7d67e6..8e6e07973c 100644 --- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h @@ -30,13 +30,12 @@ namespace arm_compute { class ICLTensor; -/** OpenCL kernel to multiply two input matrices "A" and "B" and add a vector "C" if provided. All elements of the output matrix will be multiplied by alpha. In case vector C is passed, it will be added to the previous result (a broadcast addition will be performed). +/** OpenCL kernel to multiply two input matrices "A" and "B" and add a martix "C" if provided. All elements of the output matrix will be multiplied by alpha. In case matrix C is passed, it will be added to the previous result. + * For the matrix C, the broadcast addition is supported if the flag "broadcast_bias" is set in the GEMMReshapeInfo object * * @note If the input tensors @p input0 and @p input1 have been reshaped respectively with @ref CLGEMMReshapeLHSMatrixKernel" and @ref CLGEMMReshapeRHSMatrixKernel, * the flag @p is_interleaved_transposed must be set to true * - * @attention Vector C (@p input2) must be 1D. A broadcast addition is performed. - * * @attention @p input1 tensor must have at least 2 dimensions (matrix) * */ @@ -57,22 +56,23 @@ public: * * @param[in] input0 Input tensor containing the Matrix A. Data types supported: F16/F32 * @param[in] input1 Input tensor containing the Matrix B. Data type supported: same as @p input0 - * @param[in] input2 Input tensor containing the Vector C. Can be nullptr. Data type supported: same as @p input0 + * @param[in] input2 Input tensor containing the Matrix C (bias). Can be nullptr. Data type supported: same as @p input0 * @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0 * @param[in] alpha Weight of the matrix product * @param[in] beta (Optional) Weight of vector C. Default value is 0. Only beta = 1 is currently supported. * @param[in] is_interleaved_transposed (Optional) True if input0 and input1 have been reshaped respectively using @ref CLGEMMReshapeLHSMatrixKernel and @ref CLGEMMReshapeRHSMatrixKernel * @param[in] reshape_info (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy + * @param[in] activation_info (Optional) Activation to apply after the matrix multiplication * */ void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta = 0.f, - bool is_interleaved_transposed = true, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo(), bool fp_mixed_precision = false); + bool is_interleaved_transposed = true, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo(), bool fp_mixed_precision = false, const ActivationLayerInfo &activation_info = ActivationLayerInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyKernel * * @param[in] input0 Input tensor containing the Matrix A info. Data types supported: F16/F32 * @param[in] input1 Input tensor containing the Matrix B info. Data type supported: same as @p input0 - * @param[in] input2 Input tensor containing the Vector C info. Can be nullptr. Data type supported: same as @p input0 + * @param[in] input2 Input tensor containing the Matrix C (bias) info. Can be nullptr. Data type supported: same as @p input0 * @param[in] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0 * @param[in] alpha Weight of the matrix product * @param[in] beta Weight of vector C. Default value is 0. Only beta = 1 is currently supported. @@ -80,11 +80,12 @@ public: * @param[in] reshape_info GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped * @param[in] gpu_target GPU Target * @param[in] fp_mixed_precision (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy + * @param[in] activation_info (Optional) Activation to apply after the matrix multiplication * * @return a status */ static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, - bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision = false); + bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision = false, const ActivationLayerInfo &activation_info = ActivationLayerInfo()); // Inherited methods overridden: void run(const Window &window, cl::CommandQueue &queue) override; @@ -97,7 +98,8 @@ public: bool _slide_matrix_b; bool _reinterpret_input_as_3d; bool _reinterpret_output_as_3d; - bool _has_vec_c; + bool _add_bias; + bool _broadcast_bias; }; } // namespace arm_compute #endif /* __ARM_COMPUTE_CLGEMMMATRIXMULTIPLYKERNEL_H__ */ diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index 213075df07..8d638bc6bb 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -46,15 +46,15 @@ /** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in * the output matrix unrolling the values. * - * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float) - * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16) - * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2). - * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2) + * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float) + * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16) + * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2). + * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2) * @note Only the following values for M0, K0 and V0 are supported: * M0: 2,3,4,5,6,7,8 * K0: 2,3,4,8,16 * V0: greater than 0 - * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time: + * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor @@ -246,15 +246,15 @@ __kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src), /** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in * the output matrix unrolling the values. * - * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float) - * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16) - * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2). - * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2) + * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float) + * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16) + * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2). + * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2) * @note Only the following values for M0, K0 and V0 are supported: * M0: 2,3,4,5,6,7,8 * K0: 2,3,4,8,16 * V0: greater than 0 - * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time: + * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor @@ -402,10 +402,10 @@ __kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src), /** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in * the output matrix unrolling the values. * - * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float) - * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16) - * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2). - * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2) + * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float) + * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16) + * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2). + * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2) * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time. * @note Only the following values for K0, N0 and H0 are supported: * N0: 2,3,4,8,16 @@ -555,10 +555,10 @@ __kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src), /** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in * the output matrix unrolling the values. * - * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float) - * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16) - * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2). - * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2) + * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float) + * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16) + * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2). + * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2) * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time. * @note The option -DTRANSPOSE must passed at compile time. * @note Only the following values for K0, N0 and H0 are supported: @@ -1010,11 +1010,11 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed * * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time. - * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90) - * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64) - * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4). - * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2) - * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2) + * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90) + * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64) + * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4). + * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2) + * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2) * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time. * @note Only the following configurations of M0, N0 and K0 are currently supported: * - M0 = 1, 2, 3, 4, 5, 6, 7, 8 @@ -1022,7 +1022,7 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), * - K0 = 2, 3, 4, 8, 16 * - H0 >= 1 * - * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. * The activation function is performed after the bias addition * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D @@ -1043,7 +1043,6 @@ __kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src), * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes) * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix - * @param[in] bias_ptr (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes) @@ -1392,10 +1391,10 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed * * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time. - * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90). - * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4). - * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2) - * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2) + * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90). + * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4). + * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2) + * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2) * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time. * @note Only the following configurations of M0, N0 and K0 are currently supported: * - M0 = 1, 2, 3, 4, 5, 6, 7, 8 @@ -1403,7 +1402,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), * - K0 = 2, 3, 4, 8, 16 * - H0 >= 1 * - * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. * The activation function is performed after the bias addition * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D @@ -1798,10 +1797,10 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed * * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time. - * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90). - * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4). - * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2) - * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2) + * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90). + * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4). + * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2) + * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2) * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time. * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time. * @note Only the following configurations of M0, N0 and K0 are currently supported: @@ -1811,9 +1810,9 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), * - V0 >= 1 * - H0 >= 1 * - * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. * The activation function is performed after the bias addition - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: + * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor @@ -2123,17 +2122,17 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), * The RHS matrix is NOT reshaped * * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time. - * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90) - * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64) - * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2) - * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (i.e., -DK0=2) - * @note The number of N0 columns to process must be passed at compile time using -DN0 (i.e. -DN0=2) + * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90) + * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64) + * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2) + * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (e.g., -DK0=2) + * @note The number of N0 columns to process must be passed at compile time using -DN0 (e.g. -DN0=2) * @note Only the following configurations of M0, N0 and K0 are currently supported: * - M0 = 1, 2, 3, 4, 5, 6, 7, 8 * - N0 = 2, 3, 4, 8, 16 * - K0 = 2, 3, 4, 8, 16 * - * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. * The activation function is performed after the bias addition * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D @@ -2154,7 +2153,6 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), * @param[in] rhs_stride_y Stride of the RHS matrix in Y dimension (in bytes) * @param[in] rhs_step_y rhs_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS matrix - * @param[in] bias_ptr (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes) @@ -2405,25 +2403,22 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs), #endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE) #if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT) -/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1) - * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication - * - * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time. +/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) * * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA - * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2) - * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2) - * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) - * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2) + * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2) + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition + * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * - * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C - * * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) @@ -2436,10 +2431,12 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs), * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr - * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes) - * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix + * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr + * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) + * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) + * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) @@ -2448,17 +2445,21 @@ __kernel void gemm_mm_native(IMAGE_DECLARATION(lhs), * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), -#if defined(ADD_VEC_C) - VECTOR_DECLARATION(src2), -#endif /* defined(ADD_VEC_C) */ +#if defined(BETA) + IMAGE_DECLARATION(src2), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z, +#if defined(BETA) + uint src2_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_OUTPUT_AS_3D) , @@ -2496,10 +2497,10 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), src_addr_b += offset_row_b; // Reset accumulators - float4 c00 = 0.0f; - float4 c10 = 0.0f; - float4 c20 = 0.0f; - float4 c30 = 0.0f; + float4 c0 = 0.0f; + float4 c1 = 0.0f; + float4 c2 = 0.0f; + float4 c3 = 0.0f; for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH) { @@ -2507,19 +2508,19 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), float4 a0 = vload4(0, src_addr_a); float4 b0 = vload4(0, src_addr_b); - c00 += (float4)a0.s0 * b0; - c10 += (float4)a0.s1 * b0; - c20 += (float4)a0.s2 * b0; - c30 += (float4)a0.s3 * b0; + c0 += (float4)a0.s0 * b0; + c1 += (float4)a0.s1 * b0; + c2 += (float4)a0.s2 * b0; + c3 += (float4)a0.s3 * b0; // Load values from matrix A (interleaved) and matrix B (transposed) a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT); b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH); - c00 += (float4)a0.s0 * b0; - c10 += (float4)a0.s1 * b0; - c20 += (float4)a0.s2 * b0; - c30 += (float4)a0.s3 * b0; + c0 += (float4)a0.s0 * b0; + c1 += (float4)a0.s1 * b0; + c2 += (float4)a0.s2 * b0; + c3 += (float4)a0.s3 * b0; } for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH) @@ -2528,36 +2529,20 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), float4 a0 = vload4(0, src_addr_a); float4 b0 = vload4(0, src_addr_b); - c00 += (float4)a0.s0 * b0; - c10 += (float4)a0.s1 * b0; - c20 += (float4)a0.s2 * b0; - c30 += (float4)a0.s3 * b0; + c0 += (float4)a0.s0 * b0; + c1 += (float4)a0.s1 * b0; + c2 += (float4)a0.s2 * b0; + c3 += (float4)a0.s3 * b0; } // Compute destination address Image dst = CONVERT_TO_IMAGE_STRUCT(dst); -#if defined(ALPHA) - // Multiply by the weight of matrix product - c00 = c00 * (float4)ALPHA; - c10 = c10 * (float4)ALPHA; - c20 = c20 * (float4)ALPHA; - c30 = c30 * (float4)ALPHA; -#endif // defined(ALPHA) - -#if defined(ADD_VEC_C) - __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x); - float4 c0 = vload4(0, src2_addr); - - c00 += c0; - c10 += c0; - c20 += c0; - c30 += c0; -#endif /* defined(ADD_VEC_C) */ - // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); + uint4 zout = 0; + #if defined(REINTERPRET_OUTPUT_AS_3D) // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension // in order to take into account the presence of possible cross plane paddings @@ -2575,8 +2560,8 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); + zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings zout *= (cross_plane_pad * dst_stride_y); @@ -2584,45 +2569,76 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store 4x4 block - vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0)); - vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1)); - vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2)); - vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3)); - #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; +#endif // defined(REINTERPRET_OUTPUT_AS_3D) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + SCALE_BLOCK(4, float, c, ALPHA); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) + REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0); + +#if defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)); + + LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, float, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + ADD_BLOCK_BROADCAST(4, c, bias0); + +#else // defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id( + 2) * src2_stride_z; + + LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(4, float, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias + ADD_BLOCK(4, c, bias); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + +#if defined(ACTIVATION_TYPE) + ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL); +#endif // defined(ACTIVATION_TYPE) // Store 4x4 block - vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y)); - vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y)); - vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y)); - vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y)); -#endif // defined(REINTERPRET_OUTPUT_AS_3D) + vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0)); + vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1)); + vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2)); + vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3)); } -/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1) - * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication. - * - * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time. +/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) * * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA - * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2) - * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2) - * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2) - * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) - * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2) + * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2) + * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2) + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition + * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * - * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C - * * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) @@ -2635,10 +2651,12 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr - * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes) - * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix + * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr + * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) + * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) + * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) @@ -2647,17 +2665,21 @@ __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), -#if defined(ADD_VEC_C) - VECTOR_DECLARATION(src2), -#endif /* defined(ADD_VEC_C) */ +#if defined(BETA) + IMAGE_DECLARATION(src2), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z, +#if defined(BETA) + uint src2_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_OUTPUT_AS_3D) , @@ -2692,22 +2714,10 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) src_addr_b += offset_row_b; // Reset accumulators - float c00 = 0.0f; - float c01 = 0.0f; - float c02 = 0.0f; - float c03 = 0.0f; - float c10 = 0.0f; - float c11 = 0.0f; - float c12 = 0.0f; - float c13 = 0.0f; - float c20 = 0.0f; - float c21 = 0.0f; - float c22 = 0.0f; - float c23 = 0.0f; - float c30 = 0.0f; - float c31 = 0.0f; - float c32 = 0.0f; - float c33 = 0.0f; + float4 c0 = 0.0f; + float4 c1 = 0.0f; + float4 c2 = 0.0f; + float4 c3 = 0.0f; #define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH)) @@ -2721,25 +2731,25 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma(a0.s0, b0.s0, c00); - c01 = fma(a0.s0, b0.s1, c01); - c02 = fma(a0.s0, b0.s2, c02); - c03 = fma(a0.s0, b0.s3, c03); + c0.s0 = fma(a0.s0, b0.s0, c0.s0); + c0.s1 = fma(a0.s0, b0.s1, c0.s1); + c0.s2 = fma(a0.s0, b0.s2, c0.s2); + c0.s3 = fma(a0.s0, b0.s3, c0.s3); - c10 = fma(a0.s1, b0.s0, c10); - c11 = fma(a0.s1, b0.s1, c11); - c12 = fma(a0.s1, b0.s2, c12); - c13 = fma(a0.s1, b0.s3, c13); + c1.s0 = fma(a0.s1, b0.s0, c1.s0); + c1.s1 = fma(a0.s1, b0.s1, c1.s1); + c1.s2 = fma(a0.s1, b0.s2, c1.s2); + c1.s3 = fma(a0.s1, b0.s3, c1.s3); - c20 = fma(a0.s2, b0.s0, c20); - c21 = fma(a0.s2, b0.s1, c21); - c22 = fma(a0.s2, b0.s2, c22); - c23 = fma(a0.s2, b0.s3, c23); + c2.s0 = fma(a0.s2, b0.s0, c2.s0); + c2.s1 = fma(a0.s2, b0.s1, c2.s1); + c2.s2 = fma(a0.s2, b0.s2, c2.s2); + c2.s3 = fma(a0.s2, b0.s3, c2.s3); - c30 = fma(a0.s3, b0.s0, c30); - c31 = fma(a0.s3, b0.s1, c31); - c32 = fma(a0.s3, b0.s2, c32); - c33 = fma(a0.s3, b0.s3, c33); + c3.s0 = fma(a0.s3, b0.s0, c3.s0); + c3.s1 = fma(a0.s3, b0.s1, c3.s1); + c3.s2 = fma(a0.s3, b0.s2, c3.s2); + c3.s3 = fma(a0.s3, b0.s3, c3.s3); // Load values from matrix A (interleaved) and matrix B (transposed) a0 = vload4(0, src_addr_a); @@ -2748,25 +2758,25 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma(a0.s0, b0.s0, c00); - c01 = fma(a0.s0, b0.s1, c01); - c02 = fma(a0.s0, b0.s2, c02); - c03 = fma(a0.s0, b0.s3, c03); + c0.s0 = fma(a0.s0, b0.s0, c0.s0); + c0.s1 = fma(a0.s0, b0.s1, c0.s1); + c0.s2 = fma(a0.s0, b0.s2, c0.s2); + c0.s3 = fma(a0.s0, b0.s3, c0.s3); - c10 = fma(a0.s1, b0.s0, c10); - c11 = fma(a0.s1, b0.s1, c11); - c12 = fma(a0.s1, b0.s2, c12); - c13 = fma(a0.s1, b0.s3, c13); + c1.s0 = fma(a0.s1, b0.s0, c1.s0); + c1.s1 = fma(a0.s1, b0.s1, c1.s1); + c1.s2 = fma(a0.s1, b0.s2, c1.s2); + c1.s3 = fma(a0.s1, b0.s3, c1.s3); - c20 = fma(a0.s2, b0.s0, c20); - c21 = fma(a0.s2, b0.s1, c21); - c22 = fma(a0.s2, b0.s2, c22); - c23 = fma(a0.s2, b0.s3, c23); + c2.s0 = fma(a0.s2, b0.s0, c2.s0); + c2.s1 = fma(a0.s2, b0.s1, c2.s1); + c2.s2 = fma(a0.s2, b0.s2, c2.s2); + c2.s3 = fma(a0.s2, b0.s3, c2.s3); - c30 = fma(a0.s3, b0.s0, c30); - c31 = fma(a0.s3, b0.s1, c31); - c32 = fma(a0.s3, b0.s2, c32); - c33 = fma(a0.s3, b0.s3, c33); + c3.s0 = fma(a0.s3, b0.s0, c3.s0); + c3.s1 = fma(a0.s3, b0.s1, c3.s1); + c3.s2 = fma(a0.s3, b0.s2, c3.s2); + c3.s3 = fma(a0.s3, b0.s3, c3.s3); // Load values from matrix A (interleaved) and matrix B (transposed) a0 = vload4(0, src_addr_a); @@ -2775,25 +2785,25 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma(a0.s0, b0.s0, c00); - c01 = fma(a0.s0, b0.s1, c01); - c02 = fma(a0.s0, b0.s2, c02); - c03 = fma(a0.s0, b0.s3, c03); + c0.s0 = fma(a0.s0, b0.s0, c0.s0); + c0.s1 = fma(a0.s0, b0.s1, c0.s1); + c0.s2 = fma(a0.s0, b0.s2, c0.s2); + c0.s3 = fma(a0.s0, b0.s3, c0.s3); - c10 = fma(a0.s1, b0.s0, c10); - c11 = fma(a0.s1, b0.s1, c11); - c12 = fma(a0.s1, b0.s2, c12); - c13 = fma(a0.s1, b0.s3, c13); + c1.s0 = fma(a0.s1, b0.s0, c1.s0); + c1.s1 = fma(a0.s1, b0.s1, c1.s1); + c1.s2 = fma(a0.s1, b0.s2, c1.s2); + c1.s3 = fma(a0.s1, b0.s3, c1.s3); - c20 = fma(a0.s2, b0.s0, c20); - c21 = fma(a0.s2, b0.s1, c21); - c22 = fma(a0.s2, b0.s2, c22); - c23 = fma(a0.s2, b0.s3, c23); + c2.s0 = fma(a0.s2, b0.s0, c2.s0); + c2.s1 = fma(a0.s2, b0.s1, c2.s1); + c2.s2 = fma(a0.s2, b0.s2, c2.s2); + c2.s3 = fma(a0.s2, b0.s3, c2.s3); - c30 = fma(a0.s3, b0.s0, c30); - c31 = fma(a0.s3, b0.s1, c31); - c32 = fma(a0.s3, b0.s2, c32); - c33 = fma(a0.s3, b0.s3, c33); + c3.s0 = fma(a0.s3, b0.s0, c3.s0); + c3.s1 = fma(a0.s3, b0.s1, c3.s1); + c3.s2 = fma(a0.s3, b0.s2, c3.s2); + c3.s3 = fma(a0.s3, b0.s3, c3.s3); // Load values from matrix A (interleaved) and matrix B (transposed) a0 = vload4(0, src_addr_a); @@ -2802,25 +2812,25 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma(a0.s0, b0.s0, c00); - c01 = fma(a0.s0, b0.s1, c01); - c02 = fma(a0.s0, b0.s2, c02); - c03 = fma(a0.s0, b0.s3, c03); - - c10 = fma(a0.s1, b0.s0, c10); - c11 = fma(a0.s1, b0.s1, c11); - c12 = fma(a0.s1, b0.s2, c12); - c13 = fma(a0.s1, b0.s3, c13); - - c20 = fma(a0.s2, b0.s0, c20); - c21 = fma(a0.s2, b0.s1, c21); - c22 = fma(a0.s2, b0.s2, c22); - c23 = fma(a0.s2, b0.s3, c23); - - c30 = fma(a0.s3, b0.s0, c30); - c31 = fma(a0.s3, b0.s1, c31); - c32 = fma(a0.s3, b0.s2, c32); - c33 = fma(a0.s3, b0.s3, c33); + c0.s0 = fma(a0.s0, b0.s0, c0.s0); + c0.s1 = fma(a0.s0, b0.s1, c0.s1); + c0.s2 = fma(a0.s0, b0.s2, c0.s2); + c0.s3 = fma(a0.s0, b0.s3, c0.s3); + + c1.s0 = fma(a0.s1, b0.s0, c1.s0); + c1.s1 = fma(a0.s1, b0.s1, c1.s1); + c1.s2 = fma(a0.s1, b0.s2, c1.s2); + c1.s3 = fma(a0.s1, b0.s3, c1.s3); + + c2.s0 = fma(a0.s2, b0.s0, c2.s0); + c2.s1 = fma(a0.s2, b0.s1, c2.s1); + c2.s2 = fma(a0.s2, b0.s2, c2.s2); + c2.s3 = fma(a0.s2, b0.s3, c2.s3); + + c3.s0 = fma(a0.s3, b0.s0, c3.s0); + c3.s1 = fma(a0.s3, b0.s1, c3.s1); + c3.s2 = fma(a0.s3, b0.s2, c3.s2); + c3.s3 = fma(a0.s3, b0.s3, c3.s3); } for(; i < (int)(COLS_MTX_B); ++i) @@ -2832,74 +2842,34 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma(a0.s0, b0.s0, c00); - c01 = fma(a0.s0, b0.s1, c01); - c02 = fma(a0.s0, b0.s2, c02); - c03 = fma(a0.s0, b0.s3, c03); - - c10 = fma(a0.s1, b0.s0, c10); - c11 = fma(a0.s1, b0.s1, c11); - c12 = fma(a0.s1, b0.s2, c12); - c13 = fma(a0.s1, b0.s3, c13); - - c20 = fma(a0.s2, b0.s0, c20); - c21 = fma(a0.s2, b0.s1, c21); - c22 = fma(a0.s2, b0.s2, c22); - c23 = fma(a0.s2, b0.s3, c23); - - c30 = fma(a0.s3, b0.s0, c30); - c31 = fma(a0.s3, b0.s1, c31); - c32 = fma(a0.s3, b0.s2, c32); - c33 = fma(a0.s3, b0.s3, c33); + c0.s0 = fma(a0.s0, b0.s0, c0.s0); + c0.s1 = fma(a0.s0, b0.s1, c0.s1); + c0.s2 = fma(a0.s0, b0.s2, c0.s2); + c0.s3 = fma(a0.s0, b0.s3, c0.s3); + + c1.s0 = fma(a0.s1, b0.s0, c1.s0); + c1.s1 = fma(a0.s1, b0.s1, c1.s1); + c1.s2 = fma(a0.s1, b0.s2, c1.s2); + c1.s3 = fma(a0.s1, b0.s3, c1.s3); + + c2.s0 = fma(a0.s2, b0.s0, c2.s0); + c2.s1 = fma(a0.s2, b0.s1, c2.s1); + c2.s2 = fma(a0.s2, b0.s2, c2.s2); + c2.s3 = fma(a0.s2, b0.s3, c2.s3); + + c3.s0 = fma(a0.s3, b0.s0, c3.s0); + c3.s1 = fma(a0.s3, b0.s1, c3.s1); + c3.s2 = fma(a0.s3, b0.s2, c3.s2); + c3.s3 = fma(a0.s3, b0.s3, c3.s3); } // Compute destination address Image dst = CONVERT_TO_IMAGE_STRUCT(dst); -#if defined(ALPHA) - // Multiply by the weight of matrix product - c00 = c00 * ALPHA; - c01 = c01 * ALPHA; - c02 = c02 * ALPHA; - c03 = c03 * ALPHA; - c10 = c10 * ALPHA; - c11 = c11 * ALPHA; - c12 = c12 * ALPHA; - c13 = c13 * ALPHA; - c20 = c20 * ALPHA; - c21 = c21 * ALPHA; - c22 = c22 * ALPHA; - c23 = c23 * ALPHA; - c30 = c30 * ALPHA; - c31 = c31 * ALPHA; - c32 = c32 * ALPHA; - c33 = c33 * ALPHA; -#endif // defined(ALPHA) - // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); -#if defined(ADD_VEC_C) - __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x); - float4 c0 = vload4(0, src2_addr); - - c00 += c0.s0; - c01 += c0.s1; - c02 += c0.s2; - c03 += c0.s3; - c10 += c0.s0; - c11 += c0.s1; - c12 += c0.s2; - c13 += c0.s3; - c20 += c0.s0; - c21 += c0.s1; - c22 += c0.s2; - c23 += c0.s3; - c30 += c0.s0; - c31 += c0.s1; - c32 += c0.s2; - c33 += c0.s3; -#endif /* defined(ADD_VEC_C) */ + uint4 zout = 0; #if defined(REINTERPRET_OUTPUT_AS_3D) // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension @@ -2918,8 +2888,8 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); + zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings zout *= (cross_plane_pad * dst_stride_y); @@ -2927,48 +2897,79 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store 4x4 block - vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0)); - vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1)); - vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2)); - vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3)); - #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; +#endif // defined(REINTERPRET_OUTPUT_AS_3D) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + SCALE_BLOCK(4, float, c, ALPHA); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) + REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0); + +#if defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)); + + LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, float, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + ADD_BLOCK_BROADCAST(4, c, bias0); + +#else // defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id( + 2) * src2_stride_z; + + LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(4, float, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias + ADD_BLOCK(4, c, bias); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + +#if defined(ACTIVATION_TYPE) + ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL); +#endif // defined(ACTIVATION_TYPE) // Store 4x4 block - vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y)); - vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y)); - vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y)); - vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y)); -#endif // defined(REINTERPRET_OUTPUT_AS_3D) + vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0)); + vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1)); + vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2)); + vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3)); } // Undefine local defines #undef COLS_MTX_B #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) -/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) - * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication - * - * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time. +/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) * * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA - * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2) - * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2) - * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) - * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2) + * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2) + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition + * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * - * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C - * * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) @@ -2981,10 +2982,12 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr - * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes) - * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix + * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr + * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) + * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) + * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) @@ -2993,17 +2996,21 @@ __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0) * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), -#if defined(ADD_VEC_C) - VECTOR_DECLARATION(src2), -#endif /* defined(ADD_VEC_C) */ +#if defined(BETA) + IMAGE_DECLARATION(src2), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z, +#if defined(BETA) + uint src2_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_OUTPUT_AS_3D) , @@ -3041,10 +3048,10 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), src_addr_b += offset_row_b; // Reset accumulators - half8 c00 = 0.0f; - half8 c10 = 0.0f; - half8 c20 = 0.0f; - half8 c30 = 0.0f; + half8 c0 = 0.0f; + half8 c1 = 0.0f; + half8 c2 = 0.0f; + half8 c3 = 0.0f; for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH) { @@ -3052,19 +3059,19 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), half4 a0 = vload4(0, src_addr_a); half8 b0 = vload8(0, src_addr_b); - c00 += (half8)a0.s0 * b0; - c10 += (half8)a0.s1 * b0; - c20 += (half8)a0.s2 * b0; - c30 += (half8)a0.s3 * b0; + c0 += (half8)a0.s0 * b0; + c1 += (half8)a0.s1 * b0; + c2 += (half8)a0.s2 * b0; + c3 += (half8)a0.s3 * b0; // Load values from matrix A (interleaved) and matrix B (transposed) a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT); b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH); - c00 += (half8)a0.s0 * b0; - c10 += (half8)a0.s1 * b0; - c20 += (half8)a0.s2 * b0; - c30 += (half8)a0.s3 * b0; + c0 += (half8)a0.s0 * b0; + c1 += (half8)a0.s1 * b0; + c2 += (half8)a0.s2 * b0; + c3 += (half8)a0.s3 * b0; } for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH) @@ -3073,40 +3080,20 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), half4 a0 = vload4(0, src_addr_a); half8 b0 = vload8(0, src_addr_b); - c00 += (half8)a0.s0 * b0; - c10 += (half8)a0.s1 * b0; - c20 += (half8)a0.s2 * b0; - c30 += (half8)a0.s3 * b0; + c0 += (half8)a0.s0 * b0; + c1 += (half8)a0.s1 * b0; + c2 += (half8)a0.s2 * b0; + c3 += (half8)a0.s3 * b0; } // Compute destination address Image dst = CONVERT_TO_IMAGE_STRUCT(dst); -#if defined(ALPHA) - // Multiply by the weight of matrix product - c00 = c00 * (half8)ALPHA; - c10 = c10 * (half8)ALPHA; - c20 = c20 * (half8)ALPHA; - c30 = c30 * (half8)ALPHA; -#endif // defined(ALPHA) - -#if defined(ADD_VEC_C) - // *INDENT-OFF* - // clang-format off - __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x); - half8 c0 = vload8(0, src2_addr); - // clang-format on - // *INDENT-ON* - - c00 += c0; - c10 += c0; - c20 += c0; - c30 += c0; -#endif /* defined(ADD_VEC_C) */ - // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); + uint4 zout = 0; + #if defined(REINTERPRET_OUTPUT_AS_3D) // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension // in order to take into account the presence of possible cross plane paddings @@ -3124,8 +3111,8 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); + zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings zout *= (cross_plane_pad * dst_stride_y); @@ -3133,44 +3120,76 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store 4x8 block - vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0)); - vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1)); - vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2)); - vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3)); - #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; +#endif // defined(REINTERPRET_OUTPUT_AS_3D) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + SCALE_BLOCK(4, half, c, ALPHA); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) + REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0); + +#if defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)); + + LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, half, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + ADD_BLOCK_BROADCAST(4, c, bias0); + +#else // defined(BROADCAST_BIAS) + + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id( + 2) * src2_stride_z; + + LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(4, half, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias + ADD_BLOCK(4, c, bias); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + +#if defined(ACTIVATION_TYPE) + ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL); +#endif // defined(ACTIVATION_TYPE) // Store 4x8 block - vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y)); - vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y)); - vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y)); - vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y)); -#endif // defined(REINTERPRET_OUTPUT_AS_3D) + vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0)); + vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1)); + vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2)); + vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3)); } -/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable. - * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication - * - * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time. +/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable. * * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA - * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2) - * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2) - * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) - * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2) + * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2) + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition + * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * - * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C - * * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) @@ -3183,10 +3202,12 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr - * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes) - * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix + * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr + * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) + * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) + * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) @@ -3195,17 +3216,21 @@ __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), -#if defined(ADD_VEC_C) - VECTOR_DECLARATION(src2), -#endif /* defined(ADD_VEC_C) */ +#if defined(BETA) + IMAGE_DECLARATION(src2), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z, +#if defined(BETA) + uint src2_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_OUTPUT_AS_3D) , @@ -3243,10 +3268,10 @@ __kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0), src_addr_b += offset_row_b; // Reset accumulators - float8 c00 = 0.0f; - float8 c10 = 0.0f; - float8 c20 = 0.0f; - float8 c30 = 0.0f; + float8 c0 = 0.0f; + float8 c1 = 0.0f; + float8 c2 = 0.0f; + float8 c3 = 0.0f; for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH) { @@ -3254,19 +3279,19 @@ __kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0), float4 a0 = convert_float4(vload4(0, src_addr_a)); float8 b0 = convert_float8(vload8(0, src_addr_b)); - c00 += (float8)a0.s0 * b0; - c10 += (float8)a0.s1 * b0; - c20 += (float8)a0.s2 * b0; - c30 += (float8)a0.s3 * b0; + c0 += (float8)a0.s0 * b0; + c1 += (float8)a0.s1 * b0; + c2 += (float8)a0.s2 * b0; + c3 += (float8)a0.s3 * b0; // Load values from matrix A (interleaved) and matrix B (transposed) a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT)); b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH)); - c00 += (float8)a0.s0 * b0; - c10 += (float8)a0.s1 * b0; - c20 += (float8)a0.s2 * b0; - c30 += (float8)a0.s3 * b0; + c0 += (float8)a0.s0 * b0; + c1 += (float8)a0.s1 * b0; + c2 += (float8)a0.s2 * b0; + c3 += (float8)a0.s3 * b0; } for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH) @@ -3275,40 +3300,20 @@ __kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0), float4 a0 = convert_float4(vload4(0, src_addr_a)); float8 b0 = convert_float8(vload8(0, src_addr_b)); - c00 += (float8)a0.s0 * b0; - c10 += (float8)a0.s1 * b0; - c20 += (float8)a0.s2 * b0; - c30 += (float8)a0.s3 * b0; + c0 += (float8)a0.s0 * b0; + c1 += (float8)a0.s1 * b0; + c2 += (float8)a0.s2 * b0; + c3 += (float8)a0.s3 * b0; } // Compute destination address Image dst = CONVERT_TO_IMAGE_STRUCT(dst); -#if defined(ALPHA) - // Multiply by the weight of matrix product - c00 = c00 * (float8)ALPHA; - c10 = c10 * (float8)ALPHA; - c20 = c20 * (float8)ALPHA; - c30 = c30 * (float8)ALPHA; -#endif // defined(ALPHA) - -#if defined(ADD_VEC_C) - // *INDENT-OFF* - // clang-format off - __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x); - float8 c0 = convert_float8(vload8(0, src2_addr)); - // clang-format on - // *INDENT-ON* - - c00 += c0; - c10 += c0; - c20 += c0; - c30 += c0; -#endif /* defined(ADD_VEC_C) */ - // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); + uint4 zout = 0; + #if defined(REINTERPRET_OUTPUT_AS_3D) // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension // in order to take into account the presence of possible cross plane paddings @@ -3326,8 +3331,8 @@ __kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0), // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); + zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings zout *= (cross_plane_pad * dst_stride_y); @@ -3335,44 +3340,86 @@ __kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store 4x8 block - vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0)); - vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1)); - vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2)); - vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3)); - #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; +#endif // defined(REINTERPRET_OUTPUT_AS_3D) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + SCALE_BLOCK(4, float, c, ALPHA); +#endif // defined(ALPHA) + +#if defined(BETA) + REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0); + +#if defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)); + + LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero); + + float8 bias_f0 = convert_float8(bias0); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, float, bias_f, BETA); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + ADD_BLOCK_BROADCAST(4, c, bias_f0); + +#else // defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id( + 2) * src2_stride_z; + + LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero); + + float8 bias_f0 = convert_float8(bias0); + float8 bias_f1 = convert_float8(bias1); + float8 bias_f2 = convert_float8(bias2); + float8 bias_f3 = convert_float8(bias3); + +#ifndef UNIT_BETA + SCALE_BLOCK(4, float, bias_f, BETA); +#endif // UNIT_BIAS + + // c = c + bias + ADD_BLOCK(4, c, bias_f); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + + half8 c_h0 = convert_half8(c0); + half8 c_h1 = convert_half8(c1); + half8 c_h2 = convert_half8(c2); + half8 c_h3 = convert_half8(c3); + +#if defined(ACTIVATION_TYPE) + ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c_h, A_VAL, B_VAL); +#endif // defined(ACTIVATION_TYPE) // Store 4x8 block - vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y)); - vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y)); - vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y)); - vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y)); -#endif // defined(REINTERPRET_OUTPUT_AS_3D) + vstore8(c_h0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0)); + vstore8(c_h1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1)); + vstore8(c_h2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2)); + vstore8(c_h3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3)); } -/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1) - * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication - * - * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time. +/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) * * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA - * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2) - * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2) - * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) - * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2) + * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2) + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) * - * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time: + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition + * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * - * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C - * * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) @@ -3385,26 +3432,34 @@ __kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0), * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr - * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes) - * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix + * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr + * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) + * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) + * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix + * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), -#if defined(ADD_VEC_C) - VECTOR_DECLARATION(src2), -#endif /* defined(ADD_VEC_C) */ +#if defined(BETA) + IMAGE_DECLARATION(src2), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z, +#if defined(BETA) + uint src2_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_OUTPUT_AS_3D) , @@ -3442,10 +3497,10 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) src_addr_b += offset_row_b; // Reset accumulators - half8 c00 = 0.0f; - half8 c10 = 0.0f; - half8 c20 = 0.0f; - half8 c30 = 0.0f; + half8 c0 = 0.0f; + half8 c1 = 0.0f; + half8 c2 = 0.0f; + half8 c3 = 0.0f; #define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH)) @@ -3460,20 +3515,20 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma((half8)a0.s0, b0, c00); - c10 = fma((half8)a0.s1, b0, c10); - c20 = fma((half8)a0.s2, b0, c20); - c30 = fma((half8)a0.s3, b0, c30); + c0 = fma((half8)a0.s0, b0, c0); + c1 = fma((half8)a0.s1, b0, c1); + c2 = fma((half8)a0.s2, b0, c2); + c3 = fma((half8)a0.s3, b0, c3); // Load values from matrix B (transposed) b0 = vload8(0, src_addr_b); src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma((half8)a0.s4, b0, c00); - c10 = fma((half8)a0.s5, b0, c10); - c20 = fma((half8)a0.s6, b0, c20); - c30 = fma((half8)a0.s7, b0, c30); + c0 = fma((half8)a0.s4, b0, c0); + c1 = fma((half8)a0.s5, b0, c1); + c2 = fma((half8)a0.s6, b0, c2); + c3 = fma((half8)a0.s7, b0, c3); // Load values from matrix A (interleaved) and matrix B (transposed) a0 = vload8(0, src_addr_a); @@ -3482,20 +3537,20 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma((half8)a0.s0, b0, c00); - c10 = fma((half8)a0.s1, b0, c10); - c20 = fma((half8)a0.s2, b0, c20); - c30 = fma((half8)a0.s3, b0, c30); + c0 = fma((half8)a0.s0, b0, c0); + c1 = fma((half8)a0.s1, b0, c1); + c2 = fma((half8)a0.s2, b0, c2); + c3 = fma((half8)a0.s3, b0, c3); // Load values from matrix B (transposed) b0 = vload8(0, src_addr_b); src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma((half8)a0.s4, b0, c00); - c10 = fma((half8)a0.s5, b0, c10); - c20 = fma((half8)a0.s6, b0, c20); - c30 = fma((half8)a0.s7, b0, c30); + c0 = fma((half8)a0.s4, b0, c0); + c1 = fma((half8)a0.s5, b0, c1); + c2 = fma((half8)a0.s6, b0, c2); + c3 = fma((half8)a0.s7, b0, c3); #else // MULT_INTERLEAVE4X4_HEIGHT == 1 // Load values from matrix A (interleaved) and matrix B (transposed) half4 a0 = vload4(0, src_addr_a); @@ -3504,10 +3559,10 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma((half8)a0.s0, b0, c00); - c10 = fma((half8)a0.s1, b0, c10); - c20 = fma((half8)a0.s2, b0, c20); - c30 = fma((half8)a0.s3, b0, c30); + c0 = fma((half8)a0.s0, b0, c0); + c1 = fma((half8)a0.s1, b0, c1); + c2 = fma((half8)a0.s2, b0, c2); + c3 = fma((half8)a0.s3, b0, c3); // Load values from matrix A (interleaved) and matrix B (transposed) a0 = vload4(0, src_addr_a); @@ -3516,10 +3571,10 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma((half8)a0.s0, b0, c00); - c10 = fma((half8)a0.s1, b0, c10); - c20 = fma((half8)a0.s2, b0, c20); - c30 = fma((half8)a0.s3, b0, c30); + c0 = fma((half8)a0.s0, b0, c0); + c1 = fma((half8)a0.s1, b0, c1); + c2 = fma((half8)a0.s2, b0, c2); + c3 = fma((half8)a0.s3, b0, c3); // Load values from matrix A (interleaved) and matrix B (transposed) a0 = vload4(0, src_addr_a); @@ -3528,10 +3583,10 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma((half8)a0.s0, b0, c00); - c10 = fma((half8)a0.s1, b0, c10); - c20 = fma((half8)a0.s2, b0, c20); - c30 = fma((half8)a0.s3, b0, c30); + c0 = fma((half8)a0.s0, b0, c0); + c1 = fma((half8)a0.s1, b0, c1); + c2 = fma((half8)a0.s2, b0, c2); + c3 = fma((half8)a0.s3, b0, c3); // Load values from matrix A (interleaved) and matrix B (transposed) a0 = vload4(0, src_addr_a); @@ -3540,10 +3595,10 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma((half8)a0.s0, b0, c00); - c10 = fma((half8)a0.s1, b0, c10); - c20 = fma((half8)a0.s2, b0, c20); - c30 = fma((half8)a0.s3, b0, c30); + c0 = fma((half8)a0.s0, b0, c0); + c1 = fma((half8)a0.s1, b0, c1); + c2 = fma((half8)a0.s2, b0, c2); + c3 = fma((half8)a0.s3, b0, c3); #endif // MULT_INTERLEAVE4X4_HEIGHT == 1 } @@ -3556,40 +3611,20 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT; src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH; - c00 = fma((half8)a0.s0, b0, c00); - c10 = fma((half8)a0.s1, b0, c10); - c20 = fma((half8)a0.s2, b0, c20); - c30 = fma((half8)a0.s3, b0, c30); + c0 = fma((half8)a0.s0, b0, c0); + c1 = fma((half8)a0.s1, b0, c1); + c2 = fma((half8)a0.s2, b0, c2); + c3 = fma((half8)a0.s3, b0, c3); } // Compute destination address Image dst = CONVERT_TO_IMAGE_STRUCT(dst); -#if defined(ALPHA) - // Multiply by the weight of matrix product - c00 = c00 * (half8)ALPHA; - c10 = c10 * (half8)ALPHA; - c20 = c20 * (half8)ALPHA; - c30 = c30 * (half8)ALPHA; -#endif // defined(ALPHA) - -#if defined(ADD_VEC_C) - // *INDENT-OFF* - // clang-format off - __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x); - half8 c0 = vload8(0, src2_addr); - // clang-format on - // *INDENT-ON* - - c00 += c0; - c10 += c0; - c20 += c0; - c30 += c0; -#endif /* defined(ADD_VEC_C) */ - // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); + uint4 zout = 0; + #if defined(REINTERPRET_OUTPUT_AS_3D) // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension // in order to take into account the presence of possible cross plane paddings @@ -3607,8 +3642,8 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); + zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings zout *= (cross_plane_pad * dst_stride_y); @@ -3616,23 +3651,57 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store 4x8 block - vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0)); - vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1)); - vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2)); - vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3)); - #else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; +#endif // defined(REINTERPRET_OUTPUT_AS_3D) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + SCALE_BLOCK(4, half, c, ALPHA); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) + REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0); + +#if defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)); + + LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, half, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + ADD_BLOCK_BROADCAST(4, c, bias0); + +#else // defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id( + 2) * src2_stride_z; + + LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(4, half, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias + ADD_BLOCK(4, c, bias); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + +#if defined(ACTIVATION_TYPE) + ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL); +#endif // defined(ACTIVATION_TYPE) // Store 4x8 block - vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y)); - vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y)); - vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y)); - vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y)); -#endif // defined(REINTERPRET_OUTPUT_AS_3D) + vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0)); + vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1)); + vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2)); + vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3)); } // Undefine local defines @@ -3646,16 +3715,16 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) #if defined(DATA_TYPE) #define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X) /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped. - * - * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time. * * @note This OpenCL kernel works with floating point data types (F16/F32) * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float) * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA - * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) - * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) * + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D @@ -3663,8 +3732,6 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * - * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C - * * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) @@ -3677,10 +3744,12 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr - * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes) - * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix + * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr + * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) + * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) + * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) @@ -3689,18 +3758,22 @@ __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0) * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), -#if defined(ADD_VEC_C) - VECTOR_DECLARATION(src2), -#endif /* defined(ADD_VEC_C) */ +#if defined(BETA) + IMAGE_DECLARATION(src2), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z, +#if defined(BETA) + uint src2_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_INPUT_AS_3D) , @@ -3865,49 +3938,18 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 } + int z = get_global_id(2); + // Compute destination address Image dst = CONVERT_TO_IMAGE_STRUCT(dst); // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); - // Multiply by the weight of matrix-matrix product and store the result -#if defined(ALPHA) - acc0 = acc0 * (VECTOR_TYPE)ALPHA; -#endif // defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA) - acc1 = acc1 * (VECTOR_TYPE)ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA) - acc2 = acc2 * (VECTOR_TYPE)ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA) - acc3 = acc3 * (VECTOR_TYPE)ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA) - -#if defined(ADD_VEC_C) - // *INDENT-OFF* - // clang-format off - __global DATA_TYPE *src2_addr = (__global DATA_TYPE *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x); - VECTOR_TYPE c0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src2_addr); - // clang-format on - // *INDENT-ON* - - acc0 += c0; -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc1 += c0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc2 += c0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc3 += c0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif /* defined(ADD_VEC_C) */ - - int z = get_global_id(2); + uint4 zout = 0; #if defined(REINTERPRET_OUTPUT_AS_3D) + // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension // in order to take into account the presence of possible cross plane paddings // @@ -3924,8 +3966,8 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); + zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings zout *= (dst_cross_plane_pad * dst_stride_y); @@ -3933,44 +3975,69 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store output block - STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s); -#else // defined(REINTERPRET_OUTPUT_AS_3D) +#else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; +#endif // defined(REINTERPRET_OUTPUT_AS_3D) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, acc, ALPHA); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) + REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0); + +#if defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE)); + + LOAD_BLOCK(1, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, DATA_TYPE, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0); + +#else // defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE)) + (get_global_id(1) * + (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z; + + LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, bias, BETA); +#endif // UNIT_BIAS + + // c = c + bias + ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + +#if defined(ACTIVATION_TYPE) + ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, DATA_TYPE, acc, A_VAL, B_VAL); +#endif // defined(ACTIVATION_TYPE) // Store output block - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif // defined(REINTERPRET_OUTPUT_AS_3D) + STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s); } #endif // defined(DATA_TYPE) /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped - * - * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time. * * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units. * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y. * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4. * @note The number of matrix A columns must be passed at compile time using -DCOLS_A. * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha - * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) - * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) * + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D @@ -3978,9 +4045,7 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * - * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C - * - * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32 + * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) @@ -3992,10 +4057,12 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr - * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes) - * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix + * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr + * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) + * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) + * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) @@ -4004,18 +4071,22 @@ __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), -#if defined(ADD_VEC_C) - VECTOR_DECLARATION(src2), -#endif /* defined(ADD_VEC_C) */ +#if defined(BETA) + IMAGE_DECLARATION(src2), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z, +#if defined(BETA) + uint src2_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_INPUT_AS_3D) , @@ -4080,30 +4151,18 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), #endif // defined(MATRIX_B_DEPTH) // Initialize accumulators - float acc00 = 0.0f; - float acc01 = 0.0f; - float acc02 = 0.0f; - float acc03 = 0.0f; + float4 acc0 = 0.0f; #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - float acc10 = 0.0f; - float acc11 = 0.0f; - float acc12 = 0.0f; - float acc13 = 0.0f; + float4 acc1 = 0.0f; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - float acc20 = 0.0f; - float acc21 = 0.0f; - float acc22 = 0.0f; - float acc23 = 0.0f; + float4 acc2 = 0.0f; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - float acc30 = 0.0f; - float acc31 = 0.0f; - float acc32 = 0.0f; - float acc33 = 0.0f; + float4 acc3 = 0.0f; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // A and B src indices get incremented at the same time. @@ -4131,33 +4190,33 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), src_addr.s1 += src1_stride_y; // Multiply and accumulate - acc00 = fma(a0.s0, b0.s0, acc00); - acc01 = fma(a0.s0, b0.s1, acc01); - acc02 = fma(a0.s0, b0.s2, acc02); - acc03 = fma(a0.s0, b0.s3, acc03); + acc0.s0 = fma(a0.s0, b0.s0, acc0.s0); + acc0.s1 = fma(a0.s0, b0.s1, acc0.s1); + acc0.s2 = fma(a0.s0, b0.s2, acc0.s2); + acc0.s3 = fma(a0.s0, b0.s3, acc0.s3); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc10 = fma(a1.s0, b0.s0, acc10); - acc11 = fma(a1.s0, b0.s1, acc11); - acc12 = fma(a1.s0, b0.s2, acc12); - acc13 = fma(a1.s0, b0.s3, acc13); + acc1.s0 = fma(a1.s0, b0.s0, acc1.s0); + acc1.s1 = fma(a1.s0, b0.s1, acc1.s1); + acc1.s2 = fma(a1.s0, b0.s2, acc1.s2); + acc1.s3 = fma(a1.s0, b0.s3, acc1.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc20 = fma(a2.s0, b0.s0, acc20); - acc21 = fma(a2.s0, b0.s1, acc21); - acc22 = fma(a2.s0, b0.s2, acc22); - acc23 = fma(a2.s0, b0.s3, acc23); + acc2.s0 = fma(a2.s0, b0.s0, acc2.s0); + acc2.s1 = fma(a2.s0, b0.s1, acc2.s1); + acc2.s2 = fma(a2.s0, b0.s2, acc2.s2); + acc2.s3 = fma(a2.s0, b0.s3, acc2.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 = fma(a3.s0, b0.s0, acc30); - acc31 = fma(a3.s0, b0.s1, acc31); - acc32 = fma(a3.s0, b0.s2, acc32); - acc33 = fma(a3.s0, b0.s3, acc33); + acc3.s0 = fma(a3.s0, b0.s0, acc3.s0); + acc3.s1 = fma(a3.s0, b0.s1, acc3.s1); + acc3.s2 = fma(a3.s0, b0.s2, acc3.s2); + acc3.s3 = fma(a3.s0, b0.s3, acc3.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // Load values from matrix A and matrix B @@ -4165,33 +4224,33 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), src_addr.s1 += src1_stride_y; // Multiply and accumulate - acc00 = fma(a0.s1, b0.s0, acc00); - acc01 = fma(a0.s1, b0.s1, acc01); - acc02 = fma(a0.s1, b0.s2, acc02); - acc03 = fma(a0.s1, b0.s3, acc03); + acc0.s0 = fma(a0.s1, b0.s0, acc0.s0); + acc0.s1 = fma(a0.s1, b0.s1, acc0.s1); + acc0.s2 = fma(a0.s1, b0.s2, acc0.s2); + acc0.s3 = fma(a0.s1, b0.s3, acc0.s3); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc10 = fma(a1.s1, b0.s0, acc10); - acc11 = fma(a1.s1, b0.s1, acc11); - acc12 = fma(a1.s1, b0.s2, acc12); - acc13 = fma(a1.s1, b0.s3, acc13); + acc1.s0 = fma(a1.s1, b0.s0, acc1.s0); + acc1.s1 = fma(a1.s1, b0.s1, acc1.s1); + acc1.s2 = fma(a1.s1, b0.s2, acc1.s2); + acc1.s3 = fma(a1.s1, b0.s3, acc1.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc20 = fma(a2.s1, b0.s0, acc20); - acc21 = fma(a2.s1, b0.s1, acc21); - acc22 = fma(a2.s1, b0.s2, acc22); - acc23 = fma(a2.s1, b0.s3, acc23); + acc2.s0 = fma(a2.s1, b0.s0, acc2.s0); + acc2.s1 = fma(a2.s1, b0.s1, acc2.s1); + acc2.s2 = fma(a2.s1, b0.s2, acc2.s2); + acc2.s3 = fma(a2.s1, b0.s3, acc2.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 = fma(a3.s1, b0.s0, acc30); - acc31 = fma(a3.s1, b0.s1, acc31); - acc32 = fma(a3.s1, b0.s2, acc32); - acc33 = fma(a3.s1, b0.s3, acc33); + acc3.s0 = fma(a3.s1, b0.s0, acc3.s0); + acc3.s1 = fma(a3.s1, b0.s1, acc3.s1); + acc3.s2 = fma(a3.s1, b0.s2, acc3.s2); + acc3.s3 = fma(a3.s1, b0.s3, acc3.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // Load values from matrix A and matrix B @@ -4199,33 +4258,33 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), src_addr.s1 += src1_stride_y; // Multiply and accumulate - acc00 = fma(a0.s2, b0.s0, acc00); - acc01 = fma(a0.s2, b0.s1, acc01); - acc02 = fma(a0.s2, b0.s2, acc02); - acc03 = fma(a0.s2, b0.s3, acc03); + acc0.s0 = fma(a0.s2, b0.s0, acc0.s0); + acc0.s1 = fma(a0.s2, b0.s1, acc0.s1); + acc0.s2 = fma(a0.s2, b0.s2, acc0.s2); + acc0.s3 = fma(a0.s2, b0.s3, acc0.s3); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc10 = fma(a1.s2, b0.s0, acc10); - acc11 = fma(a1.s2, b0.s1, acc11); - acc12 = fma(a1.s2, b0.s2, acc12); - acc13 = fma(a1.s2, b0.s3, acc13); + acc1.s0 = fma(a1.s2, b0.s0, acc1.s0); + acc1.s1 = fma(a1.s2, b0.s1, acc1.s1); + acc1.s2 = fma(a1.s2, b0.s2, acc1.s2); + acc1.s3 = fma(a1.s2, b0.s3, acc1.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc20 = fma(a2.s2, b0.s0, acc20); - acc21 = fma(a2.s2, b0.s1, acc21); - acc22 = fma(a2.s2, b0.s2, acc22); - acc23 = fma(a2.s2, b0.s3, acc23); + acc2.s0 = fma(a2.s2, b0.s0, acc2.s0); + acc2.s1 = fma(a2.s2, b0.s1, acc2.s1); + acc2.s2 = fma(a2.s2, b0.s2, acc2.s2); + acc2.s3 = fma(a2.s2, b0.s3, acc2.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 = fma(a3.s2, b0.s0, acc30); - acc31 = fma(a3.s2, b0.s1, acc31); - acc32 = fma(a3.s2, b0.s2, acc32); - acc33 = fma(a3.s2, b0.s3, acc33); + acc3.s0 = fma(a3.s2, b0.s0, acc3.s0); + acc3.s1 = fma(a3.s2, b0.s1, acc3.s1); + acc3.s2 = fma(a3.s2, b0.s2, acc3.s2); + acc3.s3 = fma(a3.s2, b0.s3, acc3.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // Load values from matrix A and matrix B @@ -4233,33 +4292,33 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), src_addr.s1 += src1_stride_y; // Multiply and accumulate - acc00 = fma(a0.s3, b0.s0, acc00); - acc01 = fma(a0.s3, b0.s1, acc01); - acc02 = fma(a0.s3, b0.s2, acc02); - acc03 = fma(a0.s3, b0.s3, acc03); + acc0.s0 = fma(a0.s3, b0.s0, acc0.s0); + acc0.s1 = fma(a0.s3, b0.s1, acc0.s1); + acc0.s2 = fma(a0.s3, b0.s2, acc0.s2); + acc0.s3 = fma(a0.s3, b0.s3, acc0.s3); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc10 = fma(a1.s3, b0.s0, acc10); - acc11 = fma(a1.s3, b0.s1, acc11); - acc12 = fma(a1.s3, b0.s2, acc12); - acc13 = fma(a1.s3, b0.s3, acc13); + acc1.s0 = fma(a1.s3, b0.s0, acc1.s0); + acc1.s1 = fma(a1.s3, b0.s1, acc1.s1); + acc1.s2 = fma(a1.s3, b0.s2, acc1.s2); + acc1.s3 = fma(a1.s3, b0.s3, acc1.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc20 = fma(a2.s3, b0.s0, acc20); - acc21 = fma(a2.s3, b0.s1, acc21); - acc22 = fma(a2.s3, b0.s2, acc22); - acc23 = fma(a2.s3, b0.s3, acc23); + acc2.s0 = fma(a2.s3, b0.s0, acc2.s0); + acc2.s1 = fma(a2.s3, b0.s1, acc2.s1); + acc2.s2 = fma(a2.s3, b0.s2, acc2.s2); + acc2.s3 = fma(a2.s3, b0.s3, acc2.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 = fma(a3.s3, b0.s0, acc30); - acc31 = fma(a3.s3, b0.s1, acc31); - acc32 = fma(a3.s3, b0.s2, acc32); - acc33 = fma(a3.s3, b0.s3, acc33); + acc3.s0 = fma(a3.s3, b0.s0, acc3.s0); + acc3.s1 = fma(a3.s3, b0.s1, acc3.s1); + acc3.s2 = fma(a3.s3, b0.s2, acc3.s2); + acc3.s3 = fma(a3.s3, b0.s3, acc3.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 src_addr.s0 += 4 * sizeof(float); @@ -4298,27 +4357,27 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), src_addr.s1 += src1_stride_y; // Multiply and accumulate - acc00 = fma(a0, b0.s0, acc00); - acc01 = fma(a0, b0.s1, acc01); - acc02 = fma(a0, b0.s2, acc02); - acc03 = fma(a0, b0.s3, acc03); + acc0.s0 = fma(a0, b0.s0, acc0.s0); + acc0.s1 = fma(a0, b0.s1, acc0.s1); + acc0.s2 = fma(a0, b0.s2, acc0.s2); + acc0.s3 = fma(a0, b0.s3, acc0.s3); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc10 = fma(a1, b0.s0, acc10); - acc11 = fma(a1, b0.s1, acc11); - acc12 = fma(a1, b0.s2, acc12); - acc13 = fma(a1, b0.s3, acc13); + acc1.s0 = fma(a1, b0.s0, acc1.s0); + acc1.s1 = fma(a1, b0.s1, acc1.s1); + acc1.s2 = fma(a1, b0.s2, acc1.s2); + acc1.s3 = fma(a1, b0.s3, acc1.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc20 = fma(a2, b0.s0, acc20); - acc21 = fma(a2, b0.s1, acc21); - acc22 = fma(a2, b0.s2, acc22); - acc23 = fma(a2, b0.s3, acc23); + acc2.s0 = fma(a2, b0.s0, acc2.s0); + acc2.s1 = fma(a2, b0.s1, acc2.s1); + acc2.s2 = fma(a2, b0.s2, acc2.s2); + acc2.s3 = fma(a2, b0.s3, acc2.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 = fma(a3, b0.s0, acc30); - acc31 = fma(a3, b0.s1, acc31); - acc32 = fma(a3, b0.s2, acc32); - acc33 = fma(a3, b0.s3, acc33); + acc3.s0 = fma(a3, b0.s0, acc3.s0); + acc3.s1 = fma(a3, b0.s1, acc3.s1); + acc3.s2 = fma(a3, b0.s2, acc3.s2); + acc3.s3 = fma(a3, b0.s3, acc3.s3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 src_addr.s0 += sizeof(float); @@ -4329,62 +4388,10 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), // Compute destination address Image dst = CONVERT_TO_IMAGE_STRUCT(dst); - // Multiply by the weight of matrix-matrix product and store the result -#if defined(ALPHA) - acc00 = acc00 * ALPHA; - acc01 = acc01 * ALPHA; - acc02 = acc02 * ALPHA; - acc03 = acc03 * ALPHA; -#endif // defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA) - acc10 = acc10 * ALPHA; - acc11 = acc11 * ALPHA; - acc12 = acc12 * ALPHA; - acc13 = acc13 * ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA) - acc20 = acc20 * ALPHA; - acc21 = acc21 * ALPHA; - acc22 = acc22 * ALPHA; - acc23 = acc23 * ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA) - acc30 = acc30 * ALPHA; - acc31 = acc31 * ALPHA; - acc32 = acc32 * ALPHA; - acc33 = acc33 * ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA) - // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); -#if defined(ADD_VEC_C) - __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x); - float4 c0 = vload4(0, src2_addr); - - acc00 += c0.s0; - acc01 += c0.s1; - acc02 += c0.s2; - acc03 += c0.s3; -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc10 += c0.s0; - acc11 += c0.s1; - acc12 += c0.s2; - acc13 += c0.s3; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc20 += c0.s0; - acc21 += c0.s1; - acc22 += c0.s2; - acc23 += c0.s3; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 += c0.s0; - acc31 += c0.s1; - acc32 += c0.s2; - acc33 += c0.s3; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif /* defined(ADD_VEC_C) */ + uint4 zout = 0; #if defined(REINTERPRET_OUTPUT_AS_3D) // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension @@ -4403,8 +4410,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); + zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings zout *= (dst_cross_plane_pad * dst_stride_y); @@ -4412,40 +4419,66 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store the output block - vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - -#else // defined(REINTERPRET_OUTPUT_AS_3D) +#else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; +#endif // defined(REINTERPRET_OUTPUT_AS_3D) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) + REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0); + +#if defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)); + + LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, float, bias, BETA); +#endif // UNIT_BIAS + + // acc = acc + bias[broadcasted] + ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0); + +#else // defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * + (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z; + + LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA); +#endif // UNIT_BIAS + + // acc = acc + bias + ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + +#if defined(ACTIVATION_TYPE) + ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL); +#endif // defined(ACTIVATION_TYPE) // Store the output block - vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y)); + vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y)); + vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y)); + vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y)); + vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif // defined(REINTERPRET_OUTPUT_AS_3D) } /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped - * - * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time. * * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units. * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000. @@ -4453,9 +4486,11 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2. * @note The number of matrix A columns must be passed at compile time using -DCOLS_A. * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f. - * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) - * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) * + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D @@ -4463,9 +4498,7 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * - * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C - * - * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32 + * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) @@ -4477,10 +4510,12 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr - * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes) - * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix + * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr + * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) + * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) + * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) @@ -4489,18 +4524,22 @@ __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), -#if defined(ADD_VEC_C) - VECTOR_DECLARATION(src2), -#endif /* defined(ADD_VEC_C) */ +#if defined(BETA) + IMAGE_DECLARATION(src2), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z, +#if defined(BETA) + uint src2_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_INPUT_AS_3D) , @@ -4566,20 +4605,15 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), #endif // defined(MATRIX_B_DEPTH) // Initialize accumulators - float acc00 = 0.0f; - float acc01 = 0.0f; - + float2 acc0 = 0.0f; #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - float acc10 = 0.0f; - float acc11 = 0.0f; + float2 acc1 = 0.0f; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - float acc20 = 0.0f; - float acc21 = 0.0f; + float2 acc2 = 0.0f; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - float acc30 = 0.0f; - float acc31 = 0.0f; + float2 acc3 = 0.0f; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 // A and B src indices get incremented at the same time. @@ -4613,95 +4647,95 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), src_addr.s1 += src1_stride_y; // Multiply and accumulate - acc00 = fma(a0.s0, b0.s0, acc00); - acc00 = fma(a0.s1, b1.s0, acc00); - acc00 = fma(a0.s2, b2.s0, acc00); - acc00 = fma(a0.s3, b3.s0, acc00); - acc00 = fma(a0.s4, b4.s0, acc00); - acc00 = fma(a0.s5, b5.s0, acc00); - acc00 = fma(a0.s6, b6.s0, acc00); - acc00 = fma(a0.s7, b7.s0, acc00); - - acc01 = fma(a0.s0, b0.s1, acc01); - acc01 = fma(a0.s1, b1.s1, acc01); - acc01 = fma(a0.s2, b2.s1, acc01); - acc01 = fma(a0.s3, b3.s1, acc01); - acc01 = fma(a0.s4, b4.s1, acc01); - acc01 = fma(a0.s5, b5.s1, acc01); - acc01 = fma(a0.s6, b6.s1, acc01); - acc01 = fma(a0.s7, b7.s1, acc01); + acc0.s0 = fma(a0.s0, b0.s0, acc0.s0); + acc0.s0 = fma(a0.s1, b1.s0, acc0.s0); + acc0.s0 = fma(a0.s2, b2.s0, acc0.s0); + acc0.s0 = fma(a0.s3, b3.s0, acc0.s0); + acc0.s0 = fma(a0.s4, b4.s0, acc0.s0); + acc0.s0 = fma(a0.s5, b5.s0, acc0.s0); + acc0.s0 = fma(a0.s6, b6.s0, acc0.s0); + acc0.s0 = fma(a0.s7, b7.s0, acc0.s0); + + acc0.s1 = fma(a0.s0, b0.s1, acc0.s1); + acc0.s1 = fma(a0.s1, b1.s1, acc0.s1); + acc0.s1 = fma(a0.s2, b2.s1, acc0.s1); + acc0.s1 = fma(a0.s3, b3.s1, acc0.s1); + acc0.s1 = fma(a0.s4, b4.s1, acc0.s1); + acc0.s1 = fma(a0.s5, b5.s1, acc0.s1); + acc0.s1 = fma(a0.s6, b6.s1, acc0.s1); + acc0.s1 = fma(a0.s7, b7.s1, acc0.s1); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if defined(REINTERPRET_INPUT_AS_3D) a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); #else // defined(REINTERPRET_INPUT_AS_3D) - a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); #endif // defined(REINTERPRET_INPUT_AS_3D) - acc10 = fma(a0.s0, b0.s0, acc10); - acc10 = fma(a0.s1, b1.s0, acc10); - acc10 = fma(a0.s2, b2.s0, acc10); - acc10 = fma(a0.s3, b3.s0, acc10); - acc10 = fma(a0.s4, b4.s0, acc10); - acc10 = fma(a0.s5, b5.s0, acc10); - acc10 = fma(a0.s6, b6.s0, acc10); - acc10 = fma(a0.s7, b7.s0, acc10); - - acc11 = fma(a0.s0, b0.s1, acc11); - acc11 = fma(a0.s1, b1.s1, acc11); - acc11 = fma(a0.s2, b2.s1, acc11); - acc11 = fma(a0.s3, b3.s1, acc11); - acc11 = fma(a0.s4, b4.s1, acc11); - acc11 = fma(a0.s5, b5.s1, acc11); - acc11 = fma(a0.s6, b6.s1, acc11); - acc11 = fma(a0.s7, b7.s1, acc11); + acc1.s0 = fma(a0.s0, b0.s0, acc1.s0); + acc1.s0 = fma(a0.s1, b1.s0, acc1.s0); + acc1.s0 = fma(a0.s2, b2.s0, acc1.s0); + acc1.s0 = fma(a0.s3, b3.s0, acc1.s0); + acc1.s0 = fma(a0.s4, b4.s0, acc1.s0); + acc1.s0 = fma(a0.s5, b5.s0, acc1.s0); + acc1.s0 = fma(a0.s6, b6.s0, acc1.s0); + acc1.s0 = fma(a0.s7, b7.s0, acc1.s0); + + acc1.s1 = fma(a0.s0, b0.s1, acc1.s1); + acc1.s1 = fma(a0.s1, b1.s1, acc1.s1); + acc1.s1 = fma(a0.s2, b2.s1, acc1.s1); + acc1.s1 = fma(a0.s3, b3.s1, acc1.s1); + acc1.s1 = fma(a0.s4, b4.s1, acc1.s1); + acc1.s1 = fma(a0.s5, b5.s1, acc1.s1); + acc1.s1 = fma(a0.s6, b6.s1, acc1.s1); + acc1.s1 = fma(a0.s7, b7.s1, acc1.s1); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if defined(REINTERPRET_INPUT_AS_3D) a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); #else // defined(REINTERPRET_INPUT_AS_3D) - a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); #endif // defined(REINTERPRET_INPUT_AS_3D) - acc20 = fma(a0.s0, b0.s0, acc20); - acc20 = fma(a0.s1, b1.s0, acc20); - acc20 = fma(a0.s2, b2.s0, acc20); - acc20 = fma(a0.s3, b3.s0, acc20); - acc20 = fma(a0.s4, b4.s0, acc20); - acc20 = fma(a0.s5, b5.s0, acc20); - acc20 = fma(a0.s6, b6.s0, acc20); - acc20 = fma(a0.s7, b7.s0, acc20); - - acc21 = fma(a0.s0, b0.s1, acc21); - acc21 = fma(a0.s1, b1.s1, acc21); - acc21 = fma(a0.s2, b2.s1, acc21); - acc21 = fma(a0.s3, b3.s1, acc21); - acc21 = fma(a0.s4, b4.s1, acc21); - acc21 = fma(a0.s5, b5.s1, acc21); - acc21 = fma(a0.s6, b6.s1, acc21); - acc21 = fma(a0.s7, b7.s1, acc21); + acc2.s0 = fma(a0.s0, b0.s0, acc2.s0); + acc2.s0 = fma(a0.s1, b1.s0, acc2.s0); + acc2.s0 = fma(a0.s2, b2.s0, acc2.s0); + acc2.s0 = fma(a0.s3, b3.s0, acc2.s0); + acc2.s0 = fma(a0.s4, b4.s0, acc2.s0); + acc2.s0 = fma(a0.s5, b5.s0, acc2.s0); + acc2.s0 = fma(a0.s6, b6.s0, acc2.s0); + acc2.s0 = fma(a0.s7, b7.s0, acc2.s0); + + acc2.s1 = fma(a0.s0, b0.s1, acc2.s1); + acc2.s1 = fma(a0.s1, b1.s1, acc2.s1); + acc2.s1 = fma(a0.s2, b2.s1, acc2.s1); + acc2.s1 = fma(a0.s3, b3.s1, acc2.s1); + acc2.s1 = fma(a0.s4, b4.s1, acc2.s1); + acc2.s1 = fma(a0.s5, b5.s1, acc2.s1); + acc2.s1 = fma(a0.s6, b6.s1, acc2.s1); + acc2.s1 = fma(a0.s7, b7.s1, acc2.s1); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 #if defined(REINTERPRET_INPUT_AS_3D) a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); #else // defined(REINTERPRET_INPUT_AS_3D) - a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); + a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // defined(REINTERPRET_INPUT_AS_3D) - acc30 = fma(a0.s0, b0.s0, acc30); - acc30 = fma(a0.s1, b1.s0, acc30); - acc30 = fma(a0.s2, b2.s0, acc30); - acc30 = fma(a0.s3, b3.s0, acc30); - acc30 = fma(a0.s4, b4.s0, acc30); - acc30 = fma(a0.s5, b5.s0, acc30); - acc30 = fma(a0.s6, b6.s0, acc30); - acc30 = fma(a0.s7, b7.s0, acc30); - - acc31 = fma(a0.s0, b0.s1, acc31); - acc31 = fma(a0.s1, b1.s1, acc31); - acc31 = fma(a0.s2, b2.s1, acc31); - acc31 = fma(a0.s3, b3.s1, acc31); - acc31 = fma(a0.s4, b4.s1, acc31); - acc31 = fma(a0.s5, b5.s1, acc31); - acc31 = fma(a0.s6, b6.s1, acc31); - acc31 = fma(a0.s7, b7.s1, acc31); + acc3.s0 = fma(a0.s0, b0.s0, acc3.s0); + acc3.s0 = fma(a0.s1, b1.s0, acc3.s0); + acc3.s0 = fma(a0.s2, b2.s0, acc3.s0); + acc3.s0 = fma(a0.s3, b3.s0, acc3.s0); + acc3.s0 = fma(a0.s4, b4.s0, acc3.s0); + acc3.s0 = fma(a0.s5, b5.s0, acc3.s0); + acc3.s0 = fma(a0.s6, b6.s0, acc3.s0); + acc3.s0 = fma(a0.s7, b7.s0, acc3.s0); + + acc3.s1 = fma(a0.s0, b0.s1, acc3.s1); + acc3.s1 = fma(a0.s1, b1.s1, acc3.s1); + acc3.s1 = fma(a0.s2, b2.s1, acc3.s1); + acc3.s1 = fma(a0.s3, b3.s1, acc3.s1); + acc3.s1 = fma(a0.s4, b4.s1, acc3.s1); + acc3.s1 = fma(a0.s5, b5.s1, acc3.s1); + acc3.s1 = fma(a0.s6, b6.s1, acc3.s1); + acc3.s1 = fma(a0.s7, b7.s1, acc3.s1); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 src_addr.s0 += sizeof(float) * 8; @@ -4740,42 +4774,24 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), src_addr.s1 += src1_stride_y; // Multiply and accumulate - acc00 = fma(a0, b0.s0, acc00); - acc01 = fma(a0, b0.s1, acc01); + acc0.s0 = fma(a0, b0.s0, acc0.s0); + acc0.s1 = fma(a0, b0.s1, acc0.s1); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc10 = fma(a1, b0.s0, acc10); - acc11 = fma(a1, b0.s1, acc11); + acc1.s0 = fma(a1, b0.s0, acc1.s0); + acc1.s1 = fma(a1, b0.s1, acc1.s1); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc20 = fma(a2, b0.s0, acc20); - acc21 = fma(a2, b0.s1, acc21); + acc2.s0 = fma(a2, b0.s0, acc2.s0); + acc2.s1 = fma(a2, b0.s1, acc2.s1); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 = fma(a3, b0.s0, acc30); - acc31 = fma(a3, b0.s1, acc31); + acc3.s0 = fma(a3, b0.s0, acc3.s0); + acc3.s1 = fma(a3, b0.s1, acc3.s1); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 src_addr.s0 += sizeof(float); } - // Multiply by the weight of matrix-matrix product and store the result -#if defined(ALPHA) - acc00 = acc00 * ALPHA; - acc01 = acc01 * ALPHA; -#endif // defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA) - acc10 = acc10 * ALPHA; - acc11 = acc11 * ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA) - acc20 = acc20 * ALPHA; - acc21 = acc21 * ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA) - acc30 = acc30 * ALPHA; - acc31 = acc31 * ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA) - int z = get_global_id(2); // Compute destination address @@ -4784,27 +4800,10 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); -#if defined(ADD_VEC_C) - __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x); - float2 c0 = vload2(0, src2_addr); - - acc00 += c0.s0; - acc01 += c0.s1; -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc10 += c0.s0; - acc11 += c0.s1; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc20 += c0.s0; - acc21 += c0.s1; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc30 += c0.s0; - acc31 += c0.s1; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif /* defined(ADD_VEC_C) */ + uint4 zout = 0; #if defined(REINTERPRET_OUTPUT_AS_3D) + // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension // in order to take into account the presence of possible cross plane paddings // @@ -4821,8 +4820,8 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); + zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings zout *= (dst_cross_plane_pad * dst_stride_y); @@ -4830,50 +4829,78 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store the output block - vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - -#else // defined(REINTERPRET_OUTPUT_AS_3D) +#else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; +#endif // defined(REINTERPRET_OUTPUT_AS_3D) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) + REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0); + +#if defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)); + + LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, float, bias, BETA); +#endif // UNIT_BIAS + + // acc = acc + bias[broadcasted] + ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0); + +#else // defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (get_global_id(1) * + (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z; + + LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, float, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA); +#endif // UNIT_BIAS + + // acc = acc + bias + ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + +#if defined(ACTIVATION_TYPE) + ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL); +#endif // defined(ACTIVATION_TYPE) // Store the output block - vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y)); + vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y)); + vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y)); + vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y)); + vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif // defined(REINTERPRET_OUTPUT_AS_3D) } #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped - * - * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time. * * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable. * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y. * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4. * @note The number of matrix A columns must be passed at compile time using -DCOLS_A. * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha - * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) - * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) * + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D @@ -4881,8 +4908,6 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * - * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C - * * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) @@ -4895,10 +4920,12 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr - * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes) - * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix + * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr + * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) + * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) + * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) @@ -4907,18 +4934,22 @@ __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), -#if defined(ADD_VEC_C) - VECTOR_DECLARATION(src2), -#endif /* defined(ADD_VEC_C) */ +#if defined(BETA) + IMAGE_DECLARATION(src2), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z, +#if defined(BETA) + uint src2_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_INPUT_AS_3D) , @@ -5117,56 +5148,6 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 } - // Multiply by the weight of matrix-matrix product and store the result -#if defined(ALPHA) - half8 hacc0 = convert_half8(acc0) * (half8)ALPHA; -#else //defined(ALPHA) - half8 hacc0 = convert_half8(acc0); -#endif // defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if defined(ALPHA) - half8 hacc1 = convert_half8(acc1) * (half8)ALPHA; -#else //defined(ALPHA) - half8 hacc1 = convert_half8(acc1); -#endif //defined(ALPHA) -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y - -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if defined(ALPHA) - half8 hacc2 = convert_half8(acc2) * (half8)ALPHA; -#else //defined(ALPHA) - half8 hacc2 = convert_half8(acc2); -#endif //defined(ALPHA) -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if defined(ALPHA) - half8 hacc3 = convert_half8(acc3) * (half8)ALPHA; -#else //defined(ALPHA) - half8 hacc3 = convert_half8(acc3); -#endif // defined(ALPHA) -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - -#if defined(ADD_VEC_C) - // *INDENT-OFF* - // clang-format off - __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x); - half8 c0 = vload8(0, src2_addr); - // clang-format on - // *INDENT-ON* - - hacc0 += c0; -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - hacc1 += c0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - hacc2 += c0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - hacc3 += c0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif /* defined(ADD_VEC_C) */ - int z = get_global_id(2); // Compute destination address @@ -5175,7 +5156,10 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); + uint4 zout = 0; + #if defined(REINTERPRET_OUTPUT_AS_3D) + // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension // in order to take into account the presence of possible cross plane paddings // @@ -5192,8 +5176,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); + zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings zout *= (dst_cross_plane_pad * dst_stride_y); @@ -5201,38 +5185,91 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - // Store the output block - STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, hacc, dst_addr, dst_stride_y, zout.s); -#else // defined(REINTERPRET_OUTPUT_AS_3D) +#else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; +#endif // defined(REINTERPRET_OUTPUT_AS_3D) - // Store the output block - vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y)); + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA); +#endif // defined(ALPHA) + +#if defined(BETA) + REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0); + +#if defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)); + + LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero); + + float8 bias_f0 = convert_float8(bias0); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, float, bias_f, BETA); +#endif // UNIT_BIAS + + // acc = acc + bias[broadcasted] + ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f0); + +#else // defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * + (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z; + + LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero); + + float8 bias_f0 = convert_float8(bias0); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y)); + float8 bias_f1 = convert_float8(bias1); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y)); + float8 bias_f2 = convert_float8(bias2); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y)); + float8 bias_f3 = convert_float8(bias3); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif // REINTERPRET_OUTPUT_AS_3D + +#ifndef UNIT_BETA + SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias_f, BETA); +#endif // UNIT_BIAS + + // acc = acc + bias + ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + + half8 acc_h0 = convert_half8(acc0); +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 + half8 acc_h1 = convert_half8(acc1); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 + half8 acc_h2 = convert_half8(acc2); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 +#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + half8 acc_h3 = convert_half8(acc3); +#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 + +#if defined(ACTIVATION_TYPE) + ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc_h, A_VAL, B_VAL); +#endif // defined(ACTIVATION_TYPE) + + // Store the output block + STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc_h, dst_addr, dst_stride_y, zout.s); } /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped - * - * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time. * * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units. * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y. * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4. * @note The number of matrix A columns must be passed at compile time using -DCOLS_A. * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha - * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16) - * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16]) + * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) + * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) * + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D @@ -5240,8 +5277,6 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped * - * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C - * * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) @@ -5254,10 +5289,12 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr - * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes) - * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix + * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr + * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) + * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) + * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) @@ -5266,18 +5303,22 @@ __kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) + * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) */ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1), -#if defined(ADD_VEC_C) - VECTOR_DECLARATION(src2), -#endif /* defined(ADD_VEC_C) */ +#if defined(BETA) + IMAGE_DECLARATION(src2), +#endif // defined(BETA) IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z, +#if defined(BETA) + uint src2_stride_z, +#endif //defined(BETA) uint dst_stride_z #if defined(REINTERPRET_INPUT_AS_3D) , @@ -5476,40 +5517,6 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 } - // Multiply by the weight of matrix-matrix product and store the result -#if defined(ALPHA) - acc0 = acc0 * (half8)ALPHA; -#endif // defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA) - acc1 = acc1 * (half8)ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA) - acc2 = acc2 * (half8)ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA) -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA) - acc3 = acc3 * (half8)ALPHA; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA) - -#if defined(ADD_VEC_C) - // *INDENT-OFF* - // clang-format off - __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x); - half8 c0 = vload8(0, src2_addr); - // clang-format on - // *INDENT-ON* - - acc0 += c0; -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc1 += c0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc2 += c0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc3 += c0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif /* defined(ADD_VEC_C) */ - int z = get_global_id(2); // Compute destination address @@ -5518,7 +5525,10 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), // Compute dst address __global uchar *dst_addr = offset(&dst, 0, 0); + uint4 zout = 0; + #if defined(REINTERPRET_OUTPUT_AS_3D) + // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension // in order to take into account the presence of possible cross plane paddings // @@ -5535,8 +5545,8 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), // |__________________| // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); + zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; + zout = min(DEPTH_GEMM3D - 1, zout); // Add offset due to the cross plane paddings zout *= (dst_cross_plane_pad * dst_stride_y); @@ -5544,25 +5554,54 @@ __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we // multiply dst_stride_z by DEPTH_GEMM3D dst_addr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store the output block - STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s); -#else // defined(REINTERPRET_OUTPUT_AS_3D) +#else // defined(REINTERPRET_OUTPUT_AS_3D) // Add offset for batched GEMM dst_addr += z * dst_stride_z; +#endif // defined(REINTERPRET_OUTPUT_AS_3D) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, acc, ALPHA); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) + REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0); + +#if defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)); + + LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(1, half, bias, BETA); +#endif // UNIT_BIAS + + // acc = acc + bias[broadcasted] + ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0); + +#else // defined(BROADCAST_BIAS) + __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * + (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z; + + LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero); + +#ifndef UNIT_BETA + SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, bias, BETA); +#endif // UNIT_BIAS + + // acc = acc + bias + ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + +#if defined(ACTIVATION_TYPE) + ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc, A_VAL, B_VAL); +#endif // defined(ACTIVATION_TYPE) // Store the output block - vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#endif // REINTERPRET_OUTPUT_AS_3D + STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s); } #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) @@ -5746,7 +5785,7 @@ __kernel void gemm_accumulate_biases( Image accum = CONVERT_TO_IMAGE_STRUCT(accum); Vector biases = CONVERT_TO_VECTOR_STRUCT(biases); - // Vector size, i.e. number of vector elements. + // Vector size, e.g. number of vector elements. VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE) accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr); VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE) diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp index b3ea309c93..e793c65059 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp @@ -62,27 +62,34 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && reshape_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D"); - const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f; - const bool has_vec_c = input2 != nullptr && beta != 0.f; - ARM_COMPUTE_RETURN_ERROR_ON_MSG(has_vec_c && !is_beta_one, "Adding input2 is only supported for beta equal to 1"); - if(!is_interleaved_transposed) { ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1)); - if(has_vec_c) + if(input2 != nullptr && !(helpers::float_ops::is_zero(beta))) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->num_dimensions() > 1, "input2 must be a 1D tensor"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->dimension(0) != input1->dimension(0), "Length of Vector C must match the number of columns of matrix B"); + const unsigned int m = reshape_info.reinterpret_input_as_3d() ? input0->dimension(1) * input0->dimension(2) : input0->dimension(1); + const unsigned int n = input1->dimension(0); + const unsigned int input2_dim0 = input2->dimension(0); + const unsigned int input2_dim1 = input2->dimension(1); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input2, input1); + if(reshape_info.broadcast_bias()) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((input2_dim1 != 1 || input2_dim0 != n), "Incorrect dimension of bias matrix which is to be broadcasted"); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((input2_dim0 != n || input2_dim1 != m), "Incorrect dimension of bias matrix"); + } } } else { GEMMRHSMatrixInfo rhs_info; GEMMLHSMatrixInfo lhs_info; - const int m = reshape_info.m(); - const int n = reshape_info.n(); + const auto m = static_cast(reshape_info.m()); + const auto n = static_cast(reshape_info.n()); const int k = reshape_info.k(); const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width(); const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height(); @@ -114,10 +121,20 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1); - if(has_vec_c) + if(input2 != nullptr && !(helpers::float_ops::is_zero(beta))) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input2->num_dimensions() > 1, "input2 must be a 1D tensor"); + const unsigned int input2_dim0 = input2->dimension(0); + const unsigned int input2_dim1 = input2->dimension(1); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input2, input1); + if(reshape_info.broadcast_bias()) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((input2_dim1 != 1 || input2_dim0 != n), "Incorrect dimension of bias matrix which is to be broadcasted"); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((input2_dim0 != n || input2_dim1 != m), "Incorrect dimension of bias matrix"); + } } } @@ -145,7 +162,6 @@ inline std::pair validate_and_configure_window(ITensorInfo *inpu unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1]; bool reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d(); bool reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 0); - const bool has_vec_c = input2 != nullptr && beta != 0.f; // In case both input and output have to be reinterpreted as 3D tensors, // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. @@ -194,12 +210,23 @@ inline std::pair validate_and_configure_window(ITensorInfo *inpu ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x), output->dimension(1) + bottom_pad); - window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop - update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor - if(has_vec_c) + if(input2 != nullptr) + { + const int bias_processed_per_iteration_x = num_elems_processed_per_iteration_x; + + const int bias_processed_per_iteration_y = reshape_info.broadcast_bias() ? 1 : num_elems_processed_per_iteration_y; + + AccessWindowStatic input2_access(input2, 0, 0, + ceil_to_multiple(input2->dimension(0), bias_processed_per_iteration_x), + ceil_to_multiple(input2->dimension(1), bias_processed_per_iteration_y)); + + window_changed = update_window_and_padding(win, input0_access, input1_access, input2_access) || // window used by the execute_window_loop + update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor + } + else { - AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_x); - window_changed = window_changed || update_window_and_padding(win, input2_access); + window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop + update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor } output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape())); @@ -232,12 +259,23 @@ inline std::pair validate_and_configure_window(ITensorInfo *inpu ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x), output->dimension(1) + bottom_pad); - window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop - update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor - if(has_vec_c) + if(input2 != nullptr) + { + const int bias_processed_per_iteration_x = num_elems_processed_per_iteration_x; + + const int bias_processed_per_iteration_y = reshape_info.broadcast_bias() ? 1 : num_elems_processed_per_iteration_y; + + AccessWindowStatic input2_access(input2, 0, 0, + ceil_to_multiple(input2->dimension(0), bias_processed_per_iteration_x), + ceil_to_multiple(input2->dimension(1), bias_processed_per_iteration_y)); + + window_changed = update_window_and_padding(win, input0_access, input1_access, input2_access) || // window used by the execute_window_loop + update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor + } + else { - AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration_x); - window_changed = window_changed || update_window_and_padding(win, input2_access); + window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop + update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor } Coordinates coord; @@ -257,12 +295,13 @@ inline std::pair validate_and_configure_window(ITensorInfo *inpu } // namespace CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel() - : _input0(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false), _has_vec_c(false) + : _input0(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false), _add_bias(false), + _broadcast_bias(false) { } void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, - bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, bool fp_mixed_precision) + bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, bool fp_mixed_precision, const ActivationLayerInfo &activation_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output); @@ -272,10 +311,12 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen _input0 = input0; _input1 = input1; - _input2 = input2; + _input2 = helpers::float_ops::is_zero(beta) ? nullptr : input2; _output = output; _reinterpret_input_as_3d = reshape_info.reinterpret_input_as_3d(); _reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 0); + _add_bias = _input2 != nullptr; + _broadcast_bias = reshape_info.broadcast_bias(); // In case both input and output have to be reinterpreted as 3D tensors, // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. @@ -306,23 +347,21 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen // Create build options CLBuildOptions build_opts; - // Only define ALPHA when alpha is not 1.0f. This avoids performing unnecessary multiplications. - if(!(helpers::float_ops::is_one(alpha))) - { - build_opts.add_option("-DALPHA=" + float_to_string_with_full_precision(alpha)); - } + build_opts.add_option_if(!(helpers::float_ops::is_one(alpha)), "-DALPHA=" + float_to_string_with_full_precision(alpha)); + build_opts.add_option_if(_input2 != nullptr, "-DBETA=" + float_to_string_with_full_precision(beta)); + build_opts.add_option_if(helpers::float_ops::is_one(beta), "-DUNIT_BETA"); + build_opts.add_option_if(reshape_info.broadcast_bias(), "-DBROADCAST_BIAS"); build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D"); build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D"); build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1))); build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2))); - - // Do not slide matrix B if _slide_matrix_b = false build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2))); + build_opts.add_option_if(activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(activation_info.activation()))); + build_opts.add_option_if(activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(activation_info.a())); + build_opts.add_option_if(activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(activation_info.b())); const bool is_bifrost = get_arch_from_target(gpu_target) == GPUTarget::BIFROST; - _has_vec_c = input2 != nullptr && beta != 0.f; - std::string kernel_name; if(is_interleaved_transposed) { @@ -386,15 +425,14 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x())); } - // Configure matrix C addition if necessary - build_opts.add_option_if(_has_vec_c, "-DADD_VEC_C"); - // Create kernel _kernel = static_cast(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options())); // Set config_id for enabling LWS tuning _config_id = "gemm_"; _config_id += (is_interleaved_transposed ? "reshaped_" : ""); + _config_id += (_add_bias ? "add_bias_" : ""); + _config_id += (_broadcast_bias ? "broadcast_bias_" : ""); _config_id += (fp_mixed_precision ? "fp_mixed_" : ""); _config_id += (_reinterpret_input_as_3d ? "3di_" : ""); _config_id += (_reinterpret_output_as_3d ? "3do_" : ""); @@ -412,11 +450,12 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen } Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, - bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision) + bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target, bool fp_mixed_precision, const ActivationLayerInfo &activation_info) { // Note: num_elements_processed will be set in validate_and_configure_window() ElementsProcessed num_elements_processed{}; ARM_COMPUTE_UNUSED(alpha); + ARM_COMPUTE_UNUSED(activation_info); ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, input2, output, beta, is_interleaved_transposed, reshape_info, fp_mixed_precision)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), @@ -449,12 +488,12 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1)); slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1)); - const unsigned int num_arguments_vec_c = (_has_vec_c) ? num_arguments_per_1D_tensor() : 0; + const unsigned int num_arguments_bias = _add_bias ? num_arguments_per_2D_tensor() + 1 : 0; if(_reinterpret_input_as_3d) { // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor - const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + num_arguments_vec_c; + const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + num_arguments_bias; const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom; _kernel.setArg(idx0, static_cast(total_cross_plane_pad)); } @@ -462,7 +501,7 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que if(_reinterpret_output_as_3d) { // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor - const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0) + num_arguments_vec_c; + const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0) + num_arguments_bias; const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom; _kernel.setArg(idx0, static_cast(total_cross_plane_pad)); } @@ -480,13 +519,17 @@ void CLGEMMMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &que unsigned int idx = 0; add_2D_tensor_argument(idx, _input0, slice); add_2D_tensor_argument(idx, _input1, slice_b); - if(_has_vec_c) + if(_add_bias) { - add_1D_tensor_argument(idx, _input2, slice); + add_2D_tensor_argument(idx, _input2, slice); } add_2D_tensor_argument(idx, _output, slice); _kernel.setArg(idx++, static_cast(_input0->info()->strides_in_bytes()[2])); _kernel.setArg(idx++, static_cast(_input1->info()->strides_in_bytes()[2])); + if(_add_bias) + { + _kernel.setArg(idx++, static_cast(_input2->info()->strides_in_bytes()[2])); + } _kernel.setArg(idx++, static_cast(_output->info()->strides_in_bytes()[2])); enqueue(queue, *this, slice, lws_hint()); } diff --git a/tests/CL/Helper.h b/tests/CL/Helper.h index ab2f8ccb22..0a4566be8d 100644 --- a/tests/CL/Helper.h +++ b/tests/CL/Helper.h @@ -53,6 +53,19 @@ public: k->configure(std::forward(args)...); _kernel = std::move(k); } + /** Configure the kernel setting the GPU target as well + * + * @param[in] gpu_target GPUTarget to set + * @param[in] args Configuration arguments. + */ + template + void configure(GPUTarget gpu_target, Args &&... args) + { + auto k = arm_compute::support::cpp14::make_unique(); + k->set_target(gpu_target); + k->configure(std::forward(args)...); + _kernel = std::move(k); + } /** Validate input arguments * * @param[in] args Configuration arguments. diff --git a/tests/validation/CL/GEMMMatrixMultiply.cpp b/tests/validation/CL/GEMMMatrixMultiply.cpp new file mode 100644 index 0000000000..21fd7125ec --- /dev/null +++ b/tests/validation/CL/GEMMMatrixMultiply.cpp @@ -0,0 +1,344 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h" +#include "arm_compute/core/KernelDescriptors.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/runtime/CL/CLTensor.h" +#include "arm_compute/runtime/CL/CLTensorAllocator.h" +#include "tests/CL/CLAccessor.h" +#include "tests/CL/Helper.h" +#include "tests/PaddingCalculator.h" +#include "tests/datasets/ShapeDatasets.h" +#include "tests/framework/Asserts.h" +#include "tests/framework/Macros.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/validation/Validation.h" +#include "tests/validation/fixtures/GEMMFixture.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +using namespace arm_compute::misc::shape_calculator; + +// Create function for CLGEMMMatrixMultiplyKernel +using CLGEMMMatrixMultiplyNative = CLSynthetizeFunction; + +// Fixture for GEMMMatrixMultiplyValidationFixture +template +using CLGEMMMatrixMultiplyNativeFixture = GEMMMatrixMultiplyValidationFixture; + +// Fixture for GEMMMatrixMultiply3DValidationFixture +template +using CLGEMMMatrixMultiplyNative3DFixture = GEMMMatrixMultiply3DValidationFixture; + +namespace +{ +// *INDENT-OFF* +// clang-format off +RelativeTolerance rel_tolerance_f32(0.001f); +constexpr float abs_tolerance_f32(0.0001f); + +RelativeTolerance rel_tolerance_f16(half(0.2)); +constexpr float tolerance_num_f16 = 0.02f; + +/** Alpha values to test - Precommit */ +const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} ); + +/** Beta values to test - Precommit */ +const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} ); + +/** M values to test - Precommit */ +const auto m_values_precommit = framework::dataset::make("M", {37, 1}); + +/** N values to test - Precommit */ +const auto n_values_precommit = framework::dataset::make("N", 51); + +/** K values to test - Precommit */ +const auto k_values_precommit = framework::dataset::make("K", 23); + +/** M values to test - Nightly */ +const auto m_values_nightly = framework::dataset::make("M", {421, 1}); + +/** N values to test - Nightly */ +const auto n_values_nightly = framework::dataset::make("N", {323, 1103}); + +/** K values to test - Nightly */ +const auto k_values_nightly = framework::dataset::make("K", 207); + +/** M_W values to test - Precommit */ +const auto m_w_values_precommit = framework::dataset::make("M_W", 5); + +/** M_H values to test - Precommit */ +const auto m_h_values_precommit = framework::dataset::make("M_H", 7); + +/** M_W values to test - Nightly */ +const auto m_w_values_nightly = framework::dataset::make("M_W", 13); + +/** M_H values to test - Nightly */ +const auto m_h_values_nightly = framework::dataset::make("M_H", 27); + +/** Batch size values to test */ +const auto b_values = framework::dataset::make("batch_size", 1, 3); + +/** Activation values to test */ +const auto act_values = framework::dataset::make("Activation", +{ + ActivationLayerInfo(), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), +}); + +/** Broadcast bias from vector to matrix */ +const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", {false, true} ); + +/** GPU architectures values to test */ +const auto gpu_arch_values = framework::dataset::make("GPUArch", +{ + GPUTarget::MIDGARD, + GPUTarget::BIFROST +}); + +/** Data types values to test in the configuration */ +const auto data_type_values = framework::dataset::make("DataType", +{ + DataType::F32, + DataType::F16 +}); + +/** M values to test */ +const auto fp16_mixed_precision_values = framework::dataset::make("fp16_mixed_precision", {true, false}); + +/** Configuration test */ +void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, bool broadcast_bias, bool fp16_mixed_precision, const ActivationLayerInfo &act_info, DataType data_type, GPUTarget gpu_arch_value) +{ + GEMMReshapeInfo reshape_info(m_value, n_value, k_value, 1, 1, 0, false, broadcast_bias); + + const TensorShape lhs_shape(k_value, m_value, b_value); + const TensorShape rhs_shape(n_value, k_value, b_value); + + const TensorShape dst_shape = compute_mm_shape(TensorInfo(lhs_shape, 1, data_type), + TensorInfo(rhs_shape, 1, data_type), + reshape_info); + + const TensorShape bias_shape(n_value, + broadcast_bias? 1 : m_value, + broadcast_bias? 1 : b_value); + + // Create tensors + CLTensor lhs = create_tensor(lhs_shape, data_type); + CLTensor rhs = create_tensor(rhs_shape, data_type); + CLTensor bias = create_tensor(bias_shape, data_type); + CLTensor dst = create_tensor(dst_shape, data_type); + + ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Create and configure function + CLGEMMMatrixMultiplyNative gemm; + gemm.configure(gpu_arch_value, &lhs, &rhs, &bias, &dst, 1.0f, 2.0f, false, reshape_info, fp16_mixed_precision, act_info); +} +} // namespace + +TEST_SUITE(CL) +TEST_SUITE(GEMMMatrixMultiply) +TEST_SUITE(Float) +TEST_SUITE(FP32) +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine( + m_values_precommit, + n_values_precommit), + k_values_precommit), + framework::dataset::make("batch_size", 1)), + broadcast_bias_values), + framework::dataset::make("fp16_mixed_precision", false)), + act_values), + data_type_values), + gpu_arch_values), +m_value, n_value, k_value, b_value, broadcast_bias, fp16_mixed_precision_value, act_value, data_type_value, gpu_arch_value) +{ + validate_configuration(m_value, n_value, k_value, b_value, broadcast_bias, fp16_mixed_precision_value, act_value, data_type_value, gpu_arch_value); +} + +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyNativeFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values_precommit, + n_values_precommit), + k_values_precommit), + b_values), + alpha_values), + beta_values), + broadcast_bias_values), + framework::dataset::make("fp16_mixed_precision", false)), + act_values), + framework::dataset::make("DataType", DataType::F32)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyNativeFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values_nightly, + n_values_nightly), + k_values_nightly), + b_values), + alpha_values), + beta_values), + broadcast_bias_values), + framework::dataset::make("fp16_mixed_precision", false)), + act_values), + framework::dataset::make("DataType", DataType::F32)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyNative3DFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values_precommit, + m_h_values_precommit), + n_values_precommit), + k_values_precommit), + b_values), + alpha_values), + beta_values), + broadcast_bias_values), + framework::dataset::make("fp16_mixed_precision", false)), + act_values), + framework::dataset::make("DataType", DataType::F32)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyNative3DFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values_nightly, + m_h_values_nightly), + n_values_nightly), + k_values_nightly), + b_values), + alpha_values), + beta_values), + broadcast_bias_values), + framework::dataset::make("fp16_mixed_precision", false)), + act_values), + framework::dataset::make("DataType", DataType::F32)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyNativeFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values_precommit, + n_values_precommit), + k_values_precommit), + b_values), + alpha_values), + beta_values), + broadcast_bias_values), + fp16_mixed_precision_values), + act_values), + framework::dataset::make("DataType", DataType::F16)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyNativeFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values_nightly, + n_values_nightly), + k_values_nightly), + b_values), + alpha_values), + beta_values), + broadcast_bias_values), + fp16_mixed_precision_values), + act_values), + framework::dataset::make("DataType", DataType::F16)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); +} + +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyNative3DFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values_precommit, + m_h_values_precommit), + n_values_precommit), + k_values_precommit), + b_values), + alpha_values), + beta_values), + broadcast_bias_values), + fp16_mixed_precision_values), + act_values), + framework::dataset::make("DataType", DataType::F16)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); +} + +FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyNative3DFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values_nightly, + m_h_values_nightly), + n_values_nightly), + k_values_nightly), + b_values), + alpha_values), + beta_values), + broadcast_bias_values), + fp16_mixed_precision_values), + act_values), + framework::dataset::make("DataType", DataType::F16)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); +} + +TEST_SUITE_END() // FP16 +TEST_SUITE_END() // Float +TEST_SUITE_END() // GEMMMatrixMuliplty +TEST_SUITE_END() // CL +} // namespace validation +} // namespace test +} // namespace arm_compute \ No newline at end of file diff --git a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp new file mode 100644 index 0000000000..cae94b2e15 --- /dev/null +++ b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp @@ -0,0 +1,404 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h" +#include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h" +#include "arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h" +#include "arm_compute/core/KernelDescriptors.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/runtime/CL/CLTensor.h" +#include "arm_compute/runtime/CL/CLTensorAllocator.h" +#include "tests/CL/CLAccessor.h" +#include "tests/CL/Helper.h" +#include "tests/PaddingCalculator.h" +#include "tests/datasets/ShapeDatasets.h" +#include "tests/framework/Asserts.h" +#include "tests/framework/Macros.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/validation/Validation.h" +#include "tests/validation/fixtures/GEMMFixture.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +using namespace arm_compute::misc::shape_calculator; + +// Create function for CLGEMMReshapeLHSMatrixKernel +using CLGEMMReshapeLHSMatrix = CLSynthetizeFunction; + +// Create function for CLGEMMReshapeRHSMatrixKernel +using CLGEMMReshapeRHSMatrix = CLSynthetizeFunction; + +// Create function for CLGEMMMatrixMultiplyKernel +using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction; + +// Fixture for GEMMMatrixMultiplyInterleavedTransposedValidationFixture +template +using CLGEMMMatrixMultiplyReshapedFixture = + GEMMMatrixMultiplyInterleavedTransposedValidationFixture; + +// Fixture for GEMMMatrixMultiplyInterleavedTransposed3DValidationFixture +template +using CLGEMMMatrixMultiplyReshaped3DFixture = + GEMMMatrixMultiplyInterleavedTransposed3DValidationFixture; + +namespace +{ +// *INDENT-OFF* +// clang-format off +RelativeTolerance rel_tolerance_f32(0.001f); +constexpr float abs_tolerance_f32(0.0001f); + +RelativeTolerance rel_tolerance_f16(half(0.2)); +constexpr float tolerance_num_f16 = 0.02f; + +/** Alpha values to test - Precommit */ +const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} ); + +/** Beta values to test - Precommit */ +const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} ); + +/** M values to test - Precommit */ +const auto m_values_precommit = framework::dataset::make("M", 37); + +/** N values to test - Precommit */ +const auto n_values_precommit = framework::dataset::make("N", 51); + +/** K values to test - Precommit */ +const auto k_values_precommit = framework::dataset::make("K", 23); + +/** M values to test - Nightly */ +const auto m_values_nightly = framework::dataset::make("M", {421, 1}); + +/** N values to test - Nightly */ +const auto n_values_nightly = framework::dataset::make("N", 323); + +/** K values to test - Nightly */ +const auto k_values_nightly = framework::dataset::make("K", 207); + +/** M_W values to test - Precommit */ +const auto m_w_values_precommit = framework::dataset::make("M_W", 5); + +/** M_H values to test - Precommit */ +const auto m_h_values_precommit = framework::dataset::make("M_H", 7); + +/** M_W values to test - Nightly */ +const auto m_w_values_nightly = framework::dataset::make("M_W", 13); + +/** M_H values to test - Nightly */ +const auto m_h_values_nightly = framework::dataset::make("M_H", 27); + +/** Batch size values to test */ +const auto b_values = framework::dataset::make("batch_size", 1, 3); + +/** Activation values to test */ +const auto act_values = framework::dataset::make("Activation", +{ + ActivationLayerInfo(), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), +}); + +/** V0 values to test - Precommit */ +const auto v0_values_precommit = framework::dataset::make("V0", 2); + +/** H0 values to test - Precommit */ +const auto h0_values_precommit = framework::dataset::make("H0", 4); + +/** V0 values to test - Nightly */ +const auto v0_values_nightly = framework::dataset::make("V0", {2, 4}); + +/** H0 values to test - Nightly */ +const auto h0_values_nightly = framework::dataset::make("H0", { 2, 4 }); + +/** Broadcast bias from vector to matrix */ +const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", {false, true} ); + +/** GPU architectures values to test */ +const auto gpu_arch_values = framework::dataset::make("GPUArch", +{ + GPUTarget::MIDGARD, + GPUTarget::BIFROST +}); + +/** Data types values to test in the configuration */ +const auto data_type_values = framework::dataset::make("DataType", +{ + DataType::F32, + DataType::F16 +}); + +/** M values to test */ +const auto fp16_mixed_precision_values = framework::dataset::make("fp16_mixed_precision", {true, false}); + +/** Configuration test */ +void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int v0_value, unsigned int h0_value, bool broadcast_bias, bool fp16_mixed_precision, const ActivationLayerInfo &act_info, DataType data_type, GPUTarget gpu_arch_value) +{ + GEMMLHSMatrixInfo lhs_info; + lhs_info.m0 = 4; + lhs_info.k0 = 4; + lhs_info.v0 = v0_value; + lhs_info.interleave = true; + lhs_info.transpose = true; + + GEMMRHSMatrixInfo rhs_info; + rhs_info.n0 = data_type == DataType::F32? 4 : 8; + rhs_info.k0 = 1; + rhs_info.h0 = h0_value; + rhs_info.interleave = false; + rhs_info.transpose = false; + + GEMMReshapeInfo reshape_info(m_value, n_value, k_value, rhs_info.h0, lhs_info.v0, 0, false, broadcast_bias); + + const TensorShape lhs_shape(k_value, m_value, b_value); + const TensorShape lhs_shape_reshaped = compute_lhs_reshaped_shape(TensorInfo(lhs_shape, 1, data_type), + lhs_info, + false); + + const TensorShape rhs_shape(n_value, k_value, b_value); + const TensorShape rhs_shape_reshaped = compute_rhs_reshaped_shape(TensorInfo(rhs_shape, 1, data_type), + rhs_info); + + const TensorShape dst_shape = compute_mm_shape(TensorInfo(lhs_shape_reshaped, 1, data_type), + TensorInfo(rhs_shape_reshaped, 1, data_type), + reshape_info); + + const TensorShape bias_shape(n_value, + broadcast_bias? 1 : m_value, + broadcast_bias? 1 : b_value); + + // Create tensors + CLTensor lhs_reshaped = create_tensor(lhs_shape_reshaped, data_type); + CLTensor rhs_reshaped = create_tensor(rhs_shape_reshaped, data_type); + CLTensor bias = create_tensor(bias_shape, data_type); + CLTensor dst = create_tensor(dst_shape, data_type); + + ARM_COMPUTE_EXPECT(lhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Create and configure function + CLGEMMMatrixMultiplyReshaped gemm; + gemm.configure(gpu_arch_value, &lhs_reshaped, &rhs_reshaped, &bias, &dst, 1.0f, 2.0f, true, reshape_info, fp16_mixed_precision, act_info); +} +} // namespace + +TEST_SUITE(CL) +TEST_SUITE(GEMMMatrixMultiplyInterleavedTransposed) +TEST_SUITE(Float) +TEST_SUITE(FP32) +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values_precommit, + n_values_precommit), + k_values_precommit), + framework::dataset::make("batch_size", 1)), + v0_values_precommit), + h0_values_precommit), + broadcast_bias_values), + framework::dataset::make("fp16_mixed_precision", false)), + act_values), + data_type_values), + gpu_arch_values), +m_value, n_value, k_value, b_value, v0_value, h0_value, broadcast_bias, fp16_mixed_precision_value, act_value, data_type_value, gpu_arch_value) +{ + validate_configuration(m_value, n_value, k_value, b_value, v0_value, h0_value, broadcast_bias, fp16_mixed_precision_value, act_value, data_type_value, gpu_arch_value); +} + +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values_precommit, + n_values_precommit), + k_values_precommit), + b_values), + alpha_values), + beta_values), + v0_values_precommit), + h0_values_precommit), + broadcast_bias_values), + framework::dataset::make("fp16_mixed_precision", false)), + act_values), + framework::dataset::make("DataType", DataType::F32)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values_nightly, + n_values_nightly), + k_values_nightly), + b_values), + alpha_values), + beta_values), + v0_values_nightly), + h0_values_nightly), + broadcast_bias_values), + framework::dataset::make("fp16_mixed_precision", false)), + act_values), + framework::dataset::make("DataType", DataType::F32)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values_precommit, + m_h_values_precommit), + n_values_precommit), + k_values_precommit), + b_values), + alpha_values), + beta_values), + v0_values_precommit), + h0_values_precommit), + broadcast_bias_values), + framework::dataset::make("fp16_mixed_precision", false)), + act_values), + framework::dataset::make("DataType", DataType::F32)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values_nightly, + m_h_values_nightly), + n_values_nightly), + k_values_nightly), + b_values), + alpha_values), + beta_values), + v0_values_nightly), + h0_values_nightly), + broadcast_bias_values), + framework::dataset::make("fp16_mixed_precision", false)), + act_values), + framework::dataset::make("DataType", DataType::F32)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values_precommit, + n_values_precommit), + k_values_precommit), + b_values), + alpha_values), + beta_values), + v0_values_precommit), + h0_values_precommit), + broadcast_bias_values), + fp16_mixed_precision_values), + act_values), + framework::dataset::make("DataType", DataType::F16)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values_nightly, + n_values_nightly), + k_values_nightly), + b_values), + alpha_values), + beta_values), + v0_values_nightly), + h0_values_nightly), + broadcast_bias_values), + fp16_mixed_precision_values), + act_values), + framework::dataset::make("DataType", DataType::F16)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); +} + +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values_precommit, + m_h_values_precommit), + n_values_precommit), + k_values_precommit), + b_values), + alpha_values), + beta_values), + v0_values_precommit), + h0_values_precommit), + broadcast_bias_values), + fp16_mixed_precision_values), + act_values), + framework::dataset::make("DataType", DataType::F16)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); +} + +FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values_nightly, + m_h_values_nightly), + n_values_nightly), + k_values_nightly), + b_values), + alpha_values), + beta_values), + v0_values_nightly), + h0_values_nightly), + broadcast_bias_values), + fp16_mixed_precision_values), + act_values), + framework::dataset::make("DataType", DataType::F16)), + gpu_arch_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); +} + +TEST_SUITE_END() // FP16 +TEST_SUITE_END() // Float +TEST_SUITE_END() // GEMMMatrixMulipltyInterleavedTransposed +TEST_SUITE_END() // CL +} // namespace validation +} // namespace test +} // namespace arm_compute \ No newline at end of file diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index 99af2965d2..25221451ed 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -418,7 +418,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, } TEST_SUITE_END() // FP16 TEST_SUITE_END() // Float -TEST_SUITE_END() // GEMMMatrixMulipltyReshaped +TEST_SUITE_END() // GEMMMatrixMultiplyReshaped TEST_SUITE_END() // CL } // namespace validation } // namespace test diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index ac8ab2a949..b36bb99246 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -153,6 +153,506 @@ protected: SimpleTensor _reference{}; }; +template +class GEMMMatrixMultiplyValidationFixture : public framework::Fixture +{ +public: + template + void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, float alpha, float beta, bool broadcast_bias, bool fp16_mixed_precision, const ActivationLayerInfo &act_info, + DataType data_type, GPUTarget gpu_arch) + { + // Set the tensor shapes for LHS and RHS matrices + const TensorShape lhs_shape(k, m, batch_size); + const TensorShape rhs_shape(n, k, batch_size); + const TensorShape bias_shape(n, + broadcast_bias ? 1 : m, + broadcast_bias ? 1 : batch_size); + + _target = compute_target(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, broadcast_bias, fp16_mixed_precision, act_info, gpu_arch); + _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, broadcast_bias, act_info); + } + +protected: + template + void fill(U &&tensor, int i) + { + std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, i); + + // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0) + std::uniform_real_distribution<> distribution_inf(std::numeric_limits::infinity(), std::numeric_limits::infinity()); + library->fill_borders_with_garbage(tensor, distribution_inf, i); + } + + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, bool broadcast_bias, + bool fp16_mixed_precision, const ActivationLayerInfo &act_info, GPUTarget gpu_arch) + { + // Create tensors + TensorType lhs = create_tensor(lhs_shape, data_type, 1); + TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType bias = create_tensor(bias_shape, data_type, 1); + TensorType dst; + + const unsigned int m = lhs_shape[1]; + const unsigned int n = rhs_shape[0]; + const unsigned int k = lhs_shape[0]; + GEMMReshapeInfo reshape_info(m, n, k, 1, 1, 0, false, broadcast_bias); + + // The output tensor will be auto-initialized within the function + + // Create and configure function + GEMMFunctionType gemm; + gemm.configure(gpu_arch, &lhs, &rhs, &bias, &dst, alpha, beta, false, reshape_info, fp16_mixed_precision, act_info); + + ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Allocate tensors + lhs.allocator()->allocate(); + rhs.allocator()->allocate(); + bias.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Fill tensors + fill(AccessorType(lhs), 0); + fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); + + // Compute GEMM + gemm.run(); + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, bool broadcast_bias, + const ActivationLayerInfo &act_info) + { + TensorShape dst_shape = lhs_shape; + dst_shape[0] = rhs_shape[0]; + dst_shape[1] = lhs_shape[1]; + + // Create reference + SimpleTensor lhs{ lhs_shape, data_type, 1 }; + SimpleTensor rhs{ rhs_shape, data_type, 1 }; + SimpleTensor bias{ dst_shape, data_type, 1 }; + + const int n = rhs_shape[0]; + const int m = lhs_shape[1]; + const int batch_size = lhs_shape[2]; + + // Fill reference + fill(lhs, 0); + fill(rhs, 1); + fill(bias, 2); + + if(broadcast_bias) + { + // In case of broadcast, we need simply copy the first into the following "M" ones + for(int i = 1; i < m * batch_size; i++) + { + memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); + } + } + + return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + } + + TensorType _target{}; + SimpleTensor _reference{}; +}; + +template +class GEMMMatrixMultiply3DValidationFixture : public framework::Fixture +{ +public: + template + void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, float alpha, float beta, bool broadcast_bias, bool fp16_mixed_precision, + const ActivationLayerInfo &act_info, DataType data_type, GPUTarget gpu_arch) + { + // In case of GEMM3D, m is the product between m_w and m_h + const unsigned int m = m_w * m_h; + + // Set the tensor shapes for LHS and RHS matrices + const TensorShape lhs_shape(k, m, batch_size); + const TensorShape rhs_shape(n, k, batch_size); + const TensorShape bias_shape(n, 1, 1); + + _target = compute_target(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, m_h, fp16_mixed_precision, act_info, gpu_arch); + _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, m_h, act_info); + } + +protected: + template + void fill(U &&tensor, int i) + { + std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, i); + } + + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, unsigned int m_h, + bool fp16_mixed_precision, const ActivationLayerInfo &act_info, GPUTarget gpu_arch) + { + // Create tensors + TensorType lhs = create_tensor(lhs_shape, data_type, 1); + TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType bias = create_tensor(bias_shape, data_type, 1); + TensorType dst; + + const unsigned int m = lhs_shape[1]; + const unsigned int n = rhs_shape[0]; + const unsigned int k = lhs_shape[0]; + GEMMReshapeInfo reshape_info(m, n, k, 1, 1, m_h, false, true); + + // The output tensor will be auto-initialized within the function + + // Create and configure function + GEMMFunctionType gemm; + gemm.configure(gpu_arch, &lhs, &rhs, &bias, &dst, alpha, beta, false, reshape_info, fp16_mixed_precision, act_info); + + ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Allocate tensors + lhs.allocator()->allocate(); + rhs.allocator()->allocate(); + bias.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Fill tensors + fill(AccessorType(lhs), 0); + fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); + + // Compute GEMM + gemm.run(); + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, unsigned int m_h, + const ActivationLayerInfo &act_info) + { + TensorShape dst_shape = lhs_shape; + dst_shape.set(0, rhs_shape[0]); + dst_shape.set(1, lhs_shape[1] / m_h); + dst_shape.set(2, m_h); + dst_shape.set(3, lhs_shape[2]); + + // Create reference + SimpleTensor lhs{ lhs_shape, data_type, 1 }; + SimpleTensor rhs{ rhs_shape, data_type, 1 }; + SimpleTensor bias{ dst_shape, data_type, 1 }; + + const int n = rhs_shape[0]; + const int m = lhs_shape[1]; + const int batch_size = lhs_shape[2]; + + // Fill reference + fill(lhs, 0); + fill(rhs, 1); + fill(bias, 2); + + // In case of broadcast, we need simply copy the first into the following "M" ones + for(int i = 1; i < m * batch_size; i++) + { + memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); + } + + return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + } + + TensorType _target{}; + SimpleTensor _reference{}; +}; + +template +class GEMMMatrixMultiplyInterleavedTransposedValidationFixture : public framework::Fixture +{ +public: + template + void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, float alpha, float beta, unsigned int v0, unsigned int h0, bool broadcast_bias, bool fp16_mixed_precision, + const ActivationLayerInfo &act_info, DataType data_type, GPUTarget gpu_arch) + { + GEMMLHSMatrixInfo lhs_info; + lhs_info.m0 = 4; + lhs_info.k0 = 4; + lhs_info.v0 = v0; + lhs_info.interleave = true; + lhs_info.transpose = true; + + GEMMRHSMatrixInfo rhs_info; + rhs_info.n0 = 16 / sizeof(T); + rhs_info.k0 = 1; + rhs_info.h0 = h0; + rhs_info.interleave = false; + rhs_info.transpose = false; + + // Set the tensor shapes for LHS and RHS matrices + const TensorShape lhs_shape(k, m, batch_size); + const TensorShape rhs_shape(n, k, batch_size); + const TensorShape bias_shape(n, + broadcast_bias ? 1 : m, + broadcast_bias ? 1 : batch_size); + + _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias, fp16_mixed_precision, act_info, gpu_arch); + _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, broadcast_bias, act_info); + } + +protected: + template + void fill(U &&tensor, int i) + { + std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, i); + + // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0) + std::uniform_real_distribution<> distribution_inf(std::numeric_limits::infinity(), std::numeric_limits::infinity()); + library->fill_borders_with_garbage(tensor, distribution_inf, i); + } + + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + DataType data_type, float alpha, float beta, bool broadcast_bias, bool fp16_mixed_precision, const ActivationLayerInfo &act_info, GPUTarget gpu_arch) + { + // Create tensors + TensorType lhs = create_tensor(lhs_shape, data_type, 1); + TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType bias = create_tensor(bias_shape, data_type, 1); + TensorType lhs_reshaped; + TensorType rhs_reshaped; + TensorType dst; + + const unsigned int m = lhs_shape[1]; + const unsigned int n = rhs_shape[0]; + const unsigned int k = lhs_shape[0]; + GEMMReshapeInfo reshape_info(m, n, k, rhs_info.h0, lhs_info.v0, 0, false, broadcast_bias); + + // The output tensor will be auto-initialized within the function + + // Create and configure function + ReshapeLHSFunctionType reshape_lhs; + ReshapeRHSFunctionType reshape_rhs; + GEMMFunctionType gemm; + reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info); + reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info); + gemm.configure(gpu_arch, &lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, true, reshape_info, fp16_mixed_precision, act_info); + + ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Allocate tensors + lhs.allocator()->allocate(); + rhs.allocator()->allocate(); + lhs_reshaped.allocator()->allocate(); + rhs_reshaped.allocator()->allocate(); + bias.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!lhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Fill tensors + fill(AccessorType(lhs), 0); + fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); + + // Compute GEMM + reshape_lhs.run(); + reshape_rhs.run(); + gemm.run(); + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, bool broadcast_bias, + const ActivationLayerInfo &act_info) + { + TensorShape dst_shape = lhs_shape; + dst_shape[0] = rhs_shape[0]; + dst_shape[1] = lhs_shape[1]; + + // Create reference + SimpleTensor lhs{ lhs_shape, data_type, 1 }; + SimpleTensor rhs{ rhs_shape, data_type, 1 }; + SimpleTensor bias{ dst_shape, data_type, 1 }; + + const int n = rhs_shape[0]; + const int m = lhs_shape[1]; + const int batch_size = lhs_shape[2]; + + // Fill reference + fill(lhs, 0); + fill(rhs, 1); + fill(bias, 2); + + if(broadcast_bias) + { + // In case of broadcast, we need simply copy the first into the following "M" ones + for(int i = 1; i < m * batch_size; i++) + { + memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); + } + } + + return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + } + + TensorType _target{}; + SimpleTensor _reference{}; +}; + +template +class GEMMMatrixMultiplyInterleavedTransposed3DValidationFixture : public framework::Fixture +{ +public: + template + void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, float alpha, float beta, unsigned int v0, unsigned int h0, bool broadcast_bias, + bool fp16_mixed_precision, const ActivationLayerInfo &act_info, DataType data_type, GPUTarget gpu_arch) + { + GEMMLHSMatrixInfo lhs_info; + lhs_info.m0 = 4; + lhs_info.k0 = 4; + lhs_info.v0 = v0; + lhs_info.interleave = true; + lhs_info.transpose = true; + + GEMMRHSMatrixInfo rhs_info; + rhs_info.n0 = 16 / sizeof(T); + rhs_info.k0 = 1; + rhs_info.h0 = h0; + rhs_info.interleave = false; + rhs_info.transpose = false; + + // In case of GEMM3D, m is the product between m_w and m_h + const unsigned int m = m_w * m_h; + + // Set the tensor shapes for LHS and RHS matrices + const TensorShape lhs_shape(k, m, batch_size); + const TensorShape rhs_shape(n, k, batch_size); + const TensorShape bias_shape(n, 1, 1); + + _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, m_h, fp16_mixed_precision, act_info, gpu_arch); + _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, m_h, act_info); + } + +protected: + template + void fill(U &&tensor, int i) + { + std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, i); + } + + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + DataType data_type, float alpha, float beta, unsigned int m_h, bool fp16_mixed_precision, const ActivationLayerInfo &act_info, GPUTarget gpu_arch) + { + // Create tensors + TensorType lhs = create_tensor(lhs_shape, data_type, 1); + TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType bias = create_tensor(bias_shape, data_type, 1); + TensorType lhs_reshaped; + TensorType rhs_reshaped; + TensorType dst; + + const unsigned int m = lhs_shape[1]; + const unsigned int n = rhs_shape[0]; + const unsigned int k = lhs_shape[0]; + GEMMReshapeInfo reshape_info(m, n, k, rhs_info.h0, lhs_info.v0, m_h, false, true); + + // The output tensor will be auto-initialized within the function + + // Create and configure function + ReshapeLHSFunctionType reshape_lhs; + ReshapeRHSFunctionType reshape_rhs; + GEMMFunctionType gemm; + reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info); + reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info); + gemm.configure(gpu_arch, &lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, true, reshape_info, fp16_mixed_precision, act_info); + + ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Allocate tensors + lhs.allocator()->allocate(); + rhs.allocator()->allocate(); + lhs_reshaped.allocator()->allocate(); + rhs_reshaped.allocator()->allocate(); + bias.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!lhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Fill tensors + fill(AccessorType(lhs), 0); + fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); + + // Compute GEMM + reshape_lhs.run(); + reshape_rhs.run(); + gemm.run(); + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, unsigned int m_h, + const ActivationLayerInfo &act_info) + { + TensorShape dst_shape = lhs_shape; + dst_shape.set(0, rhs_shape[0]); + dst_shape.set(1, lhs_shape[1] / m_h); + dst_shape.set(2, m_h); + dst_shape.set(3, lhs_shape[2]); + + // Create reference + SimpleTensor lhs{ lhs_shape, data_type, 1 }; + SimpleTensor rhs{ rhs_shape, data_type, 1 }; + SimpleTensor bias{ dst_shape, data_type, 1 }; + + const int n = rhs_shape[0]; + const int m = lhs_shape[1]; + const int batch_size = lhs_shape[2]; + + // Fill reference + fill(lhs, 0); + fill(rhs, 1); + fill(bias, 2); + + // In case of broadcast, we need simply copy the first into the following "M" ones + for(int i = 1; i < m * batch_size; i++) + { + memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); + } + + return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + } + + TensorType _target{}; + SimpleTensor _reference{}; +}; + template class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture { -- cgit v1.2.1