From 7a0f1bdaf74cde263b2919c7d1652b0cb87a94f3 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Wed, 26 Apr 2023 14:55:02 +0100 Subject: =?UTF-8?q?Add=20fp16=20GeMM=20heuristic=20for=20Arm=C2=AE=20Mali?= =?UTF-8?q?=E2=84=A2-G710?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Performance improvements on various networks between 5-20% Resolves COMPMID-6030 Signed-off-by: Gian Marco Iodice Change-Id: Idcf7de57e6f5a94a6a94ec78229dd53c24de44f4 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/514481 Tested-by: Viet-Hoa Do Reviewed-by: Viet-Hoa Do Comments-Addressed: bsgcomp Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9524 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp | 20 +- .../ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp | 267 ++++++++++++++++++--- .../ClGemmDefaultConfigReshapedRhsOnlyValhall.h | 3 +- src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp | 34 ++- src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h | 3 +- 5 files changed, 280 insertions(+), 47 deletions(-) diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp index b97ffedfe5..9350bf74bb 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp @@ -143,7 +143,7 @@ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const std::pair find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - float min_acc = std::numeric_limits::max(); + size_t min_acc = std::numeric_limits::max(); size_t min_idx = 0; ARM_COMPUTE_ERROR_ON(configs.size() == 0); @@ -153,18 +153,20 @@ std::pair find_lhs_rhs_info(const GeMMConf ARM_COMPUTE_ERROR_ON_MSG(num_cols != 14U, "The entry should have 14 integer values representing: M, N, K, B, M0, N0. K0, V0, H0, INT_LHS, INT_RHS, TRA_LHS, TRA_RHS, IMG_RHS"); ARM_COMPUTE_UNUSED(num_cols); - // Find nearest GeMM shape + // Find nearest GeMM workload + // Note: the workload does not depend on the K dimension for(size_t y = 0; y < num_rows; ++y) { - float mc0 = configs[y][0]; - float nc0 = configs[y][1]; - float kc0 = configs[y][2]; - float bc0 = configs[y][3]; - float acc = 0; + size_t mc0 = static_cast(configs[y][0]); + size_t nc0 = static_cast(configs[y][1]); + size_t kc0 = static_cast(configs[y][2]); + size_t bc0 = static_cast(configs[y][3]); + + size_t acc = 0; acc += (m - mc0) * (m - mc0); acc += (n - nc0) * (n - nc0); - acc += (k - kc0) * (n - kc0); - acc += (b - bc0) * (n - bc0); + acc += (k - kc0) * (k - kc0); + acc += (b - bc0) * (b - bc0); acc = std::sqrt(acc); if(acc < min_acc) { diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp index 5d666c03a5..76551b076a 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp @@ -63,6 +63,10 @@ std::pair ClGemmDefaultConfigReshapedRhsOn &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16, &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + CLGEMMConfigArray configs_G710(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + CLGEMMConfigArray configs_G715(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32, &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16, &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); @@ -74,6 +78,10 @@ std::pair ClGemmDefaultConfigReshapedRhsOn case GPUTarget::G78: func = configs_G78.get_function(data_type); break; + case GPUTarget::G710: + case GPUTarget::G610: + func = configs_G710.get_function(data_type); + break; case GPUTarget::G715: case GPUTarget::G615: func = configs_G715.get_function(data_type); @@ -224,62 +232,78 @@ std::pair ClGemmDefaultConfigReshapedRhsOn { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 } }; - const GeMMConfigsMatrix configs_mnkb_best = + const GeMMConfigsMatrix configs_mnkb_m_gt_n_best = { - { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, - { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 }, - { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, { 25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 }, - { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 }, { 369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 }, { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, { 23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, { 90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 }, - { 180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, { 12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 }, - { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1 }, - { 49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 } + { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 } }; - const GeMMConfigsMatrix configs_mnkb_fallback = + const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback = { - { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, - { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 }, - { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, { 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0 }, - { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 }, { 369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 }, { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, { 23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, { 90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 }, - { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 }, + { 12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 }, + { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = + { + { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 }, + { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1 }, + { 49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + }; + + const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = + { + { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 }, + { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 }, + }; + + const GeMMConfigsMatrix configs_mnkb_squared_best = + { + { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 }, + { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 }, + { 180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 }, + { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, + { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_squared_fallback = + { + { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 }, { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 }, { 180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 }, - { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 }, { 196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 }, - { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 } + { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, + { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 } }; const GeMMConfigsMatrix configs_mnkb_best_batched = @@ -312,6 +336,7 @@ std::pair ClGemmDefaultConfigReshapedRhsOn if(b == 1) { constexpr float ratio_m_gt_n = 10.f; + constexpr float ratio_n_gt_m = 0.1f; constexpr unsigned int n_small_thr = 4; const float ratio = static_cast(m) / static_cast(n); @@ -326,10 +351,20 @@ std::pair ClGemmDefaultConfigReshapedRhsOn configs_best_to_use = &configs_mnkb_n_small_best; configs_fallback_to_use = &configs_mnkb_n_small_fallback; } + else if(ratio > ratio_m_gt_n) + { + configs_best_to_use = &configs_mnkb_m_gt_n_best; + configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback; + } + else if(ratio < ratio_n_gt_m) + { + configs_best_to_use = &configs_mnkb_n_gt_m_best; + configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback; + } else { - configs_best_to_use = &configs_mnkb_best; - configs_fallback_to_use = &configs_mnkb_fallback; + configs_best_to_use = &configs_mnkb_squared_best; + configs_fallback_to_use = &configs_mnkb_squared_fallback; } } else @@ -634,6 +669,182 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const GeMMConfigsMatrix configs_1nkb_best = + { + { 1, 8984, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 6512, 6404, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 5304, 640, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 4096, 25088, 1, 1, 2, 8, 1, 0, 1, 0, 1, 1, 0 }, + { 1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_n_small_best = + { + { 102400, 4, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }, + { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }, + { 16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }, + { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_m_gt_n_best = + { + { 25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0 }, + { 25584, 16, 68, 1, 2, 4, 16, 1, 8, 1, 1, 1, 0, 1 }, + { 369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, + { 23036, 56, 736, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 8944, 32, 776, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 2688, 136, 1492, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 50176, 64, 300, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 }, + { 16544, 104, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 12604, 60, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 3728, 96, 196, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0 }, + { 12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 }, + }; + + const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback = + { + { 25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0 }, + { 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 }, + { 369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, + { 23036, 56, 736, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 }, + { 90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 0 }, + { 8944, 32, 776, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 }, + { 2688, 136, 1492, 1, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0 }, + { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 16544, 104, 160, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 }, + { 12604, 60, 160, 1, 2, 8, 8, 1, 8, 1, 1, 1, 0, 0 }, + { 3728, 96, 196, 1, 2, 8, 8, 1, 64, 1, 1, 1, 0, 0 }, + { 29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0 }, + { 12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 }, + }; + + const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = + { + { 24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0 }, + { 49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0 }, + { 49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = + { + { 24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0 }, + { 49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0 }, + { 49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_squared_best = + { + { 24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 }, + { 24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 }, + { 72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0 }, + { 268, 824, 5076, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 180, 420, 952, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 }, + { 1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 }, + { 272, 400, 2116, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 196, 512, 512, 1, 5, 2, 8, 1, 4, 1, 1, 1, 1, 1 }, + }; + + const GeMMConfigsMatrix configs_mnkb_squared_fallback = + { + { 24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 }, + { 24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 }, + { 72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0 }, + { 268, 824, 5076, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, + { 180, 420, 952, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0 }, + { 1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 }, + { 272, 400, 2116, 1, 2, 8, 4, 1, 4, 1, 1, 1, 0, 0 }, + { 196, 512, 512, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0 }, + }; + + const GeMMConfigsMatrix configs_mnkb_best_batched = + { + { 3136, 64, 64, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 1 }, + { 4096, 48, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 1 }, + { 24, 464, 412, 24, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 112, 184, 144, 28, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 5776, 64, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, + { 1568, 64, 40, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 }, + { 2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 } + }; + + const GeMMConfigsMatrix configs_mnkb_fallback_batched = + { + { 3136, 64, 64, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, + { 4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 }, + { 688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 }, + { 24, 464, 412, 24, 2, 8, 4, 1, 32, 1, 1, 1, 0, 0 }, + { 112, 184, 144, 28, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0 }, + { 5776, 64, 32, 36, 2, 8, 8, 1, 32, 1, 1, 1, 0, 0 }, + { 1568, 64, 40, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 }, + { 2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 } + }; + + const GeMMConfigsMatrix *configs_best_to_use = nullptr; + const GeMMConfigsMatrix *configs_fallback_to_use = nullptr; + + if(b == 1) + { + constexpr float ratio_m_gt_n = 10.f; + constexpr float ratio_n_gt_m = 0.1f; + constexpr unsigned int n_small_thr = 4; + const float ratio = static_cast(m) / static_cast(n); + + if(m == 1) + { + // We do not need fallback in this case, as we never use cl_image for the rhs tensor + configs_best_to_use = &configs_1nkb_best; + configs_fallback_to_use = &configs_1nkb_best; + } + else if(n <= n_small_thr && ratio > ratio_m_gt_n) + { + configs_best_to_use = &configs_mnkb_n_small_best; + configs_fallback_to_use = &configs_mnkb_n_small_best; + } + else if(ratio > ratio_m_gt_n) + { + configs_best_to_use = &configs_mnkb_m_gt_n_best; + configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback; + } + else if(ratio < ratio_n_gt_m) + { + configs_best_to_use = &configs_mnkb_n_gt_m_best; + configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback; + } + else + { + configs_best_to_use = &configs_mnkb_squared_best; + configs_fallback_to_use = &configs_mnkb_squared_fallback; + } + } + else + { + configs_best_to_use = &configs_mnkb_best_batched; + configs_fallback_to_use = &configs_mnkb_fallback_batched; + } + + GEMMLHSMatrixInfo lhs_info0; + GEMMRHSMatrixInfo rhs_info0; + GEMMLHSMatrixInfo lhs_info1; + GEMMRHSMatrixInfo rhs_info1; + + std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b); + std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b); + + return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), + std::make_pair(lhs_info1, rhs_info1), + n, k, b, DataType::F16); +} + std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { unsigned int best_m0; diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h index 0ec068fffd..f2952a3d30 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. + * Copyright (c) 2020-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -53,6 +53,7 @@ private: std::pair configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); }; diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp index 29d3177424..9e779d3752 100644 --- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp +++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp @@ -79,6 +79,17 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::select_kernel(const CLGEMMKernelSelec { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 } }; + // Mali-G710 and Mali-G610 configurations + static std::map gemm_g710_configs = + { + { DataType::F32, &CLGEMMDefaultTypeValhall::default_f32 }, + { DataType::F16, &CLGEMMDefaultTypeValhall::g710_f16 }, + { DataType::QASYMM8, &CLGEMMDefaultTypeValhall::default_q8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeValhall::default_q8 }, + { DataType::QSYMM8, &CLGEMMDefaultTypeValhall::default_q8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 } + }; + // Mali-G715 and Mali-G615 configurations static std::map gemm_g715_configs = { @@ -94,6 +105,13 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::select_kernel(const CLGEMMKernelSelec switch(_target) { + case GPUTarget::G710: + case GPUTarget::G610: + if(gemm_g710_configs.find(data_type) != gemm_g710_configs.end()) + { + return (this->*gemm_g710_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + } + ARM_COMPUTE_ERROR("Not supported data type"); case GPUTarget::G715: case GPUTarget::G615: if(gemm_g715_configs.find(data_type) != gemm_g715_configs.end()) @@ -140,14 +158,14 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::g77_f16(unsigned int m, unsigned int { ARM_COMPUTE_UNUSED(m, n, k, b); - if(!is_rhs_constant) - { - return CLGEMMKernelType::NATIVE; - } - else - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } + return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE; +} + +CLGEMMKernelType CLGEMMDefaultTypeValhall::g710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +{ + ARM_COMPUTE_UNUSED(m, n, k, b); + + return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE; } CLGEMMKernelType CLGEMMDefaultTypeValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h index 0893f11132..e190295ee4 100644 --- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h +++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. + * Copyright (c) 2020-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -50,6 +50,7 @@ private: CLGEMMKernelType g77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType g710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); }; -- cgit v1.2.1