diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-03-18 17:41:39 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-04-30 11:07:42 -0700 |
commit | 0f87c953018cc90de18d1a083479b06fd7ce4a8c (patch) | |
tree | 4074f5715a1fdf7fbc2bf70d01c427e91d681e94 | |
parent | ad15dfab0430b72015d13d19b8a696bb9bacd0a6 (diff) | |
download | reference_model-0f87c953018cc90de18d1a083479b06fd7ce4a8c.tar.gz |
Support 16-bit Rescale
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ifc80b83c1abcd08e1b7f8e50f647b74c861bc933
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 42 | ||||
-rw-r--r-- | reference_model/src/quant_util.h | 17 |
2 files changed, 40 insertions, 19 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 3a610ea..d988c57 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -71,9 +71,9 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() int32_t output_zp = attribute->output_zp(); std::vector<int32_t> multiplier = attribute->multiplier(); std::vector<int32_t> shift = attribute->shift(); - //bool scale32 = attribute->scale32(); - bool double_round = attribute->double_round(); - bool per_channel = attribute->per_channel(); + bool scale32 = attribute->scale32(); + bool double_round = attribute->double_round(); + bool per_channel = attribute->per_channel(); // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn] Eigen::array<Eigen::Index, 2> shape_2d; @@ -94,7 +94,6 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() ETensor2<OutEigenType> output_2d(shape_2d); - // TODO: pass scale32 in when 16-bit mode implemented if (per_channel) { ETensor2<InEigenType> curr_channel_slice_prescaled; @@ -110,10 +109,15 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() channel_shift = shift[i]; curr_channel_slice_postscaled = curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, - double_round](InEigenType in_val) -> OutEigenType { + double_round, scale32](InEigenType in_val) -> OutEigenType { InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier, - channel_shift, double_round); + int32_t scaled; + if (scale32) + scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier, + channel_shift, double_round); + else + scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier, + channel_shift); OutEigenType out_val = (OutEigenType)(scaled + output_zp); out_val = std::max<OutEigenType>(out_val, QMin); out_val = std::min<OutEigenType>(out_val, QMax); @@ -130,16 +134,20 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() { int32_t tensor_multiplier = multiplier[0]; int32_t tensor_shift = shift[0]; - output_2d = input_reshaped.unaryExpr( - [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType { - InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, - tensor_shift, double_round); - OutEigenType out_val = (OutEigenType)(scaled + output_zp); - out_val = std::max<OutEigenType>(out_val, QMin); - out_val = std::min<OutEigenType>(out_val, QMax); - return out_val; - }); + output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, + scale32](InEigenType in_val) -> OutEigenType { + InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; + int32_t scaled; + if (scale32) + scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift, + double_round); + else + scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift); + OutEigenType out_val = (OutEigenType)(scaled + output_zp); + out_val = std::max<OutEigenType>(out_val, QMin); + out_val = std::min<OutEigenType>(out_val, QMax); + return out_val; + }); } // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn] diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index 1784493..f07dd10 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -61,6 +61,19 @@ public: "apply_scale_32() error: scaled result exceed int32 numeric range"); return static_cast<int32_t>(result); } + + static int32_t apply_scale_16(int64_t value, int16_t multiplier, int32_t shift) + { + ASSERT_MSG(multiplier >= 0, "apply_scale_16() error: multiplier should >= 0 but is %d", multiplier); + ASSERT_MSG(value >= -(static_cast<int64_t>(1) << 47) && value < (static_cast<int64_t>(1) << 47), + "apply_scale_16() error: value should be within [-(1^47), 1^47]"); + int64_t round = 1L << (shift - 1); + int64_t result = value * (int64_t)multiplier + round; + result = result >> shift; + ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31), + "apply_scale_16() error: scaled result exceed int32 numeric range"); + return static_cast<int32_t>(result); + } }; class TypeChecker @@ -68,8 +81,8 @@ class TypeChecker public: static bool is_integer(DType dtype) { - if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || - dtype == DType_INT16 || dtype == DType_INT32 || dtype == DType_INT48) + if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || dtype == DType_INT16 || + dtype == DType_INT32 || dtype == DType_INT48) { return true; } |