From 189f748e1a79ed88044efbe7137963bca830cbb5 Mon Sep 17 00:00:00 2001 From: Diqing Zhong Date: Tue, 26 Jan 2021 12:12:51 +0100 Subject: MLBEDSW-3224: Support HardSwish Change-Id: If49abc31f093f1bd3393bee86f821fd35972086f Signed-off-by: Diqing Zhong --- ethosu/vela/fp_math.py | 95 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 79 insertions(+), 16 deletions(-) (limited to 'ethosu/vela/fp_math.py') diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py index 5228f031..21022c2a 100644 --- a/ethosu/vela/fp_math.py +++ b/ethosu/vela/fp_math.py @@ -35,13 +35,14 @@ def to_float(x, integer_bits=5): return x / (1 << fractional_bits) -def saturating_rounding_mul(a, b): +def saturating_rounding_mul32(a, b): assert np.int32(a) == a assert np.int32(b) == b if a == b and a == np.iinfo(np.int32).min: return np.int32(np.iinfo(np.int32).max) divider = 1 << 31 ab = np.int64(a) * np.int64(b) + if ab >= 0: nudge = 1 << 30 return (ab + nudge) // divider @@ -56,19 +57,81 @@ def saturating_rounding_mul(a, b): return result -def shift_left(a, offset): - assert np.int32(a) == a +def saturating_rounding_mul16(a, b): + assert np.int16(a) == a + assert np.int16(b) == b + if a == b and a == np.iinfo(np.int16).min: + return np.int16(np.iinfo(np.int16).max) + divider = 1 << 15 + ab = np.int32(a) * np.int32(b) + + if ab >= 0: + nudge = 1 << 14 + return (ab + nudge) // divider + else: + nudge = 1 - (1 << 14) + ab_plus_nudge = ab + nudge + result = ab_plus_nudge // divider + # Python uses floor, the reference uses truncation + # so we need to compensate for that. + if result * divider < ab_plus_nudge: + result += 1 + return result + + +# Similar to saturating_rounding_mul16 except rounding to zero instead of to nearest +# Only supports 16bit +def saturating_mul16(a, b): + assert np.int16(a) == a + assert np.int16(b) == b + if a == b and a == np.iinfo(np.int16).min: + return np.int16(np.iinfo(np.int16).max) + ab = np.int32(a) * np.int32(b) + divider = 1 << 15 + if ab >= 0: + return ab // divider + else: + result = ab // divider + # Python uses floor, the reference uses truncation + # so we need to compensate for that. + if result * divider < ab: + result += 1 + return result + + +def shift_left32(a, offset): assert offset >= 0 - i32_info = np.iinfo(np.int32) + assert np.int32(a) == a shifted = a * (1 << offset) - if shifted < i32_info.min: - return np.int32(i32_info.min) - elif shifted > i32_info.max: - return np.int32(i32_info.max) + if shifted < np.iinfo(np.int32).min: + return np.int32(np.iinfo(np.int32).min) + elif shifted > np.iinfo(np.int32).max: + return np.int32(np.iinfo(np.int32).max) else: return np.int32(shifted) +def shift_left16(a, offset): + assert offset >= 0 + assert np.int16(a) == a + shifted = a * (1 << offset) + if shifted < np.iinfo(np.int16).min: + return np.int16(np.iinfo(np.int16).min) + elif shifted > np.iinfo(np.int16).max: + return np.int16(np.iinfo(np.int16).max) + else: + return np.int16(shifted) + + +def downscale_multiplier_int32_to_int16(a): + assert np.int32(a) == a + rounding_offset = 1 << 15 + if a >= np.iinfo(np.int32).max - rounding_offset: + return np.iinfo(np.int16).max + else: + return np.int16((a + rounding_offset) >> 16) + + def rounding_divide_by_pot(x, exponent): assert np.int32(x) == x assert np.int32(exponent) == exponent @@ -92,7 +155,7 @@ def saturating_rounding_multiply_by_pot(x, exponent): elif x < -threshold: return np.iinfo(np.int32).min else: - return shift_left(x, exponent) + return shift_left32(x, exponent) def rescale(integer_bits_src, integer_bits_dst, x): @@ -115,16 +178,16 @@ def exp_on_interval_between_negative_one_quarter_and_0_excl(a): constant_term = 1895147668 constant_1_over_3 = 715827883 x = a + (1 << offset) - x2 = saturating_rounding_mul(x, x) - x3 = saturating_rounding_mul(x2, x) - x4 = saturating_rounding_mul(x2, x2) + x2 = saturating_rounding_mul32(x, x) + x3 = saturating_rounding_mul32(x2, x) + x4 = saturating_rounding_mul32(x2, x2) x4_over_4 = rounding_divide_by_pot(x4, 2) x4_over_24_plus_x3_over_6_plus_x2_over_2 = rounding_divide_by_pot( - saturating_rounding_mul((x4_over_4 + x3), constant_1_over_3) + x2, 1 + saturating_rounding_mul32((x4_over_4 + x3), constant_1_over_3) + x2, 1 ) return np.int32( - constant_term + saturating_rounding_mul(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2) + constant_term + saturating_rounding_mul32(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2) ) @@ -144,7 +207,7 @@ def exp_on_negative_values(a): integer_bits = 5 shift = fractional_bits + exponent if integer_bits > exponent else 0 if remainder & (1 << shift): - return saturating_rounding_mul(result, multiplier) + return saturating_rounding_mul32(result, multiplier) else: return result @@ -168,5 +231,5 @@ def multiply_by_quantized_multiplier(x, scale, shift): shift = 31 - shift left_shift = shift if shift > 0 else 0 right_shift = -shift if shift < 0 else 0 - mul = saturating_rounding_mul(x * (1 << left_shift), scale) + mul = saturating_rounding_mul32(x * (1 << left_shift), scale) return rounding_divide_by_pot(mul, right_shift) -- cgit v1.2.1