aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2023-04-19 16:29:26 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2023-04-26 13:54:32 +0000
commitf16eed979ecaa234b308c8eb145c5f9512673a54 (patch)
tree68c730ab2dc3aede9501849f2ee7fcb6a230de1c
parent905a3c1a8883d988edf5bdc749844a4565fe5623 (diff)
downloadComputeLibrary-f16eed979ecaa234b308c8eb145c5f9512673a54.tar.gz
Change fp16 GeMM heuristic for Arm® Mali™-G77
- Replace existing heuristic with look-up tables - Expected performance improvement is between 5-15% on various models Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Change-Id: Ie26ddf66895ede131aa06fde7b200ef94d2dd467 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9472 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp64
-rw-r--r--src/gpu/cl/kernels/gemm/ClGemmHelpers.h16
-rw-r--r--src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp181
-rw-r--r--src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp98
4 files changed, 225 insertions, 134 deletions
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
index 9e7029a7ae..b97ffedfe5 100644
--- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
+++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,6 +28,7 @@
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include <limits>
#include <utility>
namespace arm_compute
@@ -42,8 +43,18 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_lhs_rhs_info(unsigned
bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
{
ARM_COMPUTE_ERROR_ON(m0 == 0 || n0 == 0);
+ ARM_COMPUTE_ERROR_ON(v0 == 0);
v0 = std::max(std::min(static_cast<int>(m / m0), static_cast<int>(v0)), static_cast<int>(1));
- h0 = std::max(std::min(static_cast<int>(n / n0), static_cast<int>(h0)), static_cast<int>(1));
+
+ if(h0 == 0)
+ {
+ // When h0 is 0, we should take the maximum H0 possible
+ h0 = std::max(n / n0, 1U);
+ }
+ else
+ {
+ h0 = std::max(std::min(static_cast<int>(n / n0), static_cast<int>(h0)), static_cast<int>(1));
+ }
const GEMMLHSMatrixInfo lhs_info(m0, k0, v0, lhs_transpose, lhs_interleave);
const GEMMRHSMatrixInfo rhs_info(n0, k0, h0, rhs_transpose, rhs_interleave, export_to_cl_image);
@@ -55,6 +66,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> select_lhs_rhs_info(std::pair<GE
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> info_buf,
unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
+ ARM_COMPUTE_ERROR_ON_MSG(info_buf.second.export_to_cl_image == true, "The fallback GeMM configuration cannot have export_to_cl_image = true");
+
const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, data_type);
const TensorShape shape = misc::shape_calculator::compute_rhs_reshaped_shape(tensor_rhs_info, info_img.second);
const TensorInfo tensor_reshaped_info(shape, 1, data_type);
@@ -127,6 +140,53 @@ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const
return ((k % mmul_k0) == 0) && (gws_y > 4);
}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ float min_acc = std::numeric_limits<float>::max();
+ size_t min_idx = 0;
+
+ ARM_COMPUTE_ERROR_ON(configs.size() == 0);
+ const size_t num_rows = configs.size();
+ const size_t num_cols = configs[0].size();
+
+ ARM_COMPUTE_ERROR_ON_MSG(num_cols != 14U, "The entry should have 14 integer values representing: M, N, K, B, M0, N0. K0, V0, H0, INT_LHS, INT_RHS, TRA_LHS, TRA_RHS, IMG_RHS");
+ ARM_COMPUTE_UNUSED(num_cols);
+
+ // Find nearest GeMM shape
+ for(size_t y = 0; y < num_rows; ++y)
+ {
+ float mc0 = configs[y][0];
+ float nc0 = configs[y][1];
+ float kc0 = configs[y][2];
+ float bc0 = configs[y][3];
+ float acc = 0;
+ acc += (m - mc0) * (m - mc0);
+ acc += (n - nc0) * (n - nc0);
+ acc += (k - kc0) * (n - kc0);
+ acc += (b - bc0) * (n - bc0);
+ acc = std::sqrt(acc);
+ if(acc < min_acc)
+ {
+ min_acc = acc;
+ min_idx = y;
+ }
+ }
+
+ // Get the configuration from the nearest GeMM shape
+ const int m0 = configs[min_idx][4];
+ const int n0 = configs[min_idx][5];
+ const int k0 = configs[min_idx][6];
+ const int v0 = configs[min_idx][7];
+ const int h0 = configs[min_idx][8];
+ const int i_lhs = configs[min_idx][9];
+ const int i_rhs = configs[min_idx][10];
+ const int t_lhs = configs[min_idx][11];
+ const int t_rhs = configs[min_idx][12];
+ const int im_rhs = configs[min_idx][13];
+
+ return configure_lhs_rhs_info(m, n, m0, n0, k0, v0, h0, i_lhs, i_rhs, t_lhs, t_rhs, im_rhs);
+}
} // namespace gemm
} // namespace kernels
} // namespace opencl
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h
index bf1e8fce82..6689b10e69 100644
--- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h
+++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,6 +35,8 @@ namespace kernels
{
namespace gemm
{
+using GeMMConfigsMatrix = std::vector<std::vector<int32_t>>;
+
/** Configure @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo
*
* @param[in] m Number of rows (M) in the LHS matrix not reshaped
@@ -103,6 +105,18 @@ Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info,
*/
bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b,
const DataType data_type, unsigned int &best_m0, unsigned int &best_n0);
+
+/** Find the preferred configurations for the LHS and RHS tensor using the GeMMConfigsMatrix provided by the user
+ *
+ * @param[in] configs List of best configurations for a limited number of GeMM shapes
+ * @param[in] m Number of rows of the LHS matrix
+ * @param[in] n Number of columns of the RHS matrix
+ * @param[in] k Number of columns of the LHS matrix, rows of the RHS matrix
+ * @param[in] b Batch size
+ *
+ * @return @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo
+ */
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b);
} // namespace gemm
} // namespace kernels
} // namespace opencl
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
index e4e35cb8ce..5d666c03a5 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2022 Arm Limited.
+ * Copyright (c) 2020-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -196,52 +196,159 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
- ARM_COMPUTE_UNUSED(k);
- ARM_COMPUTE_UNUSED(b);
+ const GeMMConfigsMatrix configs_1nkb_best =
+ {
+ { 1, 8984, 640, 1, 1, 8, 8, 1, 0, 1, 1, 1, 1, 0 },
+ { 1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 },
+ { 1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 },
+ { 1, 6512, 6404, 1, 1, 4, 8, 1, 0, 1, 0, 1, 0, 0 },
+ { 1, 5304, 640, 1, 1, 4, 4, 1, 0, 1, 0, 1, 1, 0 },
+ { 1, 1352, 1520, 1, 1, 2, 8, 1, 0, 1, 1, 1, 1, 0 },
+ { 1, 4096, 25088, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 },
+ { 1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_n_small_best =
+ {
+ { 102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0 },
+ { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1 },
+ { 16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1 },
+ { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 1 }
+ };
- if(m == 1)
+ const GeMMConfigsMatrix configs_mnkb_n_small_fallback =
{
- const unsigned int h0 = std::max(n / 2, 1U);
- if(n <= 836.0)
- {
- return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, 0, 1, 0, 1, 0);
- }
- else
+ { 102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0 },
+ { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 },
+ { 16384, 4, 128, 1, 2, 2, 16, 1, 2, 1, 1, 1, 1, 0 },
+ { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_best =
+ {
+ { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
+ { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 },
+ { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
+ { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
+ { 25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 },
+ { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 },
+ { 369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 },
+ { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
+ { 90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
+ { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
+ { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
+ { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 },
+ { 180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
+ { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
+ { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 },
+ { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1 },
+ { 49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_fallback =
+ {
+ { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
+ { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 },
+ { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 },
+ { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 },
+ { 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0 },
+ { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 },
+ { 369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 },
+ { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 },
+ { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 },
+ { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 },
+ { 180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 },
+ { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 },
+ { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 },
+ { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_best_batched =
+ {
+ { 3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 },
+ { 688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }
+ };
+
+ const GeMMConfigsMatrix configs_mnkb_fallback_batched =
+ {
+ { 3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 4096, 48, 32, 36, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 },
+ { 112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 },
+ { 2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }
+ };
+
+ const GeMMConfigsMatrix *configs_best_to_use = nullptr;
+ const GeMMConfigsMatrix *configs_fallback_to_use = nullptr;
+
+ if(b == 1)
+ {
+ constexpr float ratio_m_gt_n = 10.f;
+ constexpr unsigned int n_small_thr = 4;
+ const float ratio = static_cast<float>(m) / static_cast<float>(n);
+
+ if(m == 1)
{
- return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, 0, 1, 0, 1, 0);
+ // We do not need fallback in this case, as we never use cl_image for the rhs tensor
+ configs_best_to_use = &configs_1nkb_best;
+ configs_fallback_to_use = &configs_1nkb_best;
}
- }
- else if(m < 128)
- {
- const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
- if(k >= 512)
+ else if(n <= n_small_thr && ratio > ratio_m_gt_n)
{
- return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0);
+ configs_best_to_use = &configs_mnkb_n_small_best;
+ configs_fallback_to_use = &configs_mnkb_n_small_fallback;
}
else
{
- return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0);
+ configs_best_to_use = &configs_mnkb_best;
+ configs_fallback_to_use = &configs_mnkb_fallback;
}
}
else
{
- const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
- if(n >= 64)
- {
- return configure_lhs_rhs_info(m, n, 4, 8, 4, 1, h0, 0, 1, 0, 0);
- }
- else
- {
- if(k >= 512)
- {
- return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0);
- }
- else
- {
- return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0);
- }
- }
+ configs_best_to_use = &configs_mnkb_best_batched;
+ configs_fallback_to_use = &configs_mnkb_fallback_batched;
}
+
+ GEMMLHSMatrixInfo lhs_info0;
+ GEMMRHSMatrixInfo rhs_info0;
+ GEMMLHSMatrixInfo lhs_info1;
+ GEMMRHSMatrixInfo rhs_info1;
+
+ std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b);
+ std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b);
+
+ return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0),
+ std::make_pair(lhs_info1, rhs_info1),
+ n, k, b, DataType::F16);
}
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
@@ -414,9 +521,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
- const float r_mn = static_cast<float>(m) / static_cast<float>(n);
- const float r_mk = static_cast<float>(m) / static_cast<float>(k);
- const float r_nk = static_cast<float>(n) / static_cast<float>(k);
+ const float r_mn = static_cast<float>(m) / static_cast<float>(n);
+ const float r_mk = static_cast<float>(m) / static_cast<float>(k);
+ const float r_nk = static_cast<float>(n) / static_cast<float>(k);
if(m == 1)
{
diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp
index 4c7daf916e..29d3177424 100644
--- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp
+++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2022 Arm Limited.
+ * Copyright (c) 2020-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -138,105 +138,15 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::default_f16(unsigned int m, unsigned
CLGEMMKernelType CLGEMMDefaultTypeValhall::g77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
{
+ ARM_COMPUTE_UNUSED(m, n, k, b);
+
if(!is_rhs_constant)
{
return CLGEMMKernelType::NATIVE;
}
-
- if(m == 1)
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
-
- const float r_mn = static_cast<float>(m) / static_cast<float>(n);
- const float r_mk = static_cast<float>(m) / static_cast<float>(k);
- const float r_nk = static_cast<float>(n) / static_cast<float>(k);
- const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
-
- if(r_mk <= 0.6817956566810608)
- {
- if(workload <= 801.6000061035156)
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- else
- {
- if(r_mn <= 0.0839829258620739)
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- else
- {
- if(r_mk <= 0.24917218834161758)
- {
- return CLGEMMKernelType::RESHAPED;
- }
- else
- {
- if(workload <= 2551.75)
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- else
- {
- if(workload <= 5061.574951171875)
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- else
- {
- return CLGEMMKernelType::RESHAPED;
- }
- }
- }
- }
- }
- }
else
{
- if(r_mk <= 4.849947690963745)
- {
- if(workload <= 17618.4501953125)
- {
- if(workload <= 5224.699951171875)
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- else
- {
- if(r_nk <= 0.7933054566383362)
- {
- return CLGEMMKernelType::RESHAPED;
- }
- else
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- }
- }
- else
- {
- if(workload <= 20275.2001953125)
- {
- return CLGEMMKernelType::RESHAPED;
- }
- else
- {
- if(r_mk <= 3.07421875)
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
- else
- {
- return CLGEMMKernelType::RESHAPED;
- }
- }
- }
- }
- else
- {
- return CLGEMMKernelType::RESHAPED_ONLY_RHS;
- }
+ return CLGEMMKernelType::RESHAPED_ONLY_RHS;
}
}