aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/rewrite_graph.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-10-14 08:32:41 +0200
committerLouis Verhaard <louis.verhaard@arm.com>2020-10-20 08:50:29 +0200
commit17afa2837ad366f2da32e2bc0e2659ebb35bd1d5 (patch)
tree7329fe546be4e2a95e205daf83637c7927bf7684 /ethosu/vela/rewrite_graph.py
parent6e827082524af57bf04833c30754384b46216e59 (diff)
downloadethos-u-vela-17afa2837ad366f2da32e2bc0e2659ebb35bd1d5.tar.gz
MLBEDSW-3268: Refactor mark_tensors
- Refactored mark_tensor_purpose - Initial weight compression is now always done in insert_dma - Removed mark_tensor_format Change-Id: Ic719b9bcd1d27e1390d7b9ce8cd21795139ec814 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/rewrite_graph.py')
-rw-r--r--ethosu/vela/rewrite_graph.py22
1 files changed, 8 insertions, 14 deletions
diff --git a/ethosu/vela/rewrite_graph.py b/ethosu/vela/rewrite_graph.py
index e71b228a..42acaf9b 100644
--- a/ethosu/vela/rewrite_graph.py
+++ b/ethosu/vela/rewrite_graph.py
@@ -82,14 +82,16 @@ def rewrite_graph_pre_order(nng, sg, arch, tensor_rewrite_list, op_rewrite_list,
return sg
-def visit_graph_post_order(sg, arch, tensor_visit_list, op_visit_list):
-
+def visit_graph_post_order(start_tensors, arch, tensor_visit_list, op_visit_list):
+ # Depth-first graph traversal, starting from the given list of tensors
+ # (typically a subgraph's output_tensors).
+ # Visits ops and tensors in input to output order.
op_visit_dict = dict()
tens_visit_dict = dict()
def visit_op(op):
if op in op_visit_dict:
- return op_visit_dict[op]
+ return
op_visit_dict[op] = op
for tens in op.inputs:
@@ -101,11 +103,9 @@ def visit_graph_post_order(sg, arch, tensor_visit_list, op_visit_list):
for tens in op.outputs:
visit_tens(tens)
- return op
-
def visit_tens(tens):
- if tens in tens_visit_dict:
- return tens_visit_dict[tens]
+ if tens is None or tens in tens_visit_dict:
+ return
tens_visit_dict[tens] = tens
@@ -115,15 +115,9 @@ def visit_graph_post_order(sg, arch, tensor_visit_list, op_visit_list):
for visit in tensor_visit_list:
visit(tens, arch)
- return tens
-
- for tens in sg.output_tensors:
+ for tens in start_tensors:
visit_tens(tens)
- sg.refresh_after_modification()
-
- return sg
-
def verify_graph_health(nng):