diff options
Diffstat (limited to 'src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp')
-rw-r--r-- | src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp index b97ffedfe5..9350bf74bb 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp @@ -143,7 +143,7 @@ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - float min_acc = std::numeric_limits<float>::max(); + size_t min_acc = std::numeric_limits<size_t>::max(); size_t min_idx = 0; ARM_COMPUTE_ERROR_ON(configs.size() == 0); @@ -153,18 +153,20 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConf ARM_COMPUTE_ERROR_ON_MSG(num_cols != 14U, "The entry should have 14 integer values representing: M, N, K, B, M0, N0. K0, V0, H0, INT_LHS, INT_RHS, TRA_LHS, TRA_RHS, IMG_RHS"); ARM_COMPUTE_UNUSED(num_cols); - // Find nearest GeMM shape + // Find nearest GeMM workload + // Note: the workload does not depend on the K dimension for(size_t y = 0; y < num_rows; ++y) { - float mc0 = configs[y][0]; - float nc0 = configs[y][1]; - float kc0 = configs[y][2]; - float bc0 = configs[y][3]; - float acc = 0; + size_t mc0 = static_cast<size_t>(configs[y][0]); + size_t nc0 = static_cast<size_t>(configs[y][1]); + size_t kc0 = static_cast<size_t>(configs[y][2]); + size_t bc0 = static_cast<size_t>(configs[y][3]); + + size_t acc = 0; acc += (m - mc0) * (m - mc0); acc += (n - nc0) * (n - nc0); - acc += (k - kc0) * (n - kc0); - acc += (b - bc0) * (n - bc0); + acc += (k - kc0) * (k - kc0); + acc += (b - bc0) * (b - bc0); acc = std::sqrt(acc); if(acc < min_acc) { |