aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/cut.py
blob: a323b7b47c3f7d1ae1a8f274d3fb6cfedb537d34 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import os
from collections import defaultdict

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

from mlia.nn.rewrite.core.utils.utils import load, save


def cut_subgraph(subgraph, input_tensor_names, output_tensor_names):
    """Change the global inputs and outputs of a graph to the provided named tensors"""

    def tensors_by_name(names):
        seek = frozenset([name.encode("utf-8") for name in names])
        tensors = [
            i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek
        ]
        return tensors

    if input_tensor_names is not None:
        subgraph.inputs = tensors_by_name(input_tensor_names)
        assert len(subgraph.inputs) == len(
            input_tensor_names
        ), "Expected %d input tensors: %s\nFound: %s" % (
            len(subgraph.inputs),
            ", ".join(input_tensor_names),
            ", ".join(subgraph.tensors[i].name for i in subgraph.inputs),
        )

    if output_tensor_names is not None:
        subgraph.outputs = tensors_by_name(output_tensor_names)
        assert len(subgraph.outputs) == len(
            output_tensor_names
        ), "Expected %d output tensors: %s\nFound: %s" % (
            len(subgraph.outputs),
            ", ".join(output_tensor_names),
            ", ".join(subgraph.tensors[i].name for i in subgraph.outputs),
        )


def simplify(model):
    """Remove any unused operators, tensors and buffers from a model"""
    for s in model.subgraphs:
        simplify_subgraph(s)

    used_buffers = {t.buffer for t in s.tensors for s in model.subgraphs}
    used_buffers = used_buffers.union({m.buffer for m in model.metadata})
    used_buffers.add(
        0
    )  # Buffer zero is always expected to be a zero-sized nullptr buffer by the TFLite runtime
    model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers)

    for s in model.subgraphs:
        for t in s.tensors:
            t.buffer = buf_relabel[t.buffer]

    for m in model.metadata:
        m.buffer = buf_relabel[m.buffer]


def simplify_subgraph(subgraph):
    requires = defaultdict(set)

    for o, operator in enumerate(subgraph.operators):
        for t in operator.outputs:
            if not t in subgraph.inputs:
                requires[t].add(o)

    op_set, ten_set = find_required(subgraph, requires, subgraph.outputs)

    subgraph.operators, op_relabel = filter_relabel(subgraph.operators, op_set)
    subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set)

    ten_relabel[-1] = -1  # Some files have ops with -1 input tensors; leave unchanged

    for op in subgraph.operators:
        op.inputs = [ten_relabel[t] for t in op.inputs]
        op.outputs = [ten_relabel[t] for t in op.outputs]

    subgraph.inputs = [ten_relabel[t] for t in subgraph.inputs]
    subgraph.outputs = [ten_relabel[t] for t in subgraph.outputs]


def find_required(subgraph, requires, tensors):
    visited_operators = set()
    visited_tensors = set(tensors)
    stop_tensors = set(subgraph.inputs)
    changed = True

    next_tensors = visited_tensors
    while next_tensors:
        loop_tensors = next_tensors
        next_tensors = set()
        for t in loop_tensors:
            candidate_operators = set(requires[t])
            new_operators = candidate_operators - visited_operators
            visited_operators = visited_operators.union(new_operators)
            for op in new_operators:
                candidate_tensors = set(subgraph.operators[op].inputs)
                new_tensors = candidate_tensors - (visited_tensors.union(next_tensors))
                next_tensors = next_tensors.union(new_tensors)
                visited_tensors = visited_tensors.union(candidate_tensors)
                visited_tensors = visited_tensors.union(
                    subgraph.operators[op].outputs
                )  # include stub outputs but do not traverse them
        next_tensors = next_tensors - stop_tensors

    return visited_operators, visited_tensors


def filter_relabel(src, filter):
    relabel = {}
    output = []
    for i, x in enumerate(src):
        if i in filter:
            relabel[i] = len(output)
            output.append(x)
    return output, relabel


def cut_model(model_file, input_names, output_names, subgraph_index, output_file):
    model = load(model_file)
    subgraph = model.subgraphs[subgraph_index]
    cut_subgraph(subgraph, input_names, output_names)
    simplify(model)
    save(model, output_file)