aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/register_command_stream_generator.py24
1 files changed, 22 insertions, 2 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index fb705b96..3b552e09 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -718,19 +718,39 @@ def generate_scaling_for_elementwise(emit: CommandStreamEmitter, npu_op: NpuElem
ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
else: # Add/Sub
+ bitdepth = npu_op.ifm.data_type.size_in_bits()
+ use_advanced_scaling = False
if None in (input_scale, input2_scale, output_scale):
opa_scale = opb_scale = ofm_scale = 1
opa_shift = shift = 0
if npu_op.rescale is not None:
ofm_scale, shift = npu_op.rescale
+ elif input_scale == input2_scale and bitdepth == 16:
+ opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
+ input_scale, input2_scale, output_scale
+ )
+ # align the double rounding with that of advanced scaling
+ opa_scale /= 2
+ opb_scale /= 2
+ shift -= 1
+ opa_shift = 0 # Unused for this case
elif input_scale == input2_scale:
opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
input_scale, input2_scale, output_scale
)
opa_shift = 0 # Unused for this case
+ # For 8 bit we can't guarantee double rounding with simplified scaling will always be
+ # the same as with advanced scaling due to different shifts. When the ofm scale fulfils
+ # the following we know that double rounding will have no effect for advanced scaling
+ # no matter the input, so we can safely use simplified scaling with double rounding disabled.
+ use_advanced_scaling = int(ofm_scale) & 0xFFF != 0
+ if not use_advanced_scaling:
+ npu_op.rounding_mode = NpuRoundingMode.NATURAL
else:
- # Use advanced implementation only when input scales differ
- bitdepth = npu_op.ifm.data_type.size_in_bits()
+ use_advanced_scaling = True
+ if use_advanced_scaling:
+ # Use advanced implementation only when input/output scales differ,
+ # or when we can't guarantee the absence of rounding errors
(opa_scale, opa_shift, ofm_scale, shift, op_to_scale,) = scaling.advanced_elementwise_add_sub_scale(
input_scale, input2_scale, output_scale, bitdepth
)