diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/mergeresults.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/mergeresults.cpp | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/mergeresults.cpp b/src/core/NEON/kernels/arm_gemm/mergeresults.cpp index 17566db375..e100d9fe46 100644 --- a/src/core/NEON/kernels/arm_gemm/mergeresults.cpp +++ b/src/core/NEON/kernels/arm_gemm/mergeresults.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018, 2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -25,21 +25,25 @@ /* As some of the merges need these headers, but are all included in the * arm_gemm namespace, put these headers here. */ #include <algorithm> -#include <limits> #include <arm_neon.h> #include "arm_gemm.hpp" #include "asmlib.hpp" +#include "bfloat.hpp" #include "utils.hpp" namespace arm_gemm { template<unsigned int twidth, unsigned int height, bool sve=false, typename Tin, typename Tout> void MergeResults(Tout * out, const Tin * in, int ldc, int y0, int ymax, int x0, int xmax, const Tout *bias, Activation act, bool append) { + // NOTE: The following code is disabled to avoid calling get_vector_length(), so templated MergeResults will not + // be correct for SVE cases. This is OK as we have specialisations for all needed SVE cases anyway. + // // For SVE cases, multiply the width up by the vector length. // Use the *input* type to determine this, since this will be what the kernel operated on. - const int width = twidth * (sve ? get_vector_length<Tin>() : 1); + // const int width = twidth * (sve ? get_vector_length<Tin>() : 1); + const int width = twidth; const int full_y_blocks = (ymax - y0) / height; const int y_remainder = (ymax - y0) % height; @@ -111,4 +115,8 @@ template void MergeResults<12u, 8u, false, float, __fp16>(__fp16*, float const*, template void MergeResults<8u, 6u, false, float, __fp16>(__fp16*, float const*, int, int, int, int, int, __fp16 const*, Activation, bool); #endif +#if defined(__arm__) && defined(ARM_COMPUTE_ENABLE_BF16) +template void MergeResults<8u, 6u, false, float, bfloat16>(bfloat16*, float const*, int, int, int, int, int, bfloat16 const*, Activation, bool); +#endif + } // namespace arm_gemm |