From 491f30c0fff416007d97f4a5a043923861ef7b64 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 2 Nov 2020 15:43:57 +0000 Subject: COMPMID-3939: Update GEMM heuristic Mali-G77 - Update heuristic for GEMM reshaped RHS only - Fix left-over block size in CLGEMMMatrixMultiplyReshapedOlyRHSKernel Change-Id: I34c738821ed2e4a537da4a15058eec164cb6b61f Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4305 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- ...MMReshapedOnlyRHSKernelConfigurationValhall.cpp | 129 ++++++++++++++------- 1 file changed, 85 insertions(+), 44 deletions(-) (limited to 'src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp') diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp index f7939d29c0..e0991674b1 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp @@ -78,66 +78,107 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi std::pair CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - ARM_COMPUTE_UNUSED(k); - - GEMMLHSMatrixInfo lhs_info_buf; - GEMMRHSMatrixInfo rhs_info_buf; - GEMMLHSMatrixInfo lhs_info_img; - GEMMRHSMatrixInfo rhs_info_img; - - // Get lhs_info/rhs_info in case of OpenCL buffer if(m == 1) { - const unsigned int h0 = std::max(n / 4, 1U); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true); - } - else - { - if(m > 256) + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + + if(r_mk <= 0.0064484127797186375) { - const int v0 = std::max(std::min(static_cast(n / 4), static_cast(8)), static_cast(1)); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, v0, false, true, false, true); + if(r_mn <= 0.0028273810748942196) + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + const unsigned int h0 = std::max(n / 4, 1U); + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, false, true, false, false, false); + } } else { - const int v0 = std::max(std::min(static_cast(n / 4), static_cast(8)), static_cast(1)); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, v0, false, true, false, true); + if(r_mk <= 0.020312500186264515) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, true, false); + } } } - - // Get lhs_info/rhs_info in case of OpenCL image - if(m == 1) - { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 8, true, true, false, false, true); - } else { - if((m / 4) * (n / 4) > 4096) + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + + if(workload <= 1999.2000122070312) { - const int h0 = std::max(std::min(static_cast(n / 4), static_cast(8)), static_cast(1)); - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true); + if(workload <= 747.1999816894531) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false); + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } } else { - const int h0 = std::max(std::min(static_cast(n / 4), static_cast(8)), static_cast(1)); - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, h0, false, true, false, false, true); - } - } - - const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32); - const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img); - const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32); + if(r_mn <= 0.03348214365541935) + { + if(r_mk <= 0.028125000186264515) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false); + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false); - // In case of small workloads, we use the OpenCL buffer rather than the OpenCL image2d - const bool use_cl_image2d = ((m / lhs_info_img.m0) * (n / rhs_info_img.n0)) * b < 1024 ? false : true; + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, false, true, false, true, false); - if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d) - { - return std::make_pair(lhs_info_img, rhs_info_img); - } - else - { - return std::make_pair(lhs_info_buf, rhs_info_buf); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } } } -- cgit v1.2.1