aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2023-04-26 14:55:02 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2023-05-02 09:27:49 +0000
commit7a0f1bdaf74cde263b2919c7d1652b0cb87a94f3 (patch)
tree62886dac919eb95811efd76d907960dfddef0b61
parenta62129a02397ba87171ebf4477795f628dcec0f6 (diff)
downloadComputeLibrary-7a0f1bdaf74cde263b2919c7d1652b0cb87a94f3.tar.gz
Add fp16 GeMM heuristic for Arm® Mali™-G710
- Performance improvements on various networks between 5-20% Resolves COMPMID-6030 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Change-Id: Idcf7de57e6f5a94a6a94ec78229dd53c24de44f4 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/514481 Tested-by: Viet-Hoa Do <viet-hoa.do@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: bsgcomp <bsgcomp@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9524 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp20
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp267
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h3
-rw-r--r--src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp34
-rw-r--r--src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h3
5 files changed, 280 insertions, 47 deletions
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
index b97ffedfe5..9350bf74bb 100644
--- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
+++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
@@ -143,7 +143,7 @@ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
- float min_acc = std::numeric_limits<float>::max();
+ size_t min_acc = std::numeric_limits<size_t>::max();
size_t min_idx = 0;
ARM_COMPUTE_ERROR_ON(configs.size() == 0);
@@ -153,18 +153,20 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConf
ARM_COMPUTE_ERROR_ON_MSG(num_cols != 14U, "The entry should have 14 integer values representing: M, N, K, B, M0, N0. K0, V0, H0, INT_LHS, INT_RHS, TRA_LHS, TRA_RHS, IMG_RHS");
ARM_COMPUTE_UNUSED(num_cols);
- // Find nearest GeMM shape
+ // Find nearest GeMM workload
+ // Note: the workload does not depend on the K dimension
for(size_t y = 0; y < num_rows; ++y)
{
- float mc0 = configs[y][0];
- float nc0 = configs[y][1];
- float kc0 = configs[y][2];
- float bc0 = configs[y][3];
- float acc = 0;
+ size_t mc0 = static_cast<size_t>(configs[y][0]);
+ size_t nc0 = static_cast<size_t>(configs[y][1]);
+ size_t kc0 = static_cast<size_t>(configs[y][2]);
+ size_t bc0 = static_cast<size_t>(configs[y][3]);
+
+ size_t acc = 0;
acc += (m - mc0) * (m - mc0);
acc += (n - nc0) * (n - nc0);
- acc += (k - kc0) * (n - kc0);
- acc += (b - bc0) * (n - bc0);
+ acc += (k - kc0) * (k - kc0);
+ acc += (b - bc0) * (b - bc0);
acc = std::sqrt(acc);
if(acc < min_acc)
{
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
index 5d666c03a5..76551b076a 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
@@ -63,6 +63,10 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16,
&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
+ CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G710(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32,
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16,
+ &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
+
CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G715(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32,
&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16,
&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
@@ -74,6 +78,10 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
case GPUTarget::G78:
func = configs_G78.get_function(data_type);
break;
+ case GPUTarget::G710:
+ case GPUTarget::G610:
+ func = configs_G710.get_function(data_type);
+ break;
case GPUTarget::G715:
case GPUTarget::G615:
func = configs_G715.get_function(data_type);
@@ -224,62 +232,78 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
{ 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 }
};
- const GeMMConfigsMatrix configs_mnkb_best =
+ const GeMMConfigsMatrix configs_mnkb_m_gt_n_best =
{
- { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
- { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 },
- { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
{ 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
{ 25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 },
- { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 },
{ 369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 },
{ 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
{ 23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
{ 90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
{ 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
{ 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
{ 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 },
- { 180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
{ 12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
- { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
{ 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
{ 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 },
- { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1 },
- { 49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }
+ { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }
};
- const GeMMConfigsMatrix configs_mnkb_fallback =
+ const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback =
{
- { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
- { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 },
- { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
{ 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
{ 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0 },
- { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 },
{ 369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 },
{ 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
{ 23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
{ 90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
{ 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 },
- { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
{ 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
{ 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 },
+ { 12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 },
+ { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_n_gt_m_best =
+ {
+ { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 },
+ { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1 },
+ { 49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback =
+ {
+ { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 },
+ { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 },
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_squared_best =
+ {
+ { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 },
+ { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 },
+ { 180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
+ { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 },
+ { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
+ { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_squared_fallback =
+ {
+ { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 },
{ 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 },
{ 180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
{ 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 },
- { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
{ 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
- { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
- { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 },
{ 196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 },
- { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
- { 49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 }
+ { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
+ { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }
};
const GeMMConfigsMatrix configs_mnkb_best_batched =
@@ -312,6 +336,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
if(b == 1)
{
constexpr float ratio_m_gt_n = 10.f;
+ constexpr float ratio_n_gt_m = 0.1f;
constexpr unsigned int n_small_thr = 4;
const float ratio = static_cast<float>(m) / static_cast<float>(n);
@@ -326,10 +351,20 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
configs_best_to_use = &configs_mnkb_n_small_best;
configs_fallback_to_use = &configs_mnkb_n_small_fallback;
}
+ else if(ratio > ratio_m_gt_n)
+ {
+ configs_best_to_use = &configs_mnkb_m_gt_n_best;
+ configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback;
+ }
+ else if(ratio < ratio_n_gt_m)
+ {
+ configs_best_to_use = &configs_mnkb_n_gt_m_best;
+ configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback;
+ }
else
{
- configs_best_to_use = &configs_mnkb_best;
- configs_fallback_to_use = &configs_mnkb_fallback;
+ configs_best_to_use = &configs_mnkb_squared_best;
+ configs_fallback_to_use = &configs_mnkb_squared_fallback;
}
}
else
@@ -634,6 +669,182 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ const GeMMConfigsMatrix configs_1nkb_best =
+ {
+ { 1, 8984, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0 },
+ { 1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 },
+ { 1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 },
+ { 1, 6512, 6404, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 },
+ { 1, 5304, 640, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 },
+ { 1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 },
+ { 1, 4096, 25088, 1, 1, 2, 8, 1, 0, 1, 0, 1, 1, 0 },
+ { 1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_n_small_best =
+ {
+ { 102400, 4, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 },
+ { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 },
+ { 16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 },
+ { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_m_gt_n_best =
+ {
+ { 25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0 },
+ { 25584, 16, 68, 1, 2, 4, 16, 1, 8, 1, 1, 1, 0, 1 },
+ { 369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
+ { 23036, 56, 736, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 8944, 32, 776, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 2688, 136, 1492, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 50176, 64, 300, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 },
+ { 16544, 104, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 12604, 60, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 3728, 96, 196, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0 },
+ { 12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 },
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback =
+ {
+ { 25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0 },
+ { 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 },
+ { 369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
+ { 23036, 56, 736, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 },
+ { 90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 0 },
+ { 8944, 32, 776, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 },
+ { 2688, 136, 1492, 1, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0 },
+ { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 16544, 104, 160, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 },
+ { 12604, 60, 160, 1, 2, 8, 8, 1, 8, 1, 1, 1, 0, 0 },
+ { 3728, 96, 196, 1, 2, 8, 8, 1, 64, 1, 1, 1, 0, 0 },
+ { 29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0 },
+ { 12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 },
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_n_gt_m_best =
+ {
+ { 24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0 },
+ { 49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0 },
+ { 49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback =
+ {
+ { 24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0 },
+ { 49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0 },
+ { 49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_squared_best =
+ {
+ { 24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 },
+ { 24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 },
+ { 72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0 },
+ { 268, 824, 5076, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 180, 420, 952, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 },
+ { 1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 },
+ { 272, 400, 2116, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 196, 512, 512, 1, 5, 2, 8, 1, 4, 1, 1, 1, 1, 1 },
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_squared_fallback =
+ {
+ { 24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 },
+ { 24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 },
+ { 72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0 },
+ { 268, 824, 5076, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
+ { 180, 420, 952, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0 },
+ { 1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 },
+ { 272, 400, 2116, 1, 2, 8, 4, 1, 4, 1, 1, 1, 0, 0 },
+ { 196, 512, 512, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0 },
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_best_batched =
+ {
+ { 3136, 64, 64, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 1 },
+ { 4096, 48, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 1 },
+ { 24, 464, 412, 24, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 112, 184, 144, 28, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 5776, 64, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 },
+ { 1568, 64, 40, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 },
+ { 2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_fallback_batched =
+ {
+ { 3136, 64, 64, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
+ { 4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 },
+ { 688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 },
+ { 24, 464, 412, 24, 2, 8, 4, 1, 32, 1, 1, 1, 0, 0 },
+ { 112, 184, 144, 28, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0 },
+ { 5776, 64, 32, 36, 2, 8, 8, 1, 32, 1, 1, 1, 0, 0 },
+ { 1568, 64, 40, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 },
+ { 2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }
+ };
+
+ const GeMMConfigsMatrix *configs_best_to_use = nullptr;
+ const GeMMConfigsMatrix *configs_fallback_to_use = nullptr;
+
+ if(b == 1)
+ {
+ constexpr float ratio_m_gt_n = 10.f;
+ constexpr float ratio_n_gt_m = 0.1f;
+ constexpr unsigned int n_small_thr = 4;
+ const float ratio = static_cast<float>(m) / static_cast<float>(n);
+
+ if(m == 1)
+ {
+ // We do not need fallback in this case, as we never use cl_image for the rhs tensor
+ configs_best_to_use = &configs_1nkb_best;
+ configs_fallback_to_use = &configs_1nkb_best;
+ }
+ else if(n <= n_small_thr && ratio > ratio_m_gt_n)
+ {
+ configs_best_to_use = &configs_mnkb_n_small_best;
+ configs_fallback_to_use = &configs_mnkb_n_small_best;
+ }
+ else if(ratio > ratio_m_gt_n)
+ {
+ configs_best_to_use = &configs_mnkb_m_gt_n_best;
+ configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback;
+ }
+ else if(ratio < ratio_n_gt_m)
+ {
+ configs_best_to_use = &configs_mnkb_n_gt_m_best;
+ configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback;
+ }
+ else
+ {
+ configs_best_to_use = &configs_mnkb_squared_best;
+ configs_fallback_to_use = &configs_mnkb_squared_fallback;
+ }
+ }
+ else
+ {
+ configs_best_to_use = &configs_mnkb_best_batched;
+ configs_fallback_to_use = &configs_mnkb_fallback_batched;
+ }
+
+ GEMMLHSMatrixInfo lhs_info0;
+ GEMMRHSMatrixInfo rhs_info0;
+ GEMMLHSMatrixInfo lhs_info1;
+ GEMMRHSMatrixInfo rhs_info1;
+
+ std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b);
+ std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b);
+
+ return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0),
+ std::make_pair(lhs_info1, rhs_info1),
+ n, k, b, DataType::F16);
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
unsigned int best_m0;
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h
index 0ec068fffd..f2952a3d30 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2022 Arm Limited.
+ * Copyright (c) 2020-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,6 +53,7 @@ private:
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
};
diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp
index 29d3177424..9e779d3752 100644
--- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp
+++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp
@@ -79,6 +79,17 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::select_kernel(const CLGEMMKernelSelec
{ DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 }
};
+ // Mali-G710 and Mali-G610 configurations
+ static std::map<DataType, FunctionExecutorPtr> gemm_g710_configs =
+ {
+ { DataType::F32, &CLGEMMDefaultTypeValhall::default_f32 },
+ { DataType::F16, &CLGEMMDefaultTypeValhall::g710_f16 },
+ { DataType::QASYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
+ { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeValhall::default_q8 },
+ { DataType::QSYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
+ { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 }
+ };
+
// Mali-G715 and Mali-G615 configurations
static std::map<DataType, FunctionExecutorPtr> gemm_g715_configs =
{
@@ -94,6 +105,13 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::select_kernel(const CLGEMMKernelSelec
switch(_target)
{
+ case GPUTarget::G710:
+ case GPUTarget::G610:
+ if(gemm_g710_configs.find(data_type) != gemm_g710_configs.end())
+ {
+ return (this->*gemm_g710_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
+ }
+ ARM_COMPUTE_ERROR("Not supported data type");
case GPUTarget::G715:
case GPUTarget::G615:
if(gemm_g715_configs.find(data_type) != gemm_g715_configs.end())
@@ -140,14 +158,14 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::g77_f16(unsigned int m, unsigned int
{
ARM_COMPUTE_UNUSED(m, n, k, b);
- if(!is_rhs_constant)
- {
- return CLGEMMKernelType::NATIVE;
- }
- else
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
+ return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE;
+}
+
+CLGEMMKernelType CLGEMMDefaultTypeValhall::g710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
+{
+ ARM_COMPUTE_UNUSED(m, n, k, b);
+
+ return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE;
}
CLGEMMKernelType CLGEMMDefaultTypeValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h
index 0893f11132..e190295ee4 100644
--- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h
+++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2022 Arm Limited.
+ * Copyright (c) 2020-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,6 +50,7 @@ private:
CLGEMMKernelType g77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
CLGEMMKernelType g78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
CLGEMMKernelType g78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+ CLGEMMKernelType g710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
CLGEMMKernelType g715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
CLGEMMKernelType g715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
};