aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp38
1 files changed, 30 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index 8d17651f37..f646ea5db7 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -1035,6 +1035,8 @@ void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, c
const auto window_end_x = static_cast<int>(window.x().end());
const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
+ using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
+
if(is_broadcast_across_x)
{
const bool is_broadcast_input_2 = input2_win.x().step() == 0;
@@ -1057,16 +1059,38 @@ void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, c
const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
+ // Compute window_step_x elements per iteration
int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const auto a = wrapper::vloadq(non_broadcast_input_ptr + 2 * x);
+ float32x4_t b = vdupq_n_f32(broadcast_value);
+
+ const float32x4_t mask = { -1.0f, 1.0f, -1.0f, 1.0f };
+ const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
+ const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
+ const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
+ const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
+
+ const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
+ const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
+
+ float32x4_t res = wrapper::vmul(tmp0, b);
+ b = wrapper::vmul(b, mask);
+
+ res = wrapper::vmla(res, tmp1, b);
+ wrapper::vstore(output_ptr + 2 * x, res);
+ }
+
// Compute left-over elements
for(; x < window_end_x; ++x)
{
- const auto broadcast_value0 = *(non_broadcast_input_ptr + 2 * x);
- const auto broadcast_value1 = *(non_broadcast_input_ptr + 2 * x + 1);
- auto res1 = broadcast_value * (broadcast_value0 - broadcast_value1);
- auto res2 = broadcast_value * (broadcast_value1 + broadcast_value0);
- *(output_ptr + 2 * x) = res1;
- *(output_ptr + 2 * x + 1) = res2;
+ const auto non_broadcast_value0 = *(non_broadcast_input_ptr + 2 * x);
+ const auto non_broadcast_value1 = *(non_broadcast_input_ptr + 2 * x + 1);
+ auto res1 = broadcast_value * (non_broadcast_value0 - non_broadcast_value1);
+ auto res2 = broadcast_value * (non_broadcast_value1 + non_broadcast_value0);
+ *(output_ptr + 2 * x) = res1;
+ *(output_ptr + 2 * x + 1) = res2;
}
},
broadcast_input, non_broadcast_input, output);
@@ -1087,8 +1111,6 @@ void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, c
const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
const auto output_ptr = reinterpret_cast<float *>(output.ptr());
- using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
-
// Compute window_step_x elements per iteration
int x = window_start_x;
for(; x <= (window_end_x - window_step_x); x += window_step_x)