diff options
Diffstat (limited to 'src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp | 47 |
1 files changed, 6 insertions, 41 deletions
diff --git a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp index 71dd4c7aa1..7d659ab2e6 100644 --- a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp +++ b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp @@ -52,25 +52,8 @@ void matrix_addition_f32(const ITensor *input, ITensor *output, const Window &wi const auto in_ptr = reinterpret_cast<const float *>(in.ptr()); const auto out_ptr = reinterpret_cast<float *>(out.ptr()); - float32x4x4_t alpha_ab = - { - { - vld1q_f32(out_ptr + 0), - vld1q_f32(out_ptr + 4), - vld1q_f32(out_ptr + 8), - vld1q_f32(out_ptr + 12) - } - }; - - const float32x4x4_t c = - { - { - vld1q_f32(in_ptr + 0), - vld1q_f32(in_ptr + 4), - vld1q_f32(in_ptr + 8), - vld1q_f32(in_ptr + 12) - } - }; + float32x4x4_t alpha_ab = vld4q_f32(out_ptr); + const float32x4x4_t c = vld4q_f32(in_ptr); // Multiply matrix C by its weight and accumulate alpha_ab.val[0] = vmlaq_f32(alpha_ab.val[0], c.val[0], beta_f32); @@ -78,10 +61,7 @@ void matrix_addition_f32(const ITensor *input, ITensor *output, const Window &wi alpha_ab.val[2] = vmlaq_f32(alpha_ab.val[2], c.val[2], beta_f32); alpha_ab.val[3] = vmlaq_f32(alpha_ab.val[3], c.val[3], beta_f32); - vst1q_f32(out_ptr + 0, alpha_ab.val[0]); - vst1q_f32(out_ptr + 4, alpha_ab.val[1]); - vst1q_f32(out_ptr + 8, alpha_ab.val[2]); - vst1q_f32(out_ptr + 12, alpha_ab.val[3]); + vst4q_f32(out_ptr, alpha_ab); }, in, out); } @@ -99,28 +79,13 @@ void matrix_addition_f16(const ITensor *input, ITensor *output, const Window &wi const auto in_ptr = reinterpret_cast<const float16_t *>(in.ptr()); const auto out_ptr = reinterpret_cast<float16_t *>(out.ptr()); - float16x8x2_t alpha_ab = - { - { - vld1q_f16(out_ptr + 0), - vld1q_f16(out_ptr + 8) - } - }; - - float16x8x2_t c = - { - { - vld1q_f16(in_ptr + 0), - vld1q_f16(in_ptr + 8) - } - }; - + float16x8x2_t alpha_ab = vld2q_f16(out_ptr); + const float16x8x2_t c = vld2q_f16(in_ptr); // Multiply matrix C by its weight and accumulate alpha_ab.val[0] = vaddq_f16(alpha_ab.val[0], vmulq_f16(c.val[0], beta_f16)); alpha_ab.val[1] = vaddq_f16(alpha_ab.val[1], vmulq_f16(c.val[1], beta_f16)); - vst1q_f16(out_ptr + 0, alpha_ab.val[0]); - vst1q_f16(out_ptr + 8, alpha_ab.val[1]); + vst2q_f16(out_ptr + 0, alpha_ab); }, in, out); } |