aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/mergeresults.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/mergeresults.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/mergeresults.cpp14
1 files changed, 11 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/mergeresults.cpp b/src/core/NEON/kernels/arm_gemm/mergeresults.cpp
index 17566db375..e100d9fe46 100644
--- a/src/core/NEON/kernels/arm_gemm/mergeresults.cpp
+++ b/src/core/NEON/kernels/arm_gemm/mergeresults.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018, 2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,21 +25,25 @@
/* As some of the merges need these headers, but are all included in the
* arm_gemm namespace, put these headers here. */
#include <algorithm>
-#include <limits>
#include <arm_neon.h>
#include "arm_gemm.hpp"
#include "asmlib.hpp"
+#include "bfloat.hpp"
#include "utils.hpp"
namespace arm_gemm {
template<unsigned int twidth, unsigned int height, bool sve=false, typename Tin, typename Tout>
void MergeResults(Tout * out, const Tin * in, int ldc, int y0, int ymax, int x0, int xmax, const Tout *bias, Activation act, bool append) {
+ // NOTE: The following code is disabled to avoid calling get_vector_length(), so templated MergeResults will not
+ // be correct for SVE cases. This is OK as we have specialisations for all needed SVE cases anyway.
+ //
// For SVE cases, multiply the width up by the vector length.
// Use the *input* type to determine this, since this will be what the kernel operated on.
- const int width = twidth * (sve ? get_vector_length<Tin>() : 1);
+ // const int width = twidth * (sve ? get_vector_length<Tin>() : 1);
+ const int width = twidth;
const int full_y_blocks = (ymax - y0) / height;
const int y_remainder = (ymax - y0) % height;
@@ -111,4 +115,8 @@ template void MergeResults<12u, 8u, false, float, __fp16>(__fp16*, float const*,
template void MergeResults<8u, 6u, false, float, __fp16>(__fp16*, float const*, int, int, int, int, int, __fp16 const*, Activation, bool);
#endif
+#if defined(__arm__) && defined(ARM_COMPUTE_ENABLE_BF16)
+template void MergeResults<8u, 6u, false, float, bfloat16>(bfloat16*, float const*, int, int, int, int, int, bfloat16 const*, Activation, bool);
+#endif
+
} // namespace arm_gemm