diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2023-04-26 14:55:02 +0100 |
---|---|---|
committer | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2023-05-02 09:27:49 +0000 |
commit | 7a0f1bdaf74cde263b2919c7d1652b0cb87a94f3 (patch) | |
tree | 62886dac919eb95811efd76d907960dfddef0b61 /src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp | |
parent | a62129a02397ba87171ebf4477795f628dcec0f6 (diff) | |
download | ComputeLibrary-7a0f1bdaf74cde263b2919c7d1652b0cb87a94f3.tar.gz |
Add fp16 GeMM heuristic for Arm® Mali™-G710
- Performance improvements on various networks between 5-20%
Resolves COMPMID-6030
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Change-Id: Idcf7de57e6f5a94a6a94ec78229dd53c24de44f4
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/514481
Tested-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: bsgcomp <bsgcomp@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9524
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp')
-rw-r--r-- | src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp index b97ffedfe5..9350bf74bb 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp @@ -143,7 +143,7 @@ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - float min_acc = std::numeric_limits<float>::max(); + size_t min_acc = std::numeric_limits<size_t>::max(); size_t min_idx = 0; ARM_COMPUTE_ERROR_ON(configs.size() == 0); @@ -153,18 +153,20 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> find_lhs_rhs_info(const GeMMConf ARM_COMPUTE_ERROR_ON_MSG(num_cols != 14U, "The entry should have 14 integer values representing: M, N, K, B, M0, N0. K0, V0, H0, INT_LHS, INT_RHS, TRA_LHS, TRA_RHS, IMG_RHS"); ARM_COMPUTE_UNUSED(num_cols); - // Find nearest GeMM shape + // Find nearest GeMM workload + // Note: the workload does not depend on the K dimension for(size_t y = 0; y < num_rows; ++y) { - float mc0 = configs[y][0]; - float nc0 = configs[y][1]; - float kc0 = configs[y][2]; - float bc0 = configs[y][3]; - float acc = 0; + size_t mc0 = static_cast<size_t>(configs[y][0]); + size_t nc0 = static_cast<size_t>(configs[y][1]); + size_t kc0 = static_cast<size_t>(configs[y][2]); + size_t bc0 = static_cast<size_t>(configs[y][3]); + + size_t acc = 0; acc += (m - mc0) * (m - mc0); acc += (n - nc0) * (n - nc0); - acc += (k - kc0) * (n - kc0); - acc += (b - bc0) * (n - bc0); + acc += (k - kc0) * (k - kc0); + acc += (b - bc0) * (b - bc0); acc = std::sqrt(acc); if(acc < min_acc) { |