diff options
Diffstat (limited to 'src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp')
-rw-r--r-- | src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp | 49 |
1 files changed, 26 insertions, 23 deletions
diff --git a/src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp b/src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp index dd0384ca13..47de0f3928 100644 --- a/src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp +++ b/src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp @@ -23,6 +23,7 @@ */ #include "src/cpu/kernels/gemm_matrix_add/generic/neon/impl.h" + #include <arm_neon.h> namespace arm_compute @@ -44,33 +45,35 @@ void matrix_addition_f32(const ITensor *src, ITensor *dst, const Window &window, Iterator in(src, win); Iterator out(dst, win); - execute_window_loop(win, [&](const Coordinates &) - { - const auto in_ptr = reinterpret_cast<const float *>(in.ptr()); - const auto out_ptr = reinterpret_cast<float *>(out.ptr()); - - int x = window_start_x; - for(; x < (window_end_x - window_step_x); x += window_step_x) + execute_window_loop( + win, + [&](const Coordinates &) { - float32x4x4_t alpha_ab = vld4q_f32(out_ptr + x); - const float32x4x4_t c = vld4q_f32(in_ptr + x); + const auto in_ptr = reinterpret_cast<const float *>(in.ptr()); + const auto out_ptr = reinterpret_cast<float *>(out.ptr()); - // Multiply matrix C by its weight and accumulate - alpha_ab.val[0] = vmlaq_f32(alpha_ab.val[0], c.val[0], beta_f32); - alpha_ab.val[1] = vmlaq_f32(alpha_ab.val[1], c.val[1], beta_f32); - alpha_ab.val[2] = vmlaq_f32(alpha_ab.val[2], c.val[2], beta_f32); - alpha_ab.val[3] = vmlaq_f32(alpha_ab.val[3], c.val[3], beta_f32); + int x = window_start_x; + for (; x < (window_end_x - window_step_x); x += window_step_x) + { + float32x4x4_t alpha_ab = vld4q_f32(out_ptr + x); + const float32x4x4_t c = vld4q_f32(in_ptr + x); - vst4q_f32(out_ptr + x, alpha_ab); - } + // Multiply matrix C by its weight and accumulate + alpha_ab.val[0] = vmlaq_f32(alpha_ab.val[0], c.val[0], beta_f32); + alpha_ab.val[1] = vmlaq_f32(alpha_ab.val[1], c.val[1], beta_f32); + alpha_ab.val[2] = vmlaq_f32(alpha_ab.val[2], c.val[2], beta_f32); + alpha_ab.val[3] = vmlaq_f32(alpha_ab.val[3], c.val[3], beta_f32); - // Left-over loop - for(; x < window_end_x; ++x) - { - *(out_ptr + x) += *(in_ptr + x) * beta; - } - }, - in, out); + vst4q_f32(out_ptr + x, alpha_ab); + } + + // Left-over loop + for (; x < window_end_x; ++x) + { + *(out_ptr + x) += *(in_ptr + x) * beta; + } + }, + in, out); } } // namespace cpu } // namespace arm_compute |