From bb88f89b7a12e83eea2fc701f1f82aabf7dfcf7a Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Fri, 28 Aug 2020 11:18:47 +0100 Subject: COMPMID-3581 Add S32 support to NEPixelWiseMultiplication * Add S32 support to NEPixelWiseMultiplication and NEPixelWiseMultiplicationKernel * Scale == 1/255 is not supported for S32, as on non-aarch64 the precision requirement is not met, and scale is a non-standard parameter anyway. * Fix the data types validation logics to also test for all invalid data type combinations. * Add validation tests for S32 NEON PixelWiseMultiplication * The wrap tolerance for ScaleOther (scale == 1/2^n) cases is set to 1 instead of 0 because the reference uses floating point division followed by rounding, which is isn't bit accurate. Change-Id: I28839afda7a4f98c985d1763620e08d98f740142 Signed-off-by: SiCong Li Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3923 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 168 +++++++++++++++++++-- 1 file changed, 154 insertions(+), 14 deletions(-) (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp') diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index 907a7f197b..302ee7694f 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -49,8 +49,10 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i ARM_COMPUTE_UNUSED(rounding_policy); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16, + DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16, + DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::S32, DataType::F16, DataType::F32); @@ -65,23 +67,24 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QASYMM8 && (input1->data_type() != DataType::QASYMM8 || input2->data_type() != DataType::QASYMM8), - "Output can only be QASYMM8 if both inputs are QASYMM8"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QASYMM8_SIGNED && (input1->data_type() != DataType::QASYMM8_SIGNED || input2->data_type() != DataType::QASYMM8_SIGNED), - "Output can only be QASYMM8_SIGNED if both inputs are QASYMM8_SIGNED"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QSYMM16 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16), - "Output can only be QSYMM16 if both inputs are QSYMM16"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16), - "Output can only be S32 if both inputs are QSYMM16"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output"); + // clang-format off + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + !(input1->data_type() == input2->data_type() && input2->data_type() == output->data_type()) && + !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) && + !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) && + !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) && + !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) && + !(input1->data_type() == DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16 && output->data_type() == DataType::S32) + , "Invalid data type combination"); + // clang-format on + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::S16 && output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output"); } if(std::abs(scale - scale255_constant) < 0.00001f) { ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::S32 && input2->data_type() == DataType::S32 && output->data_type() == DataType::S32, + "Scale == 1/255 is not supported if input and output are of data type S32"); } else { @@ -710,6 +713,137 @@ void mul_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const input1, input2, output); } +template +inline int32x4_t mul_S32_S32_S32_n_loop(const int32x4_t &input1, const int32x4_t &input2, int n) +{ + const int32x2_t input1_1 = vget_low_s32(input1); + const int32x2_t input2_1 = vget_low_s32(input2); + const int32x2_t input1_2 = vget_high_s32(input1); + const int32x2_t input2_2 = vget_high_s32(input2); + + int64x2_t tmp_1 = vmull_s32(input1_1, input2_1); + int64x2_t tmp_2 = vmull_s32(input1_2, input2_2); + + // Apply scaling, conversion and rounding (round to zero) + // Right shift amount + const int64x2_t vn = vdupq_n_s64(-n); + // Left shift amount + const int64x2_t vnl = vdupq_n_s64(n); + // Calculate conversion bit + const uint64x2_t tmp_1_u = vreinterpretq_u64_s64(tmp_1); + const uint64x2_t sign_1 = vshrq_n_u64(tmp_1_u, 63); + const int64x2_t sign_1_s = vreinterpretq_s64_u64(sign_1); + const int64x2_t convert_1 = vsubq_s64(vshlq_s64(sign_1_s, vnl), sign_1_s); + + const uint64x2_t tmp_2_u = vreinterpretq_u64_s64(tmp_2); + const uint64x2_t sign_2 = vshrq_n_u64(tmp_2_u, 63); + const int64x2_t sign_2_s = vreinterpretq_s64_u64(sign_2); + const int64x2_t convert_2 = vsubq_s64(vshlq_s64(sign_2_s, vnl), sign_2_s); + if(is_sat) + { + tmp_1 = vqshlq_s64(vaddq_s64(tmp_1, convert_1), vn); + tmp_2 = vqshlq_s64(vaddq_s64(tmp_2, convert_2), vn); + return vcombine_s32(vqmovn_s64(tmp_1), vqmovn_s64(tmp_2)); + } + else + { + tmp_1 = vshlq_s64(vaddq_s64(tmp_1, convert_1), vn); + tmp_2 = vshlq_s64(vaddq_s64(tmp_2, convert_2), vn); + return vcombine_s32(vmovn_s64(tmp_1), vmovn_s64(tmp_2)); + } +} + +template +inline int32x4x2_t mul_S32_S32_S32_n_k(const int32x4x2_t &input1, const int32x4x2_t &input2, int n) +{ + const int32x4x2_t result = + { + { + // First 4 elements + mul_S32_S32_S32_n_loop(input1.val[0], input2.val[0], n), + // Second 4 elements + mul_S32_S32_S32_n_loop(input1.val[1], input2.val[1], n) + } + }; + + return result; +} + +template +void mul_S32_S32_S32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, int n) +{ + // Create input windows + Window win = window; + Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()); + Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()); + + // Clear X Dimension on execution window as we handle manually + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input1(in1, input1_win); + Iterator input2(in2, input2_win); + Iterator output(out, win); + + const int window_step_x = 8; + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + + execute_window_loop(win, [&](const Coordinates &) + { + const auto input1_ptr = reinterpret_cast(input1.ptr()); + const auto input2_ptr = reinterpret_cast(input2.ptr()); + const auto output_ptr = reinterpret_cast(output.ptr()); + + // Compute window_step_x elements per iteration + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const int32x4x2_t ta1 = + { + { + vld1q_s32(input1_ptr + x), + vld1q_s32(input1_ptr + x + 4), + } + }; + const int32x4x2_t ta2 = + { + { + vld1q_s32(input2_ptr + x), + vld1q_s32(input2_ptr + x + 4), + } + }; + const int32x4x2_t result = mul_S32_S32_S32_n_k(ta1, ta2, n); + + vst1q_s32(output_ptr + x, result.val[0]); + vst1q_s32(output_ptr + x + 4, result.val[1]); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + int64_t tmp = static_cast(*(input1_ptr + x)) * static_cast(*(input2_ptr + x)); + + if(tmp >= 0) + { + tmp >>= n; + } + else + { + uint64_t mask = (1u << n) - 1; + tmp = (tmp + static_cast(mask)) >> n; + } + if(is_sat) + { + tmp = (tmp > INT_MAX) ? INT_MAX : ((tmp < INT_MIN) ? INT_MIN : tmp); + } + *(output_ptr + x) = static_cast(tmp); + } + }, + input1, input2, output); +} + void mul_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale) { // Create input windows @@ -1200,6 +1334,12 @@ void NEPixelWiseMultiplicationKernel::configure(ITensorInfo *input1, ITensorInfo } } break; + case DataType::S32: + if(DataType::S32 == dt_input2 && DataType::S32 == dt_output) + { + _func_int = is_sat ? &mul_S32_S32_S32 : &mul_S32_S32_S32; + } + break; case DataType::U8: if(DataType::U8 == dt_input2 && DataType::U8 == dt_output) { -- cgit v1.2.1