aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/register_command_stream_generator.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 01654697..d96072bf 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -730,7 +730,7 @@ def generate_ofm_scaling_for_pooling(emit: CommandStreamEmitter, pool_op: NpuPoo
rescale = pool_op.rescale
rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
- scale = int(round_away_zero(scale * rescale))
+ scale = int(round_away_zero(scale * np.double(rescale)))
else:
# In case avg pool fused with concat or other memory operation, rescaling might be needed.
# kernel height == kernel width == 1 is always true in this case
@@ -745,7 +745,7 @@ def generate_ofm_scaling_for_pooling(emit: CommandStreamEmitter, pool_op: NpuPoo
elif rescale < 1:
rescale_bits = -(len(bin(round_up_to_int(1 / rescale))) - 2 - 1)
scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
- scale = int(round_away_zero(scale * rescale))
+ scale = int(round_away_zero(scale * np.double(rescale)))
else:
scale = 1
shift = 0