diff options
author | Tim Hall <tim.hall@arm.com> | 2020-04-27 18:20:16 +0100 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2020-04-29 13:00:51 +0100 |
commit | 79d07d2cbf1c5013ab40bb46a6ccd4c569966536 (patch) | |
tree | 410d17239b417be5593b3e6800001b797f8d3f98 /ethosu/vela/rewrite_graph.py | |
parent | 47bca71566d4d10e48f5a4d66e1130b8bf60700d (diff) | |
download | ethos-u-vela-79d07d2cbf1c5013ab40bb46a6ccd4c569966536.tar.gz |
Add Vela codebase0.1.0
- Added modules ethosu.vela and ethosu.mlw_codec.
- Added README and various configuration files.
Change-Id: I3690f8c8f5966306ecddaeb2793c30ca9c6e2eee
Diffstat (limited to 'ethosu/vela/rewrite_graph.py')
-rw-r--r-- | ethosu/vela/rewrite_graph.py | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/ethosu/vela/rewrite_graph.py b/ethosu/vela/rewrite_graph.py new file mode 100644 index 00000000..e6e24e62 --- /dev/null +++ b/ethosu/vela/rewrite_graph.py @@ -0,0 +1,171 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Description: +# Functions for abstracting out the traversal and rewriting of graphs so that the optimisation passes can focus on the +# correct operation. +# +# Requires two lists, one of functions that rewrite Tensors, and one of functions that rewrite Operations. +# +# Pre-order traversal, this supports rewrites. Therefore, functions can return something other than the original value. +# +# Post-order traversal, this does not support rewrites. Therefore, functions must return the original value. + + +def rewrite_graph_pre_order(sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True): + + op_visit_dict = dict() + tens_visit_dict = dict() + + def visit_op(op): + if op in op_visit_dict: + return op_visit_dict[op] + res = op + prev_res = None + while prev_res != res: + prev_res = res + for rewrite in op_rewrite_list: + if res.run_on_npu or rewrite_unsupported: + res = rewrite(res, arch) + + op_visit_dict[op] = res + op_visit_dict[res] = res + + inputs = res.inputs + res.inputs = [] + for tens in inputs: + res.inputs.append(visit_tens(tens)) + + outputs = res.outputs + res.outputs = [] + for tens in outputs: + res.outputs.append(visit_tens(tens)) + + return res + + def visit_tens(tens): + if tens in tens_visit_dict: + return tens_visit_dict[tens] + + res = tens + prev_res = None + while prev_res != res: + prev_res = res + for rewrite in tensor_rewrite_list: + res = rewrite(res, arch) + + tens_visit_dict[tens] = res + tens_visit_dict[res] = res + + ops = res.ops + res.ops = [] + for op in ops: + res.ops.append(visit_op(op)) + return res + + sg.output_tensors = [visit_tens(tens) for tens in sg.output_tensors] + sg.refresh_after_modification() + + return sg + + +def visit_graph_post_order(sg, arch, tensor_visit_list, op_visit_list): + + op_visit_dict = dict() + tens_visit_dict = dict() + + def visit_op(op): + if op in op_visit_dict: + return op_visit_dict[op] + op_visit_dict[op] = op + + for tens in op.inputs: + visit_tens(tens) + + for visit in op_visit_list: + visit(op, arch) + + for tens in op.outputs: + visit_tens(tens) + + return op + + def visit_tens(tens): + if tens in tens_visit_dict: + return tens_visit_dict[tens] + + tens_visit_dict[tens] = tens + + for op in tens.ops: + visit_op(op) + + for visit in tensor_visit_list: + visit(tens, arch) + + return tens + + for tens in sg.output_tensors: + visit_tens(tens) + + sg.refresh_after_modification() + + return sg + + +def verify_graph_health(nng): + + for sg in nng.subgraphs: + verify_subgraph_health(sg) + + return True + + +def verify_subgraph_health(sg): + op_visit_dict = dict() + tens_visit_dict = dict() + + def visit_op(op): + if op in op_visit_dict: + return op_visit_dict[op] + op_visit_dict[op] = op + + for tens in op.inputs: + assert op in tens.consumers() + visit_tens(tens) + + for tens in op.outputs: + assert op in tens.ops + visit_tens(tens) + + return op + + def visit_tens(tens): + if tens in tens_visit_dict: + return tens_visit_dict[tens] + + tens_visit_dict[tens] = tens + + for op in tens.ops: + assert tens in op.outputs + visit_op(op) + + return tens + + for tens in sg.output_tensors: + visit_tens(tens) + + return True |