aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp')
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp163
1 files changed, 91 insertions, 72 deletions
diff --git a/src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp b/src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp
index 657018eb53..c956c347ef 100644
--- a/src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp
+++ b/src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.cpp
@@ -29,6 +29,7 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+
#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
#include <utility>
@@ -43,30 +44,31 @@ namespace gemm
{
using namespace arm_compute::misc::shape_calculator;
-ClGemmDefaultConfigReshapedBifrost::ClGemmDefaultConfigReshapedBifrost(GPUTarget gpu)
- : IClGemmKernelConfig(gpu)
+ClGemmDefaultConfigReshapedBifrost::ClGemmDefaultConfigReshapedBifrost(GPUTarget gpu) : IClGemmKernelConfig(gpu)
{
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
- using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
+ ClGemmDefaultConfigReshapedBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(&ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32,
- &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16,
- &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(
+ &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32, &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16,
+ &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(&ClGemmDefaultConfigReshapedBifrost::configure_G52_f32,
- &ClGemmDefaultConfigReshapedBifrost::configure_G52_f16,
- &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(
+ &ClGemmDefaultConfigReshapedBifrost::configure_G52_f32, &ClGemmDefaultConfigReshapedBifrost::configure_G52_f16,
+ &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&ClGemmDefaultConfigReshapedBifrost::configure_G76_f32,
- &ClGemmDefaultConfigReshapedBifrost::configure_G76_f16,
- &ClGemmDefaultConfigReshapedBifrost::configure_G76_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(
+ &ClGemmDefaultConfigReshapedBifrost::configure_G76_f32, &ClGemmDefaultConfigReshapedBifrost::configure_G76_f16,
+ &ClGemmDefaultConfigReshapedBifrost::configure_G76_u8);
ConfigurationFunctionExecutorPtr func = nullptr;
- switch(_target)
+ switch (_target)
{
case GPUTarget::G76:
func = configs_G76.get_function(data_type);
@@ -83,12 +85,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
return (this->*func)(m, n, k, b);
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(n <= 4)
+ if (n <= 4)
{
return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
}
@@ -98,12 +101,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(n <= 4)
+ if (n <= 4)
{
return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
}
@@ -113,14 +117,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(dot8_supported(CLKernelLibrary::get().get_device()))
+ if (dot8_supported(CLKernelLibrary::get().get_device()))
{
- if(n <= 4)
+ if (n <= 4)
{
return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2, true, false, false, true);
}
@@ -131,7 +136,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
}
else
{
- if(n <= 4)
+ if (n <= 4)
{
return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2, true, false, false, true);
}
@@ -142,7 +147,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
const float r_mn = static_cast<float>(m) / static_cast<float>(n);
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
@@ -154,100 +160,108 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
GEMMLHSMatrixInfo lhs_info_img;
GEMMRHSMatrixInfo rhs_info_img;
- if(workload <= 274.4000f)
+ if (workload <= 274.4000f)
{
- if(r_nk <= 0.7461f)
+ if (r_nk <= 0.7461f)
{
- if(r_mn <= 21.1667f)
+ if (r_mn <= 21.1667f)
{
return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4, false, true, true, false, false);
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
}
else
{
- if(r_mk <= 17.3926f)
+ if (r_mk <= 17.3926f)
{
- if(workload <= 542.4000f)
+ if (workload <= 542.4000f)
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
}
else
{
- if(r_nk <= 0.5463f)
+ if (r_nk <= 0.5463f)
{
- if(workload <= 11767.6001f)
+ if (workload <= 11767.6001f)
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
}
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
- if(workload <= 323.4000f)
+ if (workload <= 323.4000f)
{
return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8, false, false, false, true, false);
}
@@ -257,7 +271,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
@@ -268,7 +283,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
GEMMRHSMatrixInfo rhs_info_img;
// Get lhs_info/rhs_info in case of OpenCL buffer
- if(n <= 4)
+ if (n <= 4)
{
std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
}
@@ -279,15 +294,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
// Get lhs_info/rhs_info in case of OpenCL image
// Condition on the GPU workload
- if((m / 4) * (n / 4) >= 2560)
+ if ((m / 4) * (n / 4) >= 2560)
{
// Big workload
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8, true, true, true, false, true);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8, true, true, true, false, true);
}
else
{
// Small workload
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1, true, true, true, false, true);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1, true, true, true, false, true);
}
const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
@@ -297,7 +314,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
// In case of vector by matrix with few work-items, we use the OpenCL buffer rather than the OpenCL image2d
const bool use_cl_image2d = (n <= 4) ? false : true;
- if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
+ if (bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
{
return std::make_pair(lhs_info_img, rhs_info_img);
}
@@ -307,16 +324,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
const float r_mk = static_cast<float>(m) / static_cast<float>(k);
- if(workload <= 1595.2000f)
+ if (workload <= 1595.2000f)
{
- if(r_mk <= 2.1044f)
+ if (r_mk <= 2.1044f)
{
- if(workload <= 870.4000f)
+ if (workload <= 870.4000f)
{
return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2, true, false, true, false, false);
}
@@ -336,12 +354,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifro
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ClGemmDefaultConfigReshapedBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(n <= 4)
+ if (n <= 4)
{
return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, false, false, false, true);
}