aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/graph_optimiser.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 1a6aaf10..a57ac82e 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -307,6 +307,15 @@ def fixup_resizebilinear(op, arch):
return op
+def convert_nop_split_to_identity(op, arch):
+ if op.type == "Split" and op.attrs.get("num_splits") == 1:
+ # the list comprehension should return a list with a single tensor
+ # if it shouldn't, remove_passthrough_tensor will fail appropriately
+ op.inputs = [i for i in op.inputs if i.shape == op.outputs[0].shape]
+ op.type = "Identity"
+ return op
+
+
def fixup_fully_connected_input(op, arch):
if op.type == "FullyConnectedAct":
inp = op.inputs[0]
@@ -956,6 +965,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
reorder_depthwise_weights,
fixup_resizebilinear,
fixup_bias_tensors,
+ convert_nop_split_to_identity,
convert_mul_max_to_abs_or_lrelu,
remove_unwanted_reshapes,
convert_lrelu,