From ad656a820394580d9e1b6d79597905074d599464 Mon Sep 17 00:00:00 2001 From: Henrik G Olsson Date: Fri, 19 Mar 2021 15:50:28 +0100 Subject: MLBEDSW-3550 Only use simple scaling when bitexact with TFLite For 8 bit arithmetic we cannot guarantee reproducibility in the general case since precision differs, affecting rounding near half integers. It should be safe when the ratio between output and input scales has its 12 LSBs all set to 0, however. For 16 bit arithmetic it should be sufficient to adjust the input and output scalings with a factor of 2 to get the same rounding. Signed-off-by: Henrik G Olsson Change-Id: I809c0042615d16c5488d61f0c7d88e1a1315e6eb --- ethosu/vela/register_command_stream_generator.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index fb705b96..3b552e09 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -718,19 +718,39 @@ def generate_scaling_for_elementwise(emit: CommandStreamEmitter, npu_op: NpuElem ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale) emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift) else: # Add/Sub + bitdepth = npu_op.ifm.data_type.size_in_bits() + use_advanced_scaling = False if None in (input_scale, input2_scale, output_scale): opa_scale = opb_scale = ofm_scale = 1 opa_shift = shift = 0 if npu_op.rescale is not None: ofm_scale, shift = npu_op.rescale + elif input_scale == input2_scale and bitdepth == 16: + opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale( + input_scale, input2_scale, output_scale + ) + # align the double rounding with that of advanced scaling + opa_scale /= 2 + opb_scale /= 2 + shift -= 1 + opa_shift = 0 # Unused for this case elif input_scale == input2_scale: opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale( input_scale, input2_scale, output_scale ) opa_shift = 0 # Unused for this case + # For 8 bit we can't guarantee double rounding with simplified scaling will always be + # the same as with advanced scaling due to different shifts. When the ofm scale fulfils + # the following we know that double rounding will have no effect for advanced scaling + # no matter the input, so we can safely use simplified scaling with double rounding disabled. + use_advanced_scaling = int(ofm_scale) & 0xFFF != 0 + if not use_advanced_scaling: + npu_op.rounding_mode = NpuRoundingMode.NATURAL else: - # Use advanced implementation only when input scales differ - bitdepth = npu_op.ifm.data_type.size_in_bits() + use_advanced_scaling = True + if use_advanced_scaling: + # Use advanced implementation only when input/output scales differ, + # or when we can't guarantee the absence of rounding errors (opa_scale, opa_shift, ofm_scale, shift, op_to_scale,) = scaling.advanced_elementwise_add_sub_scale( input_scale, input2_scale, output_scale, bitdepth ) -- cgit v1.2.1