aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/fp_math.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/fp_math.py')
-rw-r--r--ethosu/vela/fp_math.py95
1 files changed, 79 insertions, 16 deletions
diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py
index 5228f031..21022c2a 100644
--- a/ethosu/vela/fp_math.py
+++ b/ethosu/vela/fp_math.py
@@ -35,13 +35,14 @@ def to_float(x, integer_bits=5):
return x / (1 << fractional_bits)
-def saturating_rounding_mul(a, b):
+def saturating_rounding_mul32(a, b):
assert np.int32(a) == a
assert np.int32(b) == b
if a == b and a == np.iinfo(np.int32).min:
return np.int32(np.iinfo(np.int32).max)
divider = 1 << 31
ab = np.int64(a) * np.int64(b)
+
if ab >= 0:
nudge = 1 << 30
return (ab + nudge) // divider
@@ -56,19 +57,81 @@ def saturating_rounding_mul(a, b):
return result
-def shift_left(a, offset):
- assert np.int32(a) == a
+def saturating_rounding_mul16(a, b):
+ assert np.int16(a) == a
+ assert np.int16(b) == b
+ if a == b and a == np.iinfo(np.int16).min:
+ return np.int16(np.iinfo(np.int16).max)
+ divider = 1 << 15
+ ab = np.int32(a) * np.int32(b)
+
+ if ab >= 0:
+ nudge = 1 << 14
+ return (ab + nudge) // divider
+ else:
+ nudge = 1 - (1 << 14)
+ ab_plus_nudge = ab + nudge
+ result = ab_plus_nudge // divider
+ # Python uses floor, the reference uses truncation
+ # so we need to compensate for that.
+ if result * divider < ab_plus_nudge:
+ result += 1
+ return result
+
+
+# Similar to saturating_rounding_mul16 except rounding to zero instead of to nearest
+# Only supports 16bit
+def saturating_mul16(a, b):
+ assert np.int16(a) == a
+ assert np.int16(b) == b
+ if a == b and a == np.iinfo(np.int16).min:
+ return np.int16(np.iinfo(np.int16).max)
+ ab = np.int32(a) * np.int32(b)
+ divider = 1 << 15
+ if ab >= 0:
+ return ab // divider
+ else:
+ result = ab // divider
+ # Python uses floor, the reference uses truncation
+ # so we need to compensate for that.
+ if result * divider < ab:
+ result += 1
+ return result
+
+
+def shift_left32(a, offset):
assert offset >= 0
- i32_info = np.iinfo(np.int32)
+ assert np.int32(a) == a
shifted = a * (1 << offset)
- if shifted < i32_info.min:
- return np.int32(i32_info.min)
- elif shifted > i32_info.max:
- return np.int32(i32_info.max)
+ if shifted < np.iinfo(np.int32).min:
+ return np.int32(np.iinfo(np.int32).min)
+ elif shifted > np.iinfo(np.int32).max:
+ return np.int32(np.iinfo(np.int32).max)
else:
return np.int32(shifted)
+def shift_left16(a, offset):
+ assert offset >= 0
+ assert np.int16(a) == a
+ shifted = a * (1 << offset)
+ if shifted < np.iinfo(np.int16).min:
+ return np.int16(np.iinfo(np.int16).min)
+ elif shifted > np.iinfo(np.int16).max:
+ return np.int16(np.iinfo(np.int16).max)
+ else:
+ return np.int16(shifted)
+
+
+def downscale_multiplier_int32_to_int16(a):
+ assert np.int32(a) == a
+ rounding_offset = 1 << 15
+ if a >= np.iinfo(np.int32).max - rounding_offset:
+ return np.iinfo(np.int16).max
+ else:
+ return np.int16((a + rounding_offset) >> 16)
+
+
def rounding_divide_by_pot(x, exponent):
assert np.int32(x) == x
assert np.int32(exponent) == exponent
@@ -92,7 +155,7 @@ def saturating_rounding_multiply_by_pot(x, exponent):
elif x < -threshold:
return np.iinfo(np.int32).min
else:
- return shift_left(x, exponent)
+ return shift_left32(x, exponent)
def rescale(integer_bits_src, integer_bits_dst, x):
@@ -115,16 +178,16 @@ def exp_on_interval_between_negative_one_quarter_and_0_excl(a):
constant_term = 1895147668
constant_1_over_3 = 715827883
x = a + (1 << offset)
- x2 = saturating_rounding_mul(x, x)
- x3 = saturating_rounding_mul(x2, x)
- x4 = saturating_rounding_mul(x2, x2)
+ x2 = saturating_rounding_mul32(x, x)
+ x3 = saturating_rounding_mul32(x2, x)
+ x4 = saturating_rounding_mul32(x2, x2)
x4_over_4 = rounding_divide_by_pot(x4, 2)
x4_over_24_plus_x3_over_6_plus_x2_over_2 = rounding_divide_by_pot(
- saturating_rounding_mul((x4_over_4 + x3), constant_1_over_3) + x2, 1
+ saturating_rounding_mul32((x4_over_4 + x3), constant_1_over_3) + x2, 1
)
return np.int32(
- constant_term + saturating_rounding_mul(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
+ constant_term + saturating_rounding_mul32(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
)
@@ -144,7 +207,7 @@ def exp_on_negative_values(a):
integer_bits = 5
shift = fractional_bits + exponent if integer_bits > exponent else 0
if remainder & (1 << shift):
- return saturating_rounding_mul(result, multiplier)
+ return saturating_rounding_mul32(result, multiplier)
else:
return result
@@ -168,5 +231,5 @@ def multiply_by_quantized_multiplier(x, scale, shift):
shift = 31 - shift
left_shift = shift if shift > 0 else 0
right_shift = -shift if shift < 0 else 0
- mul = saturating_rounding_mul(x * (1 << left_shift), scale)
+ mul = saturating_rounding_mul32(x * (1 << left_shift), scale)
return rounding_divide_by_pot(mul, right_shift)