aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp87
-rw-r--r--src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp3
2 files changed, 84 insertions, 6 deletions
diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
index a533f14d02..b5fc074fb4 100644
--- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
@@ -206,15 +206,94 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
- ARM_COMPUTE_UNUSED(b);
- if(n <= 4)
+ const float r_mn = static_cast<float>(m) / static_cast<float>(n);
+ const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
+
+ if(workload <= 1049.59f)
{
- return configure_lhs_rhs_info(m, n, 4, 4, 4, 8, 2, true, true, true, false);
+ if(b <= 5)
+ {
+ if(workload <= 790.39f)
+ {
+ return configure_lhs_rhs_info(m,n,2,4,4,2,2,false,false,true,false,false);
+ }
+ else
+ {
+ if(workload <= 982.39f)
+ {
+ return configure_lhs_rhs_info(m,n,4,2,4,4,4,false,false,true,false,false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m,n,2,4,4,2,1,false,true,true,false,false);
+ }
+ }
+ }
+ else
+ {
+ if(r_mn <= 0.21f)
+ {
+ if(r_mn <= 0.11f)
+ {
+ return configure_lhs_rhs_info(m,n,2,4,4,2,2,false,false,true,false,false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m,n,4,4,4,4,4,false,true,true,false,false);
+ }
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m,n,2,4,4,2,2,false,false,true,false,false);
+ }
+ }
}
else
{
- return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 8, true, true, true, false);
+ if(n <= 200)
+ {
+ if(workload <= 29772.79f)
+ {
+ if(m <= 64.5)
+ {
+ return configure_lhs_rhs_info(m,n,4,4,4,2,4,true,false,true,false,false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m,n,4,4,4,2,2,false,true,true,false,false);
+ }
+ }
+ else
+ {
+ if(r_mn <= 1.09f)
+ {
+ return configure_lhs_rhs_info(m,n,4,4,4,4,4,false,true,true,false,false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m,n,4,4,4,2,2,true,true,true,false,false);
+ }
+ }
+ }
+ else
+ {
+ if(m <= 43)
+ {
+ return configure_lhs_rhs_info(m,n,4,4,4,2,4,true,false,true,false,false);
+ }
+ else
+ {
+ if(workload <= 26364.79f)
+ {
+ return configure_lhs_rhs_info(m,n,4,4,4,2,2,false,true,true,false,false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m,n,4,4,4,4,4,false,true,true,false,false);
+ }
+ }
+ }
}
}
diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
index dcb0e0be96..16eabf069c 100644
--- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
@@ -251,8 +251,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
if(m == 1)
{
- const unsigned int h0 = std::max(n / 2, 1U);
- return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
+ return configure_lhs_rhs_info(m,n,1,2,16,1,32,false,true,false,true,false);
}
else
{