diff options
Diffstat (limited to 'src/core/CL/gemm/reshaped_only_rhs')
-rw-r--r-- | src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp | 69 |
1 files changed, 62 insertions, 7 deletions
diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp index 0a0fc5d152..3105db6693 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp @@ -151,15 +151,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi // Get lhs_info/rhs_info in case of OpenCL buffer if(m == 1) { - if((n / 4) >= 2048) + if(n <= 204.0) { - const unsigned int h0 = std::max(n / 4, 1U); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true); + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false); } else { - const unsigned int h0 = std::max(n / 2, 1U); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true); + return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 32, false, true, false, true, false); } } else @@ -247,7 +245,6 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); - ARM_COMPUTE_UNUSED(b); if(m == 1) { @@ -255,7 +252,65 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi } else { - return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true); + const float r_mn = static_cast<float>(m) / static_cast<float>(n); + const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; + + if(workload <= 7449.60f) + { + if(workload <= 691.60f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false); + } + else + { + if(workload <= 4155.20f) + { + return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false); + } + } + } + else + { + if(workload <= 16300.80f) + { + if(r_mn <= 44.56f) + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + else + { + return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + } + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F16); + } + } } } |