diff options
-rw-r--r-- | ethosu/vela/nn_graph.py | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index 92c7e1b8..3c87f9be 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -182,12 +182,22 @@ class Subgraph: def update_consumers(self): visit_op_set = set() visit_tensor_set = set() + sg_passes_op_set = set() self.input_tensors = [] + for ps in self.passes: + for op in ps.ops: + sg_passes_op_set.add(op) + print_visit = False def visit_op(op): - if op in visit_op_set: + if op in visit_op_set or (sg_passes_op_set and op not in sg_passes_op_set): + # Op already visited or op is not part of a pass in this subgraph + # Typcial case when op is not part of this subgraph but is visited anyway are concat ops + # that are split up into different subgraphs (several avgpool). Since they share the same + # output the avgpool that do not belong to this subgraph will be traversed which + # should be avoided. return visit_op_set.add(op) |