diff options
author | TatWai Chong <tatwai.chong@arm.com> | 2024-03-21 14:34:33 -0700 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-03-28 00:35:36 +0000 |
commit | 08fe7a5b7e2c7c1a77968130e11267ef61490ac8 (patch) | |
tree | 29e974b12bc5540a4a364aaa7b0d6aa47c30c923 /reference_model | |
parent | d5b1512b1d2cea3b87e52a0ecc123db2a7a7cad3 (diff) | |
download | reference_model-08fe7a5b7e2c7c1a77968130e11267ef61490ac8.tar.gz |
Take into account of `output_unsigned` in rescale operation
Set QMin and QMax based on the value of attribute `output_unsigned`.
Change-Id: I7f21f3edd7311295285fb3988b3c800de114777a
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/include/dtype.h | 29 | ||||
-rw-r--r-- | reference_model/src/arith_util.h | 70 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 151 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.h | 3 |
4 files changed, 176 insertions, 77 deletions
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h index 3e8bdf5..a283f39 100644 --- a/reference_model/include/dtype.h +++ b/reference_model/include/dtype.h @@ -145,6 +145,35 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) return TOSA_REF_TYPE_UNKNOWN; } +template <TOSA_REF_TYPE Dtype> +bool IsSignedInt() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_INT4: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_INT48: + return true; + + case TOSA_REF_TYPE_UINT8: + case TOSA_REF_TYPE_UINT16: + return false; + + case TOSA_REF_TYPE_BOOL: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_SHAPE: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: + default: + FATAL_ERROR("dtype is not an integer type"); + break; + } +} + }; // namespace TosaReference #endif diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h index f0d184c..fee9fef 100644 --- a/reference_model/src/arith_util.h +++ b/reference_model/src/arith_util.h @@ -22,6 +22,7 @@ * fix point arithmetic * fp16 type conversion(in binary translation) * fp16 arithmetic (disguised with fp32 now) + * and include the arithmetic helpers listed in Section 4.3.1. of the spec */ #ifndef ARITH_UTIL_H @@ -35,6 +36,7 @@ #include "func_debug.h" #include "half.hpp" #include "inttypes.h" +#include "ops/template_types.h" #include <bitset> #include <cassert> #include <limits> @@ -247,4 +249,72 @@ float fpTrunc(float f_in) return f_in; } +// return the maximum value when interpreting type T as a signed value. +template <TOSA_REF_TYPE Dtype> +int32_t getSignedMaximum() +{ + if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8) + return GetQMax<TOSA_REF_TYPE_INT8>::value; + + if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16) + return GetQMax<TOSA_REF_TYPE_INT16>::value; + + if (Dtype == TOSA_REF_TYPE_INT32) + return GetQMax<TOSA_REF_TYPE_INT32>::value; + + FATAL_ERROR("Get maximum_s for the dtype input is not supported"); + return 0; +} + +// return the minimum value when interpreting type T as a signed value. +template <TOSA_REF_TYPE Dtype> +int32_t getSignedMinimum() +{ + if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8) + return GetQMin<TOSA_REF_TYPE_INT8>::value; + + if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16) + return GetQMin<TOSA_REF_TYPE_INT16>::value; + + if (Dtype == TOSA_REF_TYPE_INT32) + return GetQMin<TOSA_REF_TYPE_INT32>::value; + + FATAL_ERROR("Get minimum_s for the dtype input is not supported"); + return 0; +} + +// return the maximum value when interpreting type T as an unsigned value. +template <TOSA_REF_TYPE Dtype> +int32_t getUnsignedMaximum() +{ + if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8) + return GetQMax<TOSA_REF_TYPE_UINT8>::value; + + if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16) + return GetQMax<TOSA_REF_TYPE_UINT16>::value; + + if (Dtype == TOSA_REF_TYPE_INT32) + return std::numeric_limits<uint32_t>::max(); + + FATAL_ERROR("Get maximum_u for the dtype input is not supported"); + return 0; +} + +// return the minimum value when interpreting type T as an unsigned value. +template <TOSA_REF_TYPE Dtype> +int32_t getUnsignedMinimum() +{ + if (Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8) + return GetQMin<TOSA_REF_TYPE_UINT8>::value; + + if (Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16) + return GetQMin<TOSA_REF_TYPE_UINT16>::value; + + if (Dtype == TOSA_REF_TYPE_INT32) + return std::numeric_limits<uint32_t>::min(); + + FATAL_ERROR("Get minimum_u for the dtype input is not supported"); + return 0; +} + #endif /* _ARITH_UTIL_H */ diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index d58cfeb..835b656 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -148,6 +148,9 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() bool input_unsigned = attribute->input_unsigned(); bool output_unsigned = attribute->output_unsigned(); + int32_t QMin = output_unsigned ? getUnsignedMinimum<OutDtype>() : getSignedMinimum<OutDtype>(); + int32_t QMax = output_unsigned ? getUnsignedMaximum<OutDtype>() : getSignedMaximum<OutDtype>(); + // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn] Eigen::array<Eigen::Index, 2> shape_2d; shape_2d[0] = 1; @@ -200,13 +203,12 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() { for (int32_t i = 0; i < shape_2d[1]; i++) { - begin = Eigen::array<Eigen::Index, 2>({ 0, i }); - curr_channel_slice_prescaled = input_reshaped.slice(begin, size); - channel_multiplier = multiplier[i]; - channel_shift = shift[i]; - curr_channel_slice_postscaled = curr_channel_slice_prescaled.unaryExpr( - [input_zp, output_zp, channel_multiplier, channel_shift, double_round, scale32, input_unsigned, - output_unsigned](InEigenType in_val) -> OutEigenType { + begin = Eigen::array<Eigen::Index, 2>({ 0, i }); + curr_channel_slice_prescaled = input_reshaped.slice(begin, size); + channel_multiplier = multiplier[i]; + channel_shift = shift[i]; + curr_channel_slice_postscaled = + curr_channel_slice_prescaled.unaryExpr([=](InEigenType in_val) -> OutEigenType { int64_t input_zp_shifted; if (input_unsigned) { @@ -293,78 +295,79 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() int32_t tensor_shift = shift[0]; try { - output_2d = - input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, scale32, - input_unsigned, output_unsigned](InEigenType in_val) -> OutEigenType { - int64_t input_zp_shifted; - if (input_unsigned) - { - int64_t in_val64; - int64_t in_zp64; - switch (GetNumBits<InDtype>::value) - { - case 8: - in_val64 = zero_extend(static_cast<int8_t>(in_val)); - in_zp64 = zero_extend(static_cast<int8_t>(input_zp)); - break; - case 16: - in_val64 = zero_extend(static_cast<int16_t>(in_val)); - in_zp64 = zero_extend(static_cast<int16_t>(input_zp)); - break; - default: - in_val64 = static_cast<int64_t>(in_val); - in_zp64 = static_cast<int64_t>(input_zp); - break; - } - input_zp_shifted = in_val64 - in_zp64; - } - else - { - input_zp_shifted = in_val - input_zp; - } - int32_t scaled; - if (scale32) - scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, - tensor_shift, double_round); - else - scaled = - TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift); - - int64_t output_zp_extended; - if (output_unsigned) - { - switch (GetNumBits<OutDtype>::value) - { - case 8: - output_zp_extended = zero_extend(static_cast<int8_t>(output_zp)); - break; - case 16: - output_zp_extended = zero_extend(static_cast<int16_t>(output_zp)); - break; - default: - output_zp_extended = static_cast<int64_t>(output_zp); - break; - } - } - else + output_2d = input_reshaped.unaryExpr([=](InEigenType in_val) -> OutEigenType { + int64_t input_zp_shifted; + if (input_unsigned) + { + int64_t in_val64; + int64_t in_zp64; + switch (GetNumBits<InDtype>::value) { - output_zp_extended = static_cast<int64_t>(output_zp); + case 8: + in_val64 = zero_extend(static_cast<int8_t>(in_val)); + in_zp64 = zero_extend(static_cast<int8_t>(input_zp)); + break; + case 16: + in_val64 = zero_extend(static_cast<int16_t>(in_val)); + in_zp64 = zero_extend(static_cast<int16_t>(input_zp)); + break; + default: + in_val64 = static_cast<int64_t>(in_val); + in_zp64 = static_cast<int64_t>(input_zp); + break; } - int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended; - int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); - int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min()); - if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64) + input_zp_shifted = in_val64 - in_zp64; + } + else + { + input_zp_shifted = in_val - input_zp; + } + int32_t scaled; + if (scale32) + scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift, + double_round); + else + scaled = + TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift); + + int64_t output_zp_extended; + if (output_unsigned) + { + switch (GetNumBits<OutDtype>::value) { - std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" + - std::to_string(output_zp) + "] not in i32 range"; - throw desc; + case 8: + output_zp_extended = zero_extend(static_cast<int8_t>(output_zp)); + break; + case 16: + output_zp_extended = zero_extend(static_cast<int16_t>(output_zp)); + break; + default: + output_zp_extended = static_cast<int64_t>(output_zp); + break; } + } + else + { + output_zp_extended = static_cast<int64_t>(output_zp); + } + int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended; + int64_t i32_max_in_64 = IsSignedInt<OutDtype>() + ? static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + : static_cast<int64_t>(std::numeric_limits<uint32_t>::max()); + int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min()); + + if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64) + { + std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" + + std::to_string(output_zp) + "] not in i32 range"; + throw desc; + } - OutEigenType out_val = static_cast<OutEigenType>(res_in_64); - out_val = std::max<OutEigenType>(out_val, QMin); - out_val = std::min<OutEigenType>(out_val, QMax); - return out_val; - }); + OutEigenType out_val = static_cast<OutEigenType>(res_in_64); + out_val = std::max<OutEigenType>(out_val, QMin); + out_val = std::min<OutEigenType>(out_val, QMax); + return out_val; + }); } catch (std::string desc) { diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index a06dccc..da5537e 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -43,9 +43,6 @@ public: using TMultiplierI32 = Eigen::Tensor<I32EigenType, 1>; using TShift = Eigen::Tensor<I8EigenType, 1>; - static constexpr int32_t QMin = GetQMin<OutDtype>::value; - static constexpr int32_t QMax = GetQMax<OutDtype>::value; - protected: TosaRescaleAttribute* attribute; TosaReference::TensorTemplate<TIn>* in; |