aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
blob: d7d430fb174d3017a7df2371fba92d98bda8d33a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Model configuration."""
from __future__ import annotations

import logging
from pathlib import Path
from typing import cast
from typing import List

import tensorflow as tf

from mlia.core.context import Context
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.utils.logging import log_action

logger = logging.getLogger(__name__)


class ModelConfiguration:
    """Base class for model configuration."""

    def __init__(self, model_path: str | Path) -> None:
        """Init model configuration instance."""
        self.model_path = str(model_path)

    def convert_to_tflite(
        self, tflite_model_path: str | Path, quantized: bool = False
    ) -> TFLiteModel:
        """Convert model to TensorFlow Lite format."""
        raise NotImplementedError()

    def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel:
        """Convert model to Keras format."""
        raise NotImplementedError()


class KerasModel(ModelConfiguration):
    """Keras model configuration.

    Supports all models supported by Keras API: saved model, H5, HDF5
    """

    def get_keras_model(self) -> tf.keras.Model:
        """Return associated Keras model."""
        return tf.keras.models.load_model(self.model_path)

    def convert_to_tflite(
        self, tflite_model_path: str | Path, quantized: bool = False
    ) -> TFLiteModel:
        """Convert model to TensorFlow Lite format."""
        with log_action("Converting Keras to TensorFlow Lite ..."):
            converted_model = convert_to_tflite(self.get_keras_model(), quantized)

        save_tflite_model(converted_model, tflite_model_path)
        logger.debug(
            "Model %s converted and saved to %s", self.model_path, tflite_model_path
        )

        return TFLiteModel(tflite_model_path)

    def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel:
        """Convert model to Keras format."""
        return self


class TFLiteModel(ModelConfiguration):  # pylint: disable=abstract-method
    """TensorFlow Lite model configuration."""

    def input_details(self) -> list[dict]:
        """Get model's input details."""
        interpreter = tf.lite.Interpreter(model_path=self.model_path)
        return cast(List[dict], interpreter.get_input_details())

    def convert_to_tflite(
        self, tflite_model_path: str | Path, quantized: bool = False
    ) -> TFLiteModel:
        """Convert model to TensorFlow Lite format."""
        return self


class TfModel(ModelConfiguration):  # pylint: disable=abstract-method
    """TensorFlow model configuration.

    Supports models supported by TensorFlow API (not Keras)
    """

    def convert_to_tflite(
        self, tflite_model_path: str | Path, quantized: bool = False
    ) -> TFLiteModel:
        """Convert model to TensorFlow Lite format."""
        converted_model = convert_to_tflite(self.model_path, quantized)
        save_tflite_model(converted_model, tflite_model_path)

        return TFLiteModel(tflite_model_path)


def get_model(model: str | Path) -> ModelConfiguration:
    """Return the model object."""
    if is_tflite_model(model):
        return TFLiteModel(model)

    if is_keras_model(model):
        return KerasModel(model)

    if is_saved_model(model):
        return TfModel(model)

    raise ValueError(
        "The input model format is not supported "
        "(supported formats: TensorFlow Lite, Keras, TensorFlow saved model)!"
    )


def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel:
    """Convert input model to TensorFlow Lite and returns TFLiteModel object."""
    tflite_model_path = ctx.get_model_path("converted_model.tflite")
    converted_model = get_model(model)

    return converted_model.convert_to_tflite(tflite_model_path, True)


def get_keras_model(model: str | Path, ctx: Context) -> KerasModel:
    """Convert input model to Keras and returns KerasModel object."""
    keras_model_path = ctx.get_model_path("converted_model.h5")
    converted_model = get_model(model)

    return converted_model.convert_to_keras(keras_model_path)