aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp')
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp168
1 files changed, 84 insertions, 84 deletions
diff --git a/src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp
index 58d0873b86..70b324eb5a 100644
--- a/src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp
+++ b/src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/GPUTarget.h"
+
#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
#include <utility>
@@ -38,26 +39,27 @@ namespace kernels
{
namespace gemm
{
-ClGemmDefaultConfigReshapedValhall::ClGemmDefaultConfigReshapedValhall(GPUTarget gpu)
- : IClGemmKernelConfig(gpu)
+ClGemmDefaultConfigReshapedValhall::ClGemmDefaultConfigReshapedValhall(GPUTarget gpu) : IClGemmKernelConfig(gpu)
{
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
- using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
+ ClGemmDefaultConfigReshapedValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(&ClGemmDefaultConfigReshapedValhall::configure_G77_f32,
- &ClGemmDefaultConfigReshapedValhall::configure_G77_f16,
- &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(
+ &ClGemmDefaultConfigReshapedValhall::configure_G77_f32, &ClGemmDefaultConfigReshapedValhall::configure_G77_f16,
+ &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(&ClGemmDefaultConfigReshapedValhall::configure_G78_f32,
- &ClGemmDefaultConfigReshapedValhall::configure_G78_f16,
- &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(
+ &ClGemmDefaultConfigReshapedValhall::configure_G78_f32, &ClGemmDefaultConfigReshapedValhall::configure_G78_f16,
+ &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
ConfigurationFunctionExecutorPtr func = nullptr;
- switch(_target)
+ switch (_target)
{
case GPUTarget::G78:
func = configs_G78.get_function(data_type);
@@ -72,12 +74,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
return (this->*func)(m, n, k, b);
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(n <= 4)
+ if (n <= 4)
{
return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, 1, 0, 0, 1);
}
@@ -87,7 +90,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
@@ -104,17 +108,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
- if(r_mk <= 0.11824845522642136)
+ if (r_mk <= 0.11824845522642136)
{
- if(workload <= 880.0)
+ if (workload <= 880.0)
{
return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
}
else
{
- if(r_nk <= 0.42521367967128754)
+ if (r_nk <= 0.42521367967128754)
{
- if(workload <= 1726.4000244140625)
+ if (workload <= 1726.4000244140625)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 0);
}
@@ -123,13 +127,12 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
}
else
{
- if(workload <= 1241.6000366210938)
+ if (workload <= 1241.6000366210938)
{
return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
}
@@ -142,17 +145,16 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(workload <= 11404.7998046875)
+ if (workload <= 11404.7998046875)
{
- if(r_mk <= 1.0126488208770752)
+ if (r_mk <= 1.0126488208770752)
{
- if(r_mn <= 2.545312523841858)
+ if (r_mn <= 2.545312523841858)
{
std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
else
{
@@ -161,43 +163,39 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(workload <= 2881.199951171875)
+ if (workload <= 2881.199951171875)
{
std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, 0, 0, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
else
{
std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
}
}
else
{
- if(r_nk <= 0.5765306055545807)
+ if (r_nk <= 0.5765306055545807)
{
- if(r_mn <= 6.010416746139526)
+ if (r_mn <= 6.010416746139526)
{
std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
else
{
std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
}
else
@@ -205,27 +203,27 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
}
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
const float r_mn = static_cast<float>(m) / static_cast<float>(n);
const float r_mk = static_cast<float>(m) / static_cast<float>(k);
const float r_nk = static_cast<float>(n) / static_cast<float>(k);
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
- if(workload <= 1288.0000f)
+ if (workload <= 1288.0000f)
{
- if(workload <= 505.6000f)
+ if (workload <= 505.6000f)
{
- if(r_mn <= 0.4466f)
+ if (r_mn <= 0.4466f)
{
- if(r_nk <= 0.2384f)
+ if (r_nk <= 0.2384f)
{
return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
}
@@ -241,9 +239,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(r_mn <= 0.2250f)
+ if (r_mn <= 0.2250f)
{
- if(r_mn <= 0.1599f)
+ if (r_mn <= 0.1599f)
{
return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
}
@@ -254,11 +252,11 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(r_mk <= 0.7609f)
+ if (r_mk <= 0.7609f)
{
- if(r_mn <= 2.5453f)
+ if (r_mn <= 2.5453f)
{
- if(workload <= 1089.6000f)
+ if (workload <= 1089.6000f)
{
return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
}
@@ -281,29 +279,29 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(workload <= 5434.4001f)
+ if (workload <= 5434.4001f)
{
- if(workload <= 1603.2000f)
+ if (workload <= 1603.2000f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
}
else
{
- if(r_nk <= 0.6192f)
+ if (r_nk <= 0.6192f)
{
- if(r_mn <= 16.1016f)
+ if (r_mn <= 16.1016f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
}
else
{
- if(workload <= 2750.0000f)
+ if (workload <= 2750.0000f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
}
else
{
- if(r_mk <= 6.3151f)
+ if (r_mk <= 6.3151f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
}
@@ -316,15 +314,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(r_mk <= 0.0387f)
+ if (r_mk <= 0.0387f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
}
else
{
- if(r_mk <= 2.5859f)
+ if (r_mk <= 2.5859f)
{
- if(r_mk <= 0.2734f)
+ if (r_mk <= 0.2734f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
}
@@ -343,13 +341,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(r_mk <= 25.7500f)
+ if (r_mk <= 25.7500f)
{
- if(r_mk <= 0.3615f)
+ if (r_mk <= 0.3615f)
{
- if(r_mn <= 0.0913f)
+ if (r_mn <= 0.0913f)
{
- if(r_mk <= 0.0683f)
+ if (r_mk <= 0.0683f)
{
return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
}
@@ -365,15 +363,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(workload <= 11174.3999f)
+ if (workload <= 11174.3999f)
{
- if(r_mk <= 0.8047f)
+ if (r_mk <= 0.8047f)
{
return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
}
else
{
- if(workload <= 7185.5999f)
+ if (workload <= 7185.5999f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
}
@@ -385,9 +383,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(workload <= 17917.5000f)
+ if (workload <= 17917.5000f)
{
- if(r_mk <= 1.5078f)
+ if (r_mk <= 1.5078f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
}
@@ -398,7 +396,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(workload <= 34449.6016f)
+ if (workload <= 34449.6016f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
}
@@ -412,11 +410,11 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(r_mk <= 331.1111f)
+ if (r_mk <= 331.1111f)
{
- if(workload <= 53397.5996f)
+ if (workload <= 53397.5996f)
{
- if(r_mn <= 57.8063f)
+ if (r_mn <= 57.8063f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
}
@@ -427,7 +425,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(r_nk <= 0.9211f)
+ if (r_nk <= 0.9211f)
{
return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
}
@@ -439,7 +437,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(workload <= 38070.4004f)
+ if (workload <= 38070.4004f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
}
@@ -453,27 +451,28 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
const float r_mn = static_cast<float>(m) / static_cast<float>(n);
const float r_nk = static_cast<float>(n) / static_cast<float>(k);
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
- if(workload <= 801.6000f)
+ if (workload <= 801.6000f)
{
return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
}
else
{
- if(r_mn <= 0.1211f)
+ if (r_mn <= 0.1211f)
{
- if(workload <= 3296.0000f)
+ if (workload <= 3296.0000f)
{
return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
}
else
{
- if(r_nk <= 1.0625f)
+ if (r_nk <= 1.0625f)
{
return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
}
@@ -485,15 +484,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(workload <= 5068.8000f)
+ if (workload <= 5068.8000f)
{
return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
}
else
{
- if(r_nk <= 0.2361f)
+ if (r_nk <= 0.2361f)
{
- if(workload <= 12630.0000f)
+ if (workload <= 12630.0000f)
{
return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
}
@@ -504,7 +503,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
else
{
- if(workload <= 178790.3984f)
+ if (workload <= 178790.3984f)
{
return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
}
@@ -518,12 +517,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValha
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(n <= 4)
+ if (n <= 4)
{
return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, 0, 0, 0, 1);
}