aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/cut.py
blob: 13a5268ea374eb56d95cd70d34d53307f5524b2e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Cut module."""
import os
from collections import defaultdict
from typing import Optional

import tensorflow as tf
from tensorflow.lite.python.schema_py_generated import ModelT
from tensorflow.lite.python.schema_py_generated import SubGraphT

from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


def tensors_by_name(subgraph: SubGraphT, names: list) -> list:
    """Seek out tensors from a subgraph and return the result."""
    seek = frozenset([name.encode("utf-8") for name in names])
    tensors = [i for i, tensor in enumerate(subgraph.tensors) if tensor.name in seek]
    return tensors


def cut_subgraph(
    subgraph: SubGraphT,
    input_tensor_names: Optional[list],
    output_tensor_names: Optional[list],
) -> None:
    """Change the global inputs and outputs of a graph to the provided named tensors."""
    if input_tensor_names is not None:
        subgraph.inputs = tensors_by_name(subgraph, input_tensor_names)
        assert len(subgraph.inputs) == len(
            input_tensor_names
        ), f"Expected {len(subgraph.inputs)} input tensors: \
            {', '.join(input_tensor_names)}\nFound: \
            {', '.join(subgraph.tensors[i].name for i in subgraph.inputs)}"
    if output_tensor_names is not None:
        subgraph.outputs = tensors_by_name(subgraph, output_tensor_names)
        assert len(subgraph.outputs) == len(
            output_tensor_names
        ), f"Expected {len(subgraph.outputs)} output tensors: \
            {', '.join(output_tensor_names)}\nFound: \
            {', '.join(subgraph.tensors[i].name for i in subgraph.outputs)}"


def simplify(model: ModelT) -> None:
    """Remove any unused operators, tensors and buffers from a model."""
    for subgraph in model.subgraphs:
        simplify_subgraph(subgraph)

    used_buffers = {
        tensor.buffer for tensor in subgraph.tensors for subgraph in model.subgraphs
    }
    used_buffers = used_buffers.union({metadata.buffer for metadata in model.metadata})
    used_buffers.add(
        0
    )  # Buffer zero is always expected to be a zero-sized nullptr buffer by the
    # TFLite runtime
    model.buffers, buf_relabel = filter_relabel(model.buffers, used_buffers)

    for subgraph in model.subgraphs:
        for tensor in subgraph.tensors:
            tensor.buffer = buf_relabel[tensor.buffer]

    for metadata in model.metadata:
        metadata.buffer = buf_relabel[metadata.buffer]


def simplify_subgraph(subgraph: SubGraphT) -> None:
    """Simplify a subgraph given its operators."""
    requires = defaultdict(set)

    for output, operator in enumerate(subgraph.operators):
        for tensor in operator.outputs:
            if tensor not in subgraph.inputs:
                requires[tensor].add(output)

    op_set, ten_set = find_required(subgraph, requires, subgraph.outputs)

    subgraph.operators, _ = filter_relabel(subgraph.operators, op_set)
    subgraph.tensors, ten_relabel = filter_relabel(subgraph.tensors, ten_set)

    ten_relabel[-1] = -1  # Some files have ops with -1 input tensors; leave unchanged

    for operator in subgraph.operators:
        operator.inputs = [ten_relabel[tensor] for tensor in operator.inputs]
        operator.outputs = [ten_relabel[tensor] for tensor in operator.outputs]

    subgraph.inputs = [ten_relabel[tensor] for tensor in subgraph.inputs]
    subgraph.outputs = [ten_relabel[tensors] for tensors in subgraph.outputs]


def find_required(subgraph: SubGraphT, requires: dict, tensors: dict) -> tuple:
    """Find required operators in a given subgraph."""
    visited_operators: set = set()
    visited_tensors = set(tensors)
    stop_tensors = set(subgraph.inputs)

    next_tensors = visited_tensors
    while next_tensors:
        loop_tensors = next_tensors
        next_tensors = set()
        for tensor in loop_tensors:
            candidate_operators = set(requires[tensor])
            new_operators = candidate_operators - visited_operators
            visited_operators = visited_operators.union(new_operators)
            for operator in new_operators:
                candidate_tensors = set(subgraph.operators[operator].inputs)
                new_tensors = candidate_tensors - (visited_tensors.union(next_tensors))
                next_tensors = next_tensors.union(new_tensors)
                visited_tensors = visited_tensors.union(candidate_tensors)
                visited_tensors = visited_tensors.union(
                    subgraph.operators[operator].outputs
                )  # include stub outputs but do not traverse them
        next_tensors = next_tensors - stop_tensors

    return visited_operators, visited_tensors


def filter_relabel(src_subgraph: SubGraphT, relabel_filter: set) -> tuple:
    """Relabel tensors in a subgraph based on a filter."""
    relabel: dict = {}
    output: list = []
    for i, out in enumerate(src_subgraph):
        if i in relabel_filter:
            relabel[i] = len(output)
            output.append(out)
    return output, relabel


def cut_model(
    model_file: str,
    input_names: Optional[list],
    output_names: Optional[list],
    subgraph_index: int,
    output_file: str,
) -> None:
    """Cut subgraphs and simplify a given model."""
    model = load_fb(model_file)
    subgraph = model.subgraphs[subgraph_index]
    cut_subgraph(subgraph, input_names, output_names)
    simplify(model)
    save_fb(model, output_file)