From 99bea145a050e12f1b5f8301979713d9a9b04e12 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Mon, 19 Oct 2020 12:35:05 -0700 Subject: Update apply_scale_32() Signed-off-by: Kevin Cheng Change-Id: Ida8e3a17d74e5d6379b2244896ddf9e295d0ecc9 --- reference_model/src/quant_util.h | 38 +++++++++++++------------------------- 1 file changed, 13 insertions(+), 25 deletions(-) (limited to 'reference_model/src/quant_util.h') diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index 3638b3b..3b58b66 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -26,18 +26,16 @@ using namespace tosa; namespace TosaReference { -template class QuantUtil { public: - using T = typename GetEigenType::type; - static void reciprocal_scale(int32_t value, // Output int32_t& multiplier, int32_t& shift) { - ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value); + ASSERT_MSG(value > 0, + "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value); uint32_t value_u32 = (uint32_t)value; int32_t k = 32 - LEADING_ZEROS_32(value_u32 - 1); // (1<= 0, "apply_scale() error: multiplier should >= 0 but is %d", multiplier); - int64_t round = (shift > 0) ? (1L << (shift - 1)) : 0; - if (enabled_adjusted_rounding) + ASSERT_MSG(multiplier >= 0, "apply_scale_32() error: multiplier should >= 0 but is %d", multiplier); + ASSERT_MSG(shift >= 2 && shift <= 62, "apply_scale_32() error: shift should be within [2, 62] but is %d", + shift); + int64_t round = 1L << (shift - 1); + if (double_round) { - if (AccDType != DType_INT48) - { - if (shift > 31 && value >= 0) - round += (1L << 30); - if (shift > 31 && value < 0) - round -= (1L << 30); - } - else - { // input data could be int16, which leads to 48 bits accumulator - ASSERT_MSG(multiplier < (1 << 15), "apply_scale() error: multiplier should <= %d in 48 bit mode", - (1 << 15)); - } + if (shift > 31 && value >= 0) + round += (1L << 30); + if (shift > 31 && value < 0) + round -= (1L << 30); } int64_t result = (int64_t)value * multiplier + round; result = result >> shift; ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31), - "apply_scale() error: scaled result exceed int32 numeric range"); + "apply_scale_32() error: scaled result exceed int32 numeric range"); return static_cast(result); } }; -- cgit v1.2.1