diff options
author | Louis Verhaard <louis.verhaard@arm.com> | 2020-10-14 08:32:41 +0200 |
---|---|---|
committer | Louis Verhaard <louis.verhaard@arm.com> | 2020-10-20 08:50:29 +0200 |
commit | 17afa2837ad366f2da32e2bc0e2659ebb35bd1d5 (patch) | |
tree | 7329fe546be4e2a95e205daf83637c7927bf7684 /ethosu/vela/rewrite_graph.py | |
parent | 6e827082524af57bf04833c30754384b46216e59 (diff) | |
download | ethos-u-vela-17afa2837ad366f2da32e2bc0e2659ebb35bd1d5.tar.gz |
MLBEDSW-3268: Refactor mark_tensors
- Refactored mark_tensor_purpose
- Initial weight compression is now always done in insert_dma
- Removed mark_tensor_format
Change-Id: Ic719b9bcd1d27e1390d7b9ce8cd21795139ec814
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/rewrite_graph.py')
-rw-r--r-- | ethosu/vela/rewrite_graph.py | 22 |
1 files changed, 8 insertions, 14 deletions
diff --git a/ethosu/vela/rewrite_graph.py b/ethosu/vela/rewrite_graph.py index e71b228a..42acaf9b 100644 --- a/ethosu/vela/rewrite_graph.py +++ b/ethosu/vela/rewrite_graph.py @@ -82,14 +82,16 @@ def rewrite_graph_pre_order(nng, sg, arch, tensor_rewrite_list, op_rewrite_list, return sg -def visit_graph_post_order(sg, arch, tensor_visit_list, op_visit_list): - +def visit_graph_post_order(start_tensors, arch, tensor_visit_list, op_visit_list): + # Depth-first graph traversal, starting from the given list of tensors + # (typically a subgraph's output_tensors). + # Visits ops and tensors in input to output order. op_visit_dict = dict() tens_visit_dict = dict() def visit_op(op): if op in op_visit_dict: - return op_visit_dict[op] + return op_visit_dict[op] = op for tens in op.inputs: @@ -101,11 +103,9 @@ def visit_graph_post_order(sg, arch, tensor_visit_list, op_visit_list): for tens in op.outputs: visit_tens(tens) - return op - def visit_tens(tens): - if tens in tens_visit_dict: - return tens_visit_dict[tens] + if tens is None or tens in tens_visit_dict: + return tens_visit_dict[tens] = tens @@ -115,15 +115,9 @@ def visit_graph_post_order(sg, arch, tensor_visit_list, op_visit_list): for visit in tensor_visit_list: visit(tens, arch) - return tens - - for tens in sg.output_tensors: + for tens in start_tensors: visit_tens(tens) - sg.refresh_after_modification() - - return sg - def verify_graph_health(nng): |