aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSheri Zhang <sheri.zhang@arm.com>2020-11-09 15:12:32 +0000
committerSheri Zhang <sheri.zhang@arm.com>2020-11-09 20:41:48 +0000
commit2cb05d9ee91880179ad2537cbf66229c7c2a2356 (patch)
treef0fb43f4f506a0f7a56730eeecb7c98b4c47e41d
parent3327def321af49040ae1fbd6026234a0980e3289 (diff)
downloadComputeLibrary-2cb05d9ee91880179ad2537cbf66229c7c2a2356.tar.gz
COMPMID-3852: Fix complex multiplication remove padding performance regression
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com> Change-Id: I2605baba63c9cca0370328860313b8ec09e04fb6 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4355 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp38
1 files changed, 30 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index 8d17651f37..f646ea5db7 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -1035,6 +1035,8 @@ void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, c
const auto window_end_x = static_cast<int>(window.x().end());
const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0);
+ using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
+
if(is_broadcast_across_x)
{
const bool is_broadcast_input_2 = input2_win.x().step() == 0;
@@ -1057,16 +1059,38 @@ void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, c
const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
+ // Compute window_step_x elements per iteration
int x = window_start_x;
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const auto a = wrapper::vloadq(non_broadcast_input_ptr + 2 * x);
+ float32x4_t b = vdupq_n_f32(broadcast_value);
+
+ const float32x4_t mask = { -1.0f, 1.0f, -1.0f, 1.0f };
+ const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
+ const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
+ const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
+ const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
+
+ const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
+ const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
+
+ float32x4_t res = wrapper::vmul(tmp0, b);
+ b = wrapper::vmul(b, mask);
+
+ res = wrapper::vmla(res, tmp1, b);
+ wrapper::vstore(output_ptr + 2 * x, res);
+ }
+
// Compute left-over elements
for(; x < window_end_x; ++x)
{
- const auto broadcast_value0 = *(non_broadcast_input_ptr + 2 * x);
- const auto broadcast_value1 = *(non_broadcast_input_ptr + 2 * x + 1);
- auto res1 = broadcast_value * (broadcast_value0 - broadcast_value1);
- auto res2 = broadcast_value * (broadcast_value1 + broadcast_value0);
- *(output_ptr + 2 * x) = res1;
- *(output_ptr + 2 * x + 1) = res2;
+ const auto non_broadcast_value0 = *(non_broadcast_input_ptr + 2 * x);
+ const auto non_broadcast_value1 = *(non_broadcast_input_ptr + 2 * x + 1);
+ auto res1 = broadcast_value * (non_broadcast_value0 - non_broadcast_value1);
+ auto res2 = broadcast_value * (non_broadcast_value1 + non_broadcast_value0);
+ *(output_ptr + 2 * x) = res1;
+ *(output_ptr + 2 * x + 1) = res2;
}
},
broadcast_input, non_broadcast_input, output);
@@ -1087,8 +1111,6 @@ void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, c
const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
const auto output_ptr = reinterpret_cast<float *>(output.ptr());
- using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
-
// Compute window_step_x elements per iteration
int x = window_start_x;
for(; x <= (window_end_x - window_step_x); x += window_step_x)