aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_u8u32_dot_4VLx4/generic.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_u8u32_dot_4VLx4/generic.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_u8u32_dot_4VLx4/generic.cpp17
1 files changed, 14 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_u8u32_dot_4VLx4/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_u8u32_dot_4VLx4/generic.cpp
index 4fb7e825b5..13614700e3 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_u8u32_dot_4VLx4/generic.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_u8u32_dot_4VLx4/generic.cpp
@@ -32,7 +32,7 @@
namespace arm_gemm {
-void sve_hybrid_u8u32_dot_4VLx4(const uint8_t *A, int lda, const uint8_t *B, uint32_t *C, int ldc, int M, int N, int K, const uint32_t *, Activation, bool append) {
+void sve_hybrid_u8u32_dot_4VLx4(const uint8_t *A, int lda, const uint8_t *B, uint32_t *C, int ldc, int M, int N, int K, const uint32_t *, Activation , bool append) {
const int K_stride = ((K + 3) / 4) * 4;
const long loops_count = ((K + 16) / 32) - 1;
K -= loops_count * 32;
@@ -41,12 +41,23 @@ void sve_hybrid_u8u32_dot_4VLx4(const uint8_t *A, int lda, const uint8_t *B, uin
const long leftovers = K;
const long blocks_count = (K + 3) / 4;
- for (int y=0; y<M; y+=4) {
+ int rows_to_compute;
+
+ for (int y=0; y<M; y+=rows_to_compute) {
const uint8_t * const a_ptr0_base = A + (y * lda);
const unsigned long ldab = lda * sizeof(uint8_t);
uint32_t *c_ptr0 = C + (y * ldc);
+ rows_to_compute = M-y;
+ if (rows_to_compute > 4) {
+ if (rows_to_compute % 4) {
+ rows_to_compute = 4 - 1;
+ } else {
+ rows_to_compute = 4;
+ }
+ }
+
for (int x0=0; x0<N; x0+=(4 * get_vector_length<uint32_t>())) {
const long width = std::min((unsigned long)N-x0, (4 * get_vector_length<uint32_t>()));
long loops = loops_count;
@@ -57,7 +68,7 @@ void sve_hybrid_u8u32_dot_4VLx4(const uint8_t *A, int lda, const uint8_t *B, uin
const uint8_t *b_ptr0 = B + (K_stride * x0);
const unsigned long ldcb = ldc * sizeof(uint32_t);
- switch(M-y) {
+ switch(rows_to_compute) {
case 1:
__asm __volatile (
"whilelt p6.b, %[temp], %[leftovers]\n"