diff options
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 28 |
1 files changed, 19 insertions, 9 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 85f8c58..7bca697 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -37,6 +37,11 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttri { setRequiredOperands(3, 1); INIT_ATTRIBUTE(Rescale); + + QMax_s = getSignedMaximum<OutDtype>(); + QMin_s = getSignedMinimum<OutDtype>(); + QMax_u = getUnsignedMaximum<OutDtype>(); + QMin_u = getUnsignedMinimum<OutDtype>(); } template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> @@ -155,9 +160,6 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() bool input_unsigned = attribute->input_unsigned(); bool output_unsigned = attribute->output_unsigned(); - int32_t QMin = output_unsigned ? getUnsignedMinimum<OutDtype>() : getSignedMinimum<OutDtype>(); - int32_t QMax = output_unsigned ? getUnsignedMaximum<OutDtype>() : getSignedMaximum<OutDtype>(); - // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn] Eigen::array<Eigen::Index, 2> shape_2d; shape_2d[0] = 1; @@ -270,18 +272,24 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() { output_zp_extended = static_cast<int64_t>(output_zp); } + int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended; int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min()); + if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64) { std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" + std::to_string(output_zp) + "] not in i32 range"; throw desc; } - OutEigenType out_val = static_cast<OutEigenType>(res_in_64); - out_val = std::max<OutEigenType>(out_val, QMin); - out_val = std::min<OutEigenType>(out_val, QMax); + + // Treat the output values as unsigned if `output_unsigned` is true. + int32_t clipped_val = (output_unsigned) + ? applyClip<int32_t, uint32_t>(res_in_64, QMin_u, QMax_u) + : applyClip<int32_t, int32_t>(res_in_64, QMin_s, QMax_s); + + OutEigenType out_val = static_cast<OutEigenType>(clipped_val); return out_val; }); @@ -370,9 +378,11 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() throw desc; } - OutEigenType out_val = static_cast<OutEigenType>(res_in_64); - out_val = std::max<OutEigenType>(out_val, QMin); - out_val = std::min<OutEigenType>(out_val, QMax); + // Treat the output values as unsigned if `output_unsigned` is true. + int32_t clipped_val = (output_unsigned) ? applyClip<int32_t, uint32_t>(res_in_64, QMin_u, QMax_u) + : applyClip<int32_t, int32_t>(res_in_64, QMin_s, QMax_s); + + OutEigenType out_val = static_cast<OutEigenType>(clipped_val); return out_val; }); } |