aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2023-04-26 14:55:02 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2023-05-02 09:27:49 +0000
commit7a0f1bdaf74cde263b2919c7d1652b0cb87a94f3 (patch)
tree62886dac919eb95811efd76d907960dfddef0b61 /src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
parenta62129a02397ba87171ebf4477795f628dcec0f6 (diff)
downloadComputeLibrary-7a0f1bdaf74cde263b2919c7d1652b0cb87a94f3.tar.gz
Add fp16 GeMM heuristic for Arm® Mali™-G710
- Performance improvements on various networks between 5-20% Resolves COMPMID-6030 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Change-Id: Idcf7de57e6f5a94a6a94ec78229dd53c24de44f4 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/514481 Tested-by: Viet-Hoa Do <viet-hoa.do@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: bsgcomp <bsgcomp@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9524 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp')
-rw-r--r--src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp20
1 files changed, 11 insertions, 9 deletions
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
index b97ffedfe5..9350bf74bb 100644
--- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
+++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
@@ -143,7 +143,7 @@ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
- float min_acc = std::numeric_limits<float>::max();
+ size_t min_acc = std::numeric_limits<size_t>::max();
size_t min_idx = 0;
ARM_COMPUTE_ERROR_ON(configs.size() == 0);
@@ -153,18 +153,20 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConf
ARM_COMPUTE_ERROR_ON_MSG(num_cols != 14U, "The entry should have 14 integer values representing: M, N, K, B, M0, N0. K0, V0, H0, INT_LHS, INT_RHS, TRA_LHS, TRA_RHS, IMG_RHS");
ARM_COMPUTE_UNUSED(num_cols);
- // Find nearest GeMM shape
+ // Find nearest GeMM workload
+ // Note: the workload does not depend on the K dimension
for(size_t y = 0; y < num_rows; ++y)
{
- float mc0 = configs[y][0];
- float nc0 = configs[y][1];
- float kc0 = configs[y][2];
- float bc0 = configs[y][3];
- float acc = 0;
+ size_t mc0 = static_cast<size_t>(configs[y][0]);
+ size_t nc0 = static_cast<size_t>(configs[y][1]);
+ size_t kc0 = static_cast<size_t>(configs[y][2]);
+ size_t bc0 = static_cast<size_t>(configs[y][3]);
+
+ size_t acc = 0;
acc += (m - mc0) * (m - mc0);
acc += (n - nc0) * (n - nc0);
- acc += (k - kc0) * (n - kc0);
- acc += (b - bc0) * (n - bc0);
+ acc += (k - kc0) * (k - kc0);
+ acc += (b - bc0) * (b - bc0);
acc = std::sqrt(acc);
if(acc < min_acc)
{