aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/graph_edit/record.py
blob: f85433d54b37c6cc3381d8e131c16cb30605ba62 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Save subgraph data."""
# pylint: disable=too-many-locals
from __future__ import annotations

import math
import os
from contextlib import ExitStack
from pathlib import Path

import tensorflow as tf
from rich.progress import track

from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
from mlia.nn.tensorflow.config import NameToTensorMap

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


DEQUANT_SUFFIX = "_dequant"


def dequantized_path(filename: str | Path) -> Path:
    """Append the de-quantization suffix to the given filename."""
    path = Path(filename)
    path = Path(path.parent, f"{path.stem}{DEQUANT_SUFFIX}{path.suffix}")
    return path


def record_model(
    input_filename: str | Path,
    model_filename: str | Path,
    output_filename: str | Path,
    batch_size: int = 0,
    show_progress: bool = False,
    num_procs: int = 1,
    num_threads: int = 0,
    dequantize_output: bool = False,
) -> None:
    """Model recorder.

    num_procs: 0 => detect real cores on system
    num_threads: 0 => TFLite impl. specific setting, usually 3

    dequantize: True => de-quantize the recorded output before saving
    """
    model = ParallelTFLiteModel(model_filename, num_procs, num_threads, batch_size)
    if not batch_size:
        batch_size = (
            model.num_procs * model.batch_size
        )  # automatically batch to the minimum effective size if not specified

    total = numpytf_count(input_filename)
    dataset = numpytf_read(input_filename)

    if batch_size > 1:
        # Collapse batch-size 1 items into batch-size n.
        dataset = dataset.map(
            lambda d: {k: tf.squeeze(v, axis=0) for k, v in d.items()}
        )
        dataset = dataset.batch(batch_size, drop_remainder=False)
        total = int(math.ceil(total / batch_size))

    with ExitStack() as stack:
        writer = stack.enter_context(NumpyTFWriter(output_filename))
        writer_dequant = None
        if dequantize_output:
            dequant_path = dequantized_path(output_filename)
            writer_dequant = stack.enter_context(NumpyTFWriter(dequant_path))

        def write(writer: NumpyTFWriter, data: NameToTensorMap) -> None:
            """Write the data using the given NumpyTFWriter instance."""
            if batch_size > 1:
                for i in range(batch_size):
                    # Expand the batches and recreate each dict as a
                    # batch-size 1 item for the tfrec output
                    recreated_dict = {
                        k: v[i : i + 1]  # noqa: E203
                        for k, v in data.items()
                        if i < v.shape[0]
                    }
                    if recreated_dict:
                        writer.write(recreated_dict)
            else:
                writer.write(data)

        for _, named_x in enumerate(
            track(dataset.as_numpy_iterator(), total=total, disable=not show_progress)
        ):
            named_y = model(named_x)
            write(writer, named_y)

            if dequantize_output:
                assert writer_dequant
                named_y_dequant = model.dequantize_outputs(named_y)
                write(writer_dequant, named_y_dequant)

        model.close()