aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/rewrite_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/rewrite_graph.py')
-rw-r--r--ethosu/vela/rewrite_graph.py171
1 files changed, 171 insertions, 0 deletions
diff --git a/ethosu/vela/rewrite_graph.py b/ethosu/vela/rewrite_graph.py
new file mode 100644
index 00000000..e6e24e62
--- /dev/null
+++ b/ethosu/vela/rewrite_graph.py
@@ -0,0 +1,171 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Description:
+# Functions for abstracting out the traversal and rewriting of graphs so that the optimisation passes can focus on the
+# correct operation.
+#
+# Requires two lists, one of functions that rewrite Tensors, and one of functions that rewrite Operations.
+#
+# Pre-order traversal, this supports rewrites. Therefore, functions can return something other than the original value.
+#
+# Post-order traversal, this does not support rewrites. Therefore, functions must return the original value.
+
+
+def rewrite_graph_pre_order(sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True):
+
+ op_visit_dict = dict()
+ tens_visit_dict = dict()
+
+ def visit_op(op):
+ if op in op_visit_dict:
+ return op_visit_dict[op]
+ res = op
+ prev_res = None
+ while prev_res != res:
+ prev_res = res
+ for rewrite in op_rewrite_list:
+ if res.run_on_npu or rewrite_unsupported:
+ res = rewrite(res, arch)
+
+ op_visit_dict[op] = res
+ op_visit_dict[res] = res
+
+ inputs = res.inputs
+ res.inputs = []
+ for tens in inputs:
+ res.inputs.append(visit_tens(tens))
+
+ outputs = res.outputs
+ res.outputs = []
+ for tens in outputs:
+ res.outputs.append(visit_tens(tens))
+
+ return res
+
+ def visit_tens(tens):
+ if tens in tens_visit_dict:
+ return tens_visit_dict[tens]
+
+ res = tens
+ prev_res = None
+ while prev_res != res:
+ prev_res = res
+ for rewrite in tensor_rewrite_list:
+ res = rewrite(res, arch)
+
+ tens_visit_dict[tens] = res
+ tens_visit_dict[res] = res
+
+ ops = res.ops
+ res.ops = []
+ for op in ops:
+ res.ops.append(visit_op(op))
+ return res
+
+ sg.output_tensors = [visit_tens(tens) for tens in sg.output_tensors]
+ sg.refresh_after_modification()
+
+ return sg
+
+
+def visit_graph_post_order(sg, arch, tensor_visit_list, op_visit_list):
+
+ op_visit_dict = dict()
+ tens_visit_dict = dict()
+
+ def visit_op(op):
+ if op in op_visit_dict:
+ return op_visit_dict[op]
+ op_visit_dict[op] = op
+
+ for tens in op.inputs:
+ visit_tens(tens)
+
+ for visit in op_visit_list:
+ visit(op, arch)
+
+ for tens in op.outputs:
+ visit_tens(tens)
+
+ return op
+
+ def visit_tens(tens):
+ if tens in tens_visit_dict:
+ return tens_visit_dict[tens]
+
+ tens_visit_dict[tens] = tens
+
+ for op in tens.ops:
+ visit_op(op)
+
+ for visit in tensor_visit_list:
+ visit(tens, arch)
+
+ return tens
+
+ for tens in sg.output_tensors:
+ visit_tens(tens)
+
+ sg.refresh_after_modification()
+
+ return sg
+
+
+def verify_graph_health(nng):
+
+ for sg in nng.subgraphs:
+ verify_subgraph_health(sg)
+
+ return True
+
+
+def verify_subgraph_health(sg):
+ op_visit_dict = dict()
+ tens_visit_dict = dict()
+
+ def visit_op(op):
+ if op in op_visit_dict:
+ return op_visit_dict[op]
+ op_visit_dict[op] = op
+
+ for tens in op.inputs:
+ assert op in tens.consumers()
+ visit_tens(tens)
+
+ for tens in op.outputs:
+ assert op in tens.ops
+ visit_tens(tens)
+
+ return op
+
+ def visit_tens(tens):
+ if tens in tens_visit_dict:
+ return tens_visit_dict[tens]
+
+ tens_visit_dict[tens] = tens
+
+ for op in tens.ops:
+ assert tens in op.outputs
+ visit_op(op)
+
+ return tens
+
+ for tens in sg.output_tensors:
+ visit_tens(tens)
+
+ return True