diff options
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/fp_math.py | 17 | ||||
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 2 |
2 files changed, 11 insertions, 8 deletions
diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py index 20a989b9..dca2b692 100644 --- a/ethosu/vela/fp_math.py +++ b/ethosu/vela/fp_math.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2020-2021, 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> # # Copyright 2015 The Gemmlowp Authors. All Rights Reserved. # @@ -102,7 +102,9 @@ def saturating_mul16(a, b): def shift_left32(a, offset): assert offset >= 0 assert np.int32(a) == a - shifted = a * (1 << offset) + # Force a and offset to Python int to avoid potential overflows when its type is unable to represent the result of + # the multiplication + shifted = int(a) * (1 << int(offset)) if shifted < np.iinfo(np.int32).min: return np.int32(np.iinfo(np.int32).min) elif shifted > np.iinfo(np.int32).max: @@ -114,7 +116,9 @@ def shift_left32(a, offset): def shift_left16(a, offset): assert offset >= 0 assert np.int16(a) == a - shifted = a * (1 << offset) + # Force a and offset to Python int to avoid potential overflows when its type is unable to represent the result of + # the multiplication + shifted = int(a) * (1 << int(offset)) if shifted < np.iinfo(np.int16).min: return np.int16(np.iinfo(np.int16).min) elif shifted > np.iinfo(np.int16).max: @@ -230,8 +234,7 @@ def multiply_by_quantized_multiplier(x, scale, shift): shift = 31 - shift left_shift = shift if shift > 0 else 0 right_shift = -shift if shift < 0 else 0 - # Force x to a Python int to avoid potential overflows when its type is unable to represent the result of the - # multiplication - x_int = int(x) - mul = saturating_rounding_mul32(x_int * (1 << left_shift), scale) + # Force x and left_shift to Python int to avoid potential overflows when its type is unable to represent the result + # of the multiplication + mul = saturating_rounding_mul32(int(x) * (1 << int(left_shift)), scale) return rounding_divide_by_pot(mul, right_shift) diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index ccbb1f28..13692d2a 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -1544,7 +1544,7 @@ def convert_hardswish_to_lut(op: Operation, arch, nng) -> Operation: # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1] # Now convert that to a 16bit fixedpoint value in [0, 1] - relu_value = (relu_value + (1 << 15)) >> 1 + relu_value = (int(relu_value) + (1 << 15)) >> 1 lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift) shift = 31 - out_shift shift = -shift if shift < 0 else 0 |