aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/join.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/graph_edit/join.py')
-rw-r--r--src/mlia/nn/rewrite/core/graph_edit/join.py128
1 files changed, 128 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py
new file mode 100644
index 0000000..758f4cf
--- /dev/null
+++ b/src/mlia/nn/rewrite/core/graph_edit/join.py
@@ -0,0 +1,128 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+import os
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import tensorflow as tf
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from mlia.nn.rewrite.core.utils.utils import load, save
+
+
+def join_models(input_src, input_dst, output_file, subgraph_src=0, subgraph_dst=0):
+ src_model = load(input_src)
+ dst_model = load(input_dst)
+ src_subgraph = src_model.subgraphs[subgraph_src]
+ dst_subgraph = dst_model.subgraphs[subgraph_dst]
+ join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph)
+ save(dst_model, output_file)
+
+
+def join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph):
+ """Copy subgraph src into subgraph dst from model, connecting tensors with the same names"""
+ # Find inputs that match outputs in the other graph and vice versa
+ dst_to_src = {
+ i: o
+ for i in src_subgraph.inputs
+ for o in dst_subgraph.outputs
+ if src_subgraph.tensors[i].name == dst_subgraph.tensors[o].name
+ }
+
+ src_to_dst = {
+ o: i
+ for i in dst_subgraph.inputs
+ for o in src_subgraph.outputs
+ if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name
+ }
+
+ assert not (src_to_dst and dst_to_src), (
+ "Source and destination subgraphs appear to connect in a loop: %d tensors from src to dst, %d tensors from dst to src"
+ % (len(src_to_dst), len(dst_to_src))
+ )
+
+ # Relabel matched input/output tensors between graphs
+ tensor_relabel = src_to_dst if src_to_dst else dst_to_src
+
+ # Remove matched inputs/outputs as these will now become internal tensors
+ if src_to_dst:
+ src_subgraph.outputs = [
+ o for o in src_subgraph.outputs if not o in tensor_relabel.keys()
+ ]
+ dst_subgraph.inputs = [
+ i for i in dst_subgraph.inputs if not i in tensor_relabel.values()
+ ]
+ else:
+ src_subgraph.inputs = [
+ i for i in src_subgraph.inputs if not i in tensor_relabel.keys()
+ ]
+ dst_subgraph.outputs = [
+ o for o in dst_subgraph.outputs if not o in tensor_relabel.values()
+ ]
+
+ buffer_relabel = {
+ src_subgraph.tensors[i].buffer: dst_subgraph.tensors[o].buffer
+ for i, o in tensor_relabel.items()
+ }
+
+ used_tensors = [
+ t for i, t in enumerate(src_subgraph.tensors) if not i in tensor_relabel
+ ]
+
+ used_buffer_ids = [t.buffer for t in used_tensors]
+
+ opcode_data = lambda c: (
+ c.builtinCode,
+ c.deprecatedBuiltinCode,
+ c.customCode,
+ c.version,
+ )
+ opcode_relabel = {
+ s: d
+ for s in range(len(src_model.operatorCodes))
+ for d in range(len(dst_model.operatorCodes))
+ if opcode_data(src_model.operatorCodes[s])
+ == opcode_data(dst_model.operatorCodes[d])
+ }
+
+ # operator order defines execution schedule so must reflect the inputs/outputs dependencies
+ if dst_to_src:
+ dst_subgraph.operators += src_subgraph.operators
+ else:
+ dst_subgraph.operators = src_subgraph.operators + dst_subgraph.operators
+
+ append_relabel(src_subgraph.tensors, dst_subgraph.tensors, tensor_relabel)
+ append_relabel(src_model.operatorCodes, dst_model.operatorCodes, opcode_relabel)
+
+ tensor_relabel[
+ -1
+ ] = -1 # Some files have ops with -1 input tensors; leave unchanged
+
+ for i in used_buffer_ids:
+ if not i in buffer_relabel:
+ buffer_relabel[i] = len(dst_model.buffers)
+ dst_model.buffers.append(src_model.buffers[i])
+
+ for o in src_subgraph.operators:
+ o.inputs = [tensor_relabel[t] for t in o.inputs]
+ o.outputs = [tensor_relabel[t] for t in o.outputs]
+ o.opcodeIndex = opcode_relabel[o.opcodeIndex]
+
+ for t in used_tensors:
+ t.buffer = buffer_relabel[t.buffer]
+
+ src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs]
+ src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs]
+
+ dst_subgraph.inputs = list(set(src_subgraph.inputs).union(dst_subgraph.inputs))
+ dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))
+
+
+def append_relabel(src, dst, map=None):
+ if map is None:
+ map = {}
+ for i, x in enumerate(src):
+ if not i in map:
+ map[i] = len(dst)
+ dst.append(x)
+ return map