aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/tflite_convert.py
blob: d3a833ae09f51343d380b23316bdbaa732cf77d5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Support module to call TFLiteConverter."""
from __future__ import annotations

import argparse
import logging
import sys
from pathlib import Path
from typing import Any
from typing import Callable
from typing import cast
from typing import Iterable

import numpy as np
import tensorflow as tf

from mlia.nn.tensorflow.utils import get_tf_tensor_shape
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.utils.logging import redirect_output
from mlia.utils.proc import Command
from mlia.utils.proc import command_output

logger = logging.getLogger(__name__)


def representative_dataset(
    input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
) -> Callable:
    """Sample dataset used for quantization."""

    def dataset() -> Iterable:
        for _ in range(sample_count):
            data = np.random.rand(1, *input_shape[1:])
            yield [data.astype(input_dtype)]

    return dataset


def get_tflite_converter(
    model: tf.keras.Model | str | Path, quantized: bool = False
) -> tf.lite.TFLiteConverter:
    """Configure TensorFlow Lite converter for the provided model."""
    if isinstance(model, (str, Path)):
        # converter's methods accept string as input parameter
        model = str(model)

    if isinstance(model, tf.keras.Model):
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        input_shape = model.input_shape
    elif isinstance(model, str) and is_saved_model(model):
        converter = tf.lite.TFLiteConverter.from_saved_model(model)
        input_shape = get_tf_tensor_shape(model)
    elif isinstance(model, str) and is_keras_model(model):
        keras_model = tf.keras.models.load_model(model)
        input_shape = keras_model.input_shape
        converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
    else:
        raise ValueError(f"Unable to create TensorFlow Lite converter for {model}")

    if quantized:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset(input_shape)
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8

    return converter


def convert_to_tflite_bytes(
    model: tf.keras.Model | str, quantized: bool = False
) -> bytes:
    """Convert Keras model to TensorFlow Lite."""
    converter = get_tflite_converter(model, quantized)

    with redirect_output(logging.getLogger("tensorflow")):
        output_bytes = cast(bytes, converter.convert())

    return output_bytes


def _convert_to_tflite(
    model: tf.keras.Model | str,
    quantized: bool = False,
    output_path: Path | None = None,
) -> bytes:
    """Convert Keras model to TensorFlow Lite."""
    output_bytes = convert_to_tflite_bytes(model, quantized)

    if output_path:
        save_tflite_model(output_bytes, output_path)

    return output_bytes


def convert_to_tflite(
    model: tf.keras.Model | str,
    quantized: bool = False,
    output_path: Path | None = None,
    input_path: Path | None = None,
    subprocess: bool = False,
) -> None:
    """Convert Keras model to TensorFlow Lite.

    Optionally runs TFLiteConverter in a subprocess,
    this is added mainly to work around issues when redirecting
    Tensorflow's output using SDK calls, didn't make an effect,
    which would produce unwanted output for MLIA.

    In the subprocess mode, the model should be passed as a
    file path, or via a dedicated 'input_path' parameter.

    If 'output_path' is provided, the result model be saved under
    that path.
    """
    if not subprocess:
        _convert_to_tflite(model, quantized, output_path)
        return

    if input_path is None:
        if isinstance(model, str):
            input_path = Path(model)
        else:
            raise RuntimeError(
                f"Input path is required for {model}"
                " when converter is called in subprocess."
            )

    args = ["python", __file__, str(input_path)]
    if output_path:
        args.append("--output")
        args.append(str(output_path))
    if quantized:
        args.append("--quantize")

    command = Command(args)

    for line in command_output(command):
        logger.debug("TFLiteConverter: %s", line)


def main(argv: list[str] | None = None) -> int:
    """Entry point to run this module as a standalone executable."""
    parser = argparse.ArgumentParser()
    parser.add_argument("input", type=Path)
    parser.add_argument("--output", type=Path, default=None)
    parser.add_argument("--quantize", default=False, action="store_true")
    args = parser.parse_args(argv)

    if not Path(args.input).exists():
        raise ValueError(f"Input file doesn't exist: [{args.input}]")

    logger.debug(
        "Invoking TFLiteConverter on [%s] -> [%s], quantize: [%s]",
        args.input,
        args.output,
        args.quantize,
    )
    _convert_to_tflite(args.input, args.quantize, args.output)
    return 0


if __name__ == "__main__":
    sys.exit(main())