aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2022-10-04 15:29:34 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2022-10-04 16:35:17 +0000
commitdb14af697b934d684d8b3d63a00ad5bea5c07bfb (patch)
treefce94cd25cb6a4285b1f58473452157baf302830
parentb5368fb3da65ca1d31e6acd6cd45b8b6b789f1eb (diff)
downloadComputeLibrary-db14af697b934d684d8b3d63a00ad5bea5c07bfb.tar.gz
Update GEMM reshaped rhs only heuristic
Resolves COMPMID-5631 Change-Id: I37d1d0d043f8d44d782d2225091af607ad131b58 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8364 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp143
1 files changed, 40 insertions, 103 deletions
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
index 97762980be..e4e35cb8ce 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
@@ -413,163 +413,100 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
- const float r_mn = static_cast<float>(m) / static_cast<float>(n);
- const float r_mk = static_cast<float>(m) / static_cast<float>(k);
- const float r_nk = static_cast<float>(n) / static_cast<float>(k);
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
+ const float r_mn = static_cast<float>(m) / static_cast<float>(n);
+ const float r_mk = static_cast<float>(m) / static_cast<float>(k);
+ const float r_nk = static_cast<float>(n) / static_cast<float>(k);
if(m == 1)
{
- if(r_mn <= 0.0038f)
+ if(r_mn <= 0.0045f)
{
- if(workload <= 353.9000f)
+ if(workload <= 278.7000f)
{
- if(workload <= 278.7000f)
- {
- return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0);
- }
- else
- {
- if(r_mk <= 0.0004f)
- {
- return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0);
- }
- else
- {
- if(r_mk <= 0.0030f)
- {
- return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 8, 0, 1, 1, 0, 1);
- }
- else
- {
- return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0);
- }
- }
- }
+ return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 0, 0, 1, 1);
}
else
{
- if(r_nk <= 1.9384f)
- {
- return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0);
- }
- else
- {
- return configure_lhs_rhs_info(m, n, 1, 8, 4, 1, 8, 0, 1, 1, 0, 1);
- }
+ return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 32, 0, 0, 1, 0, 0);
}
}
else
{
- if(r_nk <= 1.0368f)
- {
- return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, 0, 0, 1, 0, 0);
- }
- else
- {
- return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, 32, 0, 0, 1, 0, 0);
- }
+ return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 0, 1, 0, 0);
}
}
else
{
- if(workload <= 1422.4000f)
+ if(workload <= 1384.8000f)
{
- if(workload <= 704.0000f)
+ if(r_nk <= 0.8333f)
{
- return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0);
- }
- else
- {
- if(workload <= 1197.6000f)
+ if(r_mk <= 0.9119f)
{
- return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1);
+ return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 4, 0, 1, 0, 1, 1);
}
else
{
- if(workload <= 1241.6000f)
+ if(r_nk <= 0.1181f)
{
- return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0);
+ return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0);
}
else
{
- return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1);
+ return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0);
}
}
}
+ else
+ {
+ if(r_mk <= 1.0013f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1);
+ }
+ }
}
else
{
- if(workload <= 2769.6000f)
+ if(workload <= 11404.7998f)
{
- if(workload <= 1846.4000f)
+ if(r_mk <= 2.2884f)
{
- if(r_mn <= 2.4927f)
+ if(r_nk <= 0.9286f)
{
- return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0);
+ return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 4, 0, 1, 1, 0, 1);
}
else
{
- return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0);
+ return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
}
}
else
{
- if(r_mn <= 0.6261f)
- {
- return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0);
- }
- else
- {
- if(r_mk <= 3.4453f)
- {
- if(r_mn <= 1.4135f)
- {
- return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0);
- }
- else
- {
- return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0);
- }
- }
- else
- {
- return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0);
- }
- }
+ return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1);
}
}
else
{
- if(r_nk <= 0.0302f)
- {
- return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 8, 0, 1, 1, 0, 1);
- }
- else
+ if(r_nk <= 1.1926f)
{
- if(r_mk <= 181.3750f)
+ if(r_mn <= 1385.7917f)
{
- return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0);
+ return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 4, 0, 1, 1, 0, 1);
}
else
{
- if(workload <= 28035.2002f)
- {
- return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0);
- }
- else
- {
- if(r_mk <= 808.6667f)
- {
- return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0);
- }
- else
- {
- return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 16, 0, 1, 1, 0, 0);
- }
- }
+ return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 32, 0, 1, 1, 0, 0);
}
}
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 32, 0, 1, 1, 0, 1);
+ }
}
}
}