aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
blob: d3235d72c256e2a15045d532e00c75e9d5cc3cf1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Model configuration."""
import logging
from pathlib import Path
from typing import cast
from typing import Dict
from typing import List
from typing import Union

import tensorflow as tf

from mlia.core.context import Context
from mlia.nn.tensorflow.utils import convert_tf_to_tflite
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_tf_model
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import save_tflite_model

logger = logging.getLogger(__name__)


class ModelConfiguration:
    """Base class for model configuration."""

    def __init__(self, model_path: Union[str, Path]) -> None:
        """Init model configuration instance."""
        self.model_path = str(model_path)

    def convert_to_tflite(
        self, tflite_model_path: Union[str, Path], quantized: bool = False
    ) -> "TFLiteModel":
        """Convert model to TFLite format."""
        raise NotImplementedError()

    def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
        """Convert model to Keras format."""
        raise NotImplementedError()


class KerasModel(ModelConfiguration):
    """Keras model configuration.

    Supports all models supported by Keras API: saved model, H5, HDF5
    """

    def get_keras_model(self) -> tf.keras.Model:
        """Return associated Keras model."""
        return tf.keras.models.load_model(self.model_path)

    def convert_to_tflite(
        self, tflite_model_path: Union[str, Path], quantized: bool = False
    ) -> "TFLiteModel":
        """Convert model to TFLite format."""
        logger.info("Converting Keras to TFLite ...")

        converted_model = convert_to_tflite(self.get_keras_model(), quantized)
        logger.info("Done\n")

        save_tflite_model(converted_model, tflite_model_path)
        logger.debug(
            "Model %s converted and saved to %s", self.model_path, tflite_model_path
        )

        return TFLiteModel(tflite_model_path)

    def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
        """Convert model to Keras format."""
        return self


class TFLiteModel(ModelConfiguration):  # pylint: disable=abstract-method
    """TFLite model configuration."""

    def input_details(self) -> List[Dict]:
        """Get model's input details."""
        interpreter = tf.lite.Interpreter(model_path=self.model_path)
        return cast(List[Dict], interpreter.get_input_details())

    def convert_to_tflite(
        self, tflite_model_path: Union[str, Path], quantized: bool = False
    ) -> "TFLiteModel":
        """Convert model to TFLite format."""
        return self


class TfModel(ModelConfiguration):  # pylint: disable=abstract-method
    """TensorFlow model configuration.

    Supports models supported by TensorFlow API (not Keras)
    """

    def convert_to_tflite(
        self, tflite_model_path: Union[str, Path], quantized: bool = False
    ) -> "TFLiteModel":
        """Convert model to TFLite format."""
        converted_model = convert_tf_to_tflite(self.model_path, quantized)
        save_tflite_model(converted_model, tflite_model_path)

        return TFLiteModel(tflite_model_path)


def get_model(model: Union[Path, str]) -> "ModelConfiguration":
    """Return the model object."""
    if is_tflite_model(model):
        return TFLiteModel(model)

    if is_keras_model(model):
        return KerasModel(model)

    if is_tf_model(model):
        return TfModel(model)

    raise Exception(
        "The input model format is not supported"
        "(supported formats: TFLite, Keras, TensorFlow saved model)!"
    )


def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel":
    """Convert input model to TFLite and returns TFLiteModel object."""
    tflite_model_path = ctx.get_model_path("converted_model.tflite")
    converted_model = get_model(model)

    return converted_model.convert_to_tflite(tflite_model_path, True)


def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel":
    """Convert input model to Keras and returns KerasModel object."""
    keras_model_path = ctx.get_model_path("converted_model.h5")
    converted_model = get_model(model)

    return converted_model.convert_to_keras(keras_model_path)