aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp47
1 files changed, 6 insertions, 41 deletions
diff --git a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
index 71dd4c7aa1..7d659ab2e6 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
@@ -52,25 +52,8 @@ void matrix_addition_f32(const ITensor *input, ITensor *output, const Window &wi
const auto in_ptr = reinterpret_cast<const float *>(in.ptr());
const auto out_ptr = reinterpret_cast<float *>(out.ptr());
- float32x4x4_t alpha_ab =
- {
- {
- vld1q_f32(out_ptr + 0),
- vld1q_f32(out_ptr + 4),
- vld1q_f32(out_ptr + 8),
- vld1q_f32(out_ptr + 12)
- }
- };
-
- const float32x4x4_t c =
- {
- {
- vld1q_f32(in_ptr + 0),
- vld1q_f32(in_ptr + 4),
- vld1q_f32(in_ptr + 8),
- vld1q_f32(in_ptr + 12)
- }
- };
+ float32x4x4_t alpha_ab = vld4q_f32(out_ptr);
+ const float32x4x4_t c = vld4q_f32(in_ptr);
// Multiply matrix C by its weight and accumulate
alpha_ab.val[0] = vmlaq_f32(alpha_ab.val[0], c.val[0], beta_f32);
@@ -78,10 +61,7 @@ void matrix_addition_f32(const ITensor *input, ITensor *output, const Window &wi
alpha_ab.val[2] = vmlaq_f32(alpha_ab.val[2], c.val[2], beta_f32);
alpha_ab.val[3] = vmlaq_f32(alpha_ab.val[3], c.val[3], beta_f32);
- vst1q_f32(out_ptr + 0, alpha_ab.val[0]);
- vst1q_f32(out_ptr + 4, alpha_ab.val[1]);
- vst1q_f32(out_ptr + 8, alpha_ab.val[2]);
- vst1q_f32(out_ptr + 12, alpha_ab.val[3]);
+ vst4q_f32(out_ptr, alpha_ab);
},
in, out);
}
@@ -99,28 +79,13 @@ void matrix_addition_f16(const ITensor *input, ITensor *output, const Window &wi
const auto in_ptr = reinterpret_cast<const float16_t *>(in.ptr());
const auto out_ptr = reinterpret_cast<float16_t *>(out.ptr());
- float16x8x2_t alpha_ab =
- {
- {
- vld1q_f16(out_ptr + 0),
- vld1q_f16(out_ptr + 8)
- }
- };
-
- float16x8x2_t c =
- {
- {
- vld1q_f16(in_ptr + 0),
- vld1q_f16(in_ptr + 8)
- }
- };
-
+ float16x8x2_t alpha_ab = vld2q_f16(out_ptr);
+ const float16x8x2_t c = vld2q_f16(in_ptr);
// Multiply matrix C by its weight and accumulate
alpha_ab.val[0] = vaddq_f16(alpha_ab.val[0], vmulq_f16(c.val[0], beta_f16));
alpha_ab.val[1] = vaddq_f16(alpha_ab.val[1], vmulq_f16(c.val[1], beta_f16));
- vst1q_f16(out_ptr + 0, alpha_ab.val[0]);
- vst1q_f16(out_ptr + 8, alpha_ab.val[1]);
+ vst2q_f16(out_ptr + 0, alpha_ab);
},
in, out);
}