aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp')
-rw-r--r--src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp64
1 files changed, 62 insertions, 2 deletions
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
index 9e7029a7ae..b97ffedfe5 100644
--- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
+++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,6 +28,7 @@
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include <limits>
#include <utility>
namespace arm_compute
@@ -42,8 +43,18 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_lhs_rhs_info(unsigned
bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
{
ARM_COMPUTE_ERROR_ON(m0 == 0 || n0 == 0);
+ ARM_COMPUTE_ERROR_ON(v0 == 0);
v0 = std::max(std::min(static_cast<int>(m / m0), static_cast<int>(v0)), static_cast<int>(1));
- h0 = std::max(std::min(static_cast<int>(n / n0), static_cast<int>(h0)), static_cast<int>(1));
+
+ if(h0 == 0)
+ {
+ // When h0 is 0, we should take the maximum H0 possible
+ h0 = std::max(n / n0, 1U);
+ }
+ else
+ {
+ h0 = std::max(std::min(static_cast<int>(n / n0), static_cast<int>(h0)), static_cast<int>(1));
+ }
const GEMMLHSMatrixInfo lhs_info(m0, k0, v0, lhs_transpose, lhs_interleave);
const GEMMRHSMatrixInfo rhs_info(n0, k0, h0, rhs_transpose, rhs_interleave, export_to_cl_image);
@@ -55,6 +66,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> select_lhs_rhs_info(std::pair<GE
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> info_buf,
unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
+ ARM_COMPUTE_ERROR_ON_MSG(info_buf.second.export_to_cl_image == true, "The fallback GeMM configuration cannot have export_to_cl_image = true");
+
const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, data_type);
const TensorShape shape = misc::shape_calculator::compute_rhs_reshaped_shape(tensor_rhs_info, info_img.second);
const TensorInfo tensor_reshaped_info(shape, 1, data_type);
@@ -127,6 +140,53 @@ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const
return ((k % mmul_k0) == 0) && (gws_y > 4);
}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ float min_acc = std::numeric_limits<float>::max();
+ size_t min_idx = 0;
+
+ ARM_COMPUTE_ERROR_ON(configs.size() == 0);
+ const size_t num_rows = configs.size();
+ const size_t num_cols = configs[0].size();
+
+ ARM_COMPUTE_ERROR_ON_MSG(num_cols != 14U, "The entry should have 14 integer values representing: M, N, K, B, M0, N0. K0, V0, H0, INT_LHS, INT_RHS, TRA_LHS, TRA_RHS, IMG_RHS");
+ ARM_COMPUTE_UNUSED(num_cols);
+
+ // Find nearest GeMM shape
+ for(size_t y = 0; y < num_rows; ++y)
+ {
+ float mc0 = configs[y][0];
+ float nc0 = configs[y][1];
+ float kc0 = configs[y][2];
+ float bc0 = configs[y][3];
+ float acc = 0;
+ acc += (m - mc0) * (m - mc0);
+ acc += (n - nc0) * (n - nc0);
+ acc += (k - kc0) * (n - kc0);
+ acc += (b - bc0) * (n - bc0);
+ acc = std::sqrt(acc);
+ if(acc < min_acc)
+ {
+ min_acc = acc;
+ min_idx = y;
+ }
+ }
+
+ // Get the configuration from the nearest GeMM shape
+ const int m0 = configs[min_idx][4];
+ const int n0 = configs[min_idx][5];
+ const int k0 = configs[min_idx][6];
+ const int v0 = configs[min_idx][7];
+ const int h0 = configs[min_idx][8];
+ const int i_lhs = configs[min_idx][9];
+ const int i_rhs = configs[min_idx][10];
+ const int t_lhs = configs[min_idx][11];
+ const int t_rhs = configs[min_idx][12];
+ const int im_rhs = configs[min_idx][13];
+
+ return configure_lhs_rhs_info(m, n, m0, n0, k0, v0, h0, i_lhs, i_rhs, t_lhs, t_rhs, im_rhs);
+}
} // namespace gemm
} // namespace kernels
} // namespace opencl