diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2020-09-10 16:12:33 +0200 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2020-09-14 15:28:05 +0000 |
commit | 2f6f3790fba2e81594acd7ed927515e0367c150e (patch) | |
tree | a6923c98a6e63efd76e4181e8158cf2fdc07b577 /ethosu | |
parent | 55d9e33c77589d61cdcfda5fedb57fb67ff0c55a (diff) | |
download | ethos-u-vela-2f6f3790fba2e81594acd7ed927515e0367c150e.tar.gz |
[MLBEDSW-2845] Improve unit test coverage of fp_math
Improved unit test coverage of fp_math.py
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: I883fd984a1bfa67102826a400380e41a363fc59d
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/fp_math.py | 49 | ||||
-rw-r--r-- | ethosu/vela/test/test_fp_math.py | 118 |
2 files changed, 122 insertions, 45 deletions
diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py index eaeb84a1..2515b771 100644 --- a/ethosu/vela/fp_math.py +++ b/ethosu/vela/fp_math.py @@ -21,29 +21,49 @@ # point implementation. import numpy as np +# Convert floating point to fixed point, default Q5.26 +def from_float(x, integer_bits=5): + i32info = np.iinfo(np.int32) + fractional_bits = i32info.bits - integer_bits - 1 + return min(max(round(x * (1 << fractional_bits)), i32info.min), i32info.max) + + +# Convert fixed point to floating point, default Q5.26 +def to_float(x, integer_bits=5): + fractional_bits = np.iinfo(np.int32).bits - integer_bits - 1 + return x / (1 << fractional_bits) + def saturating_rounding_mul(a, b): assert np.int32(a) == a assert np.int32(b) == b if a == b and a == np.iinfo(np.int32).min: return np.int32(np.iinfo(np.int32).max) - ab = np.int64(a) * np.int64(b) - nudge = (1 << 30) if ab >= 0 else (1 - (1 << 30)) - result = np.int32(np.right_shift(ab + nudge, 31)) - if result < 0: - result += 1 - return result + divider = 1 << 31 + ab = a * b + if ab >= 0: + nudge = 1 << 30 + return (ab + nudge) // divider + else: + nudge = 1 - (1 << 30) + ab_plus_nudge = ab + nudge + result = ab_plus_nudge // divider + # Python uses floor, the reference uses truncation + # so we need to compensate for that. + if result * divider < ab_plus_nudge: + result += 1 + return result def shift_left(a, offset): assert np.int32(a) == a assert offset >= 0 - a_info = np.iinfo(a) + i32_info = np.iinfo(np.int32) shifted = a * (1 << offset) - if shifted < a_info.min: - return np.int32(a_info.min) - elif shifted > a_info.max: - return np.int32(a_info.max) + if shifted < i32_info.min: + return np.int32(i32_info.min) + elif shifted > i32_info.max: + return np.int32(i32_info.max) else: return np.int32(shifted) @@ -62,7 +82,7 @@ def rounding_divide_by_pot(x, exponent): return result -def saturating_rounding_multiply_by_pot(exponent, x): +def saturating_rounding_multiply_by_pot(x, exponent): assert np.int32(x) == x assert np.int32(exponent) == exponent threshold = (1 << (np.iinfo(np.int32).bits - 1 - exponent)) - 1 @@ -79,7 +99,10 @@ def rescale(integer_bits_src, integer_bits_dst, x): assert np.int32(integer_bits_dst) == integer_bits_dst assert np.int32(x) == x exponent = integer_bits_src - integer_bits_dst - result = saturating_rounding_multiply_by_pot(exponent, x) + if exponent < 0: + result = rounding_divide_by_pot(x, -exponent) + else: + result = saturating_rounding_multiply_by_pot(x, exponent) return result diff --git a/ethosu/vela/test/test_fp_math.py b/ethosu/vela/test/test_fp_math.py index 8c1ed679..905826f4 100644 --- a/ethosu/vela/test/test_fp_math.py +++ b/ethosu/vela/test/test_fp_math.py @@ -64,53 +64,107 @@ EXP_LUT = [ def test_saturating_rounding_mul(): i32info = np.iinfo(np.int32) - shift = 22 - multiplier = 1760306048 + # Saturation assert fp_math.saturating_rounding_mul(i32info.min, i32info.min) == i32info.max - assert fp_math.saturating_rounding_mul(-255 * 1 << shift, multiplier) == -876714926 - assert fp_math.saturating_rounding_mul(-128 * 1 << shift, multiplier) == -440076512 - assert fp_math.saturating_rounding_mul(0, multiplier) == 0 - assert fp_math.saturating_rounding_mul(128 * 1 << shift, multiplier) == 440076512 - assert fp_math.saturating_rounding_mul(255 * 1 << shift, multiplier) == 876714926 + assert fp_math.saturating_rounding_mul(i32info.min, i32info.max) == -i32info.max + assert fp_math.saturating_rounding_mul(i32info.max, i32info.min) == -i32info.max + + # Multiply by zero + assert fp_math.saturating_rounding_mul(0, fp_math.from_float(1.0)) == 0 + assert fp_math.saturating_rounding_mul(0, fp_math.from_float(-1.0)) == 0 + assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), 0) == 0 + assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), 0) == 0 + + # Multiply positive/negative + assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), fp_math.from_float(1.0)) == fp_math.from_float( + 1.0, 5 + 5 + ) + assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), fp_math.from_float(1.0)) == fp_math.from_float( + -1.0, 5 + 5 + ) + assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), fp_math.from_float(-1.0)) == fp_math.from_float( + -1.0, 5 + 5 + ) + assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), fp_math.from_float(-1.0)) == fp_math.from_float( + 1.0, 5 + 5 + ) + + # Rounding + assert fp_math.saturating_rounding_mul(fp_math.from_float(16.0), 1) == 1 + assert fp_math.saturating_rounding_mul(fp_math.from_float(-16.0), 1) == 0 + assert fp_math.saturating_rounding_mul(fp_math.from_float(16.0) - 1, 1) == 0 + assert fp_math.saturating_rounding_mul(fp_math.from_float(-16.0) - 1, 1) == -1 def test_shift_left(): i32info = np.iinfo(np.int32) - assert fp_math.shift_left(np.int32(1), i32info.bits) == i32info.max - assert fp_math.shift_left(np.int32(-1), i32info.bits) == i32info.min - assert fp_math.shift_left(np.int32(1), i32info.bits - 2) == (i32info.max + 1) / 2 - assert fp_math.shift_left(np.int32(-1), i32info.bits - 2) == i32info.min // 2 + assert fp_math.shift_left(1, i32info.bits) == i32info.max + assert fp_math.shift_left(-1, i32info.bits) == i32info.min + assert fp_math.shift_left(1, i32info.bits - 2) == (i32info.max + 1) / 2 + assert fp_math.shift_left(-1, i32info.bits - 2) == i32info.min // 2 + + assert fp_math.shift_left(fp_math.from_float(1.0), 5) == i32info.max + assert fp_math.shift_left(fp_math.from_float(-1.0), 5) == i32info.min + assert fp_math.shift_left(fp_math.from_float(1.0), 4) == 16 * fp_math.from_float(1.0) + assert fp_math.shift_left(fp_math.from_float(-1.0), 4) == 16 * fp_math.from_float(-1.0) + + with pytest.raises(AssertionError): + fp_math.shift_left(1, -1) def test_rounding_divide_by_pot(): - assert fp_math.rounding_divide_by_pot(1024, 4) == 64 - assert fp_math.rounding_divide_by_pot(1031, 4) == 64 - assert fp_math.rounding_divide_by_pot(1032, 4) == 65 - assert fp_math.rounding_divide_by_pot(1047, 4) == 65 - assert fp_math.rounding_divide_by_pot(1048, 4) == 66 - assert fp_math.rounding_divide_by_pot(1056, 4) == 66 - assert fp_math.rounding_divide_by_pot(-1024, 4) == -64 - assert fp_math.rounding_divide_by_pot(-1031, 4) == -64 - assert fp_math.rounding_divide_by_pot(-1032, 4) == -65 - assert fp_math.rounding_divide_by_pot(-1047, 4) == -65 - assert fp_math.rounding_divide_by_pot(-1048, 4) == -66 - assert fp_math.rounding_divide_by_pot(-1056, 4) == -66 + # No remainder division + assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0), 26) == 1 + assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0), 26) == -1 + + # Remainder rounding the result away from zero + assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0), 27) == -1 + assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0), 27) == 1 + + # Remainder smaller than threshold to round the result away from zero + # Positive and negative edge cases + assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0) - 1, 27) == 0 + assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0) + 1, 27) == 0 + # Far from the edge + assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0), 28) == 0 + assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0), 28) == 0 + + # Regular division - no remainder + assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0), 4) == fp_math.from_float(1.0 / 16) + assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0), 4) == fp_math.from_float(-1.0 / 16) + + # Rounding/no rounding edge cases + assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0) + (1 << 3) - 1, 4) == fp_math.from_float(1.0 / 16) + assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0) + (1 << 3), 4) == fp_math.from_float(1.0 / 16) + 1 + assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0) - (1 << 3) + 1, 4) == fp_math.from_float(-1.0 / 16) + assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0) - (1 << 3), 4) == fp_math.from_float(-1.0 / 16) - 1 def test_saturating_rounding_multiply_by_pot(): i32info = np.iinfo(np.int32) - assert fp_math.saturating_rounding_multiply_by_pot(4, np.int32(1025)) == 16400 - assert fp_math.saturating_rounding_multiply_by_pot(5, np.int32(67108865)) == i32info.max - assert fp_math.saturating_rounding_multiply_by_pot(5, np.int32(-67108865)) == i32info.min + assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(1.0), 5) == i32info.max + assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(-1.0), 5) == i32info.min + assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(1.0) - 1, 5) == i32info.max - 32 + 1 + assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(-1.0) + 1, 5) == -i32info.max + 32 - 1 + assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(1.0), 4) == fp_math.from_float(1.0 * 16) + assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(-1.0), 4) == fp_math.from_float(-1.0 * 16) def test_rescale(): - assert fp_math.rescale(5, 0, np.int32(1025)) == 32800 - assert fp_math.rescale(3, 0, np.int32(1025)) == 8200 - assert fp_math.rescale(5, 1, np.int32(1025)) == 16400 - assert fp_math.rescale(3, 1, np.int32(1025)) == 4100 - with pytest.raises(AssertionError): - fp_math.rescale(1, 3, np.int32(1024)) + assert fp_math.rescale(5, 0, fp_math.from_float(1.0)) == fp_math.from_float(1.0, 0) + assert fp_math.rescale(5, 10, fp_math.from_float(1.0)) == fp_math.from_float(1.0, 10) + assert fp_math.rescale(5, 0, fp_math.from_float(-1.0)) == fp_math.from_float(-1.0, 0) + assert fp_math.rescale(5, 10, fp_math.from_float(-1.0)) == fp_math.from_float(-1.0, 10) + + assert fp_math.rescale(5, 4, fp_math.from_float(32.0)) == fp_math.from_float(32.0, 4) + assert fp_math.rescale(5, 6, fp_math.from_float(32.0)) == fp_math.from_float(32.0, 6) + assert fp_math.rescale(5, 4, fp_math.from_float(-32.0)) == fp_math.from_float(-32.0, 4) + assert fp_math.rescale(5, 6, fp_math.from_float(-32.0)) == fp_math.from_float(-32.0, 6) + + assert fp_math.rescale(5, 4, fp_math.from_float(31.9)) == fp_math.from_float(31.9, 4) + assert fp_math.rescale(5, 6, fp_math.from_float(31.9)) == fp_math.from_float(31.9, 6) + assert fp_math.rescale(5, 4, fp_math.from_float(-31.9)) == fp_math.from_float(-31.9, 4) + assert fp_math.rescale(5, 6, fp_math.from_float(-31.9)) == fp_math.from_float(-31.9, 6) def test_exp(): |