diff options
Diffstat (limited to 'src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp')
-rw-r--r-- | src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp | 44 |
1 files changed, 23 insertions, 21 deletions
diff --git a/src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp b/src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp index 505a37174e..4d7507a5da 100644 --- a/src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp +++ b/src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp @@ -48,30 +48,32 @@ void matrix_addition_f16(const ITensor *src, ITensor *dst, const Window &window, Iterator in(src, win); Iterator out(dst, win); - execute_window_loop(win, [&](const Coordinates &) - { - const auto in_ptr = reinterpret_cast<const float16_t *>(in.ptr()); - const auto out_ptr = reinterpret_cast<float16_t *>(out.ptr()); - - int x = window_start_x; - for(; x < (window_end_x - window_step_x); x += window_step_x) + execute_window_loop( + win, + [&](const Coordinates &) { - float16x8x2_t alpha_ab = vld2q_f16(out_ptr + x); - const float16x8x2_t c = vld2q_f16(in_ptr + x); - // Multiply matrix C by its weight and accumulate - alpha_ab.val[0] = vaddq_f16(alpha_ab.val[0], vmulq_f16(c.val[0], beta_f16)); - alpha_ab.val[1] = vaddq_f16(alpha_ab.val[1], vmulq_f16(c.val[1], beta_f16)); + const auto in_ptr = reinterpret_cast<const float16_t *>(in.ptr()); + const auto out_ptr = reinterpret_cast<float16_t *>(out.ptr()); - vst2q_f16(out_ptr + x, alpha_ab); - } + int x = window_start_x; + for (; x < (window_end_x - window_step_x); x += window_step_x) + { + float16x8x2_t alpha_ab = vld2q_f16(out_ptr + x); + const float16x8x2_t c = vld2q_f16(in_ptr + x); + // Multiply matrix C by its weight and accumulate + alpha_ab.val[0] = vaddq_f16(alpha_ab.val[0], vmulq_f16(c.val[0], beta_f16)); + alpha_ab.val[1] = vaddq_f16(alpha_ab.val[1], vmulq_f16(c.val[1], beta_f16)); - // Left-over loop - for(; x < window_end_x; ++x) - { - *(out_ptr + x) += *(in_ptr + x) * static_cast<float16_t>(beta); - } - }, - in, out); + vst2q_f16(out_ptr + x, alpha_ab); + } + + // Left-over loop + for (; x < window_end_x; ++x) + { + *(out_ptr + x) += *(in_ptr + x) * static_cast<float16_t>(beta); + } + }, + in, out); } } // namespace void neon_fp16_gemm_matrix_add(const ITensor *src, ITensor *dst, const Window &window, float beta) |