aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/utils.py
blob: 4abf6cd9b427c20d3b22e6160af2299206528f94 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Collection of useful functions for optimizations."""
import logging
from pathlib import Path
from typing import Callable
from typing import Iterable
from typing import Union

import numpy as np
import tensorflow as tf
from tensorflow.lite.python.interpreter import Interpreter

from mlia.utils.logging import redirect_output


def representative_dataset(model: tf.keras.Model) -> Callable:
    """Sample dataset used for quantization."""
    input_shape = model.input_shape

    def dataset() -> Iterable:
        for _ in range(100):
            if input_shape[0] != 1:
                raise Exception("Only the input batch_size=1 is supported!")
            data = np.random.rand(*input_shape)
            yield [data.astype(np.float32)]

    return dataset


def get_tf_tensor_shape(model: str) -> list:
    """Get input shape for the TensorFlow tensor model."""
    # Loading the model
    loaded = tf.saved_model.load(model)
    # The model signature must have 'serving_default' as a key
    if "serving_default" not in loaded.signatures.keys():
        raise Exception(
            "Unsupported TensorFlow model signature, must have 'serving_default'"
        )
    # Get the signature inputs
    inputs_tensor_info = loaded.signatures["serving_default"].inputs
    dims = []
    # Build a list of all inputs shape sizes
    for input_key in inputs_tensor_info:
        if input_key.get_shape():
            dims.extend(list(input_key.get_shape()))
    return dims


def representative_tf_dataset(model: str) -> Callable:
    """Sample dataset used for quantization."""
    if not (input_shape := get_tf_tensor_shape(model)):
        raise Exception("Unable to get input shape")

    def dataset() -> Iterable:
        for _ in range(100):
            data = np.random.rand(*input_shape)
            yield [data.astype(np.float32)]

    return dataset


def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter:
    """Convert Keras model to TFLite."""
    if not isinstance(model, tf.keras.Model):
        raise Exception("Invalid model type")

    converter = tf.lite.TFLiteConverter.from_keras_model(model)

    if quantized:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset(model)
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8

    with redirect_output(logging.getLogger("tensorflow")):
        tflite_model = converter.convert()

    return tflite_model


def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter:
    """Convert TensorFlow model to TFLite."""
    if not isinstance(model, str):
        raise Exception("Invalid model type")

    converter = tf.lite.TFLiteConverter.from_saved_model(model)

    if quantized:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_tf_dataset(model)
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8

    with redirect_output(logging.getLogger("tensorflow")):
        tflite_model = converter.convert()

    return tflite_model


def save_keras_model(model: tf.keras.Model, save_path: Union[str, Path]) -> None:
    """Save Keras model at provided path."""
    # Checkpoint: saving the optimizer is necessary.
    model.save(save_path, include_optimizer=True)


def save_tflite_model(
    model: tf.lite.TFLiteConverter, save_path: Union[str, Path]
) -> None:
    """Save TFLite model at provided path."""
    with open(save_path, "wb") as file:
        file.write(model)


def is_tflite_model(model: Union[Path, str]) -> bool:
    """Check if model type is supported by TFLite API.

    TFLite model is indicated by the model file extension .tflite
    """
    model_path = Path(model)
    return model_path.suffix == ".tflite"


def is_keras_model(model: Union[Path, str]) -> bool:
    """Check if model type is supported by Keras API.

    Keras model is indicated by:
        1. if it's a directory (meaning saved model),
             it should contain keras_metadata.pb file
        2. or if the model file extension is .h5/.hdf5
    """
    model_path = Path(model)

    if model_path.is_dir():
        return (model_path / "keras_metadata.pb").exists()
    return model_path.suffix in (".h5", ".hdf5")


def is_tf_model(model: Union[Path, str]) -> bool:
    """Check if model type is supported by TensorFlow API.

    TensorFlow model is indicated if its directory (meaning saved model)
    doesn't contain keras_metadata.pb file
    """
    model_path = Path(model)
    return model_path.is_dir() and not is_keras_model(model)