diff options
-rw-r--r-- | ethosu/vela/register_command_stream_generator.py | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index fb705b96..3b552e09 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -718,19 +718,39 @@ def generate_scaling_for_elementwise(emit: CommandStreamEmitter, npu_op: NpuElem ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale) emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift) else: # Add/Sub + bitdepth = npu_op.ifm.data_type.size_in_bits() + use_advanced_scaling = False if None in (input_scale, input2_scale, output_scale): opa_scale = opb_scale = ofm_scale = 1 opa_shift = shift = 0 if npu_op.rescale is not None: ofm_scale, shift = npu_op.rescale + elif input_scale == input2_scale and bitdepth == 16: + opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale( + input_scale, input2_scale, output_scale + ) + # align the double rounding with that of advanced scaling + opa_scale /= 2 + opb_scale /= 2 + shift -= 1 + opa_shift = 0 # Unused for this case elif input_scale == input2_scale: opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale( input_scale, input2_scale, output_scale ) opa_shift = 0 # Unused for this case + # For 8 bit we can't guarantee double rounding with simplified scaling will always be + # the same as with advanced scaling due to different shifts. When the ofm scale fulfils + # the following we know that double rounding will have no effect for advanced scaling + # no matter the input, so we can safely use simplified scaling with double rounding disabled. + use_advanced_scaling = int(ofm_scale) & 0xFFF != 0 + if not use_advanced_scaling: + npu_op.rounding_mode = NpuRoundingMode.NATURAL else: - # Use advanced implementation only when input scales differ - bitdepth = npu_op.ifm.data_type.size_in_bits() + use_advanced_scaling = True + if use_advanced_scaling: + # Use advanced implementation only when input/output scales differ, + # or when we can't guarantee the absence of rounding errors (opa_scale, opa_shift, ofm_scale, shift, op_to_scale,) = scaling.advanced_elementwise_add_sub_scale( input_scale, input2_scale, output_scale, bitdepth ) |