aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/utils.py
blob: 3ac506497da78e66421476b43d8d9b344ad68fbc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Collection of useful functions for optimizations."""
from __future__ import annotations

from pathlib import Path
from typing import Any

import tensorflow as tf
from keras.api._v2 import keras  # Temporary workaround for now: MLIA-1107


def get_tf_tensor_shape(model: str) -> list:
    """Get input shape for the TensorFlow tensor model."""
    loaded = tf.saved_model.load(model)

    try:
        default_signature_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
        default_signature = loaded.signatures[default_signature_key]
        inputs_tensor_info = default_signature.inputs
    except KeyError as err:
        raise KeyError(f"Signature '{default_signature_key}' not found.") from err

    return [
        dim
        for input_key in inputs_tensor_info
        if (shape := input_key.get_shape())
        for dim in shape
    ]


def save_keras_model(
    model: keras.Model, save_path: str | Path, include_optimizer: bool = True
) -> None:
    """Save Keras model at provided path."""
    model.save(save_path, include_optimizer=include_optimizer)


def save_tflite_model(tflite_model: bytes, save_path: str | Path) -> None:
    """Save TensorFlow Lite model at provided path."""
    with open(save_path, "wb") as file:
        file.write(tflite_model)


def is_tflite_model(model: str | Path) -> bool:
    """Check if path contains TensorFlow Lite model."""
    model_path = Path(model)

    return model_path.suffix == ".tflite"


def is_keras_model(model: str | Path) -> bool:
    """Check if path contains a Keras model."""
    model_path = Path(model)

    if model_path.is_dir():
        return model_path.joinpath("keras_metadata.pb").exists()

    return model_path.suffix in (".h5", ".hdf5")


def is_saved_model(model: str | Path) -> bool:
    """Check if path contains SavedModel model."""
    model_path = Path(model)

    return model_path.is_dir() and not is_keras_model(model)


def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]:
    """Get type map from tflite model."""
    model_type_map: dict[str, Any] = {}
    interpreter = tf.lite.Interpreter(str(model_filename))
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    model_type_map = {
        input_detail["name"]: input_detail["dtype"] for input_detail in input_details
    }
    return model_type_map


def check_tflite_datatypes(model_filename: str | Path, *allowed_types: type) -> None:
    """Check if the model only has the given allowed datatypes."""
    type_map = get_tflite_model_type_map(model_filename)
    types = set(type_map.values())
    allowed = set(allowed_types)
    unexpected = types - allowed

    def cls_to_str(types: set[type]) -> list[str]:
        return [t.__name__ for t in types]

    if len(unexpected) > 0:
        raise TypeError(
            f"Model {model_filename} has "
            f"unexpected data types: {cls_to_str(unexpected)}. "
            f"Only {cls_to_str(allowed)} are allowed."
        )