From c0b6f76561580414f08633a804fc548ccad65659 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 2 Nov 2020 01:37:17 +0000 Subject: COMPMID-3776: Indirect GEMM Signed-off-by: Georgios Pinitas Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343 Tested-by: Arm Jenkins Reviewed-by: Sang-Hoon Park Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp index 915227fc29..7a5fa87ee6 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp @@ -118,18 +118,27 @@ class GemmHybridQuantized : public GemmCommon { // n_block: Work out how many rows (of length k_block) will fit in the L2 // Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents. - unsigned int n_block = (((L2_size * 9) / 10) - (k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) / - (sizeof(Toi) * k_block); + const unsigned int scaled_l2_size = (L2_size * 9) / 10; + const unsigned int k_block_area = k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()); + + // .. if the L1 contents is bigger than the L2, just return a minimal size block. + if (k_block_area > scaled_l2_size) { + return strategy::out_width(); + } + + unsigned int n_block = (scaled_l2_size - k_block_area) / (sizeof(Toi) * k_block); // Needs to be (at least a single) multiple of the kernel output width. n_block /= strategy::out_width(); - n_block = std::max(n_block, 1U) * strategy::out_width(); + n_block = std::max(n_block, 1u) * strategy::out_width(); // And tune to the presented problem size. unsigned int numblocks = iceildiv(args._Nsize, n_block); n_block = iceildiv(args._Nsize, numblocks); n_block = roundup(n_block, strategy::out_width()); + assert(n_block > 0); + return n_block; } -- cgit v1.2.1