diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/mergeresults.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/mergeresults.cpp | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/mergeresults.cpp b/src/core/NEON/kernels/arm_gemm/mergeresults.cpp index 563c31d7dc..e100d9fe46 100644 --- a/src/core/NEON/kernels/arm_gemm/mergeresults.cpp +++ b/src/core/NEON/kernels/arm_gemm/mergeresults.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -30,15 +30,20 @@ #include "arm_gemm.hpp" #include "asmlib.hpp" +#include "bfloat.hpp" #include "utils.hpp" namespace arm_gemm { template<unsigned int twidth, unsigned int height, bool sve=false, typename Tin, typename Tout> void MergeResults(Tout * out, const Tin * in, int ldc, int y0, int ymax, int x0, int xmax, const Tout *bias, Activation act, bool append) { + // NOTE: The following code is disabled to avoid calling get_vector_length(), so templated MergeResults will not + // be correct for SVE cases. This is OK as we have specialisations for all needed SVE cases anyway. + // // For SVE cases, multiply the width up by the vector length. // Use the *input* type to determine this, since this will be what the kernel operated on. - const int width = twidth * (sve ? get_vector_length<Tin>() : 1); + // const int width = twidth * (sve ? get_vector_length<Tin>() : 1); + const int width = twidth; const int full_y_blocks = (ymax - y0) / height; const int y_remainder = (ymax - y0) % height; @@ -96,6 +101,12 @@ void MergeResults(Tout * out, const Tin * in, int ldc, int y0, int ymax, int x0, #include "merges/list.hpp" +/* Cortex-A53 8x6 SGEMM kernel uses a templated merge as the optimized merge + * generator cannot cope with the width (6) not being a multiple of VL (4). */ +#ifdef __aarch64__ +template void MergeResults<6u, 8u, false, float, float>(float *, float const*, int, int, int, int, int, float const *, Activation, bool); +#endif + #if defined(__aarch64__) && defined(__ARM_FP16_ARGS) template void MergeResults<12u, 8u, false, float, __fp16>(__fp16*, float const*, int, int, int, int, int, __fp16 const*, Activation, bool); #endif @@ -104,4 +115,8 @@ template void MergeResults<12u, 8u, false, float, __fp16>(__fp16*, float const*, template void MergeResults<8u, 6u, false, float, __fp16>(__fp16*, float const*, int, int, int, int, int, __fp16 const*, Activation, bool); #endif +#if defined(__arm__) && defined(ARM_COMPUTE_ENABLE_BF16) +template void MergeResults<8u, 6u, false, float, bfloat16>(bfloat16*, float const*, int, int, int, int, int, bfloat16 const*, Activation, bool); +#endif + } // namespace arm_gemm |