aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/mergeresults.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/mergeresults.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/mergeresults.hpp18
1 files changed, 11 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/mergeresults.hpp b/src/core/NEON/kernels/arm_gemm/mergeresults.hpp
index b1e2ca1daa..04d1343b1c 100644
--- a/src/core/NEON/kernels/arm_gemm/mergeresults.hpp
+++ b/src/core/NEON/kernels/arm_gemm/mergeresults.hpp
@@ -32,15 +32,19 @@
namespace arm_gemm {
-template<unsigned int width, unsigned int height, typename Tin, typename Tout>
+template<unsigned int twidth, unsigned int height, bool sve=false, typename Tin, typename Tout>
inline void MergeResults(Tout * out, const Tin * in, int ldc, int y0, int ymax, int x0, int xmax, const Tout alpha, const Tout beta) {
- int full_y_blocks = (ymax - y0) / height;
- int y_remainder = (ymax - y0) % height;
- int y_blocks = full_y_blocks + (y_remainder ? 1 : 0);
+ // For SVE cases, multiply the width up by the vector length.
+ // Use the *input* type to determine this, since this will be what the kernel operated on.
+ const int width = twidth * (sve ? get_vector_length<Tin>() : 1);
- int full_x_blocks = (xmax - x0) / width;
- int x_remainder = (xmax - x0) % width;
- int x_blocks = full_x_blocks + (x_remainder ? 1 : 0);
+ const int full_y_blocks = (ymax - y0) / height;
+ const int y_remainder = (ymax - y0) % height;
+ const int y_blocks = full_y_blocks + (y_remainder ? 1 : 0);
+
+ const int full_x_blocks = (xmax - x0) / width;
+ const int x_remainder = (xmax - x0) % width;
+ const int x_blocks = full_x_blocks + (x_remainder ? 1 : 0);
for (int y_block = 0; y_block < y_blocks; y_block++) {
int ybase = y0 + (y_block * height);