aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/type_conversion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r--reference_model/src/ops/type_conversion.cc28
1 files changed, 19 insertions, 9 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 85f8c58..7bca697 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -37,6 +37,11 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttri
{
setRequiredOperands(3, 1);
INIT_ATTRIBUTE(Rescale);
+
+ QMax_s = getSignedMaximum<OutDtype>();
+ QMin_s = getSignedMinimum<OutDtype>();
+ QMax_u = getUnsignedMaximum<OutDtype>();
+ QMin_u = getUnsignedMinimum<OutDtype>();
}
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
@@ -155,9 +160,6 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
bool input_unsigned = attribute->input_unsigned();
bool output_unsigned = attribute->output_unsigned();
- int32_t QMin = output_unsigned ? getUnsignedMinimum<OutDtype>() : getSignedMinimum<OutDtype>();
- int32_t QMax = output_unsigned ? getUnsignedMaximum<OutDtype>() : getSignedMaximum<OutDtype>();
-
// reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
Eigen::array<Eigen::Index, 2> shape_2d;
shape_2d[0] = 1;
@@ -270,18 +272,24 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
{
output_zp_extended = static_cast<int64_t>(output_zp);
}
+
int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
+
if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
{
std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
std::to_string(output_zp) + "] not in i32 range";
throw desc;
}
- OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
- out_val = std::max<OutEigenType>(out_val, QMin);
- out_val = std::min<OutEigenType>(out_val, QMax);
+
+ // Treat the output values as unsigned if `output_unsigned` is true.
+ int32_t clipped_val = (output_unsigned)
+ ? applyClip<int32_t, uint32_t>(res_in_64, QMin_u, QMax_u)
+ : applyClip<int32_t, int32_t>(res_in_64, QMin_s, QMax_s);
+
+ OutEigenType out_val = static_cast<OutEigenType>(clipped_val);
return out_val;
});
@@ -370,9 +378,11 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
throw desc;
}
- OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
- out_val = std::max<OutEigenType>(out_val, QMin);
- out_val = std::min<OutEigenType>(out_val, QMax);
+ // Treat the output values as unsigned if `output_unsigned` is true.
+ int32_t clipped_val = (output_unsigned) ? applyClip<int32_t, uint32_t>(res_in_64, QMin_u, QMax_u)
+ : applyClip<int32_t, int32_t>(res_in_64, QMin_s, QMax_s);
+
+ OutEigenType out_val = static_cast<OutEigenType>(clipped_val);
return out_val;
});
}