aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/join.py
blob: 2530ec8958612462241e5357dbb6ae1a9d8f33f1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Join module."""
from __future__ import annotations

import os
from pathlib import Path

import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.python.schema_py_generated import OperatorCodeT
from tensorflow.lite.python.schema_py_generated import SubGraphT

from mlia.nn.rewrite.core.utils.utils import load
from mlia.nn.rewrite.core.utils.utils import save

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


def join_models(
    input_src: str | Path,
    input_dst: str | Path,
    output_file: str | Path,
    subgraph_src: int = 0,
    subgraph_dst: int = 0,
) -> None:
    """Join two models and save the result into a given model file path."""
    src_model = load(input_src)
    dst_model = load(input_dst)
    src_subgraph = src_model.subgraphs[subgraph_src]
    dst_subgraph = dst_model.subgraphs[subgraph_dst]
    join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph)
    save(dst_model, output_file)


def join_subgraphs(
    src_model: ModelT,
    src_subgraph: SubGraphT,
    dst_model: ModelT,
    dst_subgraph: SubGraphT,
) -> None:
    """Join two subgraphs, connecting tensors with the same names."""
    # Find inputs that match outputs in the other graph and vice versa
    dst_to_src = {
        i: o
        for i in src_subgraph.inputs
        for o in dst_subgraph.outputs
        if src_subgraph.tensors[i].name == dst_subgraph.tensors[o].name
    }

    src_to_dst = {
        o: i
        for i in dst_subgraph.inputs
        for o in src_subgraph.outputs
        if dst_subgraph.tensors[i].name == src_subgraph.tensors[o].name
    }

    assert not (
        src_to_dst and dst_to_src
    ), f"Source and destination subgraphs appear to connect in a loop: \
        {len(src_to_dst)} tensors from src to dst, {len(dst_to_src)} \
        tensors from dst to src"

    # Relabel matched input/output tensors between graphs
    tensor_relabel = src_to_dst if src_to_dst else dst_to_src

    # Remove matched inputs/outputs as these will now become internal tensors
    if src_to_dst:
        src_subgraph.outputs = [
            output
            for output in src_subgraph.outputs
            if output not in tensor_relabel.keys()
        ]
        dst_subgraph.inputs = [
            input
            for input in dst_subgraph.inputs
            if input not in tensor_relabel.values()
        ]
    else:
        src_subgraph.inputs = [
            input for input in src_subgraph.inputs if input not in tensor_relabel.keys()
        ]
        dst_subgraph.outputs = [
            output
            for output in dst_subgraph.outputs
            if output not in tensor_relabel.values()
        ]

    buffer_relabel = {
        src_subgraph.tensors[input].buffer: dst_subgraph.tensors[output].buffer
        for input, output in tensor_relabel.items()
    }

    used_tensors = [
        tensor
        for i, tensor in enumerate(src_subgraph.tensors)
        if i not in tensor_relabel
    ]

    used_buffer_ids = [tensor.buffer for tensor in used_tensors]

    def opcode_data(code: OperatorCodeT) -> tuple:
        return (
            code.builtinCode,
            code.deprecatedBuiltinCode,
            code.customCode,
            code.version,
        )

    opcode_relabel = {
        s: d
        for s in range(len(src_model.operatorCodes))
        for d in range(len(dst_model.operatorCodes))
        if opcode_data(src_model.operatorCodes[s])
        == opcode_data(dst_model.operatorCodes[d])
    }

    # operator order defines execution schedule so must reflect
    # the inputs/outputs dependencies
    if dst_to_src:
        dst_subgraph.operators += src_subgraph.operators
    else:
        dst_subgraph.operators = src_subgraph.operators + dst_subgraph.operators

    append_relabel(src_subgraph.tensors, dst_subgraph.tensors, tensor_relabel)
    append_relabel(src_model.operatorCodes, dst_model.operatorCodes, opcode_relabel)

    tensor_relabel[
        -1
    ] = -1  # Some files have ops with -1 input tensors; leave unchanged

    for i in used_buffer_ids:
        if i not in buffer_relabel:
            buffer_relabel[i] = len(dst_model.buffers)
            dst_model.buffers.append(src_model.buffers[i])

    for operator in src_subgraph.operators:
        operator.inputs = [tensor_relabel[tensor] for tensor in operator.inputs]
        operator.outputs = [tensor_relabel[tensor] for tensor in operator.outputs]
        operator.opcodeIndex = opcode_relabel[operator.opcodeIndex]

    for tensor in used_tensors:
        tensor.buffer = buffer_relabel[tensor.buffer]

    src_subgraph.inputs = [tensor_relabel[t] for t in src_subgraph.inputs]
    src_subgraph.outputs = [tensor_relabel[t] for t in src_subgraph.outputs]

    dst_subgraph.inputs = list(set(src_subgraph.inputs).union(dst_subgraph.inputs))
    dst_subgraph.outputs = list(set(src_subgraph.outputs).union(dst_subgraph.outputs))


def append_relabel(src: list, dst: list, operator_map: dict) -> None:
    """Update the operator map over relabeled tensors in a subgraph."""
    if operator_map is None:
        raise ValueError("The input operator map cannot be None!")

    for i, x in enumerate(src):  # pylint: disable=invalid-name
        if i not in operator_map:
            operator_map[i] = len(dst)
            dst.append(x)