diff options
author | Per Åstrand <per.astrand@arm.com> | 2024-03-21 12:58:50 +0100 |
---|---|---|
committer | Per Åstrand <per.astrand@arm.com> | 2024-04-12 12:19:19 +0200 |
commit | 931613df7c68fb1c7cb45c6f69783c86003d7583 (patch) | |
tree | f315e61dd402a00c8f85e316b5c0dcec6a375115 | |
parent | 31947ad1aec50b64508bf367cb3e87c93f8c4693 (diff) | |
download | ethos-u-vela-931613df7c68fb1c7cb45c6f69783c86003d7583.tar.gz |
Fuse rescales into Add and Conv2d operation
Remove the upscale to int32 before and after the the add operation.
Re-enable fusing of conv2d and rescale that was removed earlier.
Signed-off-by: Per Åstrand <per.astrand@arm.com>
Change-Id: I5e7d9bd99bb3925588b507824d8eb3e6642cc7f0
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 90 |
1 files changed, 86 insertions, 4 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 19244c2..09b2c52 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -365,6 +365,10 @@ def rewrite_rescale(op, arch, nng): # some error checking assert len(ifm.ops) == 1 + prev_op = ifm.ops[0] + + # TODO currently not supported + assert len(ifm.consumer_list) == 1 input_zp = op.attrs["input_zp"] output_zp = op.attrs["output_zp"] @@ -402,10 +406,88 @@ def rewrite_rescale(op, arch, nng): else: rounding_mode = RoundingMode.HalfUp - # Generate Rescale behaviour attached to a compatible NOP - avgpool_op = replace_rescale_with_avg_pool(op) - avgpool_op.rounding_mode = rounding_mode - avgpool_op.explicit_scaling = explicit_scaling + if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected: + # Currently not supporting per_channel quantization + if ifm.dtype == DataType.int32 and not per_channel: + prev_op.explicit_scaling = explicit_scaling + prev_op.rounding_mode = rounding_mode + + # Bypass op + prev_op.set_output_tensor(ofm) + DebugDatabase.add_optimised(op, prev_op) + return op + else: + print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) + assert False + elif ( + (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) + or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8) + or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8) + ): + # Create NOP performing the RESCALE + avgpool_op = replace_rescale_with_avg_pool(op) + avgpool_op.rounding_mode = rounding_mode + + if per_channel: + # TODO + avgpool_op.explicit_scaling = explicit_scaling + print("Warning, unsupported TOSA Rescale") + assert False + else: + avgpool_op.explicit_scaling = explicit_scaling + elif prev_op.type == Op.Add: + # Check that the operations before the Add which creates the IFMs + # are Op.Rescale that we can fuse into the add + rescale_1 = prev_op.ifm.ops[0] + rescale_2 = prev_op.ifm2.ops[0] + + if rescale_1.type == Op.Rescale and rescale_2.type == Op.Rescale: + # We are assuming the quantization to be the same for IFMs + equal_attributes = ["multiplier", "shift", "double_round"] + for a in equal_attributes: + assert op.attrs[a] == rescale_1.attrs[a] == rescale_2.attrs[a], ( + f"Only handling equal {a} for all operands " + "({op.attrs[a]}, {rescale_1.attrs[a]}, {rescale_2.attrs[a]}) " + "for all the rescale operations to be fused with Add!" + ) + + assert rescale_1.attrs["input_zp"] == rescale_2.attrs["input_zp"], ( + f"Only handling equal input_zp ({rescale_1.attrs['input_zp']}!={rescale_2.attrs['input_zp']}) " + "for the rescale operations to be fused with Add!" + ) + for op in [rescale_1, rescale_2]: + assert op.attrs["output_zp"] == 0, "" + assert op.attrs["per_channel"] is False, "per channel quantization is not supported." + + # Create a new add op to set the rescaled ifms and ofm + add_op = create_add_nop(prev_op.name + "_fused_rescales") + add_op.type = Op.Add + + # set the IFMs and OFM for the cloned operation + add_op.set_output_tensor(ofm) + add_op.add_input_tensor(rescale_1.ifm) + add_op.add_input_tensor(rescale_2.ifm) + add_op.set_ifm_ofm_shapes() + + # Remove the consumption of the IFMs to the Add + # since we are pruning them from the graph + for i, c in enumerate(prev_op.ifm.consumers()): + if c == rescale_1: + prev_op.ifm.consumers().pop(i) + for i, c in enumerate(prev_op.ifm2.consumers()): + if c == rescale_2: + prev_op.ifm2.consumers().pop(i) + + DebugDatabase.add_optimised(prev_op, op) + DebugDatabase.add_optimised(prev_op, rescale_1) + DebugDatabase.add_optimised(prev_op, rescale_2) + op = add_op + else: + print("Warning, unsupported fusing of TOSA Rescale with Add.") + assert False + else: + print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) + assert False return op |