From 5c76742e654880b147d24c61ded1aa0128320d76 Mon Sep 17 00:00:00 2001 From: David Mansell Date: Fri, 15 Mar 2024 16:35:13 +0000 Subject: New SME2 heuristics. Change-Id: I69aa973e61df950060807a31230a1edd91add498 Signed-off-by: David Mansell Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11514 Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins Tested-by: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp | 2 +- src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 12 +++--- src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 4 +- src/core/NEON/kernels/arm_gemm/gemm_int8.cpp | 2 +- .../NEON/kernels/arm_gemm/gemm_interleaved.hpp | 45 ++++++++++++++++++---- src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp | 2 +- src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp | 2 +- 7 files changed, 50 insertions(+), 19 deletions(-) (limited to 'src/core/NEON') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index 5c08e6137d..0ddca04846 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -86,7 +86,7 @@ static const GemmImplementation gemm_bf16_methods[] = "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index 3b444ae333..c7adf8e4ac 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -69,19 +69,19 @@ static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = { }, { GemmMethod::GEMM_INTERLEAVED, - "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL", + "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); - return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved(args); } + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleaved(args); } }, { GemmMethod::GEMM_INTERLEAVED, - "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL", + "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved(args); } + return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleaved(args); } }, { GemmMethod::GEMM_INTERLEAVED, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index af0d38ec37..f223dea59e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -141,7 +141,7 @@ GemmImplementation::with_estimate( "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } }, #endif // ARM_COMPUTE_ENABLE_BF16 @@ -150,7 +150,7 @@ GemmImplementation::with_estimate( "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } }, #ifdef ARM_COMPUTE_ENABLE_BF16 diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index 0dc0d55b27..fedda3a47a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -63,7 +63,7 @@ static const GemmImplementation gemm_s8_methods[] = { "sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge(args); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index ae344f09b5..897ec9d05f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -190,10 +190,19 @@ void kernel_and_merge::run( auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k); #endif + // Offset C pointer in a similar way to non-quantized case above. + Tri *offset_c_ptr; + + if (c_ptr == nullptr) { + offset_c_ptr = nullptr; + } else { + offset_c_ptr = c_ptr + m_0 * ldc + n_0; + } + strat.kernel(// A and B pointers are just the packed panels. a_ptr, b_panel, // Provide relevant part of output array and row stride. - c_ptr + m_0 * ldc + n_0, ldc, + offset_c_ptr, ldc, // M, N, K sizes m_max-m_0, n_max - n_0, kern_k, // Bias, activation, accumulation. Need to offset the bias as needed. @@ -663,15 +672,27 @@ class GemmInterleaved : public GemmCommon { return roundup(args._cfg->inner_block_size, strategy::k_unroll()); } - // K blocking not supported if we are requantizing. - if (std::is_same::value) { + // K blocking not supported if we are requantizing with the merging + // kernels. + if (std::is_same::value && MergeStep) { return get_ktotal(args); } + const unsigned int L1_size = args._ci->get_L1_cache_size(); + // Special blocking for SME if (is_sme::value) { - // Don't bother to block below this size threshold, experimentally determined to be 320 for FP32 - unsigned int scaling_threshold = 1280 / sizeof(Toi); + // Target 512 bytes for 64kB L1, or 1024 bytes for 128kB L1. + unsigned int target_bytes_per_block = L1_size / 128; + + // Default cache size in gemm-linux is 32kB though - so make + // sure minimum is 512 + if (target_bytes_per_block < 512) { + target_bytes_per_block = 512; + } + + // Don't bother to block below this size threshold (1.25X target size) + unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Toi); if (get_ktotal(args) <= scaling_threshold) { return get_ktotal(args); @@ -679,7 +700,7 @@ class GemmInterleaved : public GemmCommon { // Once we are blocking, this (lower) threshold determines when we should use more blocks // NOTE: Could be that some factor-based solution would work better here. - unsigned int max_block_size = 1024 / sizeof(Toi); + unsigned int max_block_size = target_bytes_per_block / sizeof(Toi); unsigned int num_k_blocks = iceildiv(get_ktotal(args), max_block_size); @@ -688,7 +709,6 @@ class GemmInterleaved : public GemmCommon { return k_block; } - const unsigned int L1_size = args._ci->get_L1_cache_size(); unsigned int k_block; // k_block: Find out how much of the larger array can be loaded into half the cache. @@ -723,6 +743,17 @@ class GemmInterleaved : public GemmCommon { return roundup(args._cfg->outer_block_size, strategy::out_width()); } + // Special blocking for SME + if (is_sme::value) { + // If total width is less than 4x kernel width, return the entire width. + if (args._Nsize < strategy::out_width()*4) { + return roundup(args._Nsize, strategy::out_width()); + } + + // Otherwise block to single kernel width. + return strategy::out_width(); + } + unsigned int x_block; const unsigned int L2_size = args._ci->get_L2_cache_size(); const unsigned int k_block = get_k_block_size(args); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index d1c4e49edb..321c97262f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -82,7 +82,7 @@ static const GemmImplementation gemm_qint8_methods "sme2_interleaved_nomerge_s8q_mopa_1VLx4VL", [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline(args, qp); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index b85b1c4fcf..93eecf991e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -78,7 +78,7 @@ static const GemmImplementation gemm_quint8_meth "sme2_interleaved_nomerge_u8q_mopa_1VLx4VL", [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline(args, qp); } }, { -- cgit v1.2.1