aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHenrik G Olsson <henrik.olsson@arm.com>2021-03-19 15:50:28 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2021-04-16 06:56:08 +0000
commitad656a820394580d9e1b6d79597905074d599464 (patch)
treea209af2573e3670ebc1cb9b5585fa70f0b36300d
parent3645d009628bbb00185132e70d257d2c038716e7 (diff)
downloadethos-u-vela-ad656a820394580d9e1b6d79597905074d599464.tar.gz
MLBEDSW-3550 Only use simple scaling when bitexact with TFLite
For 8 bit arithmetic we cannot guarantee reproducibility in the general case since precision differs, affecting rounding near half integers. It should be safe when the ratio between output and input scales has its 12 LSBs all set to 0, however. For 16 bit arithmetic it should be sufficient to adjust the input and output scalings with a factor of 2 to get the same rounding. Signed-off-by: Henrik G Olsson <henrik.olsson@arm.com> Change-Id: I809c0042615d16c5488d61f0c7d88e1a1315e6eb
-rw-r--r--ethosu/vela/register_command_stream_generator.py24
1 files changed, 22 insertions, 2 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index fb705b96..3b552e09 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -718,19 +718,39 @@ def generate_scaling_for_elementwise(emit: CommandStreamEmitter, npu_op: NpuElem
ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
else: # Add/Sub
+ bitdepth = npu_op.ifm.data_type.size_in_bits()
+ use_advanced_scaling = False
if None in (input_scale, input2_scale, output_scale):
opa_scale = opb_scale = ofm_scale = 1
opa_shift = shift = 0
if npu_op.rescale is not None:
ofm_scale, shift = npu_op.rescale
+ elif input_scale == input2_scale and bitdepth == 16:
+ opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
+ input_scale, input2_scale, output_scale
+ )
+ # align the double rounding with that of advanced scaling
+ opa_scale /= 2
+ opb_scale /= 2
+ shift -= 1
+ opa_shift = 0 # Unused for this case
elif input_scale == input2_scale:
opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
input_scale, input2_scale, output_scale
)
opa_shift = 0 # Unused for this case
+ # For 8 bit we can't guarantee double rounding with simplified scaling will always be
+ # the same as with advanced scaling due to different shifts. When the ofm scale fulfils
+ # the following we know that double rounding will have no effect for advanced scaling
+ # no matter the input, so we can safely use simplified scaling with double rounding disabled.
+ use_advanced_scaling = int(ofm_scale) & 0xFFF != 0
+ if not use_advanced_scaling:
+ npu_op.rounding_mode = NpuRoundingMode.NATURAL
else:
- # Use advanced implementation only when input scales differ
- bitdepth = npu_op.ifm.data_type.size_in_bits()
+ use_advanced_scaling = True
+ if use_advanced_scaling:
+ # Use advanced implementation only when input/output scales differ,
+ # or when we can't guarantee the absence of rounding errors
(opa_scale, opa_shift, ofm_scale, shift, op_to_scale,) = scaling.advanced_elementwise_add_sub_scale(
input_scale, input2_scale, output_scale, bitdepth
)