From c727d5261f158c10f6c6dbd926b76c2b96e0c2c4 Mon Sep 17 00:00:00 2001 From: Pablo Marquez Tello Date: Wed, 27 Jan 2021 14:16:13 +0000 Subject: Add broadcasting support in NEPixelWiseMultiplicationKernel for FP16 * FP16 kernel missed the broadcast path * Resolves: COMPMID-4113 Change-Id: I8fd79030f2ae9c268dfeed672a57c6d0f64f58f4 Signed-off-by: Pablo Marquez Tello Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4926 Tested-by: Arm Jenkins Reviewed-by: Sheri Zhang Comments-Addressed: Arm Jenkins --- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 148 ++++++++++++++------- 1 file changed, 101 insertions(+), 47 deletions(-) diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index 39517f6ff6..6661326ea8 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -1157,68 +1157,122 @@ void c_mul_F32_F32_F32_n(const ITensor *in1, const ITensor *in2, ITensor *out, c void mul_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale) { // Create input windows - Window win = window; Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()); Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()); // Clear X Dimension on execution window as we handle manually + Window win = window; win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input1(in1, input1_win); - Iterator input2(in2, input2_win); - Iterator output(out, win); - - const int window_step_x = 16; - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); - - execute_window_loop(win, [&](const Coordinates &) + constexpr int window_step_x = 16; + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x(); + if(is_broadcast_across_x) { - const auto input1_ptr = reinterpret_cast(input1.ptr()); - const auto input2_ptr = reinterpret_cast(input2.ptr()); - const auto output_ptr = reinterpret_cast(output.ptr()); - - // Compute window_step_x elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + const bool is_broadcast_input_2 = input2_win.x().step() == 0; + Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win; + Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win; + const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1; + const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1; + // Clear X Dimension on execution window as we handle manually + non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + Iterator broadcast_input(broadcast_tensor, broadcast_win); + Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win); + Iterator output(out, win); + execute_window_loop(win, [&](const Coordinates &) { - const float16x8x2_t ta1 = + const auto non_broadcast_input_ptr = reinterpret_cast(non_broadcast_input.ptr()); + const auto output_ptr = reinterpret_cast(output.ptr()); + const auto broadcast_value = *reinterpret_cast(broadcast_input.ptr()); + const float16x8x2_t broadcast_value_vec = { { - vld1q_f16(input1_ptr + x), - vld1q_f16(input1_ptr + x + 8), + vdupq_n_f16(broadcast_value), + vdupq_n_f16(broadcast_value), } }; - const float16x8x2_t ta2 = + const auto scale_vec = vdupq_n_f16(scale); + // Compute window_step_x elements per iteration + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) { + const float16x8x2_t non_broadcast_v = { - vld1q_f16(input2_ptr + x), - vld1q_f16(input2_ptr + x + 8), - } - }; - const float16x8_t scale_vec = vdupq_n_f16(scale); - const float16x8x2_t result = - { + { + vld1q_f16(non_broadcast_input_ptr + x), + vld1q_f16(non_broadcast_input_ptr + x + 8), + } + }; + const float16x8x2_t result = { - vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec), - vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec), - } - }; - vst1q_f16(output_ptr + x, result.val[0]); - vst1q_f16(output_ptr + x + 8, result.val[1]); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) + { + vmulq_f16(vmulq_f16(broadcast_value_vec.val[0], non_broadcast_v.val[0]), scale_vec), + vmulq_f16(vmulq_f16(broadcast_value_vec.val[1], non_broadcast_v.val[1]), scale_vec), + } + }; + vst1q_f16(output_ptr + x, result.val[0]); + vst1q_f16(output_ptr + x + 8, result.val[1]); + } + // Compute left-over elements + for(; x < window_end_x; ++x) + { + const auto non_broadcast_v = *(non_broadcast_input_ptr + x); + *(output_ptr + x) = broadcast_value * non_broadcast_v * scale; + } + }, + broadcast_input, non_broadcast_input, output); + } + else + { + input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + Iterator input1(in1, input1_win); + Iterator input2(in2, input2_win); + Iterator output(out, win); + execute_window_loop(win, [&](const Coordinates &) { - const auto ta1 = *(input1_ptr + x); - const auto ta2 = *(input2_ptr + x); - *(output_ptr + x) = ta1 * ta2 * scale; - } - }, - input1, input2, output); + const auto input1_ptr = reinterpret_cast(input1.ptr()); + const auto input2_ptr = reinterpret_cast(input2.ptr()); + const auto output_ptr = reinterpret_cast(output.ptr()); + // Compute window_step_x elements per iteration + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float16x8x2_t ta1 = + { + { + vld1q_f16(input1_ptr + x), + vld1q_f16(input1_ptr + x + 8), + } + }; + const float16x8x2_t ta2 = + { + { + vld1q_f16(input2_ptr + x), + vld1q_f16(input2_ptr + x + 8), + } + }; + const float16x8_t scale_vec = vdupq_n_f16(scale); + const float16x8x2_t result = + { + { + vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec), + vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec), + } + }; + vst1q_f16(output_ptr + x, result.val[0]); + vst1q_f16(output_ptr + x + 8, result.val[1]); + } + // Compute left-over elements + for(; x < window_end_x; ++x) + { + const auto ta1 = *(input1_ptr + x); + const auto ta2 = *(input2_ptr + x); + *(output_ptr + x) = ta1 * ta2 * scale; + } + }, + input1, input2, output); + } } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -- cgit v1.2.1