aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRickard Bolin <rickard.bolin@arm.com>2024-06-26 13:17:17 +0000
committerRickard Bolin <rickard.bolin@arm.com>2024-06-26 14:46:27 +0000
commit86ba3fd1eef6b6ad9047ef4aaed3461e0b250a1a (patch)
treecf1f2b5e47828b5f13842754ec007de10c6c16ab
parentc8450913d7f557a3cf6d547caebe691f41d2a00d (diff)
downloadethos-u-vela-86ba3fd1eef6b6ad9047ef4aaed3461e0b250a1a.tar.gz
MLBEDSW-9222: Fix rescale precision errors
After updating to NumPy 2.0, some variables that were previously implicitly promoted to doubles instead remained floats, causing precision loss. They are now explicitly promoted to doubles instead. Change-Id: Ibc1838e2a3a05116e291a6320525b38972f0478e Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
-rw-r--r--ethosu/vela/register_command_stream_generator.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 01654697..d96072bf 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -730,7 +730,7 @@ def generate_ofm_scaling_for_pooling(emit: CommandStreamEmitter, pool_op: NpuPoo
rescale = pool_op.rescale
rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
- scale = int(round_away_zero(scale * rescale))
+ scale = int(round_away_zero(scale * np.double(rescale)))
else:
# In case avg pool fused with concat or other memory operation, rescaling might be needed.
# kernel height == kernel width == 1 is always true in this case
@@ -745,7 +745,7 @@ def generate_ofm_scaling_for_pooling(emit: CommandStreamEmitter, pool_op: NpuPoo
elif rescale < 1:
rescale_bits = -(len(bin(round_up_to_int(1 / rescale))) - 2 - 1)
scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
- scale = int(round_away_zero(scale * rescale))
+ scale = int(round_away_zero(scale * np.double(rescale)))
else:
scale = 1
shift = 0