aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp')
-rw-r--r--src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp20
1 files changed, 11 insertions, 9 deletions
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
index b97ffedfe5..9350bf74bb 100644
--- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
+++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
@@ -143,7 +143,7 @@ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
- float min_acc = std::numeric_limits<float>::max();
+ size_t min_acc = std::numeric_limits<size_t>::max();
size_t min_idx = 0;
ARM_COMPUTE_ERROR_ON(configs.size() == 0);
@@ -153,18 +153,20 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConf
ARM_COMPUTE_ERROR_ON_MSG(num_cols != 14U, "The entry should have 14 integer values representing: M, N, K, B, M0, N0. K0, V0, H0, INT_LHS, INT_RHS, TRA_LHS, TRA_RHS, IMG_RHS");
ARM_COMPUTE_UNUSED(num_cols);
- // Find nearest GeMM shape
+ // Find nearest GeMM workload
+ // Note: the workload does not depend on the K dimension
for(size_t y = 0; y < num_rows; ++y)
{
- float mc0 = configs[y][0];
- float nc0 = configs[y][1];
- float kc0 = configs[y][2];
- float bc0 = configs[y][3];
- float acc = 0;
+ size_t mc0 = static_cast<size_t>(configs[y][0]);
+ size_t nc0 = static_cast<size_t>(configs[y][1]);
+ size_t kc0 = static_cast<size_t>(configs[y][2]);
+ size_t bc0 = static_cast<size_t>(configs[y][3]);
+
+ size_t acc = 0;
acc += (m - mc0) * (m - mc0);
acc += (n - nc0) * (n - nc0);
- acc += (k - kc0) * (n - kc0);
- acc += (b - bc0) * (n - bc0);
+ acc += (k - kc0) * (k - kc0);
+ acc += (b - bc0) * (b - bc0);
acc = std::sqrt(acc);
if(acc < min_acc)
{