aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-03-18 17:41:39 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-04-30 11:07:42 -0700
commit0f87c953018cc90de18d1a083479b06fd7ce4a8c (patch)
tree4074f5715a1fdf7fbc2bf70d01c427e91d681e94
parentad15dfab0430b72015d13d19b8a696bb9bacd0a6 (diff)
downloadreference_model-0f87c953018cc90de18d1a083479b06fd7ce4a8c.tar.gz
Support 16-bit Rescale
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ifc80b83c1abcd08e1b7f8e50f647b74c861bc933
-rw-r--r--reference_model/src/ops/type_conversion.cc42
-rw-r--r--reference_model/src/quant_util.h17
2 files changed, 40 insertions, 19 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 3a610ea..d988c57 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -71,9 +71,9 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
int32_t output_zp = attribute->output_zp();
std::vector<int32_t> multiplier = attribute->multiplier();
std::vector<int32_t> shift = attribute->shift();
- //bool scale32 = attribute->scale32();
- bool double_round = attribute->double_round();
- bool per_channel = attribute->per_channel();
+ bool scale32 = attribute->scale32();
+ bool double_round = attribute->double_round();
+ bool per_channel = attribute->per_channel();
// reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
Eigen::array<Eigen::Index, 2> shape_2d;
@@ -94,7 +94,6 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
ETensor2<OutEigenType> output_2d(shape_2d);
- // TODO: pass scale32 in when 16-bit mode implemented
if (per_channel)
{
ETensor2<InEigenType> curr_channel_slice_prescaled;
@@ -110,10 +109,15 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
channel_shift = shift[i];
curr_channel_slice_postscaled =
curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
- double_round](InEigenType in_val) -> OutEigenType {
+ double_round, scale32](InEigenType in_val) -> OutEigenType {
InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
- int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
- channel_shift, double_round);
+ int32_t scaled;
+ if (scale32)
+ scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
+ channel_shift, double_round);
+ else
+ scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
+ channel_shift);
OutEigenType out_val = (OutEigenType)(scaled + output_zp);
out_val = std::max<OutEigenType>(out_val, QMin);
out_val = std::min<OutEigenType>(out_val, QMax);
@@ -130,16 +134,20 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
{
int32_t tensor_multiplier = multiplier[0];
int32_t tensor_shift = shift[0];
- output_2d = input_reshaped.unaryExpr(
- [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType {
- InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
- int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier,
- tensor_shift, double_round);
- OutEigenType out_val = (OutEigenType)(scaled + output_zp);
- out_val = std::max<OutEigenType>(out_val, QMin);
- out_val = std::min<OutEigenType>(out_val, QMax);
- return out_val;
- });
+ output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round,
+ scale32](InEigenType in_val) -> OutEigenType {
+ InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
+ int32_t scaled;
+ if (scale32)
+ scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
+ double_round);
+ else
+ scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
+ OutEigenType out_val = (OutEigenType)(scaled + output_zp);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
}
// reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index 1784493..f07dd10 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -61,6 +61,19 @@ public:
"apply_scale_32() error: scaled result exceed int32 numeric range");
return static_cast<int32_t>(result);
}
+
+ static int32_t apply_scale_16(int64_t value, int16_t multiplier, int32_t shift)
+ {
+ ASSERT_MSG(multiplier >= 0, "apply_scale_16() error: multiplier should >= 0 but is %d", multiplier);
+ ASSERT_MSG(value >= -(static_cast<int64_t>(1) << 47) && value < (static_cast<int64_t>(1) << 47),
+ "apply_scale_16() error: value should be within [-(1^47), 1^47]");
+ int64_t round = 1L << (shift - 1);
+ int64_t result = value * (int64_t)multiplier + round;
+ result = result >> shift;
+ ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31),
+ "apply_scale_16() error: scaled result exceed int32 numeric range");
+ return static_cast<int32_t>(result);
+ }
};
class TypeChecker
@@ -68,8 +81,8 @@ class TypeChecker
public:
static bool is_integer(DType dtype)
{
- if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 ||
- dtype == DType_INT16 || dtype == DType_INT32 || dtype == DType_INT48)
+ if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || dtype == DType_INT16 ||
+ dtype == DType_INT32 || dtype == DType_INT48)
{
return true;
}