diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp index d702cffce1..436316c0f7 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -74,7 +74,7 @@ class GemmHybrid : public GemmCommon<To, Tr> { } if (args._cfg && args._cfg->inner_block_size) { - return args._cfg->inner_block_size; + return roundup(args._cfg->inner_block_size, strategy::k_unroll()); } // Target block size (512 for FP32, scaling for other types). Don't block until size reaches 1.5X this. @@ -97,7 +97,13 @@ class GemmHybrid : public GemmCommon<To, Tr> { // single block. static unsigned int compute_n_block(const GemmArgs &args) { if (args._cfg && args._cfg->outer_block_size) { - return args._cfg->outer_block_size; + unsigned int n_block = args._cfg->outer_block_size; + + // Needs to be (at least a single) multiple of the kernel output width. + n_block /= strategy::out_width(); + n_block = std::max(n_block, 1u) * strategy::out_width(); + + return n_block; } if (args._Nsize <= 64) { @@ -264,6 +270,17 @@ public: return total_cycles; } + + GemmConfig get_config() override { + GemmConfig c; + + c.method = GemmMethod::GEMM_HYBRID; + c.inner_block_size = _k_block; + c.outer_block_size = _n_block; + c.filter = get_type_name<strategy>(); + + return c; + } }; } // namespace arm_gemm |