From eb65f6da695ac0d3e495817145cceb1c4de4f048 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Wed, 15 Apr 2020 11:42:15 +0100 Subject: COMPMID-3304: Update OpenCL GEMM heuristic for Int8 Change-Id: I6b7ff678d8d0437a1639db2ff602ea1cdb155464 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3056 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- src/core/CL/CLKernelLibrary.cpp | 1 - src/core/CL/cl_kernels/gemmlowp.cl | 285 ------------------- .../CLGEMMNativeKernelConfigurationBifrost.cpp | 45 ++- .../CLGEMMNativeKernelConfigurationMidgard.cpp | 75 +++++ .../CLGEMMNativeKernelConfigurationValhall.cpp | 7 +- .../CLGEMMReshapedKernelConfigurationBifrost.cpp | 10 +- .../CLGEMMReshapedKernelConfigurationValhall.cpp | 5 +- ...MMReshapedOnlyRHSKernelConfigurationBifrost.cpp | 20 +- ...MMReshapedOnlyRHSKernelConfigurationValhall.cpp | 15 +- .../CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp | 307 --------------------- .../CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp | 110 +++----- .../CL/gemm/CLGEMMKernelSelectionBifrost.cpp | 20 +- .../CL/gemm/CLGEMMKernelSelectionMidgard.cpp | 5 +- .../CL/gemm/CLGEMMKernelSelectionValhall.cpp | 13 +- 14 files changed, 207 insertions(+), 711 deletions(-) create mode 100644 src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationMidgard.cpp delete mode 100644 src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp (limited to 'src') diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp index 00e7b2bc5c..d4073c6f30 100644 --- a/src/core/CL/CLKernelLibrary.cpp +++ b/src/core/CL/CLKernelLibrary.cpp @@ -231,7 +231,6 @@ const std::map CLKernelLibrary::_kernel_program_map = { "gemmlowp_matrix_a_reduction", "gemmlowp.cl" }, { "gemmlowp_matrix_a_reduction_dot8", "gemmlowp.cl" }, { "gemmlowp_matrix_b_reduction", "gemmlowp.cl" }, - { "gemmlowp_mm_midgard", "gemmlowp.cl" }, { "gemmlowp_mm_native", "gemmlowp.cl" }, { "gemmlowp_mm_reshaped_lhs_nt_rhs_t", "gemmlowp.cl" }, { "gemmlowp_mm_reshaped_only_rhs_t", "gemmlowp.cl" }, diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl index 127df063f6..d9625e7117 100644 --- a/src/core/CL/cl_kernels/gemmlowp.cl +++ b/src/core/CL/cl_kernels/gemmlowp.cl @@ -195,291 +195,6 @@ (n0, k0, a, b, c); \ }) -#if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A) -#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X) -#define VECTOR_ACC_TYPE VEC_DATA_TYPE(ACC_DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X) -#define VECTOR_INT VEC_DATA_TYPE(int, NUM_ELEMS_PROCESSED_PER_THREAD_X) -/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped - * - * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A - * - * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar) - * @note The accumulator data type must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint) - * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: - * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D - * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D - * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. - * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor - * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped - * - * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8 - * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) - * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) - * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr - * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) - * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) - * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix - * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32 - * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) - * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) - * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) - * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) - * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix - * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) - * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) - * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) - * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D) - */ -__kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0), - IMAGE_DECLARATION(src1), - IMAGE_DECLARATION(dst), - uint src0_stride_z, - uint src1_stride_z, - uint dst_stride_z -#if defined(REINTERPRET_INPUT_AS_3D) - , - uint src_cross_plane_pad -#endif // REINTERPRET_INPUT_AS_3D -#if defined(REINTERPRET_OUTPUT_AS_3D) - , - uint dst_cross_plane_pad -#endif // REINTERPRET_OUTPUT_AS_3D - ) -{ - int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X; - - // Compute starting address for matrix A and Matrix B - int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); - - // Update address for the matrix A - src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y; - - // Update address for the matrix B - src_addr.s1 += idx; - -#if defined(REINTERPRET_INPUT_AS_3D) - // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - - // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D; - zin = min(DEPTH_GEMM3D - 1, zin); - - // Add offset due to the cross plane paddings - zin *= (src_cross_plane_pad * src0_stride_y); - - // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we - // multiply src0_stride_z by DEPTH_GEMM3D - src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; - -#else // defined(REINTERPRET_INPUT_AS_3D) - - // Add offset for batched GEMM - src_addr.s0 += get_global_id(2) * src0_stride_z; - -#endif // defined(REINTERPRET_INPUT_AS_3D) - -#if defined(MATRIX_B_DEPTH) - // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 - src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; -#else // defined(MATRIX_B_DEPTH) - src_addr.s1 += get_global_id(2) * src1_stride_z; -#endif // defined(MATRIX_B_DEPTH) - - int end_row_vec_a = src_addr.s0 + COLS_A; - - VECTOR_ACC_TYPE acc0 = 0; -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - VECTOR_ACC_TYPE acc1 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - VECTOR_ACC_TYPE acc2 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - VECTOR_ACC_TYPE acc3 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - VECTOR_ACC_TYPE acc4 = 0; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - - for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y)) - { - // Load values from matrix A - VEC_DATA_TYPE(DATA_TYPE, 2) - a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - VEC_DATA_TYPE(DATA_TYPE, 2) - a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - VEC_DATA_TYPE(DATA_TYPE, 2) - a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - VEC_DATA_TYPE(DATA_TYPE, 2) - a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - VEC_DATA_TYPE(DATA_TYPE, 2) - a4 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 4 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - // Load values from matrix B - VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1)); - VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y)); - - // Accumulate - acc0 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a0.s0; - acc0 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a0.s1; -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc1 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a1.s0; - acc1 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a1.s1; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc2 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a2.s0; - acc2 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a2.s1; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc3 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a3.s0; - acc3 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a3.s1; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - acc4 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a4.s0; - acc4 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a4.s1; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - } - - for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y)) - { - // Load values from matrix A - DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - DATA_TYPE a4 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 4 * src0_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - // Load values from matrix B - VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1)); - - // Accumulate - acc0 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a0; -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc1 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a1; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc2 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a2; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc3 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a3; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - acc4 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a4; -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - } - - const int z = get_global_id(2); - - // Compute destination address - Image dst = CONVERT_TO_IMAGE_STRUCT(dst); - -#if defined(REINTERPRET_OUTPUT_AS_3D) - // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension - // in order to take into account the presence of possible cross plane paddings - // - // | | - // | plane0 | - // | | - // |__________________| - // |******************| - // | cross_plane_pad | - // |******************| - // | | - // | plane1 | - // | | - // |__________________| - - // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D - uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) + (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint8)HEIGHT_GEMM3D; - zout = min(DEPTH_GEMM3D - 1, zout); - - // Add offset due to the cross plane paddings - zout *= (dst_cross_plane_pad * dst_stride_y); - - // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we - // multiply dst_stride_z by DEPTH_GEMM3D - dst.ptr += z * dst_stride_z * DEPTH_GEMM3D; - - // Store the result - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - -#else // defined(REINTERPRET_OUTPUT_AS_3D) - // Add offset for batched GEMM - dst.ptr += z * dst_stride_z; - - // Store the result - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y)); -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 -#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X) - (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y)); -#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 -#endif // defined(REINTERPRET_OUTPUT_AS_3D) -} -#endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A) - #if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(M) && defined(N) /** This OpenCL kernel computes the matrix multiplication between 2 matrices with QASYMM/QASYMM_SIGNED data type. * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed diff --git a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp index c4a9ccd703..c6b51c698a 100644 --- a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationBifrost.cpp @@ -42,9 +42,6 @@ CLGEMMNativeKernelConfigurationBifrost::CLGEMMNativeKernelConfigurationBifrost(G std::pair CLGEMMNativeKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) { - ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8); - ARM_COMPUTE_UNUSED(data_type); - using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMNativeKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); @@ -52,31 +49,61 @@ std::pair CLGEMMNativeKernelConfigurationB static std::map gemm_configs_G71 = { { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_f32 }, - { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 } + { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }, + { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_G71_u8 } }; // Configurations for Mali-G76 static std::map gemm_configs_G76 = { { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_f32 }, - { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 } + { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_G76_u8 } }; // Default configurations static std::map gemm_configs_default = { { DataType::F32, &CLGEMMNativeKernelConfigurationBifrost::configure_default_f32 }, - { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 } + { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }, + { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationBifrost::configure_default_u8 } }; switch(_target) { case GPUTarget::G71: - return (this->*gemm_configs_G71[data_type])(m, n, k, b); + if(gemm_configs_G71.find(data_type) != gemm_configs_G71.end()) + { + return (this->*gemm_configs_G71[data_type])(m, n, k, b); + } + else + { + ARM_COMPUTE_ERROR("Not supported data type"); + } case GPUTarget::G76: - return (this->*gemm_configs_G76[data_type])(m, n, k, b); + if(gemm_configs_G76.find(data_type) != gemm_configs_G76.end()) + { + return (this->*gemm_configs_G76[data_type])(m, n, k, b); + } + else + { + ARM_COMPUTE_ERROR("Not supported data type"); + } default: - return (this->*gemm_configs_default[data_type])(m, n, k, b); + if(gemm_configs_default.find(data_type) != gemm_configs_default.end()) + { + return (this->*gemm_configs_default[data_type])(m, n, k, b); + } + else + { + ARM_COMPUTE_ERROR("Not supported data type"); + } } } diff --git a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationMidgard.cpp b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationMidgard.cpp new file mode 100644 index 0000000000..86c056ffc2 --- /dev/null +++ b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationMidgard.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2020 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/CL/gemm/native/CLGEMMNativeKernelConfigurationMidgard.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/CL/gemm/CLGEMMHelpers.h" +#include "arm_compute/core/GPUTarget.h" + +#include +#include + +namespace arm_compute +{ +namespace cl_gemm +{ +CLGEMMNativeKernelConfigurationMidgard::CLGEMMNativeKernelConfigurationMidgard(GPUTarget gpu) + : ICLGEMMKernelConfiguration(gpu) +{ +} + +std::pair CLGEMMNativeKernelConfigurationMidgard::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +{ + using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMNativeKernelConfigurationMidgard::*)(unsigned int m, unsigned int n, unsigned int k, + unsigned int b); + + // Configurations for Midgard architectures + static std::map default_configs = + { + { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationMidgard::default_q8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationMidgard::default_q8 }, + { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationMidgard::default_q8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationMidgard::default_q8 } + }; + + if(default_configs.find(data_type) != default_configs.end()) + { + return (this->*default_configs[data_type])(m, n, k, b); + } + ARM_COMPUTE_ERROR("Not supported data type"); +} + +std::pair CLGEMMNativeKernelConfigurationMidgard::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + const unsigned int m0 = std::min(m, static_cast(4)); + const unsigned int n0 = std::min(n, static_cast(4)); + + return configure_lhs_rhs_info(m, n, m0, n0, 2, 1, 1, false, false, false, false); +} +} // namespace cl_gemm +} // namespace arm_compute \ No newline at end of file diff --git a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationValhall.cpp b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationValhall.cpp index 7cf0f0e1a8..c25cdac81a 100644 --- a/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationValhall.cpp +++ b/src/core/CL/gemm/native/CLGEMMNativeKernelConfigurationValhall.cpp @@ -45,12 +45,15 @@ std::pair CLGEMMNativeKernelConfigurationV using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMNativeKernelConfigurationValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - // Configurations for Mali-G71 + // Configurations for Mali-G77 static std::map gemm_configs_G77 = { { DataType::F32, &CLGEMMNativeKernelConfigurationValhall::configure_G77_f32 }, { DataType::F16, &CLGEMMNativeKernelConfigurationValhall::configure_G77_f16 }, - { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationValhall::configure_G77_u8 } + { DataType::QASYMM8, &CLGEMMNativeKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8, &CLGEMMNativeKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMNativeKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMNativeKernelConfigurationValhall::configure_G77_u8 } }; switch(_target) diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp index 144c23a798..990cc72eb0 100644 --- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp @@ -49,7 +49,10 @@ std::pair CLGEMMReshapedKernelConfiguratio { { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 }, { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 } + { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 } }; // Configurations for Mali-G7x @@ -57,7 +60,10 @@ std::pair CLGEMMReshapedKernelConfiguratio { { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 }, { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 } + { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 } }; switch(_target) diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp index 20fa3d65bf..b96dc96e87 100644 --- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp @@ -49,7 +49,10 @@ std::pair CLGEMMReshapedKernelConfiguratio { { DataType::F32, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_f32 }, { DataType::F16, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_u8 } + { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationValhall::configure_G77_u8 } }; switch(_target) diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp index 8e798116bf..8826cca11b 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp @@ -50,7 +50,10 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi { { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32 }, { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 } + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 } }; // Configurations for Mali-G76 @@ -58,7 +61,10 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi { { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f32 }, { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 } + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 } }; // Configurations for Mali-G7x @@ -66,7 +72,10 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi { { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f32 }, { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 } + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 } }; switch(_target) @@ -235,15 +244,14 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi } else { + const int h0 = std::max(std::min(static_cast(n / 2), static_cast(128)), static_cast(1)); if(m == 1) { - const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true); } else { - const unsigned int h0 = std::max(n / 4, 1U); - return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, h0, false, true, false, true); + return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true); } } } diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp index 951447e1a0..783d0fe91b 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp @@ -50,7 +50,10 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi { { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f32 }, { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f16 }, - { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 } + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_u8 } }; switch(_target) @@ -135,7 +138,15 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi } else { - return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 4, false, true, false, true); + const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); + if(m >= 28) + { + return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, true); + } } } } // namespace cl_gemm diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp deleted file mode 100644 index 0d4bbba0d4..0000000000 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp +++ /dev/null @@ -1,307 +0,0 @@ -/* - * Copyright (c) 2017-2020 ARM Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#include "arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.h" - -#include "arm_compute/core/AccessWindowStatic.h" -#include "arm_compute/core/AccessWindowTranspose.h" -#include "arm_compute/core/CL/CLHelpers.h" -#include "arm_compute/core/CL/CLKernelLibrary.h" -#include "arm_compute/core/CL/ICLTensor.h" -#include "arm_compute/core/CL/OpenCL.h" -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/TensorInfo.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/Utils.h" -#include "arm_compute/core/Validate.h" -#include "arm_compute/core/Window.h" -#include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "support/StringSupport.h" - -#include -#include -#include - -using namespace arm_compute; -using namespace arm_compute::misc::shape_calculator; - -namespace arm_compute -{ -class Coordinates; -} // namespace arm_compute - -namespace -{ -using ElementsProcessed = Steps; - -Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMReshapeInfo &gemm_info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && gemm_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D"); - - const int m = gemm_info.m(); - const int n = gemm_info.n(); - const int k = gemm_info.k(); - - ARM_COMPUTE_UNUSED(m); - ARM_COMPUTE_UNUSED(n); - ARM_COMPUTE_UNUSED(k); - - ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != static_cast(k)); - ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != static_cast(n)); - ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(1) != static_cast(k)); - if(gemm_info.reinterpret_input_as_3d()) - { - ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) * input0->dimension(2) != static_cast(m)); - } - else - { - ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != static_cast(m)); - } - - if(output->total_size() != 0) - { - const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, false, gemm_info)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32); - } - - return Status{}; -} - -std::pair validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output, const GEMMReshapeInfo &gemm_info, ElementsProcessed &num_elements_processed) -{ - unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0]; - unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1]; - bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); - bool reinterpret_output_as_3d = (gemm_info.depth_output_gemm3d() != 0); - - Window win{}; - Window win_out{}; - bool window_changed = false; - - // In case both input and output have to be reinterpreted as 3D tensors, - // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. - if(reinterpret_input_as_3d == reinterpret_output_as_3d) - { - reinterpret_input_as_3d = false; - reinterpret_output_as_3d = false; - } - - // Output tensor auto inizialitation if not yet initialized - auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, false, gemm_info)).set_data_type(DataType::S32)); - - TensorInfo tmp_info(*output); - - if(reinterpret_output_as_3d) - { - // Since the output tensor has to be reinterpreted as 3D and the execute window is based on a 2D GEMM, - // the window needs to be constructed on the 2D collapsed version of the tensor - TensorShape tmp_shape(output->tensor_shape()); - tmp_shape.collapse(2U, 1U); - tmp_info.set_tensor_shape(tmp_shape); - } - - // Special case for 1xN, 2xN, 3xN and 4xN input0 tensor. num_elems_processed_per_iteration_x - // Note: if the dot product instruction is available, the 8x2 tile has to be used - num_elems_processed_per_iteration_x = 4; - num_elems_processed_per_iteration_y = std::min(static_cast(output->dimension(1)), 4); - - // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor - // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic - const int m = reinterpret_input_as_3d ? input0->tensor_shape()[1] * input0->tensor_shape()[2] : input0->tensor_shape()[1]; - const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y; - - // Configure window - win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); - win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); - - AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), input0->dimension(1) + bottom_pad); - AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1)); - AccessWindowStatic output_access(output, 0, 0, - ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x), - output->dimension(1) + bottom_pad); - - window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop - update_window_and_padding(win_out, output_access); // window used to update the padding requirements of output tensor - - Coordinates coord; - coord.set_num_dimensions(output->num_dimensions()); - output_access.set_valid_region(win_out, ValidRegion(coord, output->tensor_shape())); - - // Collapse along the Z direction - // This collapse needs to be here in order to tune the Z dimension of LWS - Window collapsed = win; - const unsigned int dimension_to_collapse = std::min(static_cast(output->num_dimensions()), 2u); - collapsed = win.collapse(win, dimension_to_collapse); - - Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; - return std::make_pair(err, collapsed); -} -} // namespace - -CLGEMMLowpMatrixMultiplyKernel::CLGEMMLowpMatrixMultiplyKernel() - : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false) -{ -} - -void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMReshapeInfo &gemm_info) -{ - configure(CLKernelLibrary::get().get_compile_context(), input0, input1, output, gemm_info); -} - -void CLGEMMLowpMatrixMultiplyKernel::configure(CLCompileContext &compile_context, const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMReshapeInfo &gemm_info) -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output); - - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), gemm_info)); - - _input0 = input0; - _input1 = input1; - _output = output; - _reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); - _reinterpret_output_as_3d = (gemm_info.depth_output_gemm3d() != 0); - - // In case both input and output have to be reinterpreted as 3D tensors, - // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false. - if(_reinterpret_input_as_3d == _reinterpret_output_as_3d) - { - _reinterpret_input_as_3d = false; - _reinterpret_output_as_3d = false; - } - - // Check if we need to slide the matrix B - const unsigned int num_dimensions_input0 = _reinterpret_input_as_3d ? _input0->info()->num_dimensions() - 1 : _input0->info()->num_dimensions(); - _slide_matrix_b = (_input1->info()->num_dimensions() >= num_dimensions_input0); - - ElementsProcessed num_elements_processed{}; - - // Configure kernel window - auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), gemm_info, num_elements_processed); - ARM_COMPUTE_ERROR_THROW_ON(win_config.first); - ICLKernel::configure_internal(win_config.second); - - // Create build options - std::string kernel_name(" "); - CLBuildOptions build_opts; - build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D"); - build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D"); - build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1))); - build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2))); - build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2))); - build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0))); - build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x())); - build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elements_processed.y())); - build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type())); - build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(input0->info()->data_type())); - - kernel_name = "gemmlowp_mm_midgard"; - - // Create kernel - _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); - - // Set config_id for enabling LWS tuning - _config_id = kernel_name; - _config_id += "_"; - _config_id += (_reinterpret_input_as_3d ? "3di_" : ""); - _config_id += (_reinterpret_output_as_3d ? "3do_" : ""); - _config_id += lower_string(string_from_data_type(input0->info()->data_type())); - _config_id += "_"; - _config_id += support::cpp11::to_string(output->info()->dimension(1)); - _config_id += "_"; - _config_id += support::cpp11::to_string(output->info()->dimension(0)); -} - -Status CLGEMMLowpMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, const GEMMReshapeInfo &gemm_info) -{ - ElementsProcessed num_elements_processed{}; - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, gemm_info)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), - input1->clone().get(), - output->clone().get(), - gemm_info, - num_elements_processed) - .first); - - return Status{}; -} - -void CLGEMMLowpMatrixMultiplyKernel::run(const Window &window, cl::CommandQueue &queue) -{ - ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); - - if(_input1->info()->num_dimensions() < 3) - { - // The stride_z for matrix B must be zero if we do not slice - ARM_COMPUTE_ERROR_ON(_input1->info()->strides_in_bytes()[3] != 0); - } - - Window slice = window.first_slice_window_3D(); - Window slice_matrix_b = slice; - - slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1)); - slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1)); - - if(_reinterpret_input_as_3d) - { - // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor - const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3; - const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom; - _kernel.setArg(idx0, static_cast(total_cross_plane_pad)); - } - - if(_reinterpret_output_as_3d) - { - // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor - const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0); - const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom; - _kernel.setArg(idx0, static_cast(total_cross_plane_pad)); - } - - do - { - Window slice_b = slice; - // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 - // This scenario can happen when the matrix multiplication is used to perform a convolution operation - if(!_slide_matrix_b) - { - slice_b = slice_matrix_b; - } - - unsigned int idx = 0; - add_2D_tensor_argument(idx, _input0, slice); - add_2D_tensor_argument(idx, _input1, slice_b); - add_2D_tensor_argument(idx, _output, slice); - _kernel.setArg(idx++, static_cast(_input0->info()->strides_in_bytes()[2])); - _kernel.setArg(idx++, static_cast(_input1->info()->strides_in_bytes()[2])); - _kernel.setArg(idx++, static_cast(_output->info()->strides_in_bytes()[2])); - enqueue(queue, *this, slice, lws_hint()); - } - while(window.slide_window_slice_3D(slice)); -} diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp index 90e5698fd8..ef17f110d0 100644 --- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp @@ -35,6 +35,7 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/CL/CLScheduler.h" +#include "arm_compute/runtime/CL/gemm/CLGEMMKernelSelection.h" namespace arm_compute { @@ -43,16 +44,33 @@ using namespace arm_compute::cl_gemm; namespace { -inline bool is_gemm_reshaped(bool reshape_b_only_on_first_run, GPUTarget gpu_target) +inline bool is_gemm_reshaped(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run) { - return (get_arch_from_target(gpu_target) != GPUTarget::MIDGARD) && (reshape_b_only_on_first_run); + std::unique_ptr gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target()); + ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get()); + + CLGEMMKernelSelectionParams params; + params.m = m; + params.n = n; + params.k = k; + params.is_rhs_constant = reshape_b_only_on_first_run; + params.data_type = data_type; + + switch(gemm_kernel->select_kernel(params)) + { + case CLGEMMKernelType::NATIVE: + return false; + case CLGEMMKernelType::RESHAPED_ONLY_RHS: + return true; + default: + ARM_COMPUTE_ERROR("Not supported gemmlowp kernel!"); + } } } // namespace CLGEMMLowpMatrixMultiplyCore::CLGEMMLowpMatrixMultiplyCore(std::shared_ptr memory_manager) : _memory_group(std::move(memory_manager)), _weights_to_qasymm8(), - _mm_midgard_kernel(), _mm_native_kernel(), _mm_reshaped_only_rhs_kernel(), _mtx_b_reshape_kernel(), @@ -73,7 +91,6 @@ CLGEMMLowpMatrixMultiplyCore::CLGEMMLowpMatrixMultiplyCore(std::shared_ptrinfo()->data_type(), _reshape_b_only_on_first_run); if(_convert_to_qasymm8) { @@ -220,19 +235,12 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor } else { - if(_is_midgard) - { - // Configure matrix multiply kernel - _mm_midgard_kernel.configure(_matrix_a, matrix_b, &_mm_result_s32, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d)); - } - else - { - // Pick up the GEMM configuration - std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); - - // Configure matrix multiply kernel - _mm_native_kernel.configure(_matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d)); - } + // Pick up the GEMM configuration + std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); + + // Configure matrix multiply kernel + _mm_native_kernel.configure(_matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d)); + _offset_contribution_output_stage_kernel.configure(&_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c, output, a->info()->dimension(0), _a_offset, _b_offset, gemmlowp_output_stage, &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts); @@ -260,19 +268,11 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor } else { - if(_is_midgard) - { - // Configure matrix multiply kernel - _mm_midgard_kernel.configure(_matrix_a, matrix_b, output, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d)); - } - else - { - // Pick up the GEMM configuration - std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); + // Pick up the GEMM configuration + std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); - // Configure matrix multiply kernel - _mm_native_kernel.configure(_matrix_a, matrix_b, output, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d)); - } + // Configure matrix multiply kernel + _mm_native_kernel.configure(_matrix_a, matrix_b, output, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d)); } // Configure offset contribution kernel @@ -329,9 +329,8 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso const unsigned int k = a->dimension(0); const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); - const bool is_midgard = gpu_target == GPUTarget::MIDGARD; - bool reshape_matrix_b = is_gemm_reshaped(gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target()); + bool reshape_matrix_b = is_gemm_reshaped(m, n, k, a->data_type(), gemm_info.reshape_b_only_on_first_run()); const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d); @@ -425,19 +424,11 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso // Output tensor auto inizialitation if not yet initialized auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, false, reshape_info)).set_data_type(DataType::S32)); - if(is_midgard) - { - // Validate matrix multiply - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, reshape_info)); - } - else - { - // Pick up the GEMM configuration - std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); - - // Validate matrix multiply - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info)); - } + // Pick up the GEMM configuration + std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); + + // Validate matrix multiply + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info)); } // Validate offset contribution kernel @@ -461,19 +452,11 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso } else { - if(is_midgard) - { - // Validate matrix multiply - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, reshape_info)); - } - else - { - // Pick up the GEMM configuration - std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); + // Pick up the GEMM configuration + std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8); - // Validate matrix multiply - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, output, lhs_info, rhs_info, reshape_info)); - } + // Validate matrix multiply + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, output, lhs_info, rhs_info, reshape_info)); } if(output->total_size() != 0) @@ -524,14 +507,7 @@ void CLGEMMLowpMatrixMultiplyCore::run() } else { - if(_is_midgard) - { - CLScheduler::get().enqueue(_mm_midgard_kernel, false); - } - else - { - CLScheduler::get().enqueue(_mm_native_kernel, false); - } + CLScheduler::get().enqueue(_mm_native_kernel, false); } if(_run_output_stage) { diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp index d30eaa9edc..041e7d6cb4 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp @@ -165,27 +165,15 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f16(unsigned int m, unsig CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) { + ARM_COMPUTE_UNUSED(m, n, k); + if(is_rhs_constant) { - if(m == 1) - { - if((n > k) && gpu_target_is_in(_target, GPUTarget::G71)) - { - return CLGEMMKernelType::NATIVE_V1; - } - else - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } - } - else - { - return CLGEMMKernelType::RESHAPED; - } + return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - return CLGEMMKernelType::NATIVE_V1; + return CLGEMMKernelType::NATIVE; } } diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp index b7bb720175..a94a392553 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp @@ -86,10 +86,9 @@ CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_f16(unsigned int m, unsig CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(n, k); + ARM_COMPUTE_UNUSED(m, n, k, is_rhs_constant); - // We reshape the matrices only if we do not have the vector-by-matrix case and we reshape the matrix B only once - return ((m != 1) && is_rhs_constant) ? CLGEMMKernelType::RESHAPED_V1 : CLGEMMKernelType::NATIVE_V1; + return CLGEMMKernelType::NATIVE; } } // namespace cl_gemm } // namespace arm_compute diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp index 8016417eb9..775bb9bffd 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp @@ -83,22 +83,15 @@ CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f16(unsigned int m, unsig CLGEMMKernelType CLGEMMKernelSelectionValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(n, k); + ARM_COMPUTE_UNUSED(m, n, k); if(is_rhs_constant) { - if(m == 1) - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } - else - { - return CLGEMMKernelType::RESHAPED; - } + return CLGEMMKernelType::RESHAPED_ONLY_RHS; } else { - return CLGEMMKernelType::NATIVE_V1; + return CLGEMMKernelType::NATIVE; } } } // namespace cl_gemm -- cgit v1.2.1