From 52ea9c24607648acce37cda960e4fbaa59c9a263 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Tue, 10 Dec 2019 11:28:53 +0000 Subject: COMPMID-2811: QASYMM8_SIGNED support in NEPixelwiseMultiplication. Change-Id: I4e52bd55fc9804796f47fab04859961d846f4ceb Signed-off-by: Pablo Tello Reviewed-on: https://review.mlplatform.org/c/2459 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Tested-by: Arm Jenkins --- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 60 ++++++++++++++++------ 1 file changed, 43 insertions(+), 17 deletions(-) (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp') diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index 4bd03e959e..7ec52f788b 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -64,26 +64,18 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i ARM_COMPUTE_UNUSED(rounding_policy); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), "Output can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::QASYMM8 && input2->data_type() != DataType::QASYMM8, - "Input2 must be QASYMM8 if input1 is QASYMM8"); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() != DataType::QASYMM8 && input2->data_type() == DataType::QASYMM8, - "Input1 must be QASYMM8 if input2 is QASYMM8"); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::QSYMM16 && input2->data_type() != DataType::QSYMM16, - "Input2 must be QSYMM16 if input1 is QSYMM16"); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() != DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16, - "Input1 must be QSYMM16 if input2 is QSYMM16"); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_quantized(input1->data_type()) && overflow_policy == ConvertPolicy::WRAP, - "ConvertPolicy cannot be WRAP if datatype is quantized"); + if(is_data_type_quantized(input1->data_type())|| + is_data_type_quantized(input2->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP,"ConvertPolicy cannot be WRAP if datatype is quantized"); + } if(output->total_size() > 0) { @@ -142,6 +134,10 @@ inline std::pair validate_and_configure_window(ITensorInfo *inpu { set_data_type_if_unknown(*output, DataType::QASYMM8); } + else if(input1->data_type() == DataType::QASYMM8_SIGNED) + { + set_data_type_if_unknown(*output, DataType::QASYMM8_SIGNED); + } else if(input1->data_type() == DataType::QSYMM16) { set_data_type_if_unknown(*output, DataType::QSYMM16); @@ -238,6 +234,32 @@ inline void mul_saturate_QASYMM8_QASYMM8_QASYMM8_n_opt(const void *__restrict in vst1q_u8(output, vcombine_u8(pa, pb)); } +inline void mul_saturate_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED_n( + const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, + float scale, const UniformQuantizationInfo &input1_qua_info, const UniformQuantizationInfo &input2_qua_info, + const UniformQuantizationInfo &output_qua_info) + +{ + const auto input1 = static_cast(input1_ptr); + const auto input2 = static_cast(input2_ptr); + const auto output = static_cast(output_ptr); + const qasymm8x16_signed_t input1_q = vld1q_s8(input1); + const qasymm8x16_signed_t input2_q = vld1q_s8(input2); + // Dequantitize inputs + const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info); + const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info); + const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset }; + const float32x4x4_t out_f32x4x4 = + { + vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]), + vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]), + vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]), + vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]), + }; + const int8x16_t result = vquantize_signed(out_f32x4x4, tmp_qua_info); + vst1q_s8(output, result); +} + void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale, const UniformQuantizationInfo &input1_qua_info, const UniformQuantizationInfo &input2_qua_info, const UniformQuantizationInfo &output_qua_info) { @@ -604,6 +626,10 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe { _run_optimized_qasymm8 = true; } + else if(dt_input1 == DataType::QASYMM8_SIGNED && dt_input2 == DataType::QASYMM8_SIGNED) + { + _func_quantized = &mul_saturate_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED_n; + } else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16) { _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16_n; -- cgit v1.2.1