diff options
author | Sheri Zhang <sheri.zhang@arm.com> | 2020-11-09 15:12:32 +0000 |
---|---|---|
committer | Sheri Zhang <sheri.zhang@arm.com> | 2020-11-09 20:41:48 +0000 |
commit | 2cb05d9ee91880179ad2537cbf66229c7c2a2356 (patch) | |
tree | f0fb43f4f506a0f7a56730eeecb7c98b4c47e41d | |
parent | 3327def321af49040ae1fbd6026234a0980e3289 (diff) | |
download | ComputeLibrary-2cb05d9ee91880179ad2537cbf66229c7c2a2356.tar.gz |
COMPMID-3852: Fix complex multiplication remove padding performance regression
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com>
Change-Id: I2605baba63c9cca0370328860313b8ec09e04fb6
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4355
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r-- | src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index 8d17651f37..f646ea5db7 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -1035,6 +1035,8 @@ void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, c const auto window_end_x = static_cast<int>(window.x().end()); const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0); + using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type; + if(is_broadcast_across_x) { const bool is_broadcast_input_2 = input2_win.x().step() == 0; @@ -1057,16 +1059,38 @@ void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, c const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr()); + // Compute window_step_x elements per iteration int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto a = wrapper::vloadq(non_broadcast_input_ptr + 2 * x); + float32x4_t b = vdupq_n_f32(broadcast_value); + + const float32x4_t mask = { -1.0f, 1.0f, -1.0f, 1.0f }; + const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{}); + const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{}); + const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{}); + const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{}); + + const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10); + const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11); + + float32x4_t res = wrapper::vmul(tmp0, b); + b = wrapper::vmul(b, mask); + + res = wrapper::vmla(res, tmp1, b); + wrapper::vstore(output_ptr + 2 * x, res); + } + // Compute left-over elements for(; x < window_end_x; ++x) { - const auto broadcast_value0 = *(non_broadcast_input_ptr + 2 * x); - const auto broadcast_value1 = *(non_broadcast_input_ptr + 2 * x + 1); - auto res1 = broadcast_value * (broadcast_value0 - broadcast_value1); - auto res2 = broadcast_value * (broadcast_value1 + broadcast_value0); - *(output_ptr + 2 * x) = res1; - *(output_ptr + 2 * x + 1) = res2; + const auto non_broadcast_value0 = *(non_broadcast_input_ptr + 2 * x); + const auto non_broadcast_value1 = *(non_broadcast_input_ptr + 2 * x + 1); + auto res1 = broadcast_value * (non_broadcast_value0 - non_broadcast_value1); + auto res2 = broadcast_value * (non_broadcast_value1 + non_broadcast_value0); + *(output_ptr + 2 * x) = res1; + *(output_ptr + 2 * x + 1) = res2; } }, broadcast_input, non_broadcast_input, output); @@ -1087,8 +1111,6 @@ void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, c const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr()); const auto output_ptr = reinterpret_cast<float *>(output.ptr()); - using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type; - // Compute window_step_x elements per iteration int x = window_start_x; for(; x <= (window_end_x - window_step_x); x += window_step_x) |