aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/quant_util.h
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-03-18 17:41:39 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-04-30 11:07:42 -0700
commit0f87c953018cc90de18d1a083479b06fd7ce4a8c (patch)
tree4074f5715a1fdf7fbc2bf70d01c427e91d681e94 /reference_model/src/quant_util.h
parentad15dfab0430b72015d13d19b8a696bb9bacd0a6 (diff)
downloadreference_model-0f87c953018cc90de18d1a083479b06fd7ce4a8c.tar.gz
Support 16-bit Rescale
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ifc80b83c1abcd08e1b7f8e50f647b74c861bc933
Diffstat (limited to 'reference_model/src/quant_util.h')
-rw-r--r--reference_model/src/quant_util.h17
1 files changed, 15 insertions, 2 deletions
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index 1784493..f07dd10 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -61,6 +61,19 @@ public:
"apply_scale_32() error: scaled result exceed int32 numeric range");
return static_cast<int32_t>(result);
}
+
+ static int32_t apply_scale_16(int64_t value, int16_t multiplier, int32_t shift)
+ {
+ ASSERT_MSG(multiplier >= 0, "apply_scale_16() error: multiplier should >= 0 but is %d", multiplier);
+ ASSERT_MSG(value >= -(static_cast<int64_t>(1) << 47) && value < (static_cast<int64_t>(1) << 47),
+ "apply_scale_16() error: value should be within [-(1^47), 1^47]");
+ int64_t round = 1L << (shift - 1);
+ int64_t result = value * (int64_t)multiplier + round;
+ result = result >> shift;
+ ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31),
+ "apply_scale_16() error: scaled result exceed int32 numeric range");
+ return static_cast<int32_t>(result);
+ }
};
class TypeChecker
@@ -68,8 +81,8 @@ class TypeChecker
public:
static bool is_integer(DType dtype)
{
- if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 ||
- dtype == DType_INT16 || dtype == DType_INT32 || dtype == DType_INT48)
+ if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || dtype == DType_INT16 ||
+ dtype == DType_INT32 || dtype == DType_INT48)
{
return true;
}