diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2020-10-19 12:35:05 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2020-10-19 12:35:05 -0700 |
commit | 99bea145a050e12f1b5f8301979713d9a9b04e12 (patch) | |
tree | bc53dd8cf4566c22b75404dd5cc4ffb849b358d8 /reference_model/src | |
parent | e5e2676409a936431f87d31fb74d825257b20804 (diff) | |
download | reference_model-99bea145a050e12f1b5f8301979713d9a9b04e12.tar.gz |
Update apply_scale_32()
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ida8e3a17d74e5d6379b2244896ddf9e295d0ecc9
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 4 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 6 | ||||
-rw-r--r-- | reference_model/src/quant_util.h | 38 |
3 files changed, 18 insertions, 30 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index a735334..82ce3d2 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -268,9 +268,9 @@ int OpAvgPool2d<Dtype>::eval() { this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType { int32_t multiplier, shift; - TosaReference::QuantUtil<AccDtype>::reciprocal_scale(div, multiplier, shift); + TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift); - return (OutEigenType)TosaReference::QuantUtil<AccDtype>::apply_scale(value, multiplier, shift, false); + return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false); }); this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp()); this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin); diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 61a19f4..a97bc0d 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -130,7 +130,7 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, double_round](InEigenType in_val) -> OutEigenType { InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled = TosaReference::QuantUtil<InDtype>::apply_scale( + int32_t scaled = TosaReference::QuantUtil::apply_scale_32( input_zp_shifted, channel_multiplier, channel_shift, double_round); OutEigenType out_val = (OutEigenType)(scaled + output_zp); out_val = std::max<OutEigenType>(out_val, QMin); @@ -151,8 +151,8 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() output_2d = input_reshaped.unaryExpr( [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType { InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled = TosaReference::QuantUtil<InDtype>::apply_scale(input_zp_shifted, tensor_multiplier, - tensor_shift, double_round); + int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, + tensor_shift, double_round); OutEigenType out_val = (OutEigenType)(scaled + output_zp); out_val = std::max<OutEigenType>(out_val, QMin); out_val = std::min<OutEigenType>(out_val, QMax); diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index 3638b3b..3b58b66 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -26,18 +26,16 @@ using namespace tosa; namespace TosaReference { -template <DType AccDType> class QuantUtil { public: - using T = typename GetEigenType<AccDType>::type; - static void reciprocal_scale(int32_t value, // Output int32_t& multiplier, int32_t& shift) { - ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value); + ASSERT_MSG(value > 0, + "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value); uint32_t value_u32 = (uint32_t)value; int32_t k = 32 - LEADING_ZEROS_32(value_u32 - 1); // (1<<k)/2 < value <= (1<<k) int64_t numerator = ((1L << 30) + 1) << k; @@ -45,33 +43,23 @@ public: shift = 30 + k; } - static int32_t apply_scale(T value, int32_t multiplier, int32_t shift, bool enabled_adjusted_rounding = true) + static int32_t apply_scale_32(int32_t value, int32_t multiplier, int32_t shift, bool double_round = true) { - if (AccDType == DType_FLOAT) - { - return value; - } - ASSERT_MSG(multiplier >= 0, "apply_scale() error: multiplier should >= 0 but is %d", multiplier); - int64_t round = (shift > 0) ? (1L << (shift - 1)) : 0; - if (enabled_adjusted_rounding) + ASSERT_MSG(multiplier >= 0, "apply_scale_32() error: multiplier should >= 0 but is %d", multiplier); + ASSERT_MSG(shift >= 2 && shift <= 62, "apply_scale_32() error: shift should be within [2, 62] but is %d", + shift); + int64_t round = 1L << (shift - 1); + if (double_round) { - if (AccDType != DType_INT48) - { - if (shift > 31 && value >= 0) - round += (1L << 30); - if (shift > 31 && value < 0) - round -= (1L << 30); - } - else - { // input data could be int16, which leads to 48 bits accumulator - ASSERT_MSG(multiplier < (1 << 15), "apply_scale() error: multiplier should <= %d in 48 bit mode", - (1 << 15)); - } + if (shift > 31 && value >= 0) + round += (1L << 30); + if (shift > 31 && value < 0) + round -= (1L << 30); } int64_t result = (int64_t)value * multiplier + round; result = result >> shift; ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31), - "apply_scale() error: scaled result exceed int32 numeric range"); + "apply_scale_32() error: scaled result exceed int32 numeric range"); return static_cast<int32_t>(result); } }; |