From 4ee8b1599dbaf7634d25607fa5ac96ba3dc6b0f2 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 16 Jul 2021 16:16:43 +0100 Subject: Update GEMM assembly kernels - Introduce Fp32 kernels with internal calculations in Bfloat16 when fast_mode is enabled - Improve kernel selection heuristics Signed-off-by: Georgios Pinitas Change-Id: I68a9e7e862b6fd2721b46e0d7cc791091c4ab279 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5965 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp index d702cffce1..436316c0f7 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -74,7 +74,7 @@ class GemmHybrid : public GemmCommon { } if (args._cfg && args._cfg->inner_block_size) { - return args._cfg->inner_block_size; + return roundup(args._cfg->inner_block_size, strategy::k_unroll()); } // Target block size (512 for FP32, scaling for other types). Don't block until size reaches 1.5X this. @@ -97,7 +97,13 @@ class GemmHybrid : public GemmCommon { // single block. static unsigned int compute_n_block(const GemmArgs &args) { if (args._cfg && args._cfg->outer_block_size) { - return args._cfg->outer_block_size; + unsigned int n_block = args._cfg->outer_block_size; + + // Needs to be (at least a single) multiple of the kernel output width. + n_block /= strategy::out_width(); + n_block = std::max(n_block, 1u) * strategy::out_width(); + + return n_block; } if (args._Nsize <= 64) { @@ -264,6 +270,17 @@ public: return total_cycles; } + + GemmConfig get_config() override { + GemmConfig c; + + c.method = GemmMethod::GEMM_HYBRID; + c.inner_block_size = _k_block; + c.outer_block_size = _n_block; + c.filter = get_type_name(); + + return c; + } }; } // namespace arm_gemm -- cgit v1.2.1