From 931613df7c68fb1c7cb45c6f69783c86003d7583 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 21 Mar 2024 12:58:50 +0100 Subject: Fuse rescales into Add and Conv2d operation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the upscale to int32 before and after the the add operation. Re-enable fusing of conv2d and rescale that was removed earlier. Signed-off-by: Per Åstrand Change-Id: I5e7d9bd99bb3925588b507824d8eb3e6642cc7f0 --- ethosu/vela/tosa_graph_optimiser.py | 90 +++++++++++++++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 4 deletions(-) diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 19244c27..09b2c526 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -365,6 +365,10 @@ def rewrite_rescale(op, arch, nng): # some error checking assert len(ifm.ops) == 1 + prev_op = ifm.ops[0] + + # TODO currently not supported + assert len(ifm.consumer_list) == 1 input_zp = op.attrs["input_zp"] output_zp = op.attrs["output_zp"] @@ -402,10 +406,88 @@ def rewrite_rescale(op, arch, nng): else: rounding_mode = RoundingMode.HalfUp - # Generate Rescale behaviour attached to a compatible NOP - avgpool_op = replace_rescale_with_avg_pool(op) - avgpool_op.rounding_mode = rounding_mode - avgpool_op.explicit_scaling = explicit_scaling + if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected: + # Currently not supporting per_channel quantization + if ifm.dtype == DataType.int32 and not per_channel: + prev_op.explicit_scaling = explicit_scaling + prev_op.rounding_mode = rounding_mode + + # Bypass op + prev_op.set_output_tensor(ofm) + DebugDatabase.add_optimised(op, prev_op) + return op + else: + print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) + assert False + elif ( + (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) + or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8) + or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8) + ): + # Create NOP performing the RESCALE + avgpool_op = replace_rescale_with_avg_pool(op) + avgpool_op.rounding_mode = rounding_mode + + if per_channel: + # TODO + avgpool_op.explicit_scaling = explicit_scaling + print("Warning, unsupported TOSA Rescale") + assert False + else: + avgpool_op.explicit_scaling = explicit_scaling + elif prev_op.type == Op.Add: + # Check that the operations before the Add which creates the IFMs + # are Op.Rescale that we can fuse into the add + rescale_1 = prev_op.ifm.ops[0] + rescale_2 = prev_op.ifm2.ops[0] + + if rescale_1.type == Op.Rescale and rescale_2.type == Op.Rescale: + # We are assuming the quantization to be the same for IFMs + equal_attributes = ["multiplier", "shift", "double_round"] + for a in equal_attributes: + assert op.attrs[a] == rescale_1.attrs[a] == rescale_2.attrs[a], ( + f"Only handling equal {a} for all operands " + "({op.attrs[a]}, {rescale_1.attrs[a]}, {rescale_2.attrs[a]}) " + "for all the rescale operations to be fused with Add!" + ) + + assert rescale_1.attrs["input_zp"] == rescale_2.attrs["input_zp"], ( + f"Only handling equal input_zp ({rescale_1.attrs['input_zp']}!={rescale_2.attrs['input_zp']}) " + "for the rescale operations to be fused with Add!" + ) + for op in [rescale_1, rescale_2]: + assert op.attrs["output_zp"] == 0, "" + assert op.attrs["per_channel"] is False, "per channel quantization is not supported." + + # Create a new add op to set the rescaled ifms and ofm + add_op = create_add_nop(prev_op.name + "_fused_rescales") + add_op.type = Op.Add + + # set the IFMs and OFM for the cloned operation + add_op.set_output_tensor(ofm) + add_op.add_input_tensor(rescale_1.ifm) + add_op.add_input_tensor(rescale_2.ifm) + add_op.set_ifm_ofm_shapes() + + # Remove the consumption of the IFMs to the Add + # since we are pruning them from the graph + for i, c in enumerate(prev_op.ifm.consumers()): + if c == rescale_1: + prev_op.ifm.consumers().pop(i) + for i, c in enumerate(prev_op.ifm2.consumers()): + if c == rescale_2: + prev_op.ifm2.consumers().pop(i) + + DebugDatabase.add_optimised(prev_op, op) + DebugDatabase.add_optimised(prev_op, rescale_1) + DebugDatabase.add_optimised(prev_op, rescale_2) + op = add_op + else: + print("Warning, unsupported fusing of TOSA Rescale with Add.") + assert False + else: + print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type) + assert False return op -- cgit v1.2.1