aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/rewrite_graph.py
blob: 42acaf9bc9e59e29c53c5d4438284827a5e91dae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Description:
# Functions for abstracting out the traversal and rewriting of graphs so that the optimisation passes can focus on the
# correct operation.
#
# Requires two lists, one of functions that rewrite Tensors, and one of functions that rewrite Operations.
#
# Pre-order traversal, this supports rewrites. Therefore, functions can return something other than the original value.
#
# Post-order traversal, this does not support rewrites. Therefore, functions must return the original value.


def rewrite_graph_pre_order(nng, sg, arch, tensor_rewrite_list, op_rewrite_list, rewrite_unsupported=True):

    op_visit_dict = dict()
    tens_visit_dict = dict()

    def visit_op(op):
        if op in op_visit_dict:
            return op_visit_dict[op]
        res = op
        prev_res = None
        while prev_res != res:
            prev_res = res
            for rewrite in op_rewrite_list:
                if res.run_on_npu or rewrite_unsupported:
                    res = rewrite(res, arch, nng)

        op_visit_dict[op] = res
        op_visit_dict[res] = res

        inputs = res.inputs
        res.inputs = []
        for tens in inputs:
            res.inputs.append(visit_tens(tens))

        outputs = res.outputs
        res.outputs = []
        for tens in outputs:
            res.outputs.append(visit_tens(tens))

        return res

    def visit_tens(tens):
        if tens in tens_visit_dict:
            return tens_visit_dict[tens]

        res = tens
        prev_res = None
        while prev_res != res:
            prev_res = res
            for rewrite in tensor_rewrite_list:
                res = rewrite(res, arch, nng)

        tens_visit_dict[tens] = res
        tens_visit_dict[res] = res

        if res:
            ops = res.ops
            res.ops = []
            for op in ops:
                res.ops.append(visit_op(op))
        return res

    sg.output_tensors = [visit_tens(tens) for tens in sg.output_tensors]
    sg.refresh_after_modification()

    return sg


def visit_graph_post_order(start_tensors, arch, tensor_visit_list, op_visit_list):
    # Depth-first graph traversal, starting from the given list of tensors
    # (typically a subgraph's output_tensors).
    # Visits ops and tensors in input to output order.
    op_visit_dict = dict()
    tens_visit_dict = dict()

    def visit_op(op):
        if op in op_visit_dict:
            return
        op_visit_dict[op] = op

        for tens in op.inputs:
            visit_tens(tens)

        for visit in op_visit_list:
            visit(op, arch)

        for tens in op.outputs:
            visit_tens(tens)

    def visit_tens(tens):
        if tens is None or tens in tens_visit_dict:
            return

        tens_visit_dict[tens] = tens

        for op in tens.ops:
            visit_op(op)

        for visit in tensor_visit_list:
            visit(tens, arch)

    for tens in start_tensors:
        visit_tens(tens)


def verify_graph_health(nng):

    for sg in nng.subgraphs:
        verify_subgraph_health(sg)

    return True


def verify_subgraph_health(sg):
    op_visit_dict = dict()
    tens_visit_dict = dict()

    def visit_op(op):
        if op in op_visit_dict:
            return op_visit_dict[op]
        op_visit_dict[op] = op

        for tens in op.inputs:
            if not tens:
                continue
            assert op in tens.consumers()
            visit_tens(tens)

        for tens in op.outputs:
            assert op in tens.ops
            visit_tens(tens)

        return op

    def visit_tens(tens):
        if tens in tens_visit_dict:
            return tens_visit_dict[tens]

        tens_visit_dict[tens] = tens

        for op in tens.ops:
            assert tens in op.outputs
            visit_op(op)

        return tens

    for tens in sg.output_tensors:
        visit_tens(tens)

    return True