aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2020-11-23 16:10:27 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-11-23 18:17:36 +0000
commit8919a1a849e425aefcd09c5db5f6f9f2e403d4e9 (patch)
treeefba31c6a28143db7724dac008e19b1636bd7b16
parent5fa963fbbc00c716e120287051747b144e2d784c (diff)
downloadComputeLibrary-8919a1a849e425aefcd09c5db5f6f9f2e403d4e9.tar.gz
COMPMID-4018: Fix heuristic fallback for CLGEMMReshapedRHSOnly for
Mali-G52 - Missing fallback in case of export to cl_image Change-Id: I5bb3013fd1350628f16e4709c4bb31999fece22d Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4531 Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp6
-rw-r--r--src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp29
-rw-r--r--src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp125
3 files changed, 87 insertions, 73 deletions
diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
index 59a2a82edf..46eeff3524 100644
--- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
@@ -269,13 +269,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
- if(workload <= 232.8000f)
+ if(workload <= 323.4000f)
{
- return configure_lhs_rhs_info(m, n, 2, 4, 4, 4, 4, true, true, true, false, false);
+ return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8, false, false, false, true, false);
}
else
{
- return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, true, true, true, false, false);
+ return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2, true, true, true, false, false);
}
}
diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
index a2c1ed2c8e..d5b76d8eaf 100644
--- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
@@ -322,8 +322,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
const float r_mk = static_cast<float>(m) / static_cast<float>(k);
const float r_nk = static_cast<float>(n) / static_cast<float>(k);
+ GEMMLHSMatrixInfo lhs_info_buf;
+ GEMMRHSMatrixInfo rhs_info_buf;
+ GEMMLHSMatrixInfo lhs_info_img;
+ GEMMRHSMatrixInfo rhs_info_img;
+
if(m == 1)
{
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
+
if(r_mk <= 0.0026f)
{
if(r_nk <= 0.4664f)
@@ -332,7 +339,10 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
}
else
{
- return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
+ return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+ std::make_pair(lhs_info_buf, rhs_info_buf),
+ n, k, b, DataType::F16);
}
}
else
@@ -343,12 +353,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
}
else
{
- return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
+ return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+ std::make_pair(lhs_info_buf, rhs_info_buf),
+ n, k, b, DataType::F16);
}
}
}
else
{
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
+
if(workload <= 362.6000f)
{
return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
@@ -359,7 +374,10 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
{
if(workload <= 708.8000f)
{
- return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
+ return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+ std::make_pair(lhs_info_buf, rhs_info_buf),
+ n, k, b, DataType::F16);
}
else
{
@@ -374,7 +392,10 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
}
else
{
- return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
+ return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+ std::make_pair(lhs_info_buf, rhs_info_buf),
+ n, k, b, DataType::F16);
}
}
}
diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp
index 46d07fffba..0bda38e5e9 100644
--- a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp
+++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp
@@ -445,8 +445,6 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f16(unsigned int m, unsigned
CLGEMMKernelType CLGEMMKernelSelectionBifrost::g52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
{
- ARM_COMPUTE_UNUSED(b);
-
if (!is_rhs_constant)
{
return CLGEMMKernelType::NATIVE_V1;
@@ -457,26 +455,25 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g52_f16(unsigned int m, unsigned
return CLGEMMKernelType::RESHAPED_ONLY_RHS;
}
- const float r_mn = static_cast<float>(m) / static_cast<float>(n);
- const float r_mk = static_cast<float>(m) / static_cast<float>(k);
- const float r_nk = static_cast<float>(n) / static_cast<float>(k);
- const float r_mnk = static_cast<float>(m) / (static_cast<float>(n) * static_cast<float>(k));
-
- if(r_mn <= 22.9200f)
+ if(n <= 127.0000f)
{
- if(r_mk <= 0.0157f)
+ if(n <= 63.5000f)
{
return CLGEMMKernelType::RESHAPED_ONLY_RHS;
}
else
{
- if(r_mnk <= 7809.3750f)
+ if(m <= 3616.0000f)
{
- if(r_mnk <= 101.7937f)
+ if(b <= 18.5000f)
{
- if(r_mn <= 0.4594f)
+ if(m <= 2970.5000f)
{
- if(r_mk <= 0.0557f)
+ return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+ }
+ else
+ {
+ if(k <= 104.0000f)
{
return CLGEMMKernelType::RESHAPED_ONLY_RHS;
}
@@ -485,80 +482,76 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g52_f16(unsigned int m, unsigned
return CLGEMMKernelType::RESHAPED;
}
}
- else
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
}
else
{
- if(r_nk <= 0.4396f)
+ return CLGEMMKernelType::RESHAPED;
+ }
+ }
+ else
+ {
+ return CLGEMMKernelType::RESHAPED;
+ }
+ }
+ }
+ else
+ {
+ if(m <= 12.5000f)
+ {
+ return CLGEMMKernelType::RESHAPED_ONLY_RHS;
+ }
+ else
+ {
+ if(k <= 104.0000f)
+ {
+ if(b <= 18.5000f)
+ {
+ if(m <= 490.0000f)
{
- if(r_mn <= 1.5182f)
+ if(n <= 272.0000f)
{
- if(r_mnk <= 1709.9167f)
- {
- return CLGEMMKernelType::RESHAPED;
- }
- else
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
+ return CLGEMMKernelType::RESHAPED_ONLY_RHS;
}
else
{
- if(r_mnk <= 1330.6000f)
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- else
- {
- return CLGEMMKernelType::RESHAPED;
- }
+ return CLGEMMKernelType::RESHAPED;
}
}
else
{
- if(r_mn <= 2.5896f)
+ return CLGEMMKernelType::RESHAPED;
+ }
+ }
+ else
+ {
+ return CLGEMMKernelType::RESHAPED;
+ }
+ }
+ else
+ {
+ if(m <= 226.0000f)
+ {
+ if(n <= 140.0000f)
+ {
+ if(m <= 179.5000f)
{
return CLGEMMKernelType::RESHAPED;
}
else
{
- if(r_mnk <= 326.6667f)
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- else
- {
- return CLGEMMKernelType::RESHAPED;
- }
+ return CLGEMMKernelType::RESHAPED_ONLY_RHS;
}
}
+ else
+ {
+ return CLGEMMKernelType::RESHAPED;
+ }
+ }
+ else
+ {
+ return CLGEMMKernelType::RESHAPED;
}
}
- else
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- }
- }
- else
- {
- if(r_mn <= 86.7578f)
- {
- if(r_mnk <= 11231.6406f)
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- else
- {
- return CLGEMMKernelType::RESHAPED;
- }
- }
- else
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
}
}
}