diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/mergeresults.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/mergeresults.hpp | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/mergeresults.hpp b/src/core/NEON/kernels/arm_gemm/mergeresults.hpp index b1e2ca1daa..04d1343b1c 100644 --- a/src/core/NEON/kernels/arm_gemm/mergeresults.hpp +++ b/src/core/NEON/kernels/arm_gemm/mergeresults.hpp @@ -32,15 +32,19 @@ namespace arm_gemm { -template<unsigned int width, unsigned int height, typename Tin, typename Tout> +template<unsigned int twidth, unsigned int height, bool sve=false, typename Tin, typename Tout> inline void MergeResults(Tout * out, const Tin * in, int ldc, int y0, int ymax, int x0, int xmax, const Tout alpha, const Tout beta) { - int full_y_blocks = (ymax - y0) / height; - int y_remainder = (ymax - y0) % height; - int y_blocks = full_y_blocks + (y_remainder ? 1 : 0); + // For SVE cases, multiply the width up by the vector length. + // Use the *input* type to determine this, since this will be what the kernel operated on. + const int width = twidth * (sve ? get_vector_length<Tin>() : 1); - int full_x_blocks = (xmax - x0) / width; - int x_remainder = (xmax - x0) % width; - int x_blocks = full_x_blocks + (x_remainder ? 1 : 0); + const int full_y_blocks = (ymax - y0) / height; + const int y_remainder = (ymax - y0) % height; + const int y_blocks = full_y_blocks + (y_remainder ? 1 : 0); + + const int full_x_blocks = (xmax - x0) / width; + const int x_remainder = (xmax - x0) % width; + const int x_blocks = full_x_blocks + (x_remainder ? 1 : 0); for (int y_block = 0; y_block < y_blocks; y_block++) { int ybase = y0 + (y_block * height); |