aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDiqing Zhong <diqing.zhong@arm.com>2021-01-26 12:12:51 +0100
committerDiqing Zhong <diqing.zhong@arm.com>2021-01-29 16:17:40 +0100
commit189f748e1a79ed88044efbe7137963bca830cbb5 (patch)
tree4d3db8614574b5aedcf952941c2194e2bf7f8285
parent2c2522dd44229a03d3d778cd239478fedc19ee57 (diff)
downloadethos-u-vela-189f748e1a79ed88044efbe7137963bca830cbb5.tar.gz
MLBEDSW-3224: Support HardSwish
Change-Id: If49abc31f093f1bd3393bee86f821fd35972086f Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
-rw-r--r--ethosu/vela/fp_math.py95
-rw-r--r--ethosu/vela/graph_optimiser.py53
-rw-r--r--ethosu/vela/operation.py6
-rw-r--r--ethosu/vela/softmax.py2
-rw-r--r--ethosu/vela/supported_operators.py13
-rw-r--r--ethosu/vela/test/test_fp_math.py78
-rw-r--r--ethosu/vela/test/test_supported_operators.py23
7 files changed, 225 insertions, 45 deletions
diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py
index 5228f03..21022c2 100644
--- a/ethosu/vela/fp_math.py
+++ b/ethosu/vela/fp_math.py
@@ -35,13 +35,14 @@ def to_float(x, integer_bits=5):
return x / (1 << fractional_bits)
-def saturating_rounding_mul(a, b):
+def saturating_rounding_mul32(a, b):
assert np.int32(a) == a
assert np.int32(b) == b
if a == b and a == np.iinfo(np.int32).min:
return np.int32(np.iinfo(np.int32).max)
divider = 1 << 31
ab = np.int64(a) * np.int64(b)
+
if ab >= 0:
nudge = 1 << 30
return (ab + nudge) // divider
@@ -56,19 +57,81 @@ def saturating_rounding_mul(a, b):
return result
-def shift_left(a, offset):
- assert np.int32(a) == a
+def saturating_rounding_mul16(a, b):
+ assert np.int16(a) == a
+ assert np.int16(b) == b
+ if a == b and a == np.iinfo(np.int16).min:
+ return np.int16(np.iinfo(np.int16).max)
+ divider = 1 << 15
+ ab = np.int32(a) * np.int32(b)
+
+ if ab >= 0:
+ nudge = 1 << 14
+ return (ab + nudge) // divider
+ else:
+ nudge = 1 - (1 << 14)
+ ab_plus_nudge = ab + nudge
+ result = ab_plus_nudge // divider
+ # Python uses floor, the reference uses truncation
+ # so we need to compensate for that.
+ if result * divider < ab_plus_nudge:
+ result += 1
+ return result
+
+
+# Similar to saturating_rounding_mul16 except rounding to zero instead of to nearest
+# Only supports 16bit
+def saturating_mul16(a, b):
+ assert np.int16(a) == a
+ assert np.int16(b) == b
+ if a == b and a == np.iinfo(np.int16).min:
+ return np.int16(np.iinfo(np.int16).max)
+ ab = np.int32(a) * np.int32(b)
+ divider = 1 << 15
+ if ab >= 0:
+ return ab // divider
+ else:
+ result = ab // divider
+ # Python uses floor, the reference uses truncation
+ # so we need to compensate for that.
+ if result * divider < ab:
+ result += 1
+ return result
+
+
+def shift_left32(a, offset):
assert offset >= 0
- i32_info = np.iinfo(np.int32)
+ assert np.int32(a) == a
shifted = a * (1 << offset)
- if shifted < i32_info.min:
- return np.int32(i32_info.min)
- elif shifted > i32_info.max:
- return np.int32(i32_info.max)
+ if shifted < np.iinfo(np.int32).min:
+ return np.int32(np.iinfo(np.int32).min)
+ elif shifted > np.iinfo(np.int32).max:
+ return np.int32(np.iinfo(np.int32).max)
else:
return np.int32(shifted)
+def shift_left16(a, offset):
+ assert offset >= 0
+ assert np.int16(a) == a
+ shifted = a * (1 << offset)
+ if shifted < np.iinfo(np.int16).min:
+ return np.int16(np.iinfo(np.int16).min)
+ elif shifted > np.iinfo(np.int16).max:
+ return np.int16(np.iinfo(np.int16).max)
+ else:
+ return np.int16(shifted)
+
+
+def downscale_multiplier_int32_to_int16(a):
+ assert np.int32(a) == a
+ rounding_offset = 1 << 15
+ if a >= np.iinfo(np.int32).max - rounding_offset:
+ return np.iinfo(np.int16).max
+ else:
+ return np.int16((a + rounding_offset) >> 16)
+
+
def rounding_divide_by_pot(x, exponent):
assert np.int32(x) == x
assert np.int32(exponent) == exponent
@@ -92,7 +155,7 @@ def saturating_rounding_multiply_by_pot(x, exponent):
elif x < -threshold:
return np.iinfo(np.int32).min
else:
- return shift_left(x, exponent)
+ return shift_left32(x, exponent)
def rescale(integer_bits_src, integer_bits_dst, x):
@@ -115,16 +178,16 @@ def exp_on_interval_between_negative_one_quarter_and_0_excl(a):
constant_term = 1895147668
constant_1_over_3 = 715827883
x = a + (1 << offset)
- x2 = saturating_rounding_mul(x, x)
- x3 = saturating_rounding_mul(x2, x)
- x4 = saturating_rounding_mul(x2, x2)
+ x2 = saturating_rounding_mul32(x, x)
+ x3 = saturating_rounding_mul32(x2, x)
+ x4 = saturating_rounding_mul32(x2, x2)
x4_over_4 = rounding_divide_by_pot(x4, 2)
x4_over_24_plus_x3_over_6_plus_x2_over_2 = rounding_divide_by_pot(
- saturating_rounding_mul((x4_over_4 + x3), constant_1_over_3) + x2, 1
+ saturating_rounding_mul32((x4_over_4 + x3), constant_1_over_3) + x2, 1
)
return np.int32(
- constant_term + saturating_rounding_mul(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
+ constant_term + saturating_rounding_mul32(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
)
@@ -144,7 +207,7 @@ def exp_on_negative_values(a):
integer_bits = 5
shift = fractional_bits + exponent if integer_bits > exponent else 0
if remainder & (1 << shift):
- return saturating_rounding_mul(result, multiplier)
+ return saturating_rounding_mul32(result, multiplier)
else:
return result
@@ -168,5 +231,5 @@ def multiply_by_quantized_multiplier(x, scale, shift):
shift = 31 - shift
left_shift = shift if shift > 0 else 0
right_shift = -shift if shift < 0 else 0
- mul = saturating_rounding_mul(x * (1 << left_shift), scale)
+ mul = saturating_rounding_mul32(x * (1 << left_shift), scale)
return rounding_divide_by_pot(mul, right_shift)
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index ab4d916..7755cc3 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -823,6 +823,58 @@ def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
return op
+def convert_hardswish_to_lut(op, arch, nng):
+ if op.type == Op.HardSwish:
+ ifm, ofm = op.get_ifm_ofm()
+ # Generate the LUT
+ ifm_scale = np.double(ifm.quantization.scale_f32)
+ ofm_scale = np.double(ofm.quantization.scale_f32)
+ zp_in = ifm.quantization.zero_point
+ zp_out = ofm.quantization.zero_point
+ ifm_scale_hires = (1 / 128) * ifm_scale
+ relu_multiplier = np.double(3 / 32768)
+ out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
+ relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
+ # Use 16bit scale
+ out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
+ relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
+
+ values = []
+ ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
+ quantized_min = min(ix)
+ quantized_max = max(ix)
+ for x in ix:
+ input_value = x - zp_in
+ input_value_hires = input_value * 128
+ # Compute the input value on essentially the output scale, not shifted yet
+ input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
+ # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
+ relu_value = np.int16(input_value_hires)
+ if relu_shift < 31:
+ relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
+
+ relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
+
+ if relu_shift < 31:
+ relu_value = fp_math.shift_left16(relu_value, 1)
+
+ if relu_shift > 31:
+ relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
+
+ # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
+ # Now convert that to a 16bit fixedpoint value in [0, 1]
+ relu_value = (relu_value + (1 << 15)) >> 1
+ lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
+ shift = 31 - out_shift
+ shift = -shift if shift < 0 else 0
+ # Finally apply the output shift
+ lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
+ lut_result = min(quantized_max, max(quantized_min, lut_result))
+ values.append(lut_result)
+ return convert_to_lut(op, values, "hardswish")
+ return op
+
+
def convert_lrelu_to_mul_max(op, arch):
# Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
# (the opposite of convert_mul_max_to_abs_or_lrelu)
@@ -1245,6 +1297,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
convert_conv_to_fc,
convert_softmax,
optimise_strided_conv,
+ convert_hardswish_to_lut,
rewrite_fully_connected_input,
convert_batched_fc_shape,
fixup_conv2d_backprop,
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 8d54d65..73953ce 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -165,7 +165,7 @@ class Op(Enum):
GatherV2 = OperatorInfo()
Greater = OperatorInfo()
GreaterEqual = OperatorInfo()
- HardSwish = OperatorInfo()
+ HardSwish = OperatorInfo(indices=IFM_INDICES)
HashtableLookup = OperatorInfo()
Identity = OperatorInfo()
If = OperatorInfo()
@@ -305,7 +305,7 @@ class Op(Enum):
return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip)
def is_activation_op(self):
- return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
+ return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish)
def is_split_op(self):
return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack)
@@ -372,6 +372,8 @@ def create_activation_function(op_type: Op) -> ActivationFunction:
elif op_type == Op.Sigmoid:
act.min = 0.0
act.max = 1.0
+ elif op_type == Op.HardSwish:
+ act.min = 0.0
return act
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 656a7e6..c3b0611 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -203,7 +203,7 @@ class SoftMax:
for x in range(256):
input_diff = x - 255
if input_diff >= diff_min:
- rescale = fp_math.saturating_rounding_mul(input_diff * (1 << shift), scale)
+ rescale = fp_math.saturating_rounding_mul32(input_diff * (1 << shift), scale)
lut.append(fp_math.exp_on_negative_values(rescale))
else:
lut.append(0)
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 1bebe9a..99a4ba1 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -87,7 +87,7 @@ class SupportedOperators:
set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
)
relu_ops = Op.op_set(Op.is_relu_op)
- activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax,))
+ activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish))
npu_post_ops = (
# activation functions
activation_ops
@@ -261,6 +261,10 @@ class SupportedOperators:
self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant)
self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_ofm)
+ # HardSwish specific checks:
+ self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit)
+ self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in SupportedOperators.supported_operators:
@@ -934,6 +938,13 @@ class SupportedOperators:
return valid, f"Op has ofm_dtype={ofm_dtype}"
@staticmethod
+ def constraint_input_8bit(op):
+ "IFM must be int8 or uint8"
+ ifm_dtype = op.ifm.dtype
+ valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
+ return valid, f"Op has ifm_dtype={ifm_dtype}"
+
+ @staticmethod
def constraint_matching_quantization_parameters(op):
"Both Input quantization parameters must match OFM quantization parameters"
valid = True
diff --git a/ethosu/vela/test/test_fp_math.py b/ethosu/vela/test/test_fp_math.py
index 905826f..355d3ae 100644
--- a/ethosu/vela/test/test_fp_math.py
+++ b/ethosu/vela/test/test_fp_math.py
@@ -64,52 +64,80 @@ EXP_LUT = [
def test_saturating_rounding_mul():
i32info = np.iinfo(np.int32)
+ i16info = np.iinfo(np.int16)
+
# Saturation
- assert fp_math.saturating_rounding_mul(i32info.min, i32info.min) == i32info.max
- assert fp_math.saturating_rounding_mul(i32info.min, i32info.max) == -i32info.max
- assert fp_math.saturating_rounding_mul(i32info.max, i32info.min) == -i32info.max
+ assert fp_math.saturating_rounding_mul32(i32info.min, i32info.min) == i32info.max
+ assert fp_math.saturating_rounding_mul32(i32info.min, i32info.max) == -i32info.max
+ assert fp_math.saturating_rounding_mul32(i32info.max, i32info.min) == -i32info.max
+
+ assert fp_math.saturating_rounding_mul16(i16info.min, i16info.min) == i16info.max
+ assert fp_math.saturating_rounding_mul16(i16info.min, i16info.max) == -i16info.max
+ assert fp_math.saturating_rounding_mul16(i16info.max, i16info.min) == -i16info.max
# Multiply by zero
- assert fp_math.saturating_rounding_mul(0, fp_math.from_float(1.0)) == 0
- assert fp_math.saturating_rounding_mul(0, fp_math.from_float(-1.0)) == 0
- assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), 0) == 0
- assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), 0) == 0
+ assert fp_math.saturating_rounding_mul32(0, fp_math.from_float(1.0)) == 0
+ assert fp_math.saturating_rounding_mul32(0, fp_math.from_float(-1.0)) == 0
+ assert fp_math.saturating_rounding_mul32(fp_math.from_float(1.0), 0) == 0
+ assert fp_math.saturating_rounding_mul32(fp_math.from_float(-1.0), 0) == 0
+
+ assert fp_math.saturating_rounding_mul16(0, i16info.max) == 0
+ assert fp_math.saturating_rounding_mul16(0, i16info.min) == 0
+ assert fp_math.saturating_rounding_mul16(i16info.max, 0) == 0
+ assert fp_math.saturating_rounding_mul16(i16info.min, 0) == 0
# Multiply positive/negative
- assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), fp_math.from_float(1.0)) == fp_math.from_float(
+ assert fp_math.saturating_rounding_mul32(fp_math.from_float(1.0), fp_math.from_float(1.0)) == fp_math.from_float(
1.0, 5 + 5
)
- assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), fp_math.from_float(1.0)) == fp_math.from_float(
+ assert fp_math.saturating_rounding_mul32(fp_math.from_float(-1.0), fp_math.from_float(1.0)) == fp_math.from_float(
-1.0, 5 + 5
)
- assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), fp_math.from_float(-1.0)) == fp_math.from_float(
+ assert fp_math.saturating_rounding_mul32(fp_math.from_float(1.0), fp_math.from_float(-1.0)) == fp_math.from_float(
-1.0, 5 + 5
)
- assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), fp_math.from_float(-1.0)) == fp_math.from_float(
+ assert fp_math.saturating_rounding_mul32(fp_math.from_float(-1.0), fp_math.from_float(-1.0)) == fp_math.from_float(
1.0, 5 + 5
)
# Rounding
- assert fp_math.saturating_rounding_mul(fp_math.from_float(16.0), 1) == 1
- assert fp_math.saturating_rounding_mul(fp_math.from_float(-16.0), 1) == 0
- assert fp_math.saturating_rounding_mul(fp_math.from_float(16.0) - 1, 1) == 0
- assert fp_math.saturating_rounding_mul(fp_math.from_float(-16.0) - 1, 1) == -1
+ assert fp_math.saturating_rounding_mul32(fp_math.from_float(16.0), 1) == 1
+ assert fp_math.saturating_rounding_mul32(fp_math.from_float(-16.0), 1) == 0
+ assert fp_math.saturating_rounding_mul32(fp_math.from_float(16.0) - 1, 1) == 0
+ assert fp_math.saturating_rounding_mul32(fp_math.from_float(-16.0) - 1, 1) == -1
+
+ assert fp_math.saturating_rounding_mul16(fp_math.from_float(16.0, 21), 1) == 1
+ assert fp_math.saturating_rounding_mul16(fp_math.from_float(-16.0, 21), 1) == 0
+ assert fp_math.saturating_rounding_mul16(fp_math.from_float(16.0, 21) - 1, 1) == 0
+ assert fp_math.saturating_rounding_mul16(fp_math.from_float(-16.0, 21) - 1, 1) == -1
def test_shift_left():
i32info = np.iinfo(np.int32)
- assert fp_math.shift_left(1, i32info.bits) == i32info.max
- assert fp_math.shift_left(-1, i32info.bits) == i32info.min
- assert fp_math.shift_left(1, i32info.bits - 2) == (i32info.max + 1) / 2
- assert fp_math.shift_left(-1, i32info.bits - 2) == i32info.min // 2
-
- assert fp_math.shift_left(fp_math.from_float(1.0), 5) == i32info.max
- assert fp_math.shift_left(fp_math.from_float(-1.0), 5) == i32info.min
- assert fp_math.shift_left(fp_math.from_float(1.0), 4) == 16 * fp_math.from_float(1.0)
- assert fp_math.shift_left(fp_math.from_float(-1.0), 4) == 16 * fp_math.from_float(-1.0)
+ i16info = np.iinfo(np.int16)
+ assert fp_math.shift_left32(1, i32info.bits) == i32info.max
+ assert fp_math.shift_left32(-1, i32info.bits) == i32info.min
+ assert fp_math.shift_left32(1, i32info.bits - 2) == (i32info.max + 1) / 2
+ assert fp_math.shift_left32(-1, i32info.bits - 2) == i32info.min // 2
+
+ assert fp_math.shift_left16(1, i16info.bits) == i16info.max
+ assert fp_math.shift_left16(-1, i16info.bits) == i16info.min
+ assert fp_math.shift_left16(1, i16info.bits - 2) == (i16info.max + 1) / 2
+ assert fp_math.shift_left16(-1, i16info.bits - 2) == i16info.min // 2
+
+ assert fp_math.shift_left32(fp_math.from_float(1.0), 5) == i32info.max
+ assert fp_math.shift_left32(fp_math.from_float(-1.0), 5) == i32info.min
+ assert fp_math.shift_left32(fp_math.from_float(1.0), 4) == 16 * fp_math.from_float(1.0)
+ assert fp_math.shift_left32(fp_math.from_float(-1.0), 4) == 16 * fp_math.from_float(-1.0)
+
+ assert fp_math.shift_left16(fp_math.from_float(1.0, 21), 5) == i16info.max
+ assert fp_math.shift_left16(fp_math.from_float(-1.0, 21), 5) == i16info.min
+ assert fp_math.shift_left16(fp_math.from_float(1.0, 21), 4) == 16 * fp_math.from_float(1.0, 21)
+ assert fp_math.shift_left16(fp_math.from_float(-1.0, 21), 4) == 16 * fp_math.from_float(-1.0, 21)
with pytest.raises(AssertionError):
- fp_math.shift_left(1, -1)
+ fp_math.shift_left32(1, -1)
+ fp_math.shift_left16(1, -1)
def test_rounding_divide_by_pot():
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 36213b7..5c01027 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -834,3 +834,26 @@ def test_constraint_alpha_valid():
assert support.is_operator_supported(op)
op.attrs["alpha"] = -1
assert not support.is_operator_supported(op)
+
+
+def test_constraint_hardswish_dtype():
+ # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
+ # UINT8
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
+ assert support.is_operator_supported(op)
+ # INT8
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
+ assert support.is_operator_supported(op)
+
+ # Invalid
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
+ assert not support.is_operator_supported(op)
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
+ assert not support.is_operator_supported(op)
+ op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
+ assert not support.is_operator_supported(op)
+
+ in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
+ out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+ op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
+ assert not support.is_operator_supported(op)