aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/gemm/reshaped/CLGEMMDefaultConfigReshapedValhall.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/gemm/reshaped/CLGEMMDefaultConfigReshapedValhall.cpp')
-rw-r--r--src/core/CL/gemm/reshaped/CLGEMMDefaultConfigReshapedValhall.cpp378
1 files changed, 341 insertions, 37 deletions
diff --git a/src/core/CL/gemm/reshaped/CLGEMMDefaultConfigReshapedValhall.cpp b/src/core/CL/gemm/reshaped/CLGEMMDefaultConfigReshapedValhall.cpp
index 13c139e201..b07092ab83 100644
--- a/src/core/CL/gemm/reshaped/CLGEMMDefaultConfigReshapedValhall.cpp
+++ b/src/core/CL/gemm/reshaped/CLGEMMDefaultConfigReshapedValhall.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Arm Limited.
+ * Copyright (c) 2020-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,7 +28,6 @@
#include "arm_compute/core/GPUTarget.h"
#include "src/core/CL/gemm/CLGEMMHelpers.h"
-#include <map>
#include <utility>
namespace arm_compute
@@ -44,30 +43,29 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
{
using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMDefaultConfigReshapedValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- // Configurations for Mali-G77
- static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G77 =
- {
- { DataType::F32, &CLGEMMDefaultConfigReshapedValhall::configure_G77_f32 },
- { DataType::F16, &CLGEMMDefaultConfigReshapedValhall::configure_G77_f16 },
- { DataType::QASYMM8, &CLGEMMDefaultConfigReshapedValhall::configure_G77_u8 },
- { DataType::QSYMM8, &CLGEMMDefaultConfigReshapedValhall::configure_G77_u8 },
- { DataType::QASYMM8_SIGNED, &CLGEMMDefaultConfigReshapedValhall::configure_G77_u8 },
- { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultConfigReshapedValhall::configure_G77_u8 }
- };
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(&CLGEMMDefaultConfigReshapedValhall::configure_G77_f32,
+ &CLGEMMDefaultConfigReshapedValhall::configure_G77_f16,
+ &CLGEMMDefaultConfigReshapedValhall::configure_G77_u8);
+
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(&CLGEMMDefaultConfigReshapedValhall::configure_G78_f32,
+ &CLGEMMDefaultConfigReshapedValhall::configure_G78_f16,
+ &CLGEMMDefaultConfigReshapedValhall::configure_G77_u8);
+
+ ConfigurationFunctionExecutorPtr func = nullptr;
switch(_target)
{
+ case GPUTarget::G78:
+ func = configs_G78.get_function(data_type);
+ break;
case GPUTarget::G77:
default:
- if(gemm_configs_G77.find(data_type) != gemm_configs_G77.end())
- {
- return (this->*gemm_configs_G77[data_type])(m, n, k, b);
- }
- else
- {
- ARM_COMPUTE_ERROR("Not supported data type");
- }
+ func = configs_G77.get_function(data_type);
+ break;
}
+
+ ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
+ return (this->*func)(m, n, k, b);
}
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
@@ -77,11 +75,11 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
if(n <= 4)
{
- return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
+ return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, 1, 0, 0, 1);
}
else
{
- return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, false, true, false, true);
+ return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, 0, 1, 0, 1);
}
}
@@ -100,13 +98,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
GEMMLHSMatrixInfo lhs_info_img;
GEMMRHSMatrixInfo rhs_info_img;
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, false, false, true, false, false);
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
if(r_mk <= 0.11824845522642136)
{
if(workload <= 880.0)
{
- return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, false, false, true, false, false);
+ return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
}
else
{
@@ -114,11 +112,11 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
{
if(workload <= 1726.4000244140625)
{
- return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, false, false, true, false, false);
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 0);
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
std::make_pair(lhs_info_buf, rhs_info_buf),
@@ -129,11 +127,11 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
{
if(workload <= 1241.6000366210938)
{
- return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, false, false, true, false, false);
+ return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
}
else
{
- return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, false, false, true, false, false);
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
}
}
}
@@ -146,7 +144,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
{
if(r_mn <= 2.545312523841858)
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
std::make_pair(lhs_info_buf, rhs_info_buf),
@@ -154,14 +152,14 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
}
else
{
- return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, false, false, true, false, false);
+ return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
}
}
else
{
if(workload <= 2881.199951171875)
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, false, false, true, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, 0, 0, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
std::make_pair(lhs_info_buf, rhs_info_buf),
@@ -169,7 +167,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
std::make_pair(lhs_info_buf, rhs_info_buf),
@@ -183,7 +181,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
{
if(r_mn <= 6.010416746139526)
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
std::make_pair(lhs_info_buf, rhs_info_buf),
@@ -191,7 +189,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, false, true, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
std::make_pair(lhs_info_buf, rhs_info_buf),
@@ -200,7 +198,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, false, true, false, true);
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
std::make_pair(lhs_info_buf, rhs_info_buf),
@@ -210,6 +208,312 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ const float r_mn = static_cast<float>(m) / static_cast<float>(n);
+ const float r_mk = static_cast<float>(m) / static_cast<float>(k);
+ const float r_nk = static_cast<float>(n) / static_cast<float>(k);
+ const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
+
+ if(workload <= 1288.0000f)
+ {
+ if(workload <= 505.6000f)
+ {
+ if(r_mn <= 0.4466f)
+ {
+ if(r_nk <= 0.2384f)
+ {
+ return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0);
+ }
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0);
+ }
+ }
+ else
+ {
+ if(r_mn <= 0.2250f)
+ {
+ if(r_mn <= 0.1599f)
+ {
+ return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ }
+ else
+ {
+ if(r_mk <= 0.7609f)
+ {
+ if(r_mn <= 2.5453f)
+ {
+ if(workload <= 1089.6000f)
+ {
+ return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 2, 4, 8, 2, 4, 0, 0, 1, 0, 1);
+ }
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 2, 4, 16, 4, 4, 0, 0, 1, 0, 1);
+ }
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
+ }
+ }
+ }
+ }
+ else
+ {
+ if(workload <= 5434.4001f)
+ {
+ if(workload <= 1603.2000f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ if(r_nk <= 0.6192f)
+ {
+ if(r_mn <= 16.1016f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ if(workload <= 2750.0000f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ if(r_mk <= 6.3151f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ }
+ }
+ }
+ else
+ {
+ if(r_mk <= 0.0387f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ if(r_mk <= 2.5859f)
+ {
+ if(r_mk <= 0.2734f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ }
+ }
+ }
+ }
+ else
+ {
+ if(r_mk <= 25.7500f)
+ {
+ if(r_mk <= 0.3615f)
+ {
+ if(r_mn <= 0.0913f)
+ {
+ if(r_mk <= 0.0683f)
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
+ }
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ }
+ else
+ {
+ if(workload <= 11174.3999f)
+ {
+ if(r_mk <= 0.8047f)
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ if(workload <= 7185.5999f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
+ }
+ }
+ }
+ else
+ {
+ if(workload <= 17917.5000f)
+ {
+ if(r_mk <= 1.5078f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
+ }
+ }
+ else
+ {
+ if(workload <= 34449.6016f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1);
+ }
+ }
+ }
+ }
+ }
+ else
+ {
+ if(r_mk <= 331.1111f)
+ {
+ if(workload <= 53397.5996f)
+ {
+ if(r_mn <= 57.8063f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
+ }
+ }
+ else
+ {
+ if(r_nk <= 0.9211f)
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
+ }
+ }
+ }
+ else
+ {
+ if(workload <= 38070.4004f)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ }
+ }
+ }
+ }
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ const float r_mn = static_cast<float>(m) / static_cast<float>(n);
+ const float r_nk = static_cast<float>(n) / static_cast<float>(k);
+ const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
+
+ if(workload <= 801.6000f)
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ if(r_mn <= 0.1211f)
+ {
+ if(workload <= 3296.0000f)
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ if(r_nk <= 1.0625f)
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1);
+ }
+ }
+ }
+ else
+ {
+ if(workload <= 5068.8000f)
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ if(r_nk <= 0.2361f)
+ {
+ if(workload <= 12630.0000f)
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 1, 0, 0, 1, 0, 1);
+ }
+ }
+ else
+ {
+ if(workload <= 178790.3984f)
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
+ }
+ }
+ }
+ }
+ }
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
@@ -217,11 +521,11 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedValha
if(n <= 4)
{
- return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, false, false, false, true);
+ return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, 0, 0, 0, 1);
}
else
{
- return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, false, true, false, true);
+ return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, 0, 1, 0, 1);
}
}
} // namespace cl_gemm