diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp index 91870e2e54..735e5fd45a 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_u8u32_dot_16x4/a55.cpp @@ -32,7 +32,7 @@ namespace arm_gemm { -void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, uint32_t *C, int ldc, int M, int N, int K, const uint32_t *, Activation , bool append) { +void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, uint32_t *C, int ldc, int M, int N, int K, const uint32_t *, Activation , bool accumulate) { const int K_stride = ((K + 3) / 4) * 4; const long loops_count = ((K + 16) / 32) - 1; K -= loops_count * 32; @@ -70,7 +70,7 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, uint32_t result_buffer[64]; const unsigned long ldcb = (use_result_buffer ? 16 : ldc) * sizeof(uint32_t); uint32_t *c_ptr_real = c_ptr0; - if (use_result_buffer && append) { + if (use_result_buffer && accumulate) { for(int cy=0; cy<std::min(M-y, 4); cy++) { for(unsigned int cx=0; cx<width; cx++) { result_buffer[cy * 16 + cx] = c_ptr_real[cy * ldc + cx]; @@ -88,7 +88,7 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, "temploadreg1 .req X1\n" "temploadreg2 .req X2\n" "temploadreg3 .req X3\n" - "cbnz %[append], 1f\n" + "cbnz %[accumulate], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" @@ -469,7 +469,7 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, ".unreq temploadreg2\n" ".unreq temploadreg3\n" : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks), [odds] "+r" (odds) - : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb) + : [width] "r" (width), [accumulate] "r" (static_cast<uint64_t>(accumulate)), [lda] "r" (ldab), [ldc] "r" (ldcb) : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "cc", "memory" ); break; @@ -483,7 +483,7 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, "temploadreg3 .req X5\n" "add a_ptr1, %[a_ptr0], %[lda]\n" "add c_ptr1, %[c_ptr0], %[ldc]\n" - "cbnz %[append], 1f\n" + "cbnz %[accumulate], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" @@ -988,7 +988,7 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, ".unreq temploadreg2\n" ".unreq temploadreg3\n" : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks), [odds] "+r" (odds) - : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb) + : [width] "r" (width), [accumulate] "r" (static_cast<uint64_t>(accumulate)), [lda] "r" (ldab), [ldc] "r" (ldcb) : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory" ); break; @@ -1006,7 +1006,7 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, "add c_ptr1, %[c_ptr0], %[ldc]\n" "add a_ptr2, a_ptr1, %[lda]\n" "add c_ptr2, c_ptr1, %[ldc]\n" - "cbnz %[append], 1f\n" + "cbnz %[accumulate], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" @@ -1636,7 +1636,7 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, ".unreq temploadreg2\n" ".unreq temploadreg3\n" : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks), [odds] "+r" (odds) - : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb) + : [width] "r" (width), [accumulate] "r" (static_cast<uint64_t>(accumulate)), [lda] "r" (ldab), [ldc] "r" (ldcb) : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory" ); break; @@ -1659,7 +1659,7 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, "add c_ptr2, c_ptr1, %[ldc]\n" "add a_ptr3, a_ptr2, %[lda]\n" "add c_ptr3, c_ptr2, %[ldc]\n" - "cbnz %[append], 1f\n" + "cbnz %[accumulate], 1f\n" "movi v16.4s, #0\n" "ldr q0, [%[a_ptr0]]\n" "movi v17.4s, #0\n" @@ -2413,7 +2413,7 @@ void a64_hybrid_u8u32_dot_16x4_a55(const uint8_t *A, int lda, const uint8_t *B, ".unreq temploadreg2\n" ".unreq temploadreg3\n" : [a_ptr0] "+r" (a_ptr0), [b_ptr0] "+r" (b_ptr0), [c_ptr0] "+r" (c_ptr0), [loops] "+r" (loops), [regs] "+r" (regs), [blocks] "+r" (blocks), [odds] "+r" (odds) - : [width] "r" (width), [append] "r" (static_cast<uint64_t>(append)), [lda] "r" (ldab), [ldc] "r" (ldcb) + : [width] "r" (width), [accumulate] "r" (static_cast<uint64_t>(accumulate)), [lda] "r" (ldab), [ldc] "r" (ldcb) : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "cc", "memory" ); break; |