aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/fp_math.py17
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py2
2 files changed, 11 insertions, 8 deletions
diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py
index 20a989b9..dca2b692 100644
--- a/ethosu/vela/fp_math.py
+++ b/ethosu/vela/fp_math.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2021, 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
#
@@ -102,7 +102,9 @@ def saturating_mul16(a, b):
def shift_left32(a, offset):
assert offset >= 0
assert np.int32(a) == a
- shifted = a * (1 << offset)
+ # Force a and offset to Python int to avoid potential overflows when its type is unable to represent the result of
+ # the multiplication
+ shifted = int(a) * (1 << int(offset))
if shifted < np.iinfo(np.int32).min:
return np.int32(np.iinfo(np.int32).min)
elif shifted > np.iinfo(np.int32).max:
@@ -114,7 +116,9 @@ def shift_left32(a, offset):
def shift_left16(a, offset):
assert offset >= 0
assert np.int16(a) == a
- shifted = a * (1 << offset)
+ # Force a and offset to Python int to avoid potential overflows when its type is unable to represent the result of
+ # the multiplication
+ shifted = int(a) * (1 << int(offset))
if shifted < np.iinfo(np.int16).min:
return np.int16(np.iinfo(np.int16).min)
elif shifted > np.iinfo(np.int16).max:
@@ -230,8 +234,7 @@ def multiply_by_quantized_multiplier(x, scale, shift):
shift = 31 - shift
left_shift = shift if shift > 0 else 0
right_shift = -shift if shift < 0 else 0
- # Force x to a Python int to avoid potential overflows when its type is unable to represent the result of the
- # multiplication
- x_int = int(x)
- mul = saturating_rounding_mul32(x_int * (1 << left_shift), scale)
+ # Force x and left_shift to Python int to avoid potential overflows when its type is unable to represent the result
+ # of the multiplication
+ mul = saturating_rounding_mul32(int(x) * (1 << int(left_shift)), scale)
return rounding_divide_by_pot(mul, right_shift)
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index ccbb1f28..13692d2a 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1544,7 +1544,7 @@ def convert_hardswish_to_lut(op: Operation, arch, nng) -> Operation:
# Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
# Now convert that to a 16bit fixedpoint value in [0, 1]
- relu_value = (relu_value + (1 << 15)) >> 1
+ relu_value = (int(relu_value) + (1 << 15)) >> 1
lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
shift = 31 - out_shift
shift = -shift if shift < 0 else 0