From 189f748e1a79ed88044efbe7137963bca830cbb5 Mon Sep 17 00:00:00 2001 From: Diqing Zhong Date: Tue, 26 Jan 2021 12:12:51 +0100 Subject: MLBEDSW-3224: Support HardSwish Change-Id: If49abc31f093f1bd3393bee86f821fd35972086f Signed-off-by: Diqing Zhong --- ethosu/vela/fp_math.py | 95 +++++++++++++++++++++++----- ethosu/vela/graph_optimiser.py | 53 ++++++++++++++++ ethosu/vela/operation.py | 6 +- ethosu/vela/softmax.py | 2 +- ethosu/vela/supported_operators.py | 13 +++- ethosu/vela/test/test_fp_math.py | 78 +++++++++++++++-------- ethosu/vela/test/test_supported_operators.py | 23 +++++++ 7 files changed, 225 insertions(+), 45 deletions(-) diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py index 5228f031..21022c2a 100644 --- a/ethosu/vela/fp_math.py +++ b/ethosu/vela/fp_math.py @@ -35,13 +35,14 @@ def to_float(x, integer_bits=5): return x / (1 << fractional_bits) -def saturating_rounding_mul(a, b): +def saturating_rounding_mul32(a, b): assert np.int32(a) == a assert np.int32(b) == b if a == b and a == np.iinfo(np.int32).min: return np.int32(np.iinfo(np.int32).max) divider = 1 << 31 ab = np.int64(a) * np.int64(b) + if ab >= 0: nudge = 1 << 30 return (ab + nudge) // divider @@ -56,19 +57,81 @@ def saturating_rounding_mul(a, b): return result -def shift_left(a, offset): - assert np.int32(a) == a +def saturating_rounding_mul16(a, b): + assert np.int16(a) == a + assert np.int16(b) == b + if a == b and a == np.iinfo(np.int16).min: + return np.int16(np.iinfo(np.int16).max) + divider = 1 << 15 + ab = np.int32(a) * np.int32(b) + + if ab >= 0: + nudge = 1 << 14 + return (ab + nudge) // divider + else: + nudge = 1 - (1 << 14) + ab_plus_nudge = ab + nudge + result = ab_plus_nudge // divider + # Python uses floor, the reference uses truncation + # so we need to compensate for that. + if result * divider < ab_plus_nudge: + result += 1 + return result + + +# Similar to saturating_rounding_mul16 except rounding to zero instead of to nearest +# Only supports 16bit +def saturating_mul16(a, b): + assert np.int16(a) == a + assert np.int16(b) == b + if a == b and a == np.iinfo(np.int16).min: + return np.int16(np.iinfo(np.int16).max) + ab = np.int32(a) * np.int32(b) + divider = 1 << 15 + if ab >= 0: + return ab // divider + else: + result = ab // divider + # Python uses floor, the reference uses truncation + # so we need to compensate for that. + if result * divider < ab: + result += 1 + return result + + +def shift_left32(a, offset): assert offset >= 0 - i32_info = np.iinfo(np.int32) + assert np.int32(a) == a shifted = a * (1 << offset) - if shifted < i32_info.min: - return np.int32(i32_info.min) - elif shifted > i32_info.max: - return np.int32(i32_info.max) + if shifted < np.iinfo(np.int32).min: + return np.int32(np.iinfo(np.int32).min) + elif shifted > np.iinfo(np.int32).max: + return np.int32(np.iinfo(np.int32).max) else: return np.int32(shifted) +def shift_left16(a, offset): + assert offset >= 0 + assert np.int16(a) == a + shifted = a * (1 << offset) + if shifted < np.iinfo(np.int16).min: + return np.int16(np.iinfo(np.int16).min) + elif shifted > np.iinfo(np.int16).max: + return np.int16(np.iinfo(np.int16).max) + else: + return np.int16(shifted) + + +def downscale_multiplier_int32_to_int16(a): + assert np.int32(a) == a + rounding_offset = 1 << 15 + if a >= np.iinfo(np.int32).max - rounding_offset: + return np.iinfo(np.int16).max + else: + return np.int16((a + rounding_offset) >> 16) + + def rounding_divide_by_pot(x, exponent): assert np.int32(x) == x assert np.int32(exponent) == exponent @@ -92,7 +155,7 @@ def saturating_rounding_multiply_by_pot(x, exponent): elif x < -threshold: return np.iinfo(np.int32).min else: - return shift_left(x, exponent) + return shift_left32(x, exponent) def rescale(integer_bits_src, integer_bits_dst, x): @@ -115,16 +178,16 @@ def exp_on_interval_between_negative_one_quarter_and_0_excl(a): constant_term = 1895147668 constant_1_over_3 = 715827883 x = a + (1 << offset) - x2 = saturating_rounding_mul(x, x) - x3 = saturating_rounding_mul(x2, x) - x4 = saturating_rounding_mul(x2, x2) + x2 = saturating_rounding_mul32(x, x) + x3 = saturating_rounding_mul32(x2, x) + x4 = saturating_rounding_mul32(x2, x2) x4_over_4 = rounding_divide_by_pot(x4, 2) x4_over_24_plus_x3_over_6_plus_x2_over_2 = rounding_divide_by_pot( - saturating_rounding_mul((x4_over_4 + x3), constant_1_over_3) + x2, 1 + saturating_rounding_mul32((x4_over_4 + x3), constant_1_over_3) + x2, 1 ) return np.int32( - constant_term + saturating_rounding_mul(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2) + constant_term + saturating_rounding_mul32(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2) ) @@ -144,7 +207,7 @@ def exp_on_negative_values(a): integer_bits = 5 shift = fractional_bits + exponent if integer_bits > exponent else 0 if remainder & (1 << shift): - return saturating_rounding_mul(result, multiplier) + return saturating_rounding_mul32(result, multiplier) else: return result @@ -168,5 +231,5 @@ def multiply_by_quantized_multiplier(x, scale, shift): shift = 31 - shift left_shift = shift if shift > 0 else 0 right_shift = -shift if shift < 0 else 0 - mul = saturating_rounding_mul(x * (1 << left_shift), scale) + mul = saturating_rounding_mul32(x * (1 << left_shift), scale) return rounding_divide_by_pot(mul, right_shift) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index ab4d916e..7755cc3b 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -823,6 +823,58 @@ def convert_mul_max_to_abs_or_lrelu(op, arch, nng): return op +def convert_hardswish_to_lut(op, arch, nng): + if op.type == Op.HardSwish: + ifm, ofm = op.get_ifm_ofm() + # Generate the LUT + ifm_scale = np.double(ifm.quantization.scale_f32) + ofm_scale = np.double(ofm.quantization.scale_f32) + zp_in = ifm.quantization.zero_point + zp_out = ofm.quantization.zero_point + ifm_scale_hires = (1 / 128) * ifm_scale + relu_multiplier = np.double(3 / 32768) + out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale) + relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier) + # Use 16bit scale + out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale) + relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale) + + values = [] + ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128) + quantized_min = min(ix) + quantized_max = max(ix) + for x in ix: + input_value = x - zp_in + input_value_hires = input_value * 128 + # Compute the input value on essentially the output scale, not shifted yet + input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16) + # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel + relu_value = np.int16(input_value_hires) + if relu_shift < 31: + relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift) + + relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16) + + if relu_shift < 31: + relu_value = fp_math.shift_left16(relu_value, 1) + + if relu_shift > 31: + relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31) + + # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1] + # Now convert that to a 16bit fixedpoint value in [0, 1] + relu_value = (relu_value + (1 << 15)) >> 1 + lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift) + shift = 31 - out_shift + shift = -shift if shift < 0 else 0 + # Finally apply the output shift + lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out + lut_result = min(quantized_max, max(quantized_min, lut_result)) + values.append(lut_result) + return convert_to_lut(op, values, "hardswish") + return op + + def convert_lrelu_to_mul_max(op, arch): # Converts LeakyRelu to Max(alpha * IFM, identity * IFM) # (the opposite of convert_mul_max_to_abs_or_lrelu) @@ -1245,6 +1297,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): convert_conv_to_fc, convert_softmax, optimise_strided_conv, + convert_hardswish_to_lut, rewrite_fully_connected_input, convert_batched_fc_shape, fixup_conv2d_backprop, diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 8d54d658..73953cec 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -165,7 +165,7 @@ class Op(Enum): GatherV2 = OperatorInfo() Greater = OperatorInfo() GreaterEqual = OperatorInfo() - HardSwish = OperatorInfo() + HardSwish = OperatorInfo(indices=IFM_INDICES) HashtableLookup = OperatorInfo() Identity = OperatorInfo() If = OperatorInfo() @@ -305,7 +305,7 @@ class Op(Enum): return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip) def is_activation_op(self): - return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT) + return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish) def is_split_op(self): return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack) @@ -372,6 +372,8 @@ def create_activation_function(op_type: Op) -> ActivationFunction: elif op_type == Op.Sigmoid: act.min = 0.0 act.max = 1.0 + elif op_type == Op.HardSwish: + act.min = 0.0 return act diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py index 656a7e69..c3b0611a 100644 --- a/ethosu/vela/softmax.py +++ b/ethosu/vela/softmax.py @@ -203,7 +203,7 @@ class SoftMax: for x in range(256): input_diff = x - 255 if input_diff >= diff_min: - rescale = fp_math.saturating_rounding_mul(input_diff * (1 << shift), scale) + rescale = fp_math.saturating_rounding_mul32(input_diff * (1 << shift), scale) lut.append(fp_math.exp_on_negative_values(rescale)) else: lut.append(0) diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 1bebe9af..99a4ba10 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -87,7 +87,7 @@ class SupportedOperators: set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops ) relu_ops = Op.op_set(Op.is_relu_op) - activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax,)) + activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish)) npu_post_ops = ( # activation functions activation_ops @@ -261,6 +261,10 @@ class SupportedOperators: self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant) self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_ofm) + # HardSwish specific checks: + self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit) + self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types) + def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in SupportedOperators.supported_operators: @@ -933,6 +937,13 @@ class SupportedOperators: valid = ofm_dtype == DataType.int32 return valid, f"Op has ofm_dtype={ofm_dtype}" + @staticmethod + def constraint_input_8bit(op): + "IFM must be int8 or uint8" + ifm_dtype = op.ifm.dtype + valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8) + return valid, f"Op has ifm_dtype={ifm_dtype}" + @staticmethod def constraint_matching_quantization_parameters(op): "Both Input quantization parameters must match OFM quantization parameters" diff --git a/ethosu/vela/test/test_fp_math.py b/ethosu/vela/test/test_fp_math.py index 905826f4..355d3ae6 100644 --- a/ethosu/vela/test/test_fp_math.py +++ b/ethosu/vela/test/test_fp_math.py @@ -64,52 +64,80 @@ EXP_LUT = [ def test_saturating_rounding_mul(): i32info = np.iinfo(np.int32) + i16info = np.iinfo(np.int16) + # Saturation - assert fp_math.saturating_rounding_mul(i32info.min, i32info.min) == i32info.max - assert fp_math.saturating_rounding_mul(i32info.min, i32info.max) == -i32info.max - assert fp_math.saturating_rounding_mul(i32info.max, i32info.min) == -i32info.max + assert fp_math.saturating_rounding_mul32(i32info.min, i32info.min) == i32info.max + assert fp_math.saturating_rounding_mul32(i32info.min, i32info.max) == -i32info.max + assert fp_math.saturating_rounding_mul32(i32info.max, i32info.min) == -i32info.max + + assert fp_math.saturating_rounding_mul16(i16info.min, i16info.min) == i16info.max + assert fp_math.saturating_rounding_mul16(i16info.min, i16info.max) == -i16info.max + assert fp_math.saturating_rounding_mul16(i16info.max, i16info.min) == -i16info.max # Multiply by zero - assert fp_math.saturating_rounding_mul(0, fp_math.from_float(1.0)) == 0 - assert fp_math.saturating_rounding_mul(0, fp_math.from_float(-1.0)) == 0 - assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), 0) == 0 - assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), 0) == 0 + assert fp_math.saturating_rounding_mul32(0, fp_math.from_float(1.0)) == 0 + assert fp_math.saturating_rounding_mul32(0, fp_math.from_float(-1.0)) == 0 + assert fp_math.saturating_rounding_mul32(fp_math.from_float(1.0), 0) == 0 + assert fp_math.saturating_rounding_mul32(fp_math.from_float(-1.0), 0) == 0 + + assert fp_math.saturating_rounding_mul16(0, i16info.max) == 0 + assert fp_math.saturating_rounding_mul16(0, i16info.min) == 0 + assert fp_math.saturating_rounding_mul16(i16info.max, 0) == 0 + assert fp_math.saturating_rounding_mul16(i16info.min, 0) == 0 # Multiply positive/negative - assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), fp_math.from_float(1.0)) == fp_math.from_float( + assert fp_math.saturating_rounding_mul32(fp_math.from_float(1.0), fp_math.from_float(1.0)) == fp_math.from_float( 1.0, 5 + 5 ) - assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), fp_math.from_float(1.0)) == fp_math.from_float( + assert fp_math.saturating_rounding_mul32(fp_math.from_float(-1.0), fp_math.from_float(1.0)) == fp_math.from_float( -1.0, 5 + 5 ) - assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), fp_math.from_float(-1.0)) == fp_math.from_float( + assert fp_math.saturating_rounding_mul32(fp_math.from_float(1.0), fp_math.from_float(-1.0)) == fp_math.from_float( -1.0, 5 + 5 ) - assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), fp_math.from_float(-1.0)) == fp_math.from_float( + assert fp_math.saturating_rounding_mul32(fp_math.from_float(-1.0), fp_math.from_float(-1.0)) == fp_math.from_float( 1.0, 5 + 5 ) # Rounding - assert fp_math.saturating_rounding_mul(fp_math.from_float(16.0), 1) == 1 - assert fp_math.saturating_rounding_mul(fp_math.from_float(-16.0), 1) == 0 - assert fp_math.saturating_rounding_mul(fp_math.from_float(16.0) - 1, 1) == 0 - assert fp_math.saturating_rounding_mul(fp_math.from_float(-16.0) - 1, 1) == -1 + assert fp_math.saturating_rounding_mul32(fp_math.from_float(16.0), 1) == 1 + assert fp_math.saturating_rounding_mul32(fp_math.from_float(-16.0), 1) == 0 + assert fp_math.saturating_rounding_mul32(fp_math.from_float(16.0) - 1, 1) == 0 + assert fp_math.saturating_rounding_mul32(fp_math.from_float(-16.0) - 1, 1) == -1 + + assert fp_math.saturating_rounding_mul16(fp_math.from_float(16.0, 21), 1) == 1 + assert fp_math.saturating_rounding_mul16(fp_math.from_float(-16.0, 21), 1) == 0 + assert fp_math.saturating_rounding_mul16(fp_math.from_float(16.0, 21) - 1, 1) == 0 + assert fp_math.saturating_rounding_mul16(fp_math.from_float(-16.0, 21) - 1, 1) == -1 def test_shift_left(): i32info = np.iinfo(np.int32) - assert fp_math.shift_left(1, i32info.bits) == i32info.max - assert fp_math.shift_left(-1, i32info.bits) == i32info.min - assert fp_math.shift_left(1, i32info.bits - 2) == (i32info.max + 1) / 2 - assert fp_math.shift_left(-1, i32info.bits - 2) == i32info.min // 2 - - assert fp_math.shift_left(fp_math.from_float(1.0), 5) == i32info.max - assert fp_math.shift_left(fp_math.from_float(-1.0), 5) == i32info.min - assert fp_math.shift_left(fp_math.from_float(1.0), 4) == 16 * fp_math.from_float(1.0) - assert fp_math.shift_left(fp_math.from_float(-1.0), 4) == 16 * fp_math.from_float(-1.0) + i16info = np.iinfo(np.int16) + assert fp_math.shift_left32(1, i32info.bits) == i32info.max + assert fp_math.shift_left32(-1, i32info.bits) == i32info.min + assert fp_math.shift_left32(1, i32info.bits - 2) == (i32info.max + 1) / 2 + assert fp_math.shift_left32(-1, i32info.bits - 2) == i32info.min // 2 + + assert fp_math.shift_left16(1, i16info.bits) == i16info.max + assert fp_math.shift_left16(-1, i16info.bits) == i16info.min + assert fp_math.shift_left16(1, i16info.bits - 2) == (i16info.max + 1) / 2 + assert fp_math.shift_left16(-1, i16info.bits - 2) == i16info.min // 2 + + assert fp_math.shift_left32(fp_math.from_float(1.0), 5) == i32info.max + assert fp_math.shift_left32(fp_math.from_float(-1.0), 5) == i32info.min + assert fp_math.shift_left32(fp_math.from_float(1.0), 4) == 16 * fp_math.from_float(1.0) + assert fp_math.shift_left32(fp_math.from_float(-1.0), 4) == 16 * fp_math.from_float(-1.0) + + assert fp_math.shift_left16(fp_math.from_float(1.0, 21), 5) == i16info.max + assert fp_math.shift_left16(fp_math.from_float(-1.0, 21), 5) == i16info.min + assert fp_math.shift_left16(fp_math.from_float(1.0, 21), 4) == 16 * fp_math.from_float(1.0, 21) + assert fp_math.shift_left16(fp_math.from_float(-1.0, 21), 4) == 16 * fp_math.from_float(-1.0, 21) with pytest.raises(AssertionError): - fp_math.shift_left(1, -1) + fp_math.shift_left32(1, -1) + fp_math.shift_left16(1, -1) def test_rounding_divide_by_pot(): diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py index 36213b73..5c01027d 100644 --- a/ethosu/vela/test/test_supported_operators.py +++ b/ethosu/vela/test/test_supported_operators.py @@ -834,3 +834,26 @@ def test_constraint_alpha_valid(): assert support.is_operator_supported(op) op.attrs["alpha"] = -1 assert not support.is_operator_supported(op) + + +def test_constraint_hardswish_dtype(): + # HardSwish operator dtype should be int8 or uint8, and input dtype must match output + # UINT8 + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8]) + assert support.is_operator_supported(op) + # INT8 + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8) + assert support.is_operator_supported(op) + + # Invalid + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16) + assert not support.is_operator_supported(op) + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16) + assert not support.is_operator_supported(op) + op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32) + assert not support.is_operator_supported(op) + + in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in") + out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out") + op = testutil.create_op(Op.HardSwish, [in_tens], out_tens) + assert not support.is_operator_supported(op) -- cgit v1.2.1