aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/type_conversion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r--reference_model/src/ops/type_conversion.cc151
1 files changed, 77 insertions, 74 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index d58cfeb..835b656 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -148,6 +148,9 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
bool input_unsigned = attribute->input_unsigned();
bool output_unsigned = attribute->output_unsigned();
+ int32_t QMin = output_unsigned ? getUnsignedMinimum<OutDtype>() : getSignedMinimum<OutDtype>();
+ int32_t QMax = output_unsigned ? getUnsignedMaximum<OutDtype>() : getSignedMaximum<OutDtype>();
+
// reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
Eigen::array<Eigen::Index, 2> shape_2d;
shape_2d[0] = 1;
@@ -200,13 +203,12 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
{
for (int32_t i = 0; i < shape_2d[1]; i++)
{
- begin = Eigen::array<Eigen::Index, 2>({ 0, i });
- curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
- channel_multiplier = multiplier[i];
- channel_shift = shift[i];
- curr_channel_slice_postscaled = curr_channel_slice_prescaled.unaryExpr(
- [input_zp, output_zp, channel_multiplier, channel_shift, double_round, scale32, input_unsigned,
- output_unsigned](InEigenType in_val) -> OutEigenType {
+ begin = Eigen::array<Eigen::Index, 2>({ 0, i });
+ curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
+ channel_multiplier = multiplier[i];
+ channel_shift = shift[i];
+ curr_channel_slice_postscaled =
+ curr_channel_slice_prescaled.unaryExpr([=](InEigenType in_val) -> OutEigenType {
int64_t input_zp_shifted;
if (input_unsigned)
{
@@ -293,78 +295,79 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
int32_t tensor_shift = shift[0];
try
{
- output_2d =
- input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, scale32,
- input_unsigned, output_unsigned](InEigenType in_val) -> OutEigenType {
- int64_t input_zp_shifted;
- if (input_unsigned)
- {
- int64_t in_val64;
- int64_t in_zp64;
- switch (GetNumBits<InDtype>::value)
- {
- case 8:
- in_val64 = zero_extend(static_cast<int8_t>(in_val));
- in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
- break;
- case 16:
- in_val64 = zero_extend(static_cast<int16_t>(in_val));
- in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
- break;
- default:
- in_val64 = static_cast<int64_t>(in_val);
- in_zp64 = static_cast<int64_t>(input_zp);
- break;
- }
- input_zp_shifted = in_val64 - in_zp64;
- }
- else
- {
- input_zp_shifted = in_val - input_zp;
- }
- int32_t scaled;
- if (scale32)
- scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
- tensor_shift, double_round);
- else
- scaled =
- TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
-
- int64_t output_zp_extended;
- if (output_unsigned)
- {
- switch (GetNumBits<OutDtype>::value)
- {
- case 8:
- output_zp_extended = zero_extend(static_cast<int8_t>(output_zp));
- break;
- case 16:
- output_zp_extended = zero_extend(static_cast<int16_t>(output_zp));
- break;
- default:
- output_zp_extended = static_cast<int64_t>(output_zp);
- break;
- }
- }
- else
+ output_2d = input_reshaped.unaryExpr([=](InEigenType in_val) -> OutEigenType {
+ int64_t input_zp_shifted;
+ if (input_unsigned)
+ {
+ int64_t in_val64;
+ int64_t in_zp64;
+ switch (GetNumBits<InDtype>::value)
{
- output_zp_extended = static_cast<int64_t>(output_zp);
+ case 8:
+ in_val64 = zero_extend(static_cast<int8_t>(in_val));
+ in_zp64 = zero_extend(static_cast<int8_t>(input_zp));
+ break;
+ case 16:
+ in_val64 = zero_extend(static_cast<int16_t>(in_val));
+ in_zp64 = zero_extend(static_cast<int16_t>(input_zp));
+ break;
+ default:
+ in_val64 = static_cast<int64_t>(in_val);
+ in_zp64 = static_cast<int64_t>(input_zp);
+ break;
}
- int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended;
- int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
- int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
- if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
+ input_zp_shifted = in_val64 - in_zp64;
+ }
+ else
+ {
+ input_zp_shifted = in_val - input_zp;
+ }
+ int32_t scaled;
+ if (scale32)
+ scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
+ double_round);
+ else
+ scaled =
+ TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
+
+ int64_t output_zp_extended;
+ if (output_unsigned)
+ {
+ switch (GetNumBits<OutDtype>::value)
{
- std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
- std::to_string(output_zp) + "] not in i32 range";
- throw desc;
+ case 8:
+ output_zp_extended = zero_extend(static_cast<int8_t>(output_zp));
+ break;
+ case 16:
+ output_zp_extended = zero_extend(static_cast<int16_t>(output_zp));
+ break;
+ default:
+ output_zp_extended = static_cast<int64_t>(output_zp);
+ break;
}
+ }
+ else
+ {
+ output_zp_extended = static_cast<int64_t>(output_zp);
+ }
+ int64_t res_in_64 = static_cast<int64_t>(scaled) + output_zp_extended;
+ int64_t i32_max_in_64 = IsSignedInt<OutDtype>()
+ ? static_cast<int64_t>(std::numeric_limits<int32_t>::max())
+ : static_cast<int64_t>(std::numeric_limits<uint32_t>::max());
+ int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::min());
+
+ if (res_in_64 > i32_max_in_64 || res_in_64 < i32_min_in_64)
+ {
+ std::string desc = "scaling result [" + std::to_string(scaled) + "] plus output_zp [" +
+ std::to_string(output_zp) + "] not in i32 range";
+ throw desc;
+ }
- OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
- out_val = std::max<OutEigenType>(out_val, QMin);
- out_val = std::min<OutEigenType>(out_val, QMax);
- return out_val;
- });
+ OutEigenType out_val = static_cast<OutEigenType>(res_in_64);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
}
catch (std::string desc)
{