aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/nn_graph.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 92c7e1b8..3c87f9be 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -182,12 +182,22 @@ class Subgraph:
def update_consumers(self):
visit_op_set = set()
visit_tensor_set = set()
+ sg_passes_op_set = set()
self.input_tensors = []
+ for ps in self.passes:
+ for op in ps.ops:
+ sg_passes_op_set.add(op)
+
print_visit = False
def visit_op(op):
- if op in visit_op_set:
+ if op in visit_op_set or (sg_passes_op_set and op not in sg_passes_op_set):
+ # Op already visited or op is not part of a pass in this subgraph
+ # Typcial case when op is not part of this subgraph but is visited anyway are concat ops
+ # that are split up into different subgraphs (several avgpool). Since they share the same
+ # output the avgpool that do not belong to this subgraph will be traversed which
+ # should be avoided.
return
visit_op_set.add(op)