diff options
Diffstat (limited to 'src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp')
-rw-r--r-- | src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp | 181 |
1 files changed, 144 insertions, 37 deletions
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp index e4e35cb8ce..5d666c03a5 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. + * Copyright (c) 2020-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -196,52 +196,159 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - ARM_COMPUTE_UNUSED(k); - ARM_COMPUTE_UNUSED(b); + const GeMMConfigsMatrix configs_1nkb_best = + { + { 1, 8984, 640, 1, 1, 8, 8, 1, 0, 1, 1, 1, 1, 0 }, + { 1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 6512, 6404, 1, 1, 4, 8, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 5304, 640, 1, 1, 4, 4, 1, 0, 1, 0, 1, 1, 0 }, + { 1, 1352, 1520, 1, 1, 2, 8, 1, 0, 1, 1, 1, 1, 0 }, + { 1, 4096, 25088, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_n_small_best = + { + { 102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0 }, + { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1 }, + { 16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1 }, + { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 1 } + }; - if(m == 1) + const GeMMConfigsMatrix configs_mnkb_n_small_fallback = { - const unsigned int h0 = std::max(n / 2, 1U); - if(n <= 836.0) - { - return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, 0, 1, 0, 1, 0); - } - else + { 102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0 }, + { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 }, + { 16384, 4, 128, 1, 2, 2, 16, 1, 2, 1, 1, 1, 1, 0 }, + { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_best = + { + { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, + { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 }, + { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, + { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, + { 25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 }, + { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 }, + { 369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 }, + { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 }, + { 180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 }, + { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1 }, + { 49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 } + }; + + const GeMMConfigsMatrix configs_mnkb_fallback = + { + { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, + { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 }, + { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, + { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, + { 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0 }, + { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 }, + { 369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 }, + { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 }, + { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 }, + { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 }, + { 180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 }, + { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 }, + { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_best_batched = + { + { 3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_fallback_batched = + { + { 3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 4096, 48, 32, 36, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 } + }; + + const GeMMConfigsMatrix *configs_best_to_use = nullptr; + const GeMMConfigsMatrix *configs_fallback_to_use = nullptr; + + if(b == 1) + { + constexpr float ratio_m_gt_n = 10.f; + constexpr unsigned int n_small_thr = 4; + const float ratio = static_cast<float>(m) / static_cast<float>(n); + + if(m == 1) { - return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, 0, 1, 0, 1, 0); + // We do not need fallback in this case, as we never use cl_image for the rhs tensor + configs_best_to_use = &configs_1nkb_best; + configs_fallback_to_use = &configs_1nkb_best; } - } - else if(m < 128) - { - const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1)); - if(k >= 512) + else if(n <= n_small_thr && ratio > ratio_m_gt_n) { - return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0); + configs_best_to_use = &configs_mnkb_n_small_best; + configs_fallback_to_use = &configs_mnkb_n_small_fallback; } else { - return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0); + configs_best_to_use = &configs_mnkb_best; + configs_fallback_to_use = &configs_mnkb_fallback; } } else { - const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1)); - if(n >= 64) - { - return configure_lhs_rhs_info(m, n, 4, 8, 4, 1, h0, 0, 1, 0, 0); - } - else - { - if(k >= 512) - { - return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0); - } - else - { - return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0); - } - } + configs_best_to_use = &configs_mnkb_best_batched; + configs_fallback_to_use = &configs_mnkb_fallback_batched; } + + GEMMLHSMatrixInfo lhs_info0; + GEMMRHSMatrixInfo rhs_info0; + GEMMLHSMatrixInfo lhs_info1; + GEMMRHSMatrixInfo rhs_info1; + + std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b); + std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b); + + return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), + std::make_pair(lhs_info1, rhs_info1), + n, k, b, DataType::F16); } std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) @@ -414,9 +521,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; - const float r_mn = static_cast<float>(m) / static_cast<float>(n); - const float r_mk = static_cast<float>(m) / static_cast<float>(k); - const float r_nk = static_cast<float>(n) / static_cast<float>(k); + const float r_mn = static_cast<float>(m) / static_cast<float>(n); + const float r_mk = static_cast<float>(m) / static_cast<float>(k); + const float r_nk = static_cast<float>(n) / static_cast<float>(k); if(m == 1) { |