From db14af697b934d684d8b3d63a00ad5bea5c07bfb Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 4 Oct 2022 15:29:34 +0100 Subject: Update GEMM reshaped rhs only heuristic Resolves COMPMID-5631 Change-Id: I37d1d0d043f8d44d782d2225091af607ad131b58 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8364 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Viet-Hoa Do --- .../ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp | 143 ++++++--------------- 1 file changed, 40 insertions(+), 103 deletions(-) diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp index 97762980be..e4e35cb8ce 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp @@ -413,163 +413,100 @@ std::pair ClGemmDefaultConfigReshapedRhsOn std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - const float r_mn = static_cast(m) / static_cast(n); - const float r_mk = static_cast(m) / static_cast(k); - const float r_nk = static_cast(n) / static_cast(k); const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); if(m == 1) { - if(r_mn <= 0.0038f) + if(r_mn <= 0.0045f) { - if(workload <= 353.9000f) + if(workload <= 278.7000f) { - if(workload <= 278.7000f) - { - return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); - } - else - { - if(r_mk <= 0.0004f) - { - return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); - } - else - { - if(r_mk <= 0.0030f) - { - return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 8, 0, 1, 1, 0, 1); - } - else - { - return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); - } - } - } + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 0, 0, 1, 1); } else { - if(r_nk <= 1.9384f) - { - return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); - } - else - { - return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 8, 0, 1, 1, 0, 1); - } + return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 32, 0, 0, 1, 0, 0); } } else { - if(r_nk <= 1.0368f) - { - return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, 0, 0, 1, 0, 0); - } - else - { - return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0); - } + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 0, 1, 0, 0); } } else { - if(workload <= 1422.4000f) + if(workload <= 1384.8000f) { - if(workload <= 704.0000f) + if(r_nk <= 0.8333f) { - return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0); - } - else - { - if(workload <= 1197.6000f) + if(r_mk <= 0.9119f) { - return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1); + return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 4, 0, 1, 0, 1, 1); } else { - if(workload <= 1241.6000f) + if(r_nk <= 0.1181f) { - return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0); } else { - return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1); + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); } } } + else + { + if(r_mk <= 1.0013f) + { + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1); + } + } } else { - if(workload <= 2769.6000f) + if(workload <= 11404.7998f) { - if(workload <= 1846.4000f) + if(r_mk <= 2.2884f) { - if(r_mn <= 2.4927f) + if(r_nk <= 0.9286f) { - return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 4, 0, 1, 1, 0, 1); } else { - return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1); } } else { - if(r_mn <= 0.6261f) - { - return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); - } - else - { - if(r_mk <= 3.4453f) - { - if(r_mn <= 1.4135f) - { - return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); - } - else - { - return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); - } - } - else - { - return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); - } - } + return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1); } } else { - if(r_nk <= 0.0302f) - { - return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1); - } - else + if(r_nk <= 1.1926f) { - if(r_mk <= 181.3750f) + if(r_mn <= 1385.7917f) { - return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); + return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 4, 0, 1, 1, 0, 1); } else { - if(workload <= 28035.2002f) - { - return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); - } - else - { - if(r_mk <= 808.6667f) - { - return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0); - } - else - { - return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0); - } - } + return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 32, 0, 1, 1, 0, 0); } } + else + { + return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 32, 0, 1, 1, 0, 1); + } } } } -- cgit v1.2.1