aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2024-04-04 10:08:05 +0200
committerJohan Alfven <johan.alfven@arm.com>2024-04-05 09:02:58 +0200
commitabed3c27c9e02a96017b497a17fe8641c31c0502 (patch)
tree1e7da1a0c0634b394085cc3398dc879d40b06387
parent190b63a6ae6908625dffab203a8137c27aaec5fd (diff)
downloadethos-u-vela-abed3c27c9e02a96017b497a17fe8641c31c0502.tar.gz
MLBEDSW-8885: MLCE: Fix assert in verify_subgraph_health
- Assert triggered due to that the tensor consumer list did not contain expected operators. - The problem happened because a concat op was split into two avgpool ops and these two ops run in separate subgraphs with a cpu node in between. Since the avgpool ops share the same output tensor this caused some corruption to the tensor consumer list when the last subgraph was traversed. - The fix is to ignore ops that do not belong in the subgraph's set of operators (the pass list) when updating the consumers. Change-Id: I4d94b54c77001f6447bec31ec62daeebc9b104f9 Signed-off-by: Johan Alfven <johan.alfven@arm.com>
-rw-r--r--ethosu/vela/nn_graph.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 92c7e1b..3c87f9b 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -182,12 +182,22 @@ class Subgraph:
def update_consumers(self):
visit_op_set = set()
visit_tensor_set = set()
+ sg_passes_op_set = set()
self.input_tensors = []
+ for ps in self.passes:
+ for op in ps.ops:
+ sg_passes_op_set.add(op)
+
print_visit = False
def visit_op(op):
- if op in visit_op_set:
+ if op in visit_op_set or (sg_passes_op_set and op not in sg_passes_op_set):
+ # Op already visited or op is not part of a pass in this subgraph
+ # Typcial case when op is not part of this subgraph but is visited anyway are concat ops
+ # that are split up into different subgraphs (several avgpool). Since they share the same
+ # output the avgpool that do not belong to this subgraph will be traversed which
+ # should be avoided.
return
visit_op_set.add(op)