1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from mlia.nn.rewrite.core.utils.utils import load, save
def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0):
src_model = load(input_src)
dst_model = load(input_dst)
src_subgraph = src_model.subgraphs[subgraph_src]
dst_subgraph = dst_model.subgraphs[subgraph_dst]
join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph)
save(dst_model, output_file)
def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
"""Copy subgraph src into subgraph dst from model, connecting tensors with the same names"""
# Find inputs that match outputs in the other graph and vice versa
dst_to_src = {
i: o
for i in src_subgraph.inputs
for o in dst_subgraph.outputs
if src_subgraph.tensors[i].name == dst_subgraph.tensors[o].name
}
src_to_dst = {
o: i
for i in dst_subgraph.inputs
for o in src_subgraph.outputs
if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name
}
assert not (src_to_dst and dst_to_src), (
"Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src"
% (len(src_to_dst), len(dst_to_src))
)
# Relabel matched input/output tensors between graphs
tensor_relabel = src_to_dst if src_to_dst else dst_to_src
# Remove matched inputs/outputs as these will now become internal tensors
if src_to_dst:
src_subgraph.outputs = [
o for o in src_subgraph.outputs if not o in tensor_relabel.keys()
]
dst_subgraph.inputs = [
i for i in dst_subgraph.inputs if not i in tensor_relabel.values()
]
else:
src_subgraph.inputs = [
i for i in src_subgraph.inputs if not i in tensor_relabel.keys()
]
dst_subgraph.outputs = [
o for o in dst_subgraph.outputs if not o in tensor_relabel.values()
]
buffer_relabel = {
src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer
for i, o in tensor_relabel.items()
}
used_tensors = [
t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel
]
used_buffer_ids = [t.buffer for t in used_tensors]
opcode_data = lambda c: (
c.builtinCode,
c.deprecatedBuiltinCode,
c.customCode,
c.version,
)
opcode_relabel = {
s: d
for s in range(len(src_model.operatorCodes))
for d in range(len(dst_model.operatorCodes))
if opcode_data(src_model.operatorCodes[s])
== opcode_data(dst_model.operatorCodes[d])
}
# operator order defines execution schedule so must reflect the inputs/outputs dependencies
if dst_to_src:
dst_subgraph.operators += src_subgraph.operators
else:
dst_subgraph.operators = src_subgraph.operators + dst_subgraph.operators
append_relabel(src_subgraph.tensors, dst_subgraph.tensors, tensor_relabel)
append_relabel(src_model.operatorCodes, dst_model.operatorCodes, opcode_relabel)
tensor_relabel[
-1
] = -1 # Some files have ops with -1 input tensors; leave unchanged
for i in used_buffer_ids:
if not i in buffer_relabel:
buffer_relabel[i] = len(dst_model.buffers)
dst_model.buffers.append(src_model.buffers[i])
for o in src_subgraph.operators:
o.inputs = [tensor_relabel[t] for t in o.inputs]
o.outputs = [tensor_relabel[t] for t in o.outputs]
o.opcodeIndex = opcode_relabel[o.opcodeIndex]
for t in used_tensors:
t.buffer = buffer_relabel[t.buffer]
src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs]
src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs]
dst_subgraph.inputs = list(set(src_subgraph.inputs).union(dst_subgraph.inputs))
dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))
def append_relabel(src, dst, map=None):
if map is None:
map = {}
for i, x in enumerate(src):
if not i in map:
map[i] = len(dst)
dst.append(x)
return map
|