From 7bb56c6337997281df10fa28ad7924c921b920eb Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Wed, 26 Jun 2019 15:17:09 +0100 Subject: COMPMID-2409: Add QSYMM16 support for PixelWiseMultiplication for NEON Change-Id: Idfd3b45857201d5143242f9517d3353150b2c923 Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/1422 Reviewed-by: Pablo Marquez Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 76 ++++++++++++++++++---- 1 file changed, 63 insertions(+), 13 deletions(-) (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp') diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index c313b23ad3..6aaac818e9 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -30,6 +30,7 @@ #include "arm_compute/core/ITensor.h" #include "arm_compute/core/NEON/NEAsymm.h" #include "arm_compute/core/NEON/NEFixedPoint.h" +#include "arm_compute/core/NEON/NESymm.h" #include "arm_compute/core/NEON/wrapper/wrapper.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" @@ -63,21 +64,30 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i ARM_COMPUTE_UNUSED(rounding_policy); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), "Output can only be U8 if both inputs are U8"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::QASYMM8 && input2->data_type() != DataType::QASYMM8, - "Input2 must be QASYMM8 if both input1 is QASYMM8"); + "Input2 must be QASYMM8 if input1 is QASYMM8"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::QASYMM8 && input2->data_type() == DataType::QASYMM8 && overflow_policy == ConvertPolicy::WRAP, - "ConvertPolicy cannot be WRAP if datatype is QASYMM8"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() != DataType::QASYMM8 && input2->data_type() == DataType::QASYMM8, + "Input1 must be QASYMM8 if input2 is QASYMM8"); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() == DataType::QSYMM16 && input2->data_type() != DataType::QSYMM16, + "Input2 must be QSYMM16 if input1 is QSYMM16"); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->data_type() != DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16, + "Input1 must be QSYMM16 if input2 is QSYMM16"); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_quantized(input1->data_type()) && overflow_policy == ConvertPolicy::WRAP, + "ConvertPolicy cannot be WRAP if datatype is quantized"); if(output->total_size() > 0) { - if(output->data_type() == DataType::QASYMM8) + if(is_data_type_quantized(output->data_type())) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2, output); } @@ -128,6 +138,14 @@ inline std::pair validate_and_configure_window(ITensorInfo *inpu { set_format_if_unknown(*output, Format::F16); } + else if(input1->data_type() == DataType::QASYMM8) + { + set_data_type_if_unknown(*output, DataType::QASYMM8); + } + else if(input1->data_type() == DataType::QSYMM16) + { + set_data_type_if_unknown(*output, DataType::QSYMM16); + } } // Configure kernel window @@ -201,6 +219,34 @@ void mul_saturate_QASYMM8_QASYMM8_QASYMM8_n(const void *__restrict input1_ptr, c vst1q_u8(output, result); } +void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale, + const UniformQuantizationInfo &input1_qua_info, const UniformQuantizationInfo &input2_qua_info, const UniformQuantizationInfo &output_qua_info) +{ + const auto input1 = static_cast(input1_ptr); + const auto input2 = static_cast(input2_ptr); + const auto output = static_cast(output_ptr); + + const qsymm16x8x2_t input1_q = vld2q_s16(input1); + const qsymm16x8x2_t input2_q = vld2q_s16(input2); + + // Dequantitize inputs + const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info); + const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info); + + const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset }; + + const float32x4x4_t out_f32x4x4 = + { + vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]), + vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]), + vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]), + vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]), + }; + + const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info); + vst2q_s16(output, result); +} + template void mul_U8_U8_U8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n) { @@ -488,7 +534,7 @@ void mul_U8_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict } // namespace NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel() - : _func_float(nullptr), _func_int(nullptr), _func_qasymm8(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 } + : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 } { } @@ -508,7 +554,7 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe _output = output; _scale = scale; _scale_exponent = 0; - _func_qasymm8 = nullptr; + _func_quantized = nullptr; _func_int = nullptr; _func_float = nullptr; @@ -536,7 +582,11 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe if(dt_input1 == DataType::QASYMM8 && dt_input2 == DataType::QASYMM8) { - _func_qasymm8 = &mul_saturate_QASYMM8_QASYMM8_QASYMM8_n; + _func_quantized = &mul_saturate_QASYMM8_QASYMM8_QASYMM8_n; + } + else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16) + { + _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16_n; } else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output) { @@ -655,12 +705,12 @@ void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo Iterator input2(_input2, slice_input2); Iterator output(_output, slice); - if(_func_qasymm8 != nullptr) + if(is_data_type_quantized(_input1->info()->data_type())) { execute_window_loop(collapsed, [&](const Coordinates &) { - (*_func_qasymm8)(input1.ptr(), input2.ptr(), output.ptr(), _scale, - _input1->info()->quantization_info().uniform(), _input2->info()->quantization_info().uniform(), _output->info()->quantization_info().uniform()); + (*_func_quantized)(input1.ptr(), input2.ptr(), output.ptr(), _scale, + _input1->info()->quantization_info().uniform(), _input2->info()->quantization_info().uniform(), _output->info()->quantization_info().uniform()); collapsed.slide_window_slice_3D(slice_input1); collapsed.slide_window_slice_3D(slice_input2); }, -- cgit v1.2.1