aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPer Åstrand <per.astrand@arm.com>2024-03-21 12:58:50 +0100
committerPer Åstrand <per.astrand@arm.com>2024-04-12 12:19:19 +0200
commit931613df7c68fb1c7cb45c6f69783c86003d7583 (patch)
treef315e61dd402a00c8f85e316b5c0dcec6a375115
parent31947ad1aec50b64508bf367cb3e87c93f8c4693 (diff)
downloadethos-u-vela-931613df7c68fb1c7cb45c6f69783c86003d7583.tar.gz
Fuse rescales into Add and Conv2d operation
Remove the upscale to int32 before and after the the add operation. Re-enable fusing of conv2d and rescale that was removed earlier. Signed-off-by: Per Åstrand <per.astrand@arm.com> Change-Id: I5e7d9bd99bb3925588b507824d8eb3e6642cc7f0
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py90
1 files changed, 86 insertions, 4 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 19244c2..09b2c52 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -365,6 +365,10 @@ def rewrite_rescale(op, arch, nng):
# some error checking
assert len(ifm.ops) == 1
+ prev_op = ifm.ops[0]
+
+ # TODO currently not supported
+ assert len(ifm.consumer_list) == 1
input_zp = op.attrs["input_zp"]
output_zp = op.attrs["output_zp"]
@@ -402,10 +406,88 @@ def rewrite_rescale(op, arch, nng):
else:
rounding_mode = RoundingMode.HalfUp
- # Generate Rescale behaviour attached to a compatible NOP
- avgpool_op = replace_rescale_with_avg_pool(op)
- avgpool_op.rounding_mode = rounding_mode
- avgpool_op.explicit_scaling = explicit_scaling
+ if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
+ # Currently not supporting per_channel quantization
+ if ifm.dtype == DataType.int32 and not per_channel:
+ prev_op.explicit_scaling = explicit_scaling
+ prev_op.rounding_mode = rounding_mode
+
+ # Bypass op
+ prev_op.set_output_tensor(ofm)
+ DebugDatabase.add_optimised(op, prev_op)
+ return op
+ else:
+ print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
+ assert False
+ elif (
+ (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
+ or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
+ or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
+ ):
+ # Create NOP performing the RESCALE
+ avgpool_op = replace_rescale_with_avg_pool(op)
+ avgpool_op.rounding_mode = rounding_mode
+
+ if per_channel:
+ # TODO
+ avgpool_op.explicit_scaling = explicit_scaling
+ print("Warning, unsupported TOSA Rescale")
+ assert False
+ else:
+ avgpool_op.explicit_scaling = explicit_scaling
+ elif prev_op.type == Op.Add:
+ # Check that the operations before the Add which creates the IFMs
+ # are Op.Rescale that we can fuse into the add
+ rescale_1 = prev_op.ifm.ops[0]
+ rescale_2 = prev_op.ifm2.ops[0]
+
+ if rescale_1.type == Op.Rescale and rescale_2.type == Op.Rescale:
+ # We are assuming the quantization to be the same for IFMs
+ equal_attributes = ["multiplier", "shift", "double_round"]
+ for a in equal_attributes:
+ assert op.attrs[a] == rescale_1.attrs[a] == rescale_2.attrs[a], (
+ f"Only handling equal {a} for all operands "
+ "({op.attrs[a]}, {rescale_1.attrs[a]}, {rescale_2.attrs[a]}) "
+ "for all the rescale operations to be fused with Add!"
+ )
+
+ assert rescale_1.attrs["input_zp"] == rescale_2.attrs["input_zp"], (
+ f"Only handling equal input_zp ({rescale_1.attrs['input_zp']}!={rescale_2.attrs['input_zp']}) "
+ "for the rescale operations to be fused with Add!"
+ )
+ for op in [rescale_1, rescale_2]:
+ assert op.attrs["output_zp"] == 0, ""
+ assert op.attrs["per_channel"] is False, "per channel quantization is not supported."
+
+ # Create a new add op to set the rescaled ifms and ofm
+ add_op = create_add_nop(prev_op.name + "_fused_rescales")
+ add_op.type = Op.Add
+
+ # set the IFMs and OFM for the cloned operation
+ add_op.set_output_tensor(ofm)
+ add_op.add_input_tensor(rescale_1.ifm)
+ add_op.add_input_tensor(rescale_2.ifm)
+ add_op.set_ifm_ofm_shapes()
+
+ # Remove the consumption of the IFMs to the Add
+ # since we are pruning them from the graph
+ for i, c in enumerate(prev_op.ifm.consumers()):
+ if c == rescale_1:
+ prev_op.ifm.consumers().pop(i)
+ for i, c in enumerate(prev_op.ifm2.consumers()):
+ if c == rescale_2:
+ prev_op.ifm2.consumers().pop(i)
+
+ DebugDatabase.add_optimised(prev_op, op)
+ DebugDatabase.add_optimised(prev_op, rescale_1)
+ DebugDatabase.add_optimised(prev_op, rescale_2)
+ op = add_op
+ else:
+ print("Warning, unsupported fusing of TOSA Rescale with Add.")
+ assert False
+ else:
+ print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
+ assert False
return op