From 4bfc70e31766587c951204c93a127a486e007d0c Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Fri, 10 Dec 2021 16:17:56 +0000 Subject: Add Gemm MMUL Reshaped Only Rhs Support for FP32/FP16 This patch introduces a GEMM routine that is optimized for Arm(R) Mali(TM)-G715 and Arm(R) Mali(TM)-G615 Resolves: COMPMID-5216 Signed-off-by: Gunes Bayir Change-Id: I2e5d7806f5904347185bb3e250f73d73d6669dba Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7914 Reviewed-by: SiCong Li Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- Android.bp | 2 + SConscript | 1 + arm_compute/core/CL/CLHelpers.h | 7 + arm_compute/core/GPUTarget.h | 4 +- arm_compute/runtime/CL/CLTypes.h | 6 +- filelist.json | 1 + src/core/CL/CLHelpers.cpp | 4 + .../common/gemm_reshaped_only_rhs_mmul.cl | 528 +++++++++++++++++++++ src/core/CL/cl_kernels/tile_helpers.h | 55 ++- src/core/GPUTarget.cpp | 12 +- src/gpu/cl/ClKernelLibrary.cpp | 6 + ...GemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp | 365 ++++++++++++++ ...ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h | 89 ++++ src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp | 19 +- src/gpu/cl/kernels/gemm/ClGemmHelpers.h | 17 +- .../ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp | 42 +- .../ClGemmDefaultConfigReshapedRhsOnlyValhall.h | 4 +- src/gpu/cl/operators/ClGemm.cpp | 134 ++++++ src/gpu/cl/operators/ClGemm.h | 29 +- src/runtime/CL/functions/CLGEMM.cpp | 3 +- src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp | 61 ++- src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h | 4 +- .../CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp | 231 +++++++++ tests/validation/fixtures/GEMMFixture.h | 197 +++++++- utils/TypePrinter.h | 6 + 25 files changed, 1777 insertions(+), 50 deletions(-) create mode 100644 src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl create mode 100644 src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp create mode 100644 src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h create mode 100644 tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp diff --git a/Android.bp b/Android.bp index 2590469673..16e67ad893 100644 --- a/Android.bp +++ b/Android.bp @@ -40,6 +40,7 @@ opencl_srcs = [ "src/core/CL/cl_kernels/common/floor.cl", "src/core/CL/cl_kernels/common/gather.cl", "src/core/CL/cl_kernels/common/gemm.cl", + "src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl", "src/core/CL/cl_kernels/common/gemm_utils.cl", "src/core/CL/cl_kernels/common/gemmlowp.cl", "src/core/CL/cl_kernels/common/gemv.cl", @@ -617,6 +618,7 @@ cc_library_static { "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp", "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp", "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp", + "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp", "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp", "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp", "src/gpu/cl/kernels/ClHeightConcatenateKernel.cpp", diff --git a/SConscript b/SConscript index 358f9dd971..6f6b078b63 100644 --- a/SConscript +++ b/SConscript @@ -369,6 +369,7 @@ if env['opencl'] and env['embed_kernels']: 'src/core/CL/cl_kernels/common/floor.cl', 'src/core/CL/cl_kernels/common/gather.cl', 'src/core/CL/cl_kernels/common/gemm.cl', + 'src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl', 'src/core/CL/cl_kernels/common/gemm_utils.cl', 'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl', 'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl', diff --git a/arm_compute/core/CL/CLHelpers.h b/arm_compute/core/CL/CLHelpers.h index 94ec5d781b..edbc705c6f 100644 --- a/arm_compute/core/CL/CLHelpers.h +++ b/arm_compute/core/CL/CLHelpers.h @@ -260,5 +260,12 @@ bool export_weights_to_cl_image(const ITensorInfo *tensor); */ void set_unroll_with_pragma(CLBuildOptions &built_opts, std::initializer_list values); +/** Helper function to check whether the cl_arm_matrix_multiply extension is supported + * + * @param[in] device A CL device + * + * @return True if the extension is supported + */ +bool arm_matrix_multiply_supported(const cl::Device &device); } // namespace arm_compute #endif /* ARM_COMPUTE_CLHELPERS_H */ diff --git a/arm_compute/core/GPUTarget.h b/arm_compute/core/GPUTarget.h index 6a8577ac4d..7e2cfe1b6b 100644 --- a/arm_compute/core/GPUTarget.h +++ b/arm_compute/core/GPUTarget.h @@ -51,9 +51,11 @@ enum class GPUTarget G31 = 0x242, G76 = 0x250, G77 = 0x310, + G57 = 0x311, G78 = 0x320, G710 = 0x330, - G57 = 0x340, + G715 = 0x340, + G615 = 0x341 }; /** Enable bitwise operations on GPUTarget enumerations */ diff --git a/arm_compute/runtime/CL/CLTypes.h b/arm_compute/runtime/CL/CLTypes.h index bba25c6d64..d298ecd614 100644 --- a/arm_compute/runtime/CL/CLTypes.h +++ b/arm_compute/runtime/CL/CLTypes.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -35,7 +35,9 @@ enum class CLGEMMKernelType /** Reshaped GEMM kernel where both lhs and rhs matrices are reshaped. Configurable reshape and block size */ RESHAPED, /** Reshaped GEMM kernel where only the rhs matrix is reshaped. Configurable reshape and block size */ - RESHAPED_ONLY_RHS + RESHAPED_ONLY_RHS, + /** Reshaped GEMM kernel where only the rhs matrix is reshaped. Using MMUL with configurable block size. */ + RESHAPED_ONLY_RHS_MMUL }; /** OpenCL GEMM kernel selection parameters. These information are retrieved to select the GEMM kernel on OpenCL */ diff --git a/filelist.json b/filelist.json index ab2cc83a84..513a2207c1 100644 --- a/filelist.json +++ b/filelist.json @@ -479,6 +479,7 @@ "src/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleByFloatKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleKernel.cpp", "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp", + "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp", "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp", "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp", "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp", diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp index 5172a7730a..94675d60cc 100644 --- a/src/core/CL/CLHelpers.cpp +++ b/src/core/CL/CLHelpers.cpp @@ -491,4 +491,8 @@ void set_unroll_with_pragma(CLBuildOptions &built_opts, std::initializer_list 0 + * - N0 = 1, 2, 3, 4, 8, 16 + * - K0 = 1 + * + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition + * + * @param[in] lhs_ptr Pointer to the LHS tensor. Supported data types: F16/F32 + * @param[in] lhs_stride_y Stride of the LHS tensor in Y dimension (in bytes) + * @param[in] lhs_stride_z Stride of the LHS tensor in Z dimension (in bytes) + * @param[in] lhs_w The size of the width dimension of the LHS tensor + * @param[in] lhs_h The size of the height dimension of the LHS tensor + * @param[in] lhs_n The size of the depth dimension of the LHS tensor + * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS tensor + * @param[in] rhs_ptr Pointer to the RHS reshaped tensor. Supported data type: same as @p lhs_ptr + * @param[in] rhs_stride_y Stride of the RHS tensor in Y dimension (in bytes) + * @param[in] rhs_stride_z Stride of the RHS tensor in Z dimension (in bytes) + * @param[in] rhs_w The size of the width dimension of the RHS tensor + * @param[in] rhs_h The size of the height dimension of the RHS tensor + * @param[in] rhs_n The size of the depth dimension of the RHS tensor + * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS tensor + * @param[in] bia_ptr (Optional) Pointer to the bias tensor. Supported data type: same as @p lhs_ptr + * @param[in] bia_stride_y (Optional) Stride of the bias tensor in Y dimension (in bytes) + * @param[in] bia_stride_z (Optional) Stride of the bias tensor in Z dimension (in bytes) + * @param[in] bia_w (Optional) The size of the width dimension of the bias tensor + * @param[in] bia_h (Optional) The size of the height dimension of the bias tensor + * @param[in] bia_n (Optional) The size of the depth dimension of the bias tensor + * @param[in] bia_offset_first_element_in_bytes (Optional) The offset of the first element in the bias tensor + * @param[out] dst_ptr Pointer to the destination tensor. Supported data type: same as @p lhs_ptr + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] dst_w The size of the width dimension of the destination tensor + * @param[in] dst_h The size of the height dimension of the destination tensor + * @param[in] dst_n The size of the depth dimension of the destination tensor + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] M Number of rows in LHS matrix not reshaped + * @param[in] N Number of columns in RHS matrix not reshaped + * @param[in] K Number of columns in LHS matrix and rows in RHS matrix not reshaped + */ +__kernel void gemm_mm_reshaped_only_rhs_nt_mmul( + TENSOR3D_T(lhs, BUFFER), + TENSOR3D_T(rhs, BUFFER), +#if defined(BETA) + TENSOR3D_T(bia, BUFFER), +#endif // defined(BETA) + TENSOR3D_T(dst, BUFFER), + const int M, + const int N, + const int K) +{ +#define MMUL_BLOCK_SIZE (MMUL_N0 * MMUL_K0) + + uint x0 = get_global_id(0); // (N / N0) * MMUL_K0 + uint y0 = get_global_id(1); // (M / M0) / MMUL_M0 + uint z = get_global_id(2); // Batch + + // Get block ID and thread ID within the block + uint block_id = (x0 / MMUL_BLOCK_SIZE); + uint thread_id = (x0 % MMUL_BLOCK_SIZE); + + // Coordinate within a block + uint block_x = thread_id % MMUL_N0; + uint block_y = (thread_id / MMUL_M0); + + // Starting destination coordinates + uint dst_x = min(block_x * N0 + block_id * MMUL_N0 * N0, (uint)(N - 1)); + uint dst_y = min(block_y * M0 + y0 * M0 * MMUL_M0, (uint)(M - M0)); + + // Note: We need to clamp dst_x and dst_y because we always need to execute a complete MMUL block! Only after the matrix multiplication + // part can we exit the kernel if it is out-of-bound. Remember, we have a cooperative matrix multiplication. Therefore, we need a full block to get the correct results + + // Starting LHS coordinates + uint lhs_x = block_x; + uint lhs_y = dst_y; + + // Starting RHS coordinates + uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0; + uint rhs_y = block_id; + + // Compute LHS/RHS/DST matrix address + lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z; + rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z; + dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z; + + // Note: If RHS derives from the weights of convolution 2d layer, RHS will always be 2D and rhs_stride_z will always be equal to 0 for + // not sliding the tensor + + // Initialize the accumulators + // MMUL extension accumulate the result in F32 for both F32 and F16 + TILE(float, M0, N0, c_f32); + +#if !defined(HALF_PRECISION) +#define c c_f32 +#endif // !defined(HALF_PRECISION) + + LOOP_UNROLLING(int, i, 0, 1, M0, + { + c_f32[i].v = 0; + }) + + for(int k = 0; k <= K - MMUL_K0; k += MMUL_K0) + { + TILE(DATA_TYPE, M0, 1, a); + TILE(DATA_TYPE, 1, N0, b); + + // Load tile from the lhs/rhs tensors + T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a); + T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, 0, b); + + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + LOOP_UNROLLING(int, n0, 0, 1, N0, + { + c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0], b[0].s[n0], c_f32[m0].s[n0]); + }) + }) + + lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE); + rhs_offset_first_element_in_bytes += MMUL_K0 * MMUL_N0 * N0 * sizeof(DATA_TYPE); + } + + if(block_x * N0 + block_id * MMUL_N0 * N0 >= N) + { + return; + } + + if(block_y * M0 + y0 * M0 * MMUL_M0 >= M) + { + return; + } + +#if defined(HALF_PRECISION) + TILE(DATA_TYPE, M0, N0, c); + + // Conversion required for the half precision + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + LOOP_UNROLLING(int, n0, 0, 1, N0, + { + c[m0].s[n0] = c_f32[m0].s[n0]; + }) + }) +#endif // defined(HALF_PRECISION) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + T_SCALE_CONSTANT(DATA_TYPE, M0, N0, c, (DATA_TYPE)ALPHA, c); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) +#if defined(BROADCAST_BIAS) + bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE); + + TILE(DATA_TYPE, 1, N0, bias0); + + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + bias0[0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes)); + } + else + { + VLOAD_PARTIAL(N0, N0_LEFTOVER) + (bias0[0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes)); + } + +#ifndef UNIT_BETA + T_SCALE_CONSTANT(DATA_TYPE, 1, N0, bias0, (DATA_TYPE)BETA, bias0); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + T_ELTWISE_BROADCAST_X(V_ADD, DATA_TYPE, M0, N0, c, bias0, c); +#else // defined(BROADCAST_BIAS) + TILE(DATA_TYPE, M0, N0, bias0); + + bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * bia_stride_y + z * bia_stride_z; + + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + bias0[m0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y)); + } + }) + } + else + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VLOAD_PARTIAL(N0, N0_LEFTOVER) + (bias0[m0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y)); + } + }) + } + +#ifndef UNIT_BETA + T_SCALE_CONSTANT(DATA_TYPE, M0, N0, bias0, (DATA_TYPE)BETA, bias0); +#endif // UNIT_BIAS + + // c = c + bias + T_ADD(DATA_TYPE, M0, N0, c, bias0, c); + // c = c + bias +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + + T_ACTIVATION(DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c); + + // Store + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE(N0) + (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + else + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE_PARTIAL(N0, N0_LEFTOVER) + (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + +#undef RHS_BLOCK_SIZE +#undef RHS_OFFSET_X +#undef RHS_STEP_X +} +#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL) + +#if defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_MMUL_TEXTURE) +/** This OpenCL kernel computes the matrix multiplication between 2 matrices using the MMUL extension and the OpenCL image for RHS: + * + * The LHS matrix is NOT reshaped + * The RHS is reshaped with @ref ClGemmMatrixMultiplyReshapedOnlyRhsKernel and the block K0xN0 is NOT transposed + * + * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4). + * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2) + * @note The number of output columns processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_N0 (e.g., -DMMUL_N0=2) + * @note The number of output rows processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_M0 (e.g., -DMMUL_M0=2) + * @note The number of lhs columns (or rhs rows) processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_K0 (e.g., -DMMUL_K0=2) + * @note Only the following configurations of M0, N0 and K0 are currently supported: + * - M0 > 0 + * - N0 = 1, 2, 3, 4, 8, 16 + * - K0 = 1 + * + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition + * + * @param[in] lhs_ptr Pointer to the LHS tensor. Supported data types: F16/F32 + * @param[in] lhs_stride_y Stride of the LHS tensor in Y dimension (in bytes) + * @param[in] lhs_stride_z Stride of the LHS tensor in Z dimension (in bytes) + * @param[in] lhs_w The size of the width dimension of the LHS tensor + * @param[in] lhs_h The size of the height dimension of the LHS tensor + * @param[in] lhs_n The size of the depth dimension of the LHS tensor + * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS tensor + * @param[in] rhs_ptr Pointer to the RHS reshaped tensor. Supported data type: same as @p lhs_ptr + * @param[in] rhs_stride_y Stride of the RHS tensor in Y dimension (in bytes) + * @param[in] rhs_stride_z Stride of the RHS tensor in Z dimension (in bytes) + * @param[in] rhs_w The size of the width dimension of the RHS tensor + * @param[in] rhs_h The size of the height dimension of the RHS tensor + * @param[in] rhs_n The size of the depth dimension of the RHS tensor + * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS tensor + * @param[in] bia_ptr (Optional) Pointer to the bias tensor. Supported data type: same as @p lhs_ptr + * @param[in] bia_stride_y (Optional) Stride of the bias tensor in Y dimension (in bytes) + * @param[in] bia_stride_z (Optional) Stride of the bias tensor in Z dimension (in bytes) + * @param[in] bia_w (Optional) The size of the width dimension of the bias tensor + * @param[in] bia_h (Optional) The size of the height dimension of the bias tensor + * @param[in] bia_n (Optional) The size of the depth dimension of the bias tensor + * @param[in] bia_offset_first_element_in_bytes (Optional) The offset of the first element in the bias tensor + * @param[out] dst_ptr Pointer to the destination tensor. Supported data type: same as @p lhs_ptr + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] dst_w The size of the width dimension of the destination tensor + * @param[in] dst_h The size of the height dimension of the destination tensor + * @param[in] dst_n The size of the depth dimension of the destination tensor + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] M Number of rows in LHS matrix not reshaped + * @param[in] N Number of columns in RHS matrix not reshaped + * @param[in] K Number of columns in LHS matrix and rows in RHS matrix not reshaped + */ +__kernel void gemm_mm_reshaped_only_rhs_nt_mmul_texture( + TENSOR3D_T(lhs, BUFFER), + TENSOR3D_T(rhs, IMAGE), +#if defined(BETA) + TENSOR3D_T(bia, BUFFER), +#endif // defined(BETA) + TENSOR3D_T(dst, BUFFER), + const int M, + const int N, + const int K) +{ +#define MMUL_BLOCK_SIZE (MMUL_N0 * MMUL_K0) + + uint x0 = get_global_id(0); // (N / N0) * MMUL_K0 + uint y0 = get_global_id(1); // (M / M0) / MMUL_M0 + uint z = get_global_id(2); // Batch + + // Get block ID and thread ID within the block + uint block_id = (x0 / MMUL_BLOCK_SIZE); + uint thread_id = (x0 % MMUL_BLOCK_SIZE); + + // Coordinate within a block + uint block_x = thread_id % MMUL_N0; + uint block_y = (thread_id / MMUL_M0); + + // Starting destination coordinates + uint dst_x = min(block_x * N0 + block_id * MMUL_N0 * N0, (uint)(N - 1)); + uint dst_y = min(block_y * M0 + y0 * M0 * MMUL_M0, (uint)(M - M0)); + + // Note: We need to clamp dst_x and dst_y because we always need to execute a complete MMUL block! Only after the matrix multiplication + // part can we exit the kernel if it is out-of-bound. Remember, we have a cooperative matrix multiplication. Therefore, we need a full block to get the correct results + + // Starting LHS coordinates + uint lhs_x = block_x; + uint lhs_y = dst_y; + + // Starting RHS coordinates + uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0; + uint rhs_y = block_id + z * rhs_h; + + // Compute LHS/RHS/DST matrix address + lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z; + dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z; + + // Initialize the accumulators + // MMUL extension accumulate the result in F32 for both F32 and F16 + TILE(float, M0, N0, c_f32); + +#if !defined(HALF_PRECISION) +#define c c_f32 +#endif // !defined(HALF_PRECISION) + + LOOP_UNROLLING(int, i, 0, 1, M0, + { + c_f32[i].v = 0; + }) + + for(int k = 0; k <= K - MMUL_K0; k += MMUL_K0) + { + TILE(DATA_TYPE, M0, 1, a); + TILE(DATA_TYPE, 1, N0, b); + + // Load tile from the lhs/rhs tensors + T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a); + T_LOAD(DATA_TYPE, 1, N0, IMAGE, rhs, rhs_x, rhs_y, 1, rhs_stride_y, b); + + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + LOOP_UNROLLING(int, n0, 0, 1, N0, + { + c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0], b[0].s[n0], c_f32[m0].s[n0]); + }) + }) + + lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE); + rhs_x += MMUL_K0 * MMUL_N0 * N0; + } + + if(block_x * N0 + block_id * MMUL_N0 * N0 >= N) + { + return; + } + + if(block_y * M0 + y0 * M0 * MMUL_M0 >= M) + { + return; + } + +#if defined(HALF_PRECISION) + TILE(DATA_TYPE, M0, N0, c); + + // Conversion required for the half precision + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + LOOP_UNROLLING(int, n0, 0, 1, N0, + { + c[m0].s[n0] = c_f32[m0].s[n0]; + }) + }) +#endif // defined(HALF_PRECISION) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + T_SCALE_CONSTANT(DATA_TYPE, M0, N0, c, (DATA_TYPE)ALPHA, c); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) +#if defined(BROADCAST_BIAS) + bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE); + + TILE(DATA_TYPE, 1, N0, bias0); + + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + bias0[0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes)); + } + else + { + VLOAD_PARTIAL(N0, N0_LEFTOVER) + (bias0[0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes)); + } + +#ifndef UNIT_BETA + T_SCALE_CONSTANT(DATA_TYPE, 1, N0, bias0, (DATA_TYPE)BETA, bias0); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + T_ELTWISE_BROADCAST_X(V_ADD, DATA_TYPE, M0, N0, c, bias0, c); +#else // defined(BROADCAST_BIAS) + TILE(DATA_TYPE, M0, N0, bias0); + + bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * bia_stride_y + z * bia_stride_z; + + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + bias0[m0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y)); + } + }) + } + else + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VLOAD_PARTIAL(N0, N0_LEFTOVER) + (bias0[m0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y)); + } + }) + } + +#ifndef UNIT_BETA + T_SCALE_CONSTANT(DATA_TYPE, M0, N0, bias0, (DATA_TYPE)BETA, bias0); +#endif // UNIT_BIAS + + // c = c + bias + T_ADD(DATA_TYPE, M0, N0, c, bias0, c); + // c = c + bias +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + + T_ACTIVATION(DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c); + + // Store + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE(N0) + (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + else + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE_PARTIAL(N0, N0_LEFTOVER) + (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + +#undef RHS_BLOCK_SIZE +#undef RHS_OFFSET_X +#undef RHS_STEP_X +} +#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL_TEXTURE) \ No newline at end of file diff --git a/src/core/CL/cl_kernels/tile_helpers.h b/src/core/CL/cl_kernels/tile_helpers.h index 0ce343e3ec..4b6144a22d 100644 --- a/src/core/CL/cl_kernels/tile_helpers.h +++ b/src/core/CL/cl_kernels/tile_helpers.h @@ -970,8 +970,8 @@ #define ACT_OP_QUANTIZED(op, DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x) op##_op_quantized(DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x) #define ACTIVATION_QUANTIZED(op, DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x) ACT_OP_QUANTIZED(op, DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x) -#define T_ADD(A_VAL, B_VAL) ((A_VAL) + (B_VAL)) -#define T_DIV(A_VAL, B_VAL) ((A_VAL) / (B_VAL)) +#define V_ADD(A_VAL, B_VAL) ((A_VAL) + (B_VAL)) +#define V_DIV(A_VAL, B_VAL) ((A_VAL) / (B_VAL)) /** Element-wise activation for quantized types * @@ -995,6 +995,25 @@ }) \ }) +/** Element-wise addition between two tiles + * + * @note Performs: LHS + RHS = DST + * + * @param[in] DATA_TYPE LHS/RHS/DST data type + * @param[in] M0 Number of LHS rows + * @param[in] N0 Number of LHS columns + * @param[in] lhs LHS tile + * @param[in] rhs Constant RHS tile + * @param[out] dst DST tile + */ +#define T_ADD(DATA_TYPE, M0, N0, lhs, rhs, dst) \ + ({ \ + LOOP_UNROLLING(int, _m0, 0, 1, M0, \ + { \ + dst[_m0].v = lhs[_m0].v + rhs[_m0].v; \ + }) \ + }) + /** Element-wise addition with a constant value * * @note Performs: LHS + constant = DST @@ -1010,15 +1029,31 @@ ({ \ LOOP_UNROLLING(int, _m0, 0, 1, M0, \ { \ - LOOP_UNROLLING(int, _n0, 0, 1, N0, \ - { \ - dst[_m0].s[_n0] = lhs[_m0].s[_n0] + rhs_constant; \ - }) \ + dst[_m0].v = lhs[_m0].v + (DATA_TYPE)rhs_constant; \ }) \ }) -#define T_ELTWISE_BROADCAST_ADD_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE_BROADCAST_X(T_ADD, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) -#define T_ELTWISE_BROADCAST_DIV_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE_BROADCAST_X(T_DIV, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) +#define T_ELTWISE_BROADCAST_ADD_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE_BROADCAST_X(V_ADD, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) +#define T_ELTWISE_BROADCAST_DIV_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE_BROADCAST_X(V_DIV, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) + +/** Element-wise scale with a constant value + * + * @note Performs: LHS * constant = DST + * + * @param[in] DATA_TYPE LHS/RHS/DST data type + * @param[in] M0 Number of LHS rows + * @param[in] N0 Number of LHS columns + * @param[in] lhs LHS tile + * @param[in] rhs_constant Constant value + * @param[out] dst DST tile + */ +#define T_SCALE_CONSTANT(DATA_TYPE, M0, N0, lhs, rhs_constant, dst) \ + ({ \ + LOOP_UNROLLING(int, _m0, 0, 1, M0, \ + { \ + dst[_m0].v = lhs[_m0].v * (DATA_TYPE)rhs_constant; \ + }) \ + }) /** Element-wise operation with RHS broadcasted (RHS has the X dimension only) * @@ -1041,8 +1076,8 @@ }) \ }) -#define T_ELTWISE_ADD(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE(T_ADD, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) -#define T_ELTWISE_DIV(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE(T_DIV, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) +#define T_ELTWISE_ADD(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE(V_ADD, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) +#define T_ELTWISE_DIV(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE(V_DIV, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) /** Element-wise operation between two tiles (LHS and RHS) * diff --git a/src/core/GPUTarget.cpp b/src/core/GPUTarget.cpp index 5984c88099..e74abf6e7e 100644 --- a/src/core/GPUTarget.cpp +++ b/src/core/GPUTarget.cpp @@ -47,6 +47,14 @@ arm_compute::GPUTarget get_valhall_target(const std::string &version) { return arm_compute::GPUTarget::G57; } + else if(version.find("G715") != std::string::npos) + { + return arm_compute::GPUTarget::G715; + } + else if(version.find("G615") != std::string::npos) + { + return arm_compute::GPUTarget::G615; + } else { return arm_compute::GPUTarget::UNKNOWN; @@ -141,7 +149,9 @@ const std::string &string_from_target(GPUTarget target) { GPUTarget::G77, "g77" }, { GPUTarget::G78, "g78" }, { GPUTarget::G710, "g710" }, - { GPUTarget::G57, "g57" } + { GPUTarget::G57, "g57" }, + { GPUTarget::G715, "g715" }, + { GPUTarget::G615, "g615" } }; return gpu_target_map[target]; diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp index 1bf7f2b3ac..52661d6d79 100644 --- a/src/gpu/cl/ClKernelLibrary.cpp +++ b/src/gpu/cl/ClKernelLibrary.cpp @@ -272,6 +272,8 @@ const std::map ClKernelLibrary::_kernel_program_map = { "gemm_mv", "common/gemv.cl" }, { "gemm_mv_quantized", "common/gemv.cl" }, { "gemm_mm_native", "common/gemm.cl" }, + { "gemm_mm_reshaped_only_rhs_nt_mmul", "common/gemm_reshaped_only_rhs_mmul.cl" }, + { "gemm_mm_reshaped_only_rhs_nt_mmul_texture", "common/gemm_reshaped_only_rhs_mmul.cl" }, { "gemm_mm_native_post_act_eltwise_op_act", "common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl" }, { "gemm_mm_reshaped_lhs_nt_rhs_t", "common/gemm.cl" }, { "gemm_mm_reshaped_lhs_nt_rhs_t_texture", "common/gemm.cl" }, @@ -582,6 +584,10 @@ const std::map ClKernelLibrary::_program_source_map = { "common/gemm.cl", #include "./cl_kernels/common/gemm.clembed" + }, + { + "common/gemm_reshaped_only_rhs_mmul.cl", +#include "./cl_kernels/common/gemm_reshaped_only_rhs_mmul.clembed" }, { "common/gemm_utils.cl", diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp new file mode 100644 index 0000000000..fe46913517 --- /dev/null +++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp @@ -0,0 +1,365 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/CL/OpenCL.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/CL/CLUtils.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" +#include "src/core/utils/helpers/float_ops.h" +#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" +#include "support/Cast.h" +#include "support/StringSupport.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace +{ +using ElementsProcessed = Steps; + +// Block size dimensions for the MMUL extension +constexpr int mmul_m0 = 4; +constexpr int mmul_n0 = 4; +constexpr int mmul_k0 = 4; + +Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMKernelInfo &gemm_info) +{ + ARM_COMPUTE_UNUSED(alpha); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()), "The extension cl_arm_matrix_multiply is not supported on the target platform"); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_info.m0 < 1, "Only values greater than 0 are supported for m0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.n0 != 1 && rhs_info.n0 != 2 && rhs_info.n0 != 3 && rhs_info.n0 != 4 && rhs_info.n0 != 8 && rhs_info.n0 != 16, "Only 1,2,3,4,8, and 16 are supported for n0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.k0 != 1 || lhs_info.k0 != 1), "Only 1 is supported for k0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.h0 != 4), "Only 4 is supported for h0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.interleave != true, "Only true is supported for interleave with mmul extension enabled"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.transpose != false, "Only false is supported for transpose with mmul extension enabled"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported"); + ARM_COMPUTE_RETURN_ON_ERROR(gemm::validate_image2d_support_on_rhs(*src1, rhs_info)); + + const unsigned int m = gemm_info.m; + const unsigned int n = gemm_info.n; + const unsigned int k = gemm_info.k; + + ARM_COMPUTE_UNUSED(m); + ARM_COMPUTE_UNUSED(n); + ARM_COMPUTE_UNUSED(k); + + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(0) != k); + + // Validate the reinterpreted-as-3D-case + if(gemm_info.depth_output_gemm3d != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) * src0->dimension(2) != m); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) != m); + } + + // Validate the gemm-batched case + if(src1->num_dimensions() > 2) + { + if(gemm_info.depth_output_gemm3d != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(3) != src1->dimension(2)); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(2) != src1->dimension(2)); + } + } + + if(src2 != nullptr && !(helpers::float_ops::is_zero(beta))) + { + const unsigned int src2_dim0 = src2->dimension(0); + const unsigned int src2_dim1 = src2->dimension(1); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src2, src1); + if(gemm_info.broadcast_bias) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((src2_dim1 != 1 || src2_dim0 != n), "Incorrect dimension of bias matrix which is to be broadcasted"); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((src2_dim0 != n || src2_dim1 != m), "Incorrect dimension of bias matrix"); + } + } + + TensorShape tensor_shape1{ src1->tensor_shape() }; + tensor_shape1.set(0, n); + tensor_shape1.set(1, k); + + const TensorInfo tensor_info1 = src1->clone()->set_tensor_shape(tensor_shape1); + const TensorInfo tensor_info_reshaped1 = src1->clone()->set_tensor_shape(misc::shape_calculator::compute_rhs_reshaped_shape(tensor_info1, rhs_info)); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(src1, &tensor_info_reshaped1); + + if(dst->total_size() != 0) + { + const TensorInfo tensor_info_dst = dst->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &tensor_info_dst); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, dst); + } + + return Status{}; +} + +std::pair validate_and_configure_window(ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMKernelInfo &gemm_info) +{ + ARM_COMPUTE_UNUSED(src0, src1, src2); + bool reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0; + + // dst tensor auto initialization if not yet initialized + auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info))); + + TensorInfo tmp_info(*dst); + + if(reinterpret_output_as_3d) + { + // Since the dst tensor has to be reinterpreted as 3D and the execute window is based on a 2D GEMM, + // the window needs to be constructed on the 2D collapsed version of the tensor + TensorShape tmp_shape(dst->tensor_shape()); + tmp_shape.collapse(2U, 1U); + tmp_info.set_tensor_shape(tmp_shape); + } + + Window win = calculate_max_window(tmp_info, Steps(1, 1)); + + // Collapse along the Z direction + // This collapse needs to be here in order to tune the Z dimension of LWS + const unsigned int dimension_to_collapse = std::min(static_cast(dst->num_dimensions()), 2u); + Window collapsed = win.collapse(win, dimension_to_collapse); + + // Reconfigure window size, one arm_matrix_multiply kernel needs 16 threads to finish. + Window::Dimension x_dimension = collapsed.x(); + Window::Dimension y_dimension = collapsed.y(); + + // Make M and N multiple of M0 and N0 respectively + const unsigned int ceil_to_multiple_n_n0 = ceil_to_multiple(x_dimension.end(), rhs_info.n0); + const unsigned int ceil_to_multiple_m_m0 = ceil_to_multiple(y_dimension.end(), lhs_info.m0); + + // Divide M and N by M0 and N0 respectively + const unsigned int n_div_n0 = ceil_to_multiple_n_n0 / rhs_info.n0; + const unsigned int m_div_m0 = ceil_to_multiple_m_m0 / lhs_info.m0; + + // Make n_div_n0 and m_div_m0 multiple of mmul_n0 and mmul_k0 respectively + const unsigned int ceil_to_multiple_n_div_n0_mmul_n0 = ceil_to_multiple(n_div_n0, mmul_n0); + const unsigned int ceil_to_multiple_m_div_m0_mmul_k0 = ceil_to_multiple(m_div_m0, mmul_k0); + + // Ensure x_dimension is multiple of MMUL block size (mmul_n0 * mmul_k0) + x_dimension.set_end(ceil_to_multiple_n_div_n0_mmul_n0 * mmul_k0); + y_dimension.set_end(ceil_to_multiple_m_div_m0_mmul_k0 / mmul_k0); + + collapsed.set(Window::DimX, x_dimension); + collapsed.set(Window::DimY, y_dimension); + + return std::make_pair(Status{}, collapsed); +} +} // namespace + +ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel() +{ + _type = CLKernelType::GEMM; +} + +void ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::configure(const CLCompileContext &compile_context, ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float alpha, + float beta, + const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); + + // dst tensor auto initialization if not yet initialized + auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info))); + + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info)); + + auto padding_info = get_padding_info({ src0, src1, src2, dst }); + _add_bias = src2 != nullptr; + _export_to_cl_image = rhs_info.export_to_cl_image; + + // Configure kernel window + auto win_config = validate_and_configure_window(src0, src1, src2, dst, lhs_info, rhs_info, gemm_info); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + + IClKernel::configure_internal(win_config.second); + + _m = gemm_info.m; + _n = gemm_info.n; + _k = gemm_info.k; + + const unsigned int m0_leftover = _m % lhs_info.m0; + const unsigned int n0_leftover = _n % rhs_info.n0; + + // Create build options + CLBuildOptions build_opts; + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src0->data_type())); + build_opts.add_option_if(!(helpers::float_ops::is_one(alpha)), "-DALPHA=" + float_to_string_with_full_precision(alpha)); + build_opts.add_option_if(src2 != nullptr, "-DBETA=" + float_to_string_with_full_precision(beta)); + build_opts.add_option_if(helpers::float_ops::is_one(beta), "-DUNIT_BETA"); + build_opts.add_option_if(gemm_info.broadcast_bias, "-DBROADCAST_BIAS"); + build_opts.add_option_if(src0->data_type() == DataType::F16, "-DHALF_PRECISION"); + build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0)); + build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0)); + build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0)); + build_opts.add_option("-DM0_LEFTOVER=" + support::cpp11::to_string(m0_leftover)); + build_opts.add_option("-DN0_LEFTOVER=" + support::cpp11::to_string(n0_leftover)); + build_opts.add_option("-DMMUL_M0=" + support::cpp11::to_string(mmul_m0)); + build_opts.add_option("-DMMUL_N0=" + support::cpp11::to_string(mmul_n0)); + build_opts.add_option("-DMMUL_K0=" + support::cpp11::to_string(mmul_k0)); + build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation()))); + build_opts.add_option("-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a())); + build_opts.add_option("-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b())); + + std::string kernel_name("gemm_mm_reshaped_only_rhs_nt_mmul"); + kernel_name += rhs_info.export_to_cl_image ? "_texture" : ""; + + // A macro guard to compile ONLY the kernel of interest + build_opts.add_option("-D" + upper_string(kernel_name)); + + // Create kernel + _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); + + // Set config_id for enabling LWS tuning + _config_id = kernel_name; + _config_id += "_"; + _config_id += (_add_bias ? "add_bias_" : ""); + _config_id += (gemm_info.broadcast_bias ? "broadcast_bias_" : ""); + _config_id += (gemm_info.activation_info.enabled() ? "fused_activation_" : ""); + _config_id += lower_string(string_from_data_type(src0->data_type())); + _config_id += "_"; + _config_id += support::cpp11::to_string(_m); + _config_id += "_"; + _config_id += support::cpp11::to_string(_n); + _config_id += "_"; + _config_id += support::cpp11::to_string(_k); + _config_id += "_"; + _config_id += support::cpp11::to_string(lhs_info.m0); + _config_id += "_"; + _config_id += support::cpp11::to_string(rhs_info.n0); + + ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); +} + +Status ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, + const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(src0->clone().get(), + src1->clone().get(), + src2 != nullptr ? src2->clone().get() : nullptr, + dst->clone().get(), + lhs_info, + rhs_info, + gemm_info) + .first); + + return Status{}; +} + +void ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) +{ + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); + + const auto src0 = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_0)); + const auto src1 = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); + const auto src2 = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_2)); + auto dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); + + ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); + ARM_COMPUTE_ERROR_ON(_add_bias && src2 == nullptr); + + if(src1->info()->num_dimensions() < 3) + { + // The stride_z for matrix B must be zero if we do not slice + ARM_COMPUTE_ERROR_ON(src1->info()->strides_in_bytes()[3] != 0); + } + + cl::Image2D src1_image2d; + + if(_export_to_cl_image) + { + const TensorShape shape2d(src1->info()->dimension(0) / 4, src1->info()->dimension(1) * src1->info()->dimension(2)); + const size_t image_row_pitch = src1->info()->strides_in_bytes()[1]; + + src1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), src1->cl_buffer(), shape2d, src1->info()->data_type(), image_row_pitch); + } + + Window slice = window.first_slice_window_3D(); + + do + { + unsigned int idx = 0; + + add_3d_tensor_nhw_argument(idx, src0); + if(_export_to_cl_image) + { + _kernel.setArg(idx++, src1_image2d); + } + add_3d_tensor_nhw_argument(idx, src1); + + // Bias buffer (_add_bias == true) + if(_add_bias) + { + add_3d_tensor_nhw_argument(idx, src2); + } + // dst buffer + add_3d_tensor_nhw_argument(idx, dst); + + // Pass m, n and k at runtime as signed ints, to ensure results of any subtractions they could be operand in, would still be signed. + _kernel.setArg(idx++, _m); + _kernel.setArg(idx++, _n); + _kernel.setArg(idx++, _k); + + // LWS_x should be multiple of 16 at least. (32, 2) has been chosen to have more work-items on a single core + // LWS also enforces the order of execution of the workitems which improves cache utilization + enqueue(queue, *this, slice, cl::NDRange(32, 2), false); + } + while(window.slide_window_slice_3D(slice)); +} +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h new file mode 100644 index 0000000000..59612fcf5d --- /dev/null +++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H +#define ARM_COMPUTE_CL_GEMM_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H + +#include "arm_compute/core/KernelDescriptors.h" +#include "src/core/common/Macros.h" +#include "src/gpu/cl/ClCompileContext.h" +#include "src/gpu/cl/IClKernel.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +/** OpenCL kernel to multiply matrices using MMUL when only the input matrix RHS (src1) has been reshaped */ +class ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel : public IClKernel +{ +public: + ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel(); + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel); + /** Initialize the kernel's input and dst. + * + * @param[in] compile_context The compile context to be used. + * @param[in] src0 Input tensor for the LHS matrix. Data type supported: F16/F32. + * @param[in] src1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p src0. + * @param[in] src2 Input tensor containing the bias matrix. Data type supported: same as @p src0. + * @param[out] dst dst tensor info. Data type supported: same as @p src0 + * @param[in] alpha Weight of the matrix product + * @param[in] beta Weight of the matrix bias + * @param[in] lhs_info LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported: + * lhs_info.m0 > 0 + * lhs_info.k0: 1 + * @param[in] rhs_info RHS matrix information used to retrieve the number of columns and accumulations to be processed by each thread. Only the following values are supported: + * rhs_info.n0: 1,2,3,4,8,16 + * rhs_info.k0: same of lhs_info.k0 + * rhs_info.transpose: false + * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices + */ + void configure(const ClCompileContext &compile_context, ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float alpha, float beta, + const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMKernelInfo &gemm_info); + /** Static function to check if given info will lead to a valid configuration + * + * Similar to @ref ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::configure() + * + * @return a status + */ + static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMKernelInfo &gemm_info); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; + +private: + bool _add_bias{ false }; + bool _export_to_cl_image{ false }; + signed int _m{ 1 }; + signed int _n{ 1 }; + signed int _k{ 1 }; +}; +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H */ diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp index 1bf27ba277..67da06102d 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -110,6 +110,23 @@ Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, return Status{}; } + +bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b, + const DataType data_type, unsigned int &best_m0, unsigned int &best_n0) +{ + ARM_COMPUTE_UNUSED(n, k, b, data_type); + + const unsigned int mmul_k0 = 4; + best_m0 = 4; + best_n0 = 4; + + const unsigned int ceil_to_multiple_m_m0 = ceil_to_multiple(m, best_m0); + const unsigned int m_div_m0 = ceil_to_multiple_m_m0 / best_m0; + const unsigned int ceil_to_multiple_m_div_m0_mmul_k0 = ceil_to_multiple(m_div_m0, mmul_k0); + const unsigned int gws_y = ceil_to_multiple_m_div_m0_mmul_k0 / mmul_k0; + + return ((k % mmul_k0) == 0) && (gws_y > 4); +} } // namespace gemm } // namespace kernels } // namespace opencl diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h index 3fce8c9173..bf1e8fce82 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,6 +88,21 @@ void update_padding_for_cl_image(ITensorInfo *tensor); * @return Status reporting if we can use the image2d OpenCL object on the RHS reshaped matrix */ Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info); + +/** Determine if the MMUL kernels should be preferred + * + * @param[in] m Number of rows of the LHS matrix + * @param[in] n Number of columns of the RHS matrix + * @param[in] k Number of columns of the LHS matrix, rows of the RHS matrix + * @param[in] b Batch size + * @param[in] data_type Data type FP32/FP16 + * @param[in, out] best_m0 Suggested M0 (number of rows of the output block) for the kernel + * @param[in, out] best_n0 Suggested N0 (number of columns of the output block) for the kernel + * + * @return true if MMUL kernel is preferred over kernels w/o MMUL, false otherwise + */ +bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b, + const DataType data_type, unsigned int &best_m0, unsigned int &best_n0); } // namespace gemm } // namespace kernels } // namespace opencl diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp index a82084a8df..97762980be 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,7 +29,9 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" + #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" +#include "src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h" #include @@ -61,6 +63,10 @@ std::pair ClGemmDefaultConfigReshapedRhsOn &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16, &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + CLGEMMConfigArray configs_G715(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + ConfigurationFunctionExecutorPtr func = nullptr; switch(_target) @@ -68,6 +74,10 @@ std::pair ClGemmDefaultConfigReshapedRhsOn case GPUTarget::G78: func = configs_G78.get_function(data_type); break; + case GPUTarget::G715: + case GPUTarget::G615: + func = configs_G715.get_function(data_type); + break; case GPUTarget::G77: default: func = configs_G77.get_function(data_type); @@ -564,6 +574,36 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } } + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + unsigned int best_m0; + unsigned int best_n0; + + if(is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0)) + { + return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true); + } + else + { + return configure_G77_f32(m, n, k, b); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + unsigned int best_m0; + unsigned int best_n0; + + if(is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0)) + { + return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true); + } + else + { + return configure_G78_f16(m, n, k, b); + } +} } // namespace gemm } // namespace kernels } // namespace opencl diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h index c5e80a7ddc..0ec068fffd 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -53,6 +53,8 @@ private: std::pair configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); }; } // namespace gemm } // namespace kernels diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp index 88f6b79b56..4db39a635d 100644 --- a/src/gpu/cl/operators/ClGemm.cpp +++ b/src/gpu/cl/operators/ClGemm.cpp @@ -191,6 +191,7 @@ ClGemm::ClGemm() _mm_native_kernel(std::make_unique()), _mm_reshaped_kernel(std::make_unique()), _mm_reshaped_only_rhs_kernel(std::make_unique()), + _mm_reshaped_only_rhs_mmul_kernel(std::make_unique()), _tmp_a(), _tmp_b(), _reshape_b_only_on_first_run(false), @@ -324,6 +325,53 @@ void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size()); } +void ClGemm::configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, + const GEMMInfo &gemm_info) +{ + DataType data_type = a->data_type(); + bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); + const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); + const unsigned int n = b->dimension(0); + const unsigned int k = a->dimension(0); + const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); + const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); + const GPUTarget gpu_target = CLScheduler::get().target(); + bool broadcast_bias = gemm_info.broadcast_bias(); + + GEMMKernelInfo kernel_info; + kernel_info.m = m; + kernel_info.n = n; + kernel_info.k = k; + kernel_info.depth_output_gemm3d = depth_output_gemm3d; + kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; + kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); + kernel_info.post_ops = gemm_info.post_ops(); + + // Set the target for the kernels + _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target); + + GEMMLHSMatrixInfo lhs_info{}; + GEMMRHSMatrixInfo rhs_info{}; + + // Pick up the GEMM configuration + auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }); + lhs_info = gemm_config.lhs_info; + rhs_info = gemm_config.rhs_info; + // Force H0 to 4 in order to use the MMUL extension + rhs_info.h0 = 4; + + // Reshape Rhs matrix + _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info); + + // Configure matrix multiply kernel with no y padding support + kernel_info.has_pad_y = false; + _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); + + // Request memory for RHS reshape matrix + _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size()); +} + Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); @@ -458,6 +506,54 @@ Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf return Status{}; } +Status ClGemm::validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) +{ + ARM_COMPUTE_UNUSED(alpha); + ARM_COMPUTE_UNUSED(output); + TensorInfo tmp_b_info{}; + + // Get the GPU target + const GPUTarget gpu_target = CLScheduler::get().target(); + const DataType data_type = a->data_type(); + bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); + const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); + const unsigned int n = b->dimension(0); + const unsigned int k = a->dimension(0); + const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); + const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); + const bool broadcast_bias = gemm_info.broadcast_bias(); + + GEMMKernelInfo kernel_info; + kernel_info.m = m; + kernel_info.n = n; + kernel_info.k = k; + kernel_info.depth_output_gemm3d = depth_output_gemm3d; + kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; + kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); + kernel_info.post_ops = gemm_info.post_ops(); + + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + + // Pick up the GEMM configuration + // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails + const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }); + lhs_info = gemm_config.lhs_info; + rhs_info = gemm_config.rhs_info; + // Force H0 to 4 in order to use the MMUL extension + rhs_info.h0 = 4; + + auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); + ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)); + + // Validate matrix multiply + kernel_info.has_pad_y = false; + ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info)); + + return Status{}; +} + void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output); @@ -501,6 +597,11 @@ void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); break; } + case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: + { + configure_reshaped_only_rhs_mmul(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); + break; + } default: { ARM_COMPUTE_ERROR("GEMMType not supported"); @@ -545,6 +646,11 @@ Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info)); break; } + case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: + { + ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs_mmul(a, b, c_to_use, output, alpha, beta, gemm_info)); + break; + } default: { ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported"); @@ -627,6 +733,34 @@ void ClGemm::run(ITensorPack &tensors) } break; } + case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: + { + if(!_reshape_b_only_on_first_run) + { + // Run transpose kernel + ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } }; + CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false); + } + // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement + // Check if the lhs or dst tensors have padding + const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom; + const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom; + bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0); + + // Copy original tensor pack and overwrite rhs with reshaped counterpart + ITensorPack gemm_reshaped_onlyrhs_pack(tensors); + gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get()); + + if(has_pad_y) + { + ARM_COMPUTE_ERROR_ON(has_pad_y); + } + else + { + CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_onlyrhs_pack, true); + } + break; + } default: { ARM_COMPUTE_ERROR("GEMMType not supported"); diff --git a/src/gpu/cl/operators/ClGemm.h b/src/gpu/cl/operators/ClGemm.h index 3c0cad3ca4..aac463f0b8 100644 --- a/src/gpu/cl/operators/ClGemm.h +++ b/src/gpu/cl/operators/ClGemm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -34,6 +34,7 @@ #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h" #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h" #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.h" +#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h" #include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h" #include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" @@ -50,6 +51,7 @@ namespace opencl * -# @ref kernels::ClGemmMatrixMultiplyNativeKernel (only if NATIVE is selected by the select_gemm_kernel method()) * -# @ref kernels::ClGemmMatrixMultiplyReshapedKernel (only if RESHAPED is selected by the select_gemm_kernel method()) * -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method()) + * -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel (only if RESHAPED_ONLY_RHS_MMUL is selected by the select_gemm_kernel method()) */ class ClGemm : public IClOperator { @@ -102,10 +104,12 @@ private: void configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); void configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); void configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); private: enum AuxTensorIdx @@ -116,17 +120,18 @@ private: }; private: - std::unique_ptr _reshape_lhs_kernel; - std::unique_ptr _reshape_rhs_kernel; - std::unique_ptr _mm_native_kernel; - std::unique_ptr _mm_reshaped_kernel; - std::unique_ptr _mm_reshaped_only_rhs_kernel; - TensorInfo _tmp_a; - TensorInfo _tmp_b; - bool _reshape_b_only_on_first_run; - CLGEMMKernelType _gemm_kernel_type; - bool _is_prepared; - experimental::MemoryRequirements _aux_mem{}; + std::unique_ptr _reshape_lhs_kernel; + std::unique_ptr _reshape_rhs_kernel; + std::unique_ptr _mm_native_kernel; + std::unique_ptr _mm_reshaped_kernel; + std::unique_ptr _mm_reshaped_only_rhs_kernel; + std::unique_ptr _mm_reshaped_only_rhs_mmul_kernel; + TensorInfo _tmp_a; + TensorInfo _tmp_b; + bool _reshape_b_only_on_first_run; + CLGEMMKernelType _gemm_kernel_type; + bool _is_prepared; + experimental::MemoryRequirements _aux_mem{}; }; } // namespace opencl } // namespace arm_compute diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index cc6689c504..427ea51ab9 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -30,7 +30,6 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" -#include "arm_compute/runtime/CL/functions/CLGEMM.h" #include "src/core/helpers/MemoryHelpers.h" #include "src/gpu/cl/operators/ClGemm.h" diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp index 64271a8801..4c7daf916e 100644 --- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp +++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -79,10 +79,28 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::select_kernel(const CLGEMMKernelSelec { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 } }; + // Mali-G715 and Mali-G615 configurations + static std::map gemm_g715_configs = + { + { DataType::F32, &CLGEMMDefaultTypeValhall::g715_f32 }, + { DataType::F16, &CLGEMMDefaultTypeValhall::g715_f16 }, + { DataType::QASYMM8, &CLGEMMDefaultTypeValhall::default_q8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeValhall::default_q8 }, + { DataType::QSYMM8, &CLGEMMDefaultTypeValhall::default_q8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 } + }; + const DataType data_type = params.data_type; switch(_target) { + case GPUTarget::G715: + case GPUTarget::G615: + if(gemm_g715_configs.find(data_type) != gemm_g715_configs.end()) + { + return (this->*gemm_g715_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + } + ARM_COMPUTE_ERROR("Not supported data type"); case GPUTarget::G78: if(gemm_g78_configs.find(data_type) != gemm_g78_configs.end()) { @@ -306,5 +324,46 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::g78_f16(unsigned int m, unsigned int return CLGEMMKernelType::RESHAPED_ONLY_RHS; } + +CLGEMMKernelType CLGEMMDefaultTypeValhall::g715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +{ + if(!is_rhs_constant) + { + return default_f32(m, n, k, b, is_rhs_constant); + } + + unsigned int best_m0; + unsigned int best_n0; + + if(opencl::kernels::gemm::is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0)) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL; + } + else + { + return default_f32(m, n, k, b, is_rhs_constant); + } +} + +CLGEMMKernelType CLGEMMDefaultTypeValhall::g715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +{ + if(!is_rhs_constant) + { + return g78_f16(m, n, k, b, is_rhs_constant); + } + + unsigned int best_m0; + unsigned int best_n0; + + if(opencl::kernels::gemm::is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0)) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL; + } + else + { + return g78_f16(m, n, k, b, is_rhs_constant); + } +} + } // namespace cl_gemm } // namespace arm_compute diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h index c88fbcf557..0893f11132 100644 --- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h +++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -50,6 +50,8 @@ private: CLGEMMKernelType g77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType g715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType g715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); }; } // namespace cl_gemm } // namespace arm_compute diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp new file mode 100644 index 0000000000..7808be8529 --- /dev/null +++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h" +#include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" +#include "tests/CL/CLAccessor.h" +#include "tests/CL/Helper.h" +#include "tests/framework/Macros.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/validation/Validation.h" +#include "tests/validation/fixtures/GEMMFixture.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +using namespace arm_compute::opencl::kernels; + +// Create function for ClGemmReshapeRhsMatrixKernel +using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator; + +// Create function for ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel +using CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL = CLSynthetizeOperator; + +// Fixture for CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL +template +using CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture = GEMMMatrixMultiplyReshapedOnlyRhsMMULValidationFixture; + +namespace +{ +// *INDENT-OFF* +// clang-format off +RelativeTolerance rel_tolerance_f32(0.001f); +constexpr float abs_tolerance_f32(0.0001f); +RelativeTolerance rel_tolerance_f16(half_float::half(0.001f)); +constexpr float abs_tolerance_f16(0.3f); + +/** Alpha values to test - Precommit */ +const auto a_values = framework::dataset::make("alpha", {1.0f, 0.75f} ); + +/** Beta values to test - Precommit */ +const auto beta_values = framework::dataset::make("beta", {0.0f, -0.75f} ); + +/** M values to test */ +const auto m_values = framework::dataset::make("M", {49}); + +/** N values to test */ +const auto n_values = framework::dataset::make("N", {257}); + +/** K values to test */ +/** The test case requires this to be multiple of 4*/ +const auto k_values = framework::dataset::make("K", {192}); + +/** Batch size values to test */ +const auto b_values = framework::dataset::make("batch_size", {1, 2}); + +/** Activation values to test */ +const auto act_values = framework::dataset::make("Activation", +{ + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), +}); + +/** M0 values to test - Precommit */ +const auto m0_values_precommit = framework::dataset::make("M0", { 1, 2, 4 }); + +/** N0 values to test - Precommit */ +const auto n0_values_precommit = framework::dataset::make("N0", { 4, 8 }); + +/** K0 values to test - Precommit */ +const auto k0_values_precommit = framework::dataset::make("K0", { 1 }); + +/** Broadcast bias from vector to matrix */ +const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } ); + +} // namespace + +TEST_SUITE(CL) +TEST_SUITE(GEMMMatrixMultiplyReshapedOnlyRhsMMUL) +TEST_SUITE(Float) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportToCLImage", false)), + framework::dataset::make("DataType", DataType::F32)), + a_values), + beta_values), + broadcast_bias_values), + act_values)) +{ + // Validate output + if(validate_result) + { + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} + +TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportToCLImage", false)), + framework::dataset::make("DataType", DataType::F16)), + a_values), + beta_values), + broadcast_bias_values), + act_values)) +{ + // Validate output + if(validate_result) + { + validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} +TEST_SUITE_END() // FP16 + +TEST_SUITE(ExportToCLImage) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportToCLImage", true)), + framework::dataset::make("DataType", DataType::F32)), + a_values), + beta_values), + broadcast_bias_values), + act_values)) +{ + // Validate output + if(validate_result) + { + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} + +TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportToCLImage", true)), + framework::dataset::make("DataType", DataType::F16)), + a_values), + beta_values), + broadcast_bias_values), + act_values)) +{ + // Validate output + if(validate_result) + { + validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} +TEST_SUITE_END() // FP16 +TEST_SUITE_END() // ExportToCLImage +TEST_SUITE_END() // Float +TEST_SUITE_END() // GEMMMatrixMultiplyReshapedOnlyRhsMMUL +TEST_SUITE_END() // CL +} // namespace validation +} // namespace test +} // namespace arm_compute diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index 884b13da80..55bbbdaf80 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -163,18 +163,18 @@ protected: const int m = reinterpret_output_as_3d ? output_shape[1] * output_shape[2] : output_shape[1]; const int batch_size = reinterpret_output_as_3d ? output_shape[3] : output_shape[2]; - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(c.data() + i * n, c.data(), n * sizeof(T)); } } - + /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M), therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K) in order to be able to call reference implementation that works with (B x M x K) input. Similarly, if pretranspose_B is set to true, then B is assumed to be (B x N x K), B must be pre-transposed before passing it to the fixture. */ - + // Define transposed shapes TensorShape a_transposed_shape(a.shape().y(), a.shape().x()); TensorShape b_transposed_shape(b.shape().y(), b.shape().x()); @@ -315,7 +315,7 @@ protected: if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -438,7 +438,7 @@ protected: fill(rhs, 1); fill(bias, 2); - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -593,7 +593,7 @@ protected: if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -748,7 +748,7 @@ protected: fill(rhs, 1); fill(bias, 2); - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -923,7 +923,7 @@ protected: if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -1169,7 +1169,7 @@ protected: if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -1361,7 +1361,7 @@ protected: fill(rhs, 1); fill(bias, 2); - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -1533,7 +1533,7 @@ protected: if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -1759,7 +1759,7 @@ protected: if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -1941,7 +1941,7 @@ protected: fill(rhs, 1); fill(bias, 2); - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -2078,7 +2078,7 @@ protected: if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -2274,7 +2274,7 @@ protected: if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -2421,7 +2421,7 @@ protected: fill(rhs, 1); fill(bias, 2); - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -2434,6 +2434,171 @@ protected: SimpleTensor _reference{}; }; +template +class GEMMMatrixMultiplyReshapedOnlyRhsMMULValidationFixture : public framework::Fixture +{ +public: + template + void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, bool export_to_cl_image, DataType data_type, float alpha, + float beta, bool broadcast_bias, + const ActivationLayerInfo &act_info) + { + GEMMLHSMatrixInfo lhs_info; + lhs_info.m0 = m0; + lhs_info.k0 = k0; + + GEMMRHSMatrixInfo rhs_info; + rhs_info.n0 = n0; + rhs_info.k0 = k0; + rhs_info.interleave = true; + rhs_info.transpose = false; + rhs_info.h0 = 4; + rhs_info.export_to_cl_image = export_to_cl_image; + + // Set the tensor shapes for LHS and RHS matrices + const TensorShape lhs_shape(k, m, batch_size); + const TensorShape rhs_shape(n, k, batch_size); + const TensorShape bias_shape(n, + broadcast_bias ? 1 : m, + broadcast_bias ? 1 : batch_size); + + _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias, act_info); + _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, broadcast_bias, act_info); + } + +protected: + template + void fill(U &&tensor, int i) + { + static_assert(std::is_floating_point::value || std::is_same::value, "Only floating point data types supported."); + using DistributionType = typename std::conditional::value, arm_compute::utils::uniform_real_distribution_16bit, std::uniform_real_distribution>::type; + + DistributionType distribution{ T(-1.0f), T(1.0f) }; + library->fill(tensor, distribution, i); + + // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0) + DistributionType distribution_inf{ T(std::numeric_limits::infinity()), T(std::numeric_limits::infinity()) }; + library->fill_borders_with_garbage(tensor, distribution_inf, i); + } + + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + DataType data_type, float alpha, float beta, bool broadcast_bias, const ActivationLayerInfo &act_info) + { + // Create tensors + TensorType lhs = create_tensor(lhs_shape, data_type, 1); + TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType bias = create_tensor(bias_shape, data_type, 1); + TensorType rhs_reshaped; + TensorType dst; + + const unsigned int M = lhs_shape[1]; + const unsigned int N = rhs_shape[0]; + const unsigned int K = lhs_shape[0]; + GEMMKernelInfo kernel_info; + kernel_info.m = M; + kernel_info.n = N; + kernel_info.k = K; + kernel_info.depth_output_gemm3d = 0; + kernel_info.reinterpret_input_as_3d = false; + kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = act_info; + + // Create and configure function + ReshapeRHSOperatorType reshape_rhs; + GEMMOperatorType gemm; + + validate_result = bool(reshape_rhs.validate(rhs.info(), rhs_reshaped.info(), rhs_info)); + if(!validate_result) + { + return nullptr; + } + + reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info); + + validate_result = bool(gemm.validate(lhs.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info)); + if(!validate_result) + { + return nullptr; + } + + gemm.configure(lhs.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info); + + ARM_COMPUTE_ASSERT(lhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(rhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(bias.info()->is_resizable()); + + // Allocate tensors + lhs.allocator()->allocate(); + rhs.allocator()->allocate(); + rhs_reshaped.allocator()->allocate(); + bias.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!bias.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + // Fill tensors + fill(AccessorType(lhs), 0); + fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); + + // Compute GEMM + ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } }; + reshape_rhs.run(reshape_rhs_pack); + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, + { ACL_SRC_1, &rhs_reshaped }, + { ACL_SRC_2, &bias }, + { ACL_DST, &dst } + }); + gemm.run(gemm_pack); + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha, float beta, bool broadcast_bias, + const ActivationLayerInfo &act_info) + { + if(!validate_result) + return SimpleTensor(); + + TensorShape dst_shape = lhs_shape; + dst_shape[0] = rhs_shape[0]; + dst_shape[1] = lhs_shape[1]; + + // Create reference + SimpleTensor lhs{ lhs_shape, data_type, 1 }; + SimpleTensor rhs{ rhs_shape, data_type, 1 }; + SimpleTensor bias{ dst_shape, data_type, 1 }; + + const int n = rhs_shape[0]; + const int m = lhs_shape[1]; + const int batch_size = lhs_shape[2]; + + // Fill reference + fill(lhs, 0); + fill(rhs, 1); + fill(bias, 2); + + if(broadcast_bias) + { + // In case of broadcast, we need to simply copy the first into the following "M" ones + for(int i = 1; i < m * batch_size; i++) + { + memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); + } + } + + return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + } + + bool validate_result = true; + TensorType _target{}; + SimpleTensor _reference{}; +}; + } // namespace validation } // namespace test } // namespace arm_compute diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index dae81e4a5a..31eff57e6b 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -2358,6 +2358,12 @@ inline ::std::ostream &operator<<(::std::ostream &os, const GPUTarget &gpu_targe case GPUTarget::G710: os << "G710"; break; + case GPUTarget::G715: + os << "G715"; + break; + case GPUTarget::G615: + os << "G615"; + break; default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } -- cgit v1.2.1