aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/utils.py
blob: b8d45c6e1bbb5dc17f939a3fe8efa2846ed33984 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Collection of useful functions for optimizations."""
from __future__ import annotations

import logging
from pathlib import Path
from typing import Any
from typing import Callable
from typing import cast
from typing import Iterable

import numpy as np
import tensorflow as tf

from mlia.utils.logging import redirect_output


def representative_dataset(
    input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
) -> Callable:
    """Sample dataset used for quantization."""

    def dataset() -> Iterable:
        for _ in range(sample_count):
            data = np.random.rand(1, *input_shape[1:])
            yield [data.astype(input_dtype)]

    return dataset


def get_tf_tensor_shape(model: str) -> list:
    """Get input shape for the TensorFlow tensor model."""
    loaded = tf.saved_model.load(model)

    try:
        default_signature_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
        default_signature = loaded.signatures[default_signature_key]
        inputs_tensor_info = default_signature.inputs
    except KeyError as err:
        raise KeyError(f"Signature '{default_signature_key}' not found.") from err

    return [
        dim
        for input_key in inputs_tensor_info
        if (shape := input_key.get_shape())
        for dim in shape
    ]


def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes:
    """Convert Keras model to TensorFlow Lite."""
    converter = get_tflite_converter(model, quantized)

    with redirect_output(logging.getLogger("tensorflow")):
        return cast(bytes, converter.convert())


def save_keras_model(
    model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True
) -> None:
    """Save Keras model at provided path."""
    model.save(save_path, include_optimizer=include_optimizer)


def save_tflite_model(tflite_model: bytes, save_path: str | Path) -> None:
    """Save TensorFlow Lite model at provided path."""
    with open(save_path, "wb") as file:
        file.write(tflite_model)


def is_tflite_model(model: str | Path) -> bool:
    """Check if path contains TensorFlow Lite model."""
    model_path = Path(model)

    return model_path.suffix == ".tflite"


def is_keras_model(model: str | Path) -> bool:
    """Check if path contains a Keras model."""
    model_path = Path(model)

    if model_path.is_dir():
        return model_path.joinpath("keras_metadata.pb").exists()

    return model_path.suffix in (".h5", ".hdf5")


def is_saved_model(model: str | Path) -> bool:
    """Check if path contains SavedModel model."""
    model_path = Path(model)

    return model_path.is_dir() and not is_keras_model(model)


def get_tflite_converter(
    model: tf.keras.Model | str | Path, quantized: bool = False
) -> tf.lite.TFLiteConverter:
    """Configure TensorFlow Lite converter for the provided model."""
    if isinstance(model, (str, Path)):
        # converter's methods accept string as input parameter
        model = str(model)

    if isinstance(model, tf.keras.Model):
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        input_shape = model.input_shape
    elif isinstance(model, str) and is_saved_model(model):
        converter = tf.lite.TFLiteConverter.from_saved_model(model)
        input_shape = get_tf_tensor_shape(model)
    elif isinstance(model, str) and is_keras_model(model):
        keras_model = tf.keras.models.load_model(model)
        input_shape = keras_model.input_shape
        converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
    else:
        raise ValueError(f"Unable to create TensorFlow Lite converter for {model}")

    if quantized:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset(input_shape)
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8

    return converter


def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]:
    """Get type map from tflite model."""
    model_type_map: dict[str, Any] = {}
    interpreter = tf.lite.Interpreter(str(model_filename))
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    model_type_map = {
        input_detail["name"]: input_detail["dtype"] for input_detail in input_details
    }
    return model_type_map


def check_tflite_datatypes(model_filename: str | Path, *allowed_types: type) -> None:
    """Check if the model only has the given allowed datatypes."""
    type_map = get_tflite_model_type_map(model_filename)
    types = set(type_map.values())
    allowed = set(allowed_types)
    unexpected = types - allowed

    def cls_to_str(types: set[type]) -> list[str]:
        return [t.__name__ for t in types]

    if len(unexpected) > 0:
        raise TypeError(
            f"Model {model_filename} has "
            f"unexpected data types: {cls_to_str(unexpected)}. "
            f"Only {cls_to_str(allowed)} are allowed."
        )