aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp15
1 files changed, 12 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
index 915227fc29..7a5fa87ee6 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp
@@ -118,18 +118,27 @@ class GemmHybridQuantized : public GemmCommon<To, Tr> {
// n_block: Work out how many rows (of length k_block) will fit in the L2
// Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents.
- unsigned int n_block = (((L2_size * 9) / 10) - (k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) /
- (sizeof(Toi) * k_block);
+ const unsigned int scaled_l2_size = (L2_size * 9) / 10;
+ const unsigned int k_block_area = k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height());
+
+ // .. if the L1 contents is bigger than the L2, just return a minimal size block.
+ if (k_block_area > scaled_l2_size) {
+ return strategy::out_width();
+ }
+
+ unsigned int n_block = (scaled_l2_size - k_block_area) / (sizeof(Toi) * k_block);
// Needs to be (at least a single) multiple of the kernel output width.
n_block /= strategy::out_width();
- n_block = std::max(n_block, 1U) * strategy::out_width();
+ n_block = std::max(n_block, 1u) * strategy::out_width();
// And tune to the presented problem size.
unsigned int numblocks = iceildiv(args._Nsize, n_block);
n_block = iceildiv(args._Nsize, numblocks);
n_block = roundup(n_block, strategy::out_width());
+ assert(n_block > 0);
+
return n_block;
}