aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index d32955d5..954ac68f 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -182,7 +182,7 @@ def insert_add_copy_op_after_tens(tens, ifm_ofm_shape):
def fix_sg_input_output_tosa(op, arch, nng):
- if not op.run_on_npu or op.type != Op.Reshape:
+ if not op.run_on_npu or op.type not in (Op.Reshape, Op.Identity):
return op
# For the Reshape operators we want to remove, tensors are removed.
@@ -306,8 +306,8 @@ def rewrite_concat(op):
assert op.ofm_shapes[0][axis_4D] == offset
-def remove_reshapes(op, arch):
- if op.run_on_npu and op.type == Op.Reshape:
+def remove_memory_ops(op, arch):
+ if op.run_on_npu and op.type in (Op.Reshape, Op.Identity):
bypass_memory_only_ops(op)
@@ -820,7 +820,7 @@ def tosa_optimise_graph(nng, arch):
# Removal of reshapes
for sg in nng.subgraphs:
- rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_memory_ops])
sg.refresh_after_modification()
# Decomposing of elementwise