aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp22
1 files changed, 15 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp
index e8ed0c311e..91870e2e54 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -32,10 +32,7 @@
namespace arm_gemm {
-void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, uint32_t *C, int ldc, int M, int N, int K, const uint32_t *bias, Activation act, bool append) {
- UNUSED(bias);
- UNUSED(act);
-
+void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, uint32_t *C, int ldc, int M, int N, int K, const uint32_t *, Activation , bool append) {
const int K_stride = ((K + 3) / 4) * 4;
const long loops_count = ((K + 16) / 32) - 1;
K -= loops_count * 32;
@@ -44,12 +41,23 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B,
const long blocks_count = K / 4;
const long odds_count = K - (blocks_count * 4);
- for (int y=0; y<M; y+=4) {
+ int rows_to_compute;
+
+ for (int y=0; y<M; y+=rows_to_compute) {
const uint8_t * const a_ptr0_base = A + (y * lda);
const unsigned long ldab = lda * sizeof(uint8_t);
uint32_t *c_ptr0 = C + (y * ldc);
+ rows_to_compute = M-y;
+ if (rows_to_compute > 4) {
+ if (rows_to_compute % 4) {
+ rows_to_compute = 4 - 1;
+ } else {
+ rows_to_compute = 4;
+ }
+ }
+
for (int x0=0; x0<N; x0+=16ul) {
const long width = std::min((unsigned long)N-x0, 16ul);
long loops = loops_count;
@@ -73,7 +81,7 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B,
c_ptr0 = result_buffer;
}
- switch(M-y) {
+ switch(rows_to_compute) {
case 1:
__asm __volatile (
"temploadreg0 .req X0\n"