From 2f6f3790fba2e81594acd7ed927515e0367c150e Mon Sep 17 00:00:00 2001 From: Fredrik Svedberg Date: Thu, 10 Sep 2020 16:12:33 +0200 Subject: [MLBEDSW-2845] Improve unit test coverage of fp_math Improved unit test coverage of fp_math.py Signed-off-by: Fredrik Svedberg Change-Id: I883fd984a1bfa67102826a400380e41a363fc59d --- ethosu/vela/fp_math.py | 49 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 13 deletions(-) (limited to 'ethosu/vela/fp_math.py') diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py index eaeb84a1..2515b771 100644 --- a/ethosu/vela/fp_math.py +++ b/ethosu/vela/fp_math.py @@ -21,29 +21,49 @@ # point implementation. import numpy as np +# Convert floating point to fixed point, default Q5.26 +def from_float(x, integer_bits=5): + i32info = np.iinfo(np.int32) + fractional_bits = i32info.bits - integer_bits - 1 + return min(max(round(x * (1 << fractional_bits)), i32info.min), i32info.max) + + +# Convert fixed point to floating point, default Q5.26 +def to_float(x, integer_bits=5): + fractional_bits = np.iinfo(np.int32).bits - integer_bits - 1 + return x / (1 << fractional_bits) + def saturating_rounding_mul(a, b): assert np.int32(a) == a assert np.int32(b) == b if a == b and a == np.iinfo(np.int32).min: return np.int32(np.iinfo(np.int32).max) - ab = np.int64(a) * np.int64(b) - nudge = (1 << 30) if ab >= 0 else (1 - (1 << 30)) - result = np.int32(np.right_shift(ab + nudge, 31)) - if result < 0: - result += 1 - return result + divider = 1 << 31 + ab = a * b + if ab >= 0: + nudge = 1 << 30 + return (ab + nudge) // divider + else: + nudge = 1 - (1 << 30) + ab_plus_nudge = ab + nudge + result = ab_plus_nudge // divider + # Python uses floor, the reference uses truncation + # so we need to compensate for that. + if result * divider < ab_plus_nudge: + result += 1 + return result def shift_left(a, offset): assert np.int32(a) == a assert offset >= 0 - a_info = np.iinfo(a) + i32_info = np.iinfo(np.int32) shifted = a * (1 << offset) - if shifted < a_info.min: - return np.int32(a_info.min) - elif shifted > a_info.max: - return np.int32(a_info.max) + if shifted < i32_info.min: + return np.int32(i32_info.min) + elif shifted > i32_info.max: + return np.int32(i32_info.max) else: return np.int32(shifted) @@ -62,7 +82,7 @@ def rounding_divide_by_pot(x, exponent): return result -def saturating_rounding_multiply_by_pot(exponent, x): +def saturating_rounding_multiply_by_pot(x, exponent): assert np.int32(x) == x assert np.int32(exponent) == exponent threshold = (1 << (np.iinfo(np.int32).bits - 1 - exponent)) - 1 @@ -79,7 +99,10 @@ def rescale(integer_bits_src, integer_bits_dst, x): assert np.int32(integer_bits_dst) == integer_bits_dst assert np.int32(x) == x exponent = integer_bits_src - integer_bits_dst - result = saturating_rounding_multiply_by_pot(exponent, x) + if exponent < 0: + result = rounding_divide_by_pot(x, -exponent) + else: + result = saturating_rounding_multiply_by_pot(x, exponent) return result -- cgit v1.2.1