aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/gemm/reshaped_only_rhs
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/gemm/reshaped_only_rhs')
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp242
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h39
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp550
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h27
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h2
5 files changed, 402 insertions, 458 deletions
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp
index 9c23d9c998..c4825bfbeb 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp
@@ -29,7 +29,9 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+
#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
+
#include <utility>
namespace arm_compute
@@ -47,33 +49,39 @@ ClGemmDefaultConfigReshapedRhsOnlyBifrost::ClGemmDefaultConfigReshapedRhsOnlyBif
{
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
- using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
- unsigned int b);
-
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G51(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32,
- &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16,
- &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8);
-
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32,
- &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16,
- &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
-
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G31(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
- &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
- &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8);
-
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32,
- &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16,
- &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8);
-
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
- &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
- &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
+ using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
+ ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G51(
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8);
+
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
+
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G31(
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8);
+
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8);
+
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8);
ConfigurationFunctionExecutorPtr func = nullptr;
- switch(_target)
+ switch (_target)
{
case GPUTarget::G76:
func = configs_G76.get_function(data_type);
@@ -96,14 +104,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
return (this->*func)(m, n, k, b);
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(m == 1)
+ if (m == 1)
{
- if(n <= 2548)
+ if (n <= 2548)
{
return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false);
}
@@ -118,12 +127,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(m == 1)
+ if (m == 1)
{
const unsigned int h0 = std::max(n / 2, 1U);
return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
@@ -131,7 +141,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
else
{
const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
- if(m >= 28)
+ if (m >= 28)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1);
}
@@ -142,7 +152,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
@@ -154,9 +165,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
const bool is_workload_big = ((m * n * b) / 16) >= 2048;
- if(m == 1)
+ if (m == 1)
{
- if(n >= 8192)
+ if (n >= 8192)
{
const unsigned int h0 = std::max(n / 4, 1U);
return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false);
@@ -164,7 +175,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
else
{
const unsigned int h0 = std::max(n / 2, 1U);
- if(n <= 204)
+ if (n <= 204)
{
return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
}
@@ -177,25 +188,29 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
else
{
const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
- if(is_workload_big)
+ if (is_workload_big)
{
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true);
}
else
{
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true);
}
}
// Get lhs_info/rhs_info in case of OpenCL image
const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1));
- if(is_workload_big)
+ if (is_workload_big)
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true);
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true);
}
const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
@@ -205,7 +220,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
// In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d
const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true;
- if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
+ if (bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
{
return std::make_pair(lhs_info_img, rhs_info_img);
}
@@ -215,7 +230,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
const float r_nk = static_cast<float>(n) / static_cast<float>(k);
@@ -225,46 +241,49 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
GEMMLHSMatrixInfo lhs_info_img;
GEMMRHSMatrixInfo rhs_info_img;
- if(m == 1)
+ if (m == 1)
{
- if(r_nk <= 0.4664f)
+ if (r_nk <= 0.4664f)
{
return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
}
else
{
- if(workload <= 274.4000f)
+ if (workload <= 274.4000f)
{
return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false);
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(m == 1)
+ if (m == 1)
{
const unsigned int n0 = n < 1280 ? 2 : 4;
const unsigned int h0 = std::max(n / n0, 1U);
@@ -276,14 +295,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(m == 1)
+ if (m == 1)
{
- if(n > 2048)
+ if (n > 2048)
{
const unsigned int h0 = std::max(n / 4, 1U);
return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true);
@@ -300,7 +320,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
const float r_mn = static_cast<float>(m) / static_cast<float>(n);
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
@@ -312,57 +333,59 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
GEMMLHSMatrixInfo lhs_info_img;
GEMMRHSMatrixInfo rhs_info_img;
- if(m == 1)
+ if (m == 1)
{
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false);
- if(r_mk <= 0.0026f)
+ if (r_mk <= 0.0026f)
{
- if(r_nk <= 0.4664f)
+ if (r_nk <= 0.4664f)
{
return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
}
else
{
- if(r_mk <= 0.0148f)
+ if (r_mk <= 0.0148f)
{
return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
}
}
else
{
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false);
- if(workload <= 362.6000f)
+ if (workload <= 362.6000f)
{
return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
}
else
{
- if(r_mn <= 22.6067f)
+ if (r_mn <= 22.6067f)
{
- if(workload <= 708.8000f)
+ if (workload <= 708.8000f)
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
else
{
@@ -371,27 +394,28 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(r_nk <= 0.0917f)
+ if (r_nk <= 0.0917f)
{
return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false);
}
else
{
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
}
}
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
- if(m == 1)
+ if (m == 1)
{
return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false);
}
@@ -400,15 +424,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
const float r_mn = static_cast<float>(m) / static_cast<float>(n);
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
- if(workload <= 7449.60f)
+ if (workload <= 7449.60f)
{
- if(workload <= 691.60f)
+ if (workload <= 691.60f)
{
return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false);
}
else
{
- if(workload <= 4155.20f)
+ if (workload <= 4155.20f)
{
return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
}
@@ -420,21 +444,22 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(workload <= 16300.80f)
+ if (workload <= 16300.80f)
{
- if(r_mn <= 44.56f)
+ if (r_mn <= 44.56f)
{
GEMMLHSMatrixInfo lhs_info_buf;
GEMMRHSMatrixInfo rhs_info_buf;
GEMMLHSMatrixInfo lhs_info_img;
GEMMRHSMatrixInfo rhs_info_img;
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
else
{
@@ -448,23 +473,25 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
GEMMLHSMatrixInfo lhs_info_img;
GEMMRHSMatrixInfo rhs_info_img;
- std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
- std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
+ std::tie(lhs_info_img, rhs_info_img) =
+ configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
+ std::tie(lhs_info_buf, rhs_info_buf) =
+ configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F16);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16);
}
}
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(m == 1)
+ if (m == 1)
{
const unsigned int n0 = n < 1280 ? 2 : 4;
const unsigned int h0 = std::max(n / n0, 1U);
@@ -476,14 +503,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(dot8_supported(CLKernelLibrary::get().get_device()))
+ if (dot8_supported(CLKernelLibrary::get().get_device()))
{
- if(m == 1)
+ if (m == 1)
{
const unsigned int h0 = std::max(n / 2, 1U);
return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
@@ -497,7 +525,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
else
{
const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1));
- if(m == 1)
+ if (m == 1)
{
return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true);
}
@@ -508,12 +536,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(m == 1)
+ if (m == 1)
{
const unsigned int h0 = std::max(n / 2, 1U);
return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
@@ -524,12 +553,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(m == 1)
+ if (m == 1)
{
const unsigned int h0 = std::max(n / 2, 1U);
return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h
index 321cbb5250..77c0c8d500 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h
@@ -45,21 +45,34 @@ public:
ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu);
// Inherited overridden method
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override;
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override;
private:
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
};
} // namespace gemm
} // namespace kernels
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
index d08bf84c72..da3e2ec912 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
@@ -50,30 +50,35 @@ ClGemmDefaultConfigReshapedRhsOnlyValhall::ClGemmDefaultConfigReshapedRhsOnlyVal
{
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
- using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedRhsOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k,
- unsigned int b);
+ using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (
+ ClGemmDefaultConfigReshapedRhsOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32,
- &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16,
- &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32,
- &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16,
- &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G710(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32,
- &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16,
- &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G710(
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
- CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G715(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32,
- &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16,
- &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G715(
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
ConfigurationFunctionExecutorPtr func = nullptr;
- switch(_target)
+ switch (_target)
{
case GPUTarget::G78:
func = configs_G78.get_function(data_type);
@@ -96,29 +101,29 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
return (this->*func)(m, n, k, b);
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
- if(m == 1)
+ if (m == 1)
{
const float r_mn = static_cast<float>(m) / static_cast<float>(n);
const float r_mk = static_cast<float>(m) / static_cast<float>(k);
- if(r_mk <= 0.0064484127797186375)
+ if (r_mk <= 0.0064484127797186375)
{
- if(r_mn <= 0.0028273810748942196)
+ if (r_mn <= 0.0028273810748942196)
{
GEMMLHSMatrixInfo lhs_info_buf;
GEMMRHSMatrixInfo rhs_info_buf;
GEMMLHSMatrixInfo lhs_info_img;
GEMMRHSMatrixInfo rhs_info_img;
- const unsigned int h0 = std::max(n / 4, 1U);
+ const unsigned int h0 = std::max(n / 4, 1U);
std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, 0, 1, 0, 0, 1);
std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, 0, 1, 0, 1, 0);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
else
{
@@ -127,7 +132,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(r_mk <= 0.020312500186264515)
+ if (r_mk <= 0.020312500186264515)
{
return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, 0, 1, 0, 0, 0);
}
@@ -143,9 +148,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
const float r_mk = static_cast<float>(m) / static_cast<float>(k);
- if(workload <= 1999.2000122070312)
+ if (workload <= 1999.2000122070312)
{
- if(workload <= 747.1999816894531)
+ if (workload <= 747.1999816894531)
{
return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
}
@@ -159,15 +164,14 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
}
else
{
- if(r_mn <= 0.03348214365541935)
+ if (r_mn <= 0.03348214365541935)
{
- if(r_mk <= 0.028125000186264515)
+ if (r_mk <= 0.028125000186264515)
{
return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
}
@@ -181,8 +185,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
}
else
@@ -195,168 +198,112 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 1, 0, 1, 0);
return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
- std::make_pair(lhs_info_buf, rhs_info_buf),
- n, k, b, DataType::F32);
+ std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32);
}
}
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
- const GeMMConfigsMatrix configs_1nkb_best =
- {
- { 1, 8984, 640, 1, 1, 8, 8, 1, 0, 1, 1, 1, 1, 0 },
- { 1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 6512, 6404, 1, 1, 4, 8, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 5304, 640, 1, 1, 4, 4, 1, 0, 1, 0, 1, 1, 0 },
- { 1, 1352, 1520, 1, 1, 2, 8, 1, 0, 1, 1, 1, 1, 0 },
- { 1, 4096, 25088, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }
- };
-
- const GeMMConfigsMatrix configs_mnkb_n_small_best =
- {
- { 102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0 },
- { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1 },
- { 16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1 },
- { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 1 }
- };
-
- const GeMMConfigsMatrix configs_mnkb_n_small_fallback =
- {
- { 102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0 },
- { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 },
- { 16384, 4, 128, 1, 2, 2, 16, 1, 2, 1, 1, 1, 1, 0 },
- { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 }
- };
-
- const GeMMConfigsMatrix configs_mnkb_m_gt_n_best =
- {
- { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
- { 25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 },
- { 369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 },
- { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }
- };
-
- const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback =
- {
- { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
- { 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0 },
- { 369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 },
- { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 },
- { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 },
- { 12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 },
- { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }
- };
-
- const GeMMConfigsMatrix configs_mnkb_n_gt_m_best =
- {
- { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 },
- { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1 },
- { 49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- };
-
- const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback =
- {
- { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 },
- { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 },
+ const GeMMConfigsMatrix configs_1nkb_best = {
+ {1, 8984, 640, 1, 1, 8, 8, 1, 0, 1, 1, 1, 1, 0}, {1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0},
+ {1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 4, 8, 1, 0, 1, 0, 1, 0, 0},
+ {1, 5304, 640, 1, 1, 4, 4, 1, 0, 1, 0, 1, 1, 0}, {1, 1352, 1520, 1, 1, 2, 8, 1, 0, 1, 1, 1, 1, 0},
+ {1, 4096, 25088, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}, {1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}};
+
+ const GeMMConfigsMatrix configs_mnkb_n_small_best = {{102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0},
+ {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1},
+ {16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1},
+ {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 1}};
+
+ const GeMMConfigsMatrix configs_mnkb_n_small_fallback = {{102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0},
+ {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0},
+ {16384, 4, 128, 1, 2, 2, 16, 1, 2, 1, 1, 1, 1, 0},
+ {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0}};
+
+ const GeMMConfigsMatrix configs_mnkb_m_gt_n_best = {
+ {25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1},
+ {369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1}, {65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0},
+ {23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1},
+ {8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0},
+ {16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1},
+ {29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0},
+ {2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}};
+
+ const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback = {
+ {25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0},
+ {369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0},
+ {23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0},
+ {8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0}, {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0},
+ {16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0}, {12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0},
+ {29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0},
+ {2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}};
+
+ const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = {
+ {24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0},
+ {49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1},
+ {49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1},
};
- const GeMMConfigsMatrix configs_mnkb_squared_best =
- {
- { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 },
- { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 },
- { 180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 },
- { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
- { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }
+ const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = {
+ {24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0},
+ {49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
+ {49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0},
};
- const GeMMConfigsMatrix configs_mnkb_squared_fallback =
- {
- { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 },
- { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 },
- { 180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 },
- { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
- { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }
- };
-
- const GeMMConfigsMatrix configs_mnkb_best_batched =
- {
- { 3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }
- };
-
- const GeMMConfigsMatrix configs_mnkb_fallback_batched =
- {
- { 3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 4096, 48, 32, 36, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }
- };
+ const GeMMConfigsMatrix configs_mnkb_squared_best = {
+ {72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0},
+ {180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
+ {272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1},
+ {24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}};
+
+ const GeMMConfigsMatrix configs_mnkb_squared_fallback = {
+ {72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0},
+ {180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, {1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
+ {272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0},
+ {24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}};
+
+ const GeMMConfigsMatrix configs_mnkb_best_batched = {
+ {3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1},
+ {688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
+ {112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0},
+ {1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}};
+
+ const GeMMConfigsMatrix configs_mnkb_fallback_batched = {
+ {3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
+ {688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0},
+ {112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0},
+ {1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}};
const GeMMConfigsMatrix *configs_best_to_use = nullptr;
const GeMMConfigsMatrix *configs_fallback_to_use = nullptr;
- if(b == 1)
+ if (b == 1)
{
constexpr float ratio_m_gt_n = 10.f;
constexpr float ratio_n_gt_m = 0.1f;
constexpr unsigned int n_small_thr = 4;
const float ratio = static_cast<float>(m) / static_cast<float>(n);
- if(m == 1)
+ if (m == 1)
{
// We do not need fallback in this case, as we never use cl_image for the rhs tensor
configs_best_to_use = &configs_1nkb_best;
configs_fallback_to_use = &configs_1nkb_best;
}
- else if(n <= n_small_thr && ratio > ratio_m_gt_n)
+ else if (n <= n_small_thr && ratio > ratio_m_gt_n)
{
configs_best_to_use = &configs_mnkb_n_small_best;
configs_fallback_to_use = &configs_mnkb_n_small_fallback;
}
- else if(ratio > ratio_m_gt_n)
+ else if (ratio > ratio_m_gt_n)
{
configs_best_to_use = &configs_mnkb_m_gt_n_best;
configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback;
}
- else if(ratio < ratio_n_gt_m)
+ else if (ratio < ratio_n_gt_m)
{
configs_best_to_use = &configs_mnkb_n_gt_m_best;
configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback;
@@ -381,17 +328,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b);
std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b);
- return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0),
- std::make_pair(lhs_info1, rhs_info1),
- n, k, b, DataType::F16);
+ return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), std::make_pair(lhs_info1, rhs_info1), n, k, b,
+ DataType::F16);
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
ARM_COMPUTE_UNUSED(b);
- if(m == 1)
+ if (m == 1)
{
const unsigned int h0 = std::max(n / 2, 1U);
return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
@@ -399,7 +346,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
else
{
const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
- if(m >= 28)
+ if (m >= 28)
{
return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, 0, 1, 0, 1);
}
@@ -410,30 +357,31 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
const float r_mn = static_cast<float>(m) / static_cast<float>(n);
const float r_mk = static_cast<float>(m) / static_cast<float>(k);
const float r_nk = static_cast<float>(n) / static_cast<float>(k);
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
- if(m == 1)
+ if (m == 1)
{
- if(workload <= 278.7000f)
+ if (workload <= 278.7000f)
{
- if(workload <= 7.5000f)
+ if (workload <= 7.5000f)
{
return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
}
else
{
- if(r_mn <= 0.0031f)
+ if (r_mn <= 0.0031f)
{
- if(workload <= 256.6000f)
+ if (workload <= 256.6000f)
{
- if(workload <= 16.7500f)
+ if (workload <= 16.7500f)
{
- if(r_nk <= 1.6671f)
+ if (r_nk <= 1.6671f)
{
return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
}
@@ -454,15 +402,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(r_mk <= 0.0027f)
+ if (r_mk <= 0.0027f)
{
- if(r_mk <= 0.0014f)
+ if (r_mk <= 0.0014f)
{
return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
}
else
{
- if(workload <= 8.9500f)
+ if (workload <= 8.9500f)
{
return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
}
@@ -474,13 +422,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(workload <= 14.1500f)
+ if (workload <= 14.1500f)
{
return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
}
else
{
- if(r_mk <= 0.0041f)
+ if (r_mk <= 0.0041f)
{
return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
}
@@ -495,9 +443,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(workload <= 363.7000f)
+ if (workload <= 363.7000f)
{
- if(r_mk <= 0.0031f)
+ if (r_mk <= 0.0031f)
{
return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0);
}
@@ -514,9 +462,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(workload <= 1384.8000f)
+ if (workload <= 1384.8000f)
{
- if(workload <= 704.0000f)
+ if (workload <= 704.0000f)
{
return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 32, 0, 1, 0, 1, 0);
}
@@ -527,9 +475,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(workload <= 16761.6006f)
+ if (workload <= 16761.6006f)
{
- if(r_mn <= 187.1250f)
+ if (r_mn <= 187.1250f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 0, 0, 1, 1);
}
@@ -540,7 +488,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(r_mk <= 432.4630f)
+ if (r_mk <= 432.4630f)
{
return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 16, 0, 0, 0, 1, 1);
}
@@ -553,42 +501,37 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
const float r_mn = static_cast<float>(m) / static_cast<float>(n);
const float r_mk = static_cast<float>(m) / static_cast<float>(k);
const float r_nk = static_cast<float>(n) / static_cast<float>(k);
- if(m == 1)
+ if (m == 1)
{
- const GeMMConfigsMatrix configs_mnkb_best =
- {
- { 1, 8984, 640, 1, 1, 4, 2, 1, 0, 1, 0, 1, 1, 0 },
- { 1, 420, 392, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 644, 5288, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 6512, 6404, 1, 1, 2, 2, 1, 0, 1, 0, 1, 1, 0 },
- { 1, 5304, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 4096, 25088, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 732, 8988, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }
- };
+ const GeMMConfigsMatrix configs_mnkb_best = {
+ {1, 8984, 640, 1, 1, 4, 2, 1, 0, 1, 0, 1, 1, 0}, {1, 420, 392, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0},
+ {1, 644, 5288, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 2, 2, 1, 0, 1, 0, 1, 1, 0},
+ {1, 5304, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0}, {1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0},
+ {1, 4096, 25088, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 732, 8988, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}};
return find_lhs_rhs_info(configs_mnkb_best, m, n, k, b);
}
else
{
- if(workload <= 1384.8000f)
+ if (workload <= 1384.8000f)
{
- if(r_nk <= 0.8333f)
+ if (r_nk <= 0.8333f)
{
- if(r_mk <= 0.9119f)
+ if (r_mk <= 0.9119f)
{
return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 4, 0, 1, 0, 1, 1);
}
else
{
- if(r_nk <= 0.1181f)
+ if (r_nk <= 0.1181f)
{
return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0);
}
@@ -600,7 +543,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(r_mk <= 1.0013f)
+ if (r_mk <= 1.0013f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
}
@@ -612,11 +555,11 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(workload <= 11404.7998f)
+ if (workload <= 11404.7998f)
{
- if(r_mk <= 2.2884f)
+ if (r_mk <= 2.2884f)
{
- if(r_nk <= 0.9286f)
+ if (r_nk <= 0.9286f)
{
return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 4, 0, 1, 1, 0, 1);
}
@@ -632,9 +575,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
else
{
- if(r_nk <= 1.1926f)
+ if (r_nk <= 1.1926f)
{
- if(r_mn <= 1385.7917f)
+ if (r_mn <= 1385.7917f)
{
return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 4, 0, 1, 1, 0, 1);
}
@@ -652,12 +595,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
unsigned int best_m0;
unsigned int best_n0;
- if(is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0))
+ if (is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0))
{
return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true);
}
@@ -667,153 +611,101 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
- const GeMMConfigsMatrix configs_1nkb_best =
- {
- { 1, 8984, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 6512, 6404, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 5304, 640, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 },
- { 1, 4096, 25088, 1, 1, 2, 8, 1, 0, 1, 0, 1, 1, 0 },
- { 1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }
+ const GeMMConfigsMatrix configs_1nkb_best = {
+ {1, 8984, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0}, {1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0},
+ {1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0},
+ {1, 5304, 640, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0},
+ {1, 4096, 25088, 1, 1, 2, 8, 1, 0, 1, 0, 1, 1, 0}, {1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}};
+
+ const GeMMConfigsMatrix configs_mnkb_n_small_best = {{102400, 4, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0},
+ {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0},
+ {16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0},
+ {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}};
+
+ const GeMMConfigsMatrix configs_mnkb_m_gt_n_best = {
+ {25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 16, 1, 8, 1, 1, 1, 0, 1},
+ {369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0},
+ {23036, 56, 736, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
+ {8944, 32, 776, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {2688, 136, 1492, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
+ {50176, 64, 300, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}, {16544, 104, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
+ {12604, 60, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {3728, 96, 196, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
+ {29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0},
};
- const GeMMConfigsMatrix configs_mnkb_n_small_best =
- {
- { 102400, 4, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 },
- { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 },
- { 16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 },
- { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }
+ const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback = {
+ {25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0},
+ {369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0},
+ {23036, 56, 736, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 0},
+ {8944, 32, 776, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0}, {2688, 136, 1492, 1, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0},
+ {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {16544, 104, 160, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0},
+ {12604, 60, 160, 1, 2, 8, 8, 1, 8, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 2, 8, 8, 1, 64, 1, 1, 1, 0, 0},
+ {29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0},
};
- const GeMMConfigsMatrix configs_mnkb_m_gt_n_best =
- {
- { 25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0 },
- { 25584, 16, 68, 1, 2, 4, 16, 1, 8, 1, 1, 1, 0, 1 },
- { 369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
- { 23036, 56, 736, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 8944, 32, 776, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 2688, 136, 1492, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 50176, 64, 300, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 },
- { 16544, 104, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 12604, 60, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 3728, 96, 196, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0 },
- { 12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 },
- };
+ const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = {{24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0},
+ {49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0},
+ {49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0}};
- const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback =
- {
- { 25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0 },
- { 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 },
- { 369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
- { 23036, 56, 736, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 },
- { 90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 0 },
- { 8944, 32, 776, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 },
- { 2688, 136, 1492, 1, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0 },
- { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 16544, 104, 160, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 },
- { 12604, 60, 160, 1, 2, 8, 8, 1, 8, 1, 1, 1, 0, 0 },
- { 3728, 96, 196, 1, 2, 8, 8, 1, 64, 1, 1, 1, 0, 0 },
- { 29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0 },
- { 12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 },
- };
+ const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = {{24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0},
+ {49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0},
+ {49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0}};
- const GeMMConfigsMatrix configs_mnkb_n_gt_m_best =
- {
- { 24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0 },
- { 49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0 },
- { 49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 }
+ const GeMMConfigsMatrix configs_mnkb_squared_best = {
+ {24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0},
+ {72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
+ {180, 420, 952, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1}, {1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0},
+ {272, 400, 2116, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {196, 512, 512, 1, 5, 2, 8, 1, 4, 1, 1, 1, 1, 1},
};
- const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback =
- {
- { 24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0 },
- { 49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0 },
- { 49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 }
- };
-
- const GeMMConfigsMatrix configs_mnkb_squared_best =
- {
- { 24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 },
- { 24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 },
- { 72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0 },
- { 268, 824, 5076, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 180, 420, 952, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 },
- { 1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 },
- { 272, 400, 2116, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 196, 512, 512, 1, 5, 2, 8, 1, 4, 1, 1, 1, 1, 1 },
+ const GeMMConfigsMatrix configs_mnkb_squared_fallback = {
+ {24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0},
+ {72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0},
+ {180, 420, 952, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0}, {1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0},
+ {272, 400, 2116, 1, 2, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0},
};
- const GeMMConfigsMatrix configs_mnkb_squared_fallback =
- {
- { 24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 },
- { 24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 },
- { 72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0 },
- { 268, 824, 5076, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
- { 180, 420, 952, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0 },
- { 1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 },
- { 272, 400, 2116, 1, 2, 8, 4, 1, 4, 1, 1, 1, 0, 0 },
- { 196, 512, 512, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0 },
- };
+ const GeMMConfigsMatrix configs_mnkb_best_batched = {
+ {3136, 64, 64, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 1}, {4096, 48, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
+ {688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 1}, {24, 464, 412, 24, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
+ {112, 184, 144, 28, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {5776, 64, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1},
+ {1568, 64, 40, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}, {2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}};
- const GeMMConfigsMatrix configs_mnkb_best_batched =
- {
- { 3136, 64, 64, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 1 },
- { 4096, 48, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 1 },
- { 24, 464, 412, 24, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 112, 184, 144, 28, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 5776, 64, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
- { 1568, 64, 40, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 },
- { 2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 }
- };
-
- const GeMMConfigsMatrix configs_mnkb_fallback_batched =
- {
- { 3136, 64, 64, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
- { 4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 },
- { 688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 },
- { 24, 464, 412, 24, 2, 8, 4, 1, 32, 1, 1, 1, 0, 0 },
- { 112, 184, 144, 28, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0 },
- { 5776, 64, 32, 36, 2, 8, 8, 1, 32, 1, 1, 1, 0, 0 },
- { 1568, 64, 40, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 },
- { 2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }
- };
+ const GeMMConfigsMatrix configs_mnkb_fallback_batched = {
+ {3136, 64, 64, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0},
+ {688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 2, 8, 4, 1, 32, 1, 1, 1, 0, 0},
+ {112, 184, 144, 28, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 2, 8, 8, 1, 32, 1, 1, 1, 0, 0},
+ {1568, 64, 40, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}};
const GeMMConfigsMatrix *configs_best_to_use = nullptr;
const GeMMConfigsMatrix *configs_fallback_to_use = nullptr;
- if(b == 1)
+ if (b == 1)
{
constexpr float ratio_m_gt_n = 10.f;
constexpr float ratio_n_gt_m = 0.1f;
constexpr unsigned int n_small_thr = 4;
const float ratio = static_cast<float>(m) / static_cast<float>(n);
- if(m == 1)
+ if (m == 1)
{
// We do not need fallback in this case, as we never use cl_image for the rhs tensor
configs_best_to_use = &configs_1nkb_best;
configs_fallback_to_use = &configs_1nkb_best;
}
- else if(n <= n_small_thr && ratio > ratio_m_gt_n)
+ else if (n <= n_small_thr && ratio > ratio_m_gt_n)
{
configs_best_to_use = &configs_mnkb_n_small_best;
configs_fallback_to_use = &configs_mnkb_n_small_best;
}
- else if(ratio > ratio_m_gt_n)
+ else if (ratio > ratio_m_gt_n)
{
configs_best_to_use = &configs_mnkb_m_gt_n_best;
configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback;
}
- else if(ratio < ratio_n_gt_m)
+ else if (ratio < ratio_n_gt_m)
{
configs_best_to_use = &configs_mnkb_n_gt_m_best;
configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback;
@@ -838,17 +730,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b);
std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b);
- return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0),
- std::make_pair(lhs_info1, rhs_info1),
- n, k, b, DataType::F16);
+ return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), std::make_pair(lhs_info1, rhs_info1), n, k, b,
+ DataType::F16);
}
-std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(
+ unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
unsigned int best_m0;
unsigned int best_n0;
- if(is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0))
+ if (is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0))
{
return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true);
}
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h
index f2952a3d30..a0ea337eb1 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h
@@ -45,17 +45,26 @@ public:
ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu);
// Inherited overridden method
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override;
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override;
private:
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
- std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
+ configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
};
} // namespace gemm
} // namespace kernels
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h
index 1503e74eb6..e07ad993ed 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h
@@ -50,7 +50,7 @@ public:
*/
static std::unique_ptr<IClGemmKernelConfig> create(GPUTarget gpu)
{
- switch(get_arch_from_target(gpu))
+ switch (get_arch_from_target(gpu))
{
case GPUTarget::MIDGARD:
case GPUTarget::BIFROST: