From 40aad9bbbae5308d7302e61e1372328c9b5daf99 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Wed, 22 Jul 2020 15:17:43 +0100 Subject: COMPMID-3600: Fix requantization in NEPixelWiseMultiplicationKernel Quantization wasn't done correctly and since we have helpers for that, the code has been modified to use them. Change-Id: Ia16577cea57dcb1864d91a06ab6aebf8ead67de5 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3608 Reviewed-by: TeresaARM Reviewed-by: Michalis Spyrou Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 47 ++++++---------------- 1 file changed, 12 insertions(+), 35 deletions(-) diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index f8875324de..b5b4f841b4 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -71,7 +71,7 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QASYMM8 && (input1->data_type() != DataType::QASYMM8 || input2->data_type() != DataType::QASYMM8), "Output can only be QASYMM8 if both inputs are QASYMM8"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QASYMM8_SIGNED && (input1->data_type() != DataType::QASYMM8_SIGNED || input2->data_type() != DataType::QASYMM8_SIGNED), - "Output can only be QASYMM8 if both inputs are QASYMM8"); + "Output can only be QASYMM8_SIGNED if both inputs are QASYMM8_SIGNED"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QSYMM16 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16), "Output can only be QSYMM16 if both inputs are QSYMM16"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16), @@ -137,32 +137,6 @@ vquantize(float32x4x4_t val, const UniformQuantizationInfo &info) return vquantize(val, info); } -template -inline typename std::enable_if::value, int8_t>::type -quantize(float val, const UniformQuantizationInfo &info) -{ - const int32_t tmp = static_cast(val / info.scale) + info.offset; - - T tmp_qua = static_cast(tmp > SCHAR_MAX) ? SCHAR_MAX : ((tmp < SCHAR_MIN) ? SCHAR_MIN : tmp); - return tmp_qua; -} - -template -inline typename std::enable_if::value, uint8_t>::type -quantize(float val, const UniformQuantizationInfo &info) -{ - const int32_t tmp = static_cast(val / info.scale) + info.offset; - - T tmp_qua = static_cast(tmp > UCHAR_MAX) ? UCHAR_MAX : ((tmp < 0) ? 0 : tmp); - return tmp_qua; -} - -template -inline float dequantize(const T *input, const UniformQuantizationInfo &info) -{ - return static_cast((*input) - info.offset) * info.scale; -} - template void mul_saturate_quantized_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, float scale) { @@ -236,12 +210,13 @@ void mul_saturate_quantized_8(const ITensor *in1, const ITensor *in2, ITensor *o for(; x < window_end_x; ++x) { // Dequantize inputs - float tmp_in1 = dequantize(non_broadcast_input_ptr + x, non_broadcast_qinfo); - float tmp_in2 = dequantize(&broadcast_value, broadcast_qinfo); - float tmp_f = tmp_in1 * tmp_in2; + const T in1 = *(non_broadcast_input_ptr + x); + const float tmp_in1 = Qasymm8QuantizationHelper::dequantize(in1, non_broadcast_qinfo); + const float tmp_in2 = Qasymm8QuantizationHelper::dequantize(broadcast_value, broadcast_qinfo); + const float tmp_f = tmp_in1 * tmp_in2; // Quantize output - const auto tmp_qua = quantize(tmp_f, tmp_qua_info); + const auto tmp_qua = Qasymm8QuantizationHelper::quantize(tmp_f, tmp_qua_info); *(output_ptr + x) = tmp_qua; } }, @@ -294,12 +269,14 @@ void mul_saturate_quantized_8(const ITensor *in1, const ITensor *in2, ITensor *o for(; x < window_end_x; ++x) { // Dequantize inputs - float tmp_in1 = dequantize(input1_ptr + x, input1_qua_info); - float tmp_in2 = dequantize(input2_ptr + x, input2_qua_info); - float tmp_f = tmp_in1 * tmp_in2; + const T in1 = *(input1_ptr + x); + const T in2 = *(input2_ptr + x); + const float tmp_in1 = Qasymm8QuantizationHelper::dequantize(in1, input1_qua_info); + const float tmp_in2 = Qasymm8QuantizationHelper::dequantize(in2, input2_qua_info); + const float tmp_f = tmp_in1 * tmp_in2; // Quantize output - const auto tmp_qua = quantize(tmp_f, tmp_qua_info); + const auto tmp_qua = Qasymm8QuantizationHelper::quantize(tmp_f, tmp_qua_info); *(output_ptr + x) = tmp_qua; } }, -- cgit v1.2.1