From 08fe7a5b7e2c7c1a77968130e11267ef61490ac8 Mon Sep 17 00:00:00 2001 From: TatWai Chong Date: Thu, 21 Mar 2024 14:34:33 -0700 Subject: Take into account of `output_unsigned` in rescale operation Set QMin and QMax based on the value of attribute `output_unsigned`. Change-Id: I7f21f3edd7311295285fb3988b3c800de114777a Signed-off-by: TatWai Chong --- reference_model/include/dtype.h | 29 ++++++ reference_model/src/arith_util.h | 70 +++++++++++++ reference_model/src/ops/type_conversion.cc | 151 +++++++++++++++-------------- reference_model/src/ops/type_conversion.h | 3 - 4 files changed, 176 insertions(+), 77 deletions(-) diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h index 3e8bdf5..a283f39 100644 --- a/reference_model/include/dtype.h +++ b/reference_model/include/dtype.h @@ -145,6 +145,35 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) return TOSA_REF_TYPE_UNKNOWN; } +template +bool IsSignedInt() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_INT4: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_INT48: + return true; + + case TOSA_REF_TYPE_UINT8: + case TOSA_REF_TYPE_UINT16: + return false; + + case TOSA_REF_TYPE_BOOL: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_SHAPE: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: + default: + FATAL_ERROR("dtype is not an integer type"); + break; + } +} + }; // namespace TosaReference #endif diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h index f0d184c..fee9fef 100644 --- a/reference_model/src/arith_util.h +++ b/reference_model/src/arith_util.h @@ -22,6 +22,7 @@ * fix point arithmetic * fp16 type conversion(in binary translation) * fp16 arithmetic (disguised with fp32 now) + * and include the arithmetic helpers listed in Section 4.3.1. of the spec */ #ifndef ARITH_UTIL_H @@ -35,6 +36,7 @@ #include "func_debug.h" #include "half.hpp" #include "inttypes.h" +#include "ops/template_types.h" #include #include #include @@ -247,4 +249,72 @@ float fpTrunc(float f_in) return f_in; } +// return the maximum value when interpreting type T as a signed value. +template +int32_t getSignedMaximum() +{ + if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8) + return GetQMax::value; + + if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16) + return GetQMax::value; + + if (Dtype == TOSA_REF_TYPE_INT32) + return GetQMax::value; + + FATAL_ERROR("Get maximum_s for the dtype input is not supported"); + return 0; +} + +// return the minimum value when interpreting type T as a signed value. +template +int32_t getSignedMinimum() +{ + if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8) + return GetQMin::value; + + if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16) + return GetQMin::value; + + if (Dtype == TOSA_REF_TYPE_INT32) + return GetQMin::value; + + FATAL_ERROR("Get minimum_s for the dtype input is not supported"); + return 0; +} + +// return the maximum value when interpreting type T as an unsigned value. +template +int32_t getUnsignedMaximum() +{ + if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8) + return GetQMax::value; + + if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16) + return GetQMax::value; + + if (Dtype == TOSA_REF_TYPE_INT32) + return std::numeric_limits::max(); + + FATAL_ERROR("Get maximum_u for the dtype input is not supported"); + return 0; +} + +// return the minimum value when interpreting type T as an unsigned value. +template +int32_t getUnsignedMinimum() +{ + if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8) + return GetQMin::value; + + if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16) + return GetQMin::value; + + if (Dtype == TOSA_REF_TYPE_INT32) + return std::numeric_limits::min(); + + FATAL_ERROR("Get minimum_u for the dtype input is not supported"); + return 0; +} + #endif /* _ARITH_UTIL_H */ diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index d58cfeb..835b656 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -148,6 +148,9 @@ int OpRescale::eval() bool input_unsigned = attribute->input_unsigned(); bool output_unsigned = attribute->output_unsigned(); + int32_t QMin = output_unsigned ? getUnsignedMinimum() : getSignedMinimum(); + int32_t QMax = output_unsigned ? getUnsignedMaximum() : getSignedMaximum(); + // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn] Eigen::array shape_2d; shape_2d[0] = 1; @@ -200,13 +203,12 @@ int OpRescale::eval() { for (int32_t i = 0; i < shape_2d[1]; i++) { - begin = Eigen::array({ 0, i }); - curr_channel_slice_prescaled = input_reshaped.slice(begin, size); - channel_multiplier = multiplier[i]; - channel_shift = shift[i]; - curr_channel_slice_postscaled = curr_channel_slice_prescaled.unaryExpr( - [input_zp, output_zp, channel_multiplier, channel_shift, double_round, scale32, input_unsigned, - output_unsigned](InEigenType in_val) -> OutEigenType { + begin = Eigen::array({ 0, i }); + curr_channel_slice_prescaled = input_reshaped.slice(begin, size); + channel_multiplier = multiplier[i]; + channel_shift = shift[i]; + curr_channel_slice_postscaled = + curr_channel_slice_prescaled.unaryExpr([=](InEigenType in_val) -> OutEigenType { int64_t input_zp_shifted; if (input_unsigned) { @@ -293,78 +295,79 @@ int OpRescale::eval() int32_t tensor_shift = shift[0]; try { - output_2d = - input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, scale32, - input_unsigned, output_unsigned](InEigenType in_val) -> OutEigenType { - int64_t input_zp_shifted; - if (input_unsigned) - { - int64_t in_val64; - int64_t in_zp64; - switch (GetNumBits::value) - { - case 8: - in_val64 = zero_extend(static_cast(in_val)); - in_zp64 = zero_extend(static_cast(input_zp)); - break; - case 16: - in_val64 = zero_extend(static_cast(in_val)); - in_zp64 = zero_extend(static_cast(input_zp)); - break; - default: - in_val64 = static_cast(in_val); - in_zp64 = static_cast(input_zp); - break; - } - input_zp_shifted = in_val64 - in_zp64; - } - else - { - input_zp_shifted = in_val - input_zp; - } - int32_t scaled; - if (scale32) - scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, - tensor_shift, double_round); - else - scaled = - TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift); - - int64_t output_zp_extended; - if (output_unsigned) - { - switch (GetNumBits::value) - { - case 8: - output_zp_extended = zero_extend(static_cast(output_zp)); - break; - case 16: - output_zp_extended = zero_extend(static_cast(output_zp)); - break; - default: - output_zp_extended = static_cast(output_zp); - break; - } - } - else + output_2d = input_reshaped.unaryExpr([=](InEigenType in_val) -> OutEigenType { + int64_t input_zp_shifted; + if (input_unsigned) + { + int64_t in_val64; + int64_t in_zp64; + switch (GetNumBits::value) { - output_zp_extended = static_cast(output_zp); + case 8: + in_val64 = zero_extend(static_cast(in_val)); + in_zp64 = zero_extend(static_cast(input_zp)); + break; + case 16: + in_val64 = zero_extend(static_cast(in_val)); + in_zp64 = zero_extend(static_cast(input_zp)); + break; + default: + in_val64 = static_cast(in_val); + in_zp64 = static_cast(input_zp); + break; } - int64_t res_in_64 = static_cast(scaled) + output_zp_extended; - int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); - int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); - if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64) + input_zp_shifted = in_val64 - in_zp64; + } + else + { + input_zp_shifted = in_val - input_zp; + } + int32_t scaled; + if (scale32) + scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift, + double_round); + else + scaled = + TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift); + + int64_t output_zp_extended; + if (output_unsigned) + { + switch (GetNumBits::value) { - std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" + - std::to_string(output_zp) + "] not in i32 range"; - throw desc; + case 8: + output_zp_extended = zero_extend(static_cast(output_zp)); + break; + case 16: + output_zp_extended = zero_extend(static_cast(output_zp)); + break; + default: + output_zp_extended = static_cast(output_zp); + break; } + } + else + { + output_zp_extended = static_cast(output_zp); + } + int64_t res_in_64 = static_cast(scaled) + output_zp_extended; + int64_t i32_max_in_64 = IsSignedInt() + ? static_cast(std::numeric_limits::max()) + : static_cast(std::numeric_limits::max()); + int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); + + if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64) + { + std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" + + std::to_string(output_zp) + "] not in i32 range"; + throw desc; + } - OutEigenType out_val = static_cast(res_in_64); - out_val = std::max(out_val, QMin); - out_val = std::min(out_val, QMax); - return out_val; - }); + OutEigenType out_val = static_cast(res_in_64); + out_val = std::max(out_val, QMin); + out_val = std::min(out_val, QMax); + return out_val; + }); } catch (std::string desc) { diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index a06dccc..da5537e 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -43,9 +43,6 @@ public: using TMultiplierI32 = Eigen::Tensor; using TShift = Eigen::Tensor; - static constexpr int32_t QMin = GetQMin::value; - static constexpr int32_t QMax = GetQMax::value; - protected: TosaRescaleAttribute* attribute; TosaReference::TensorTemplate* in; -- cgit v1.2.1