aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp')
-rw-r--r--src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp69
1 files changed, 62 insertions, 7 deletions
diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
index 0a0fc5d152..3105db6693 100644
--- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
@@ -151,15 +151,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
// Get lhs_info/rhs_info in case of OpenCL buffer
if(m == 1)
{
- if((n / 4) >= 2048)
+ if(n <= 204.0)
{
- const unsigned int h0 = std::max(n / 4, 1U);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true);
+ return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
}
else
{
- const unsigned int h0 = std::max(n / 2, 1U);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
+ return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 32, false, true, false, true, false);
}
}
else
@@ -247,7 +245,6 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
- ARM_COMPUTE_UNUSED(b);
if(m == 1)
{
@@ -255,7 +252,65 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
}
else
{
- return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
+ const float r_mn = static_cast<float>(m) / static_cast<float>(n);
+ const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
+
+ if(workload <= 7449.60f)
+ {
+ if(workload <= 691.60f)
+ {
+ return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false);
+ }
+ else
+ {
+ if(workload <= 4155.20f)
+ {
+ return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false);
+ }
+ }
+ }
+ else
+ {
+ if(workload <= 16300.80f)
+ {
+ if(r_mn <= 44.56f)
+ {
+ GEMMLHSMatrixInfo lhs_info_buf;
+ GEMMRHSMatrixInfo rhs_info_buf;
+ GEMMLHSMatrixInfo lhs_info_img;
+ GEMMRHSMatrixInfo rhs_info_img;
+
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
+
+ return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+ std::make_pair(lhs_info_buf, rhs_info_buf),
+ n, k, b, DataType::F16);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
+ }
+ }
+ else
+ {
+ GEMMLHSMatrixInfo lhs_info_buf;
+ GEMMRHSMatrixInfo rhs_info_buf;
+ GEMMLHSMatrixInfo lhs_info_img;
+ GEMMRHSMatrixInfo rhs_info_img;
+
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
+
+ return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+ std::make_pair(lhs_info_buf, rhs_info_buf),
+ n, k, b, DataType::F16);
+ }
+ }
}
}