From d6d1b3682a2cdd54bae5498635b108a4b19a045a Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Thu, 24 Sep 2020 17:34:23 +0100 Subject: COMPMID-3784 Add broadcast support to S32 NEPixelwiseMultiplication Signed-off-by: SiCong Li Change-Id: Ifae31c74eb46c561225394a387fc15332423bfa9 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4030 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio --- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 168 +++++++++++++++------ 1 file changed, 122 insertions(+), 46 deletions(-) (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp') diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index 302ee7694f..84683ea69f 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -773,75 +773,151 @@ template void mul_S32_S32_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n) { // Create input windows - Window win = window; Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()); Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()); // Clear X Dimension on execution window as we handle manually + Window win = window; win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input1(in1, input1_win); - Iterator input2(in2, input2_win); - Iterator output(out, win); - const int window_step_x = 8; - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); + const int window_step_x = 8; + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const bool is_broadcast_across_x = (input1_win.x().step() == 0) || (input2_win.x().step() == 0); - execute_window_loop(win, [&](const Coordinates &) + if(is_broadcast_across_x) { - const auto input1_ptr = reinterpret_cast(input1.ptr()); - const auto input2_ptr = reinterpret_cast(input2.ptr()); - const auto output_ptr = reinterpret_cast(output.ptr()); + const bool is_broadcast_input_2 = input2_win.x().step() == 0; + Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win; + Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win; + const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1; + const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1; - // Compute window_step_x elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + // Clear X Dimension on execution window as we handle manually + non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator broadcast_input(broadcast_tensor, broadcast_win); + Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win); + Iterator output(out, win); + + execute_window_loop(win, [&](const Coordinates &) { - const int32x4x2_t ta1 = + const auto non_broadcast_input_ptr = reinterpret_cast(non_broadcast_input.ptr()); + const auto output_ptr = reinterpret_cast(output.ptr()); + + const int32_t broadcast_value = *reinterpret_cast(broadcast_input.ptr()); + const auto broadcast_value_vec = vdupq_n_s32(broadcast_value); + + // Compute window_step_x elements per iteration + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) { + const int32x4x2_t broadcast_v = { - vld1q_s32(input1_ptr + x), - vld1q_s32(input1_ptr + x + 4), - } - }; - const int32x4x2_t ta2 = + { + broadcast_value_vec, + broadcast_value_vec, + } + }; + const int32x4x2_t non_broadcast_v = + { + { + vld1q_s32(non_broadcast_input_ptr + x), + vld1q_s32(non_broadcast_input_ptr + x + 4), + } + }; + const int32x4x2_t result = mul_S32_S32_S32_n_k(broadcast_v, non_broadcast_v, n); + + vst1q_s32(output_ptr + x, result.val[0]); + vst1q_s32(output_ptr + x + 4, result.val[1]); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) { + int64_t tmp = static_cast(broadcast_value) * static_cast(*(non_broadcast_input_ptr + x)); + + if(tmp >= 0) { - vld1q_s32(input2_ptr + x), - vld1q_s32(input2_ptr + x + 4), + tmp >>= n; } - }; - const int32x4x2_t result = mul_S32_S32_S32_n_k(ta1, ta2, n); + else + { + uint64_t mask = (1u << n) - 1; + tmp = (tmp + static_cast(mask)) >> n; + } + if(is_sat) + { + tmp = utility::clamp(tmp); + } + *(output_ptr + x) = static_cast(tmp); + } + }, + broadcast_input, non_broadcast_input, output); + } + else + { + // Clear X Dimension on execution window as we handle manually + input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - vst1q_s32(output_ptr + x, result.val[0]); - vst1q_s32(output_ptr + x + 4, result.val[1]); - } + Iterator input1(in1, input1_win); + Iterator input2(in2, input2_win); + Iterator output(out, win); - // Compute left-over elements - for(; x < window_end_x; ++x) + execute_window_loop(win, [&](const Coordinates &) { - int64_t tmp = static_cast(*(input1_ptr + x)) * static_cast(*(input2_ptr + x)); + const auto input1_ptr = reinterpret_cast(input1.ptr()); + const auto input2_ptr = reinterpret_cast(input2.ptr()); + const auto output_ptr = reinterpret_cast(output.ptr()); - if(tmp >= 0) - { - tmp >>= n; - } - else + // Compute window_step_x elements per iteration + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) { - uint64_t mask = (1u << n) - 1; - tmp = (tmp + static_cast(mask)) >> n; + const int32x4x2_t ta1 = + { + { + vld1q_s32(input1_ptr + x), + vld1q_s32(input1_ptr + x + 4), + } + }; + const int32x4x2_t ta2 = + { + { + vld1q_s32(input2_ptr + x), + vld1q_s32(input2_ptr + x + 4), + } + }; + const int32x4x2_t result = mul_S32_S32_S32_n_k(ta1, ta2, n); + + vst1q_s32(output_ptr + x, result.val[0]); + vst1q_s32(output_ptr + x + 4, result.val[1]); } - if(is_sat) + + // Compute left-over elements + for(; x < window_end_x; ++x) { - tmp = (tmp > INT_MAX) ? INT_MAX : ((tmp < INT_MIN) ? INT_MIN : tmp); + int64_t tmp = static_cast(*(input1_ptr + x)) * static_cast(*(input2_ptr + x)); + + if(tmp >= 0) + { + tmp >>= n; + } + else + { + uint64_t mask = (1u << n) - 1; + tmp = (tmp + static_cast(mask)) >> n; + } + if(is_sat) + { + tmp = utility::clamp(tmp); + } + *(output_ptr + x) = static_cast(tmp); } - *(output_ptr + x) = static_cast(tmp); - } - }, - input1, input2, output); + }, + input1, input2, output); + } } void mul_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale) -- cgit v1.2.1