diff options
Diffstat (limited to 'reference_model/src/quant_util.h')
-rw-r--r-- | reference_model/src/quant_util.h | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index 1784493..f07dd10 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -61,6 +61,19 @@ public: "apply_scale_32() error: scaled result exceed int32 numeric range"); return static_cast<int32_t>(result); } + + static int32_t apply_scale_16(int64_t value, int16_t multiplier, int32_t shift) + { + ASSERT_MSG(multiplier >= 0, "apply_scale_16() error: multiplier should >= 0 but is %d", multiplier); + ASSERT_MSG(value >= -(static_cast<int64_t>(1) << 47) && value < (static_cast<int64_t>(1) << 47), + "apply_scale_16() error: value should be within [-(1^47), 1^47]"); + int64_t round = 1L << (shift - 1); + int64_t result = value * (int64_t)multiplier + round; + result = result >> shift; + ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31), + "apply_scale_16() error: scaled result exceed int32 numeric range"); + return static_cast<int32_t>(result); + } }; class TypeChecker @@ -68,8 +81,8 @@ class TypeChecker public: static bool is_integer(DType dtype) { - if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || - dtype == DType_INT16 || dtype == DType_INT32 || dtype == DType_INT48) + if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || dtype == DType_INT16 || + dtype == DType_INT32 || dtype == DType_INT48) { return true; } |