aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r--src/mlia/nn/tensorflow/config.py134
1 files changed, 134 insertions, 0 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
new file mode 100644
index 0000000..d3235d7
--- /dev/null
+++ b/src/mlia/nn/tensorflow/config.py
@@ -0,0 +1,134 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Model configuration."""
+import logging
+from pathlib import Path
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Union
+
+import tensorflow as tf
+
+from mlia.core.context import Context
+from mlia.nn.tensorflow.utils import convert_tf_to_tflite
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.nn.tensorflow.utils import is_tf_model
+from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+
+logger = logging.getLogger(__name__)
+
+
+class ModelConfiguration:
+ """Base class for model configuration."""
+
+ def __init__(self, model_path: Union[str, Path]) -> None:
+ """Init model configuration instance."""
+ self.model_path = str(model_path)
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ raise NotImplementedError()
+
+ def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ """Convert model to Keras format."""
+ raise NotImplementedError()
+
+
+class KerasModel(ModelConfiguration):
+ """Keras model configuration.
+
+ Supports all models supported by Keras API: saved model, H5, HDF5
+ """
+
+ def get_keras_model(self) -> tf.keras.Model:
+ """Return associated Keras model."""
+ return tf.keras.models.load_model(self.model_path)
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ logger.info("Converting Keras to TFLite ...")
+
+ converted_model = convert_to_tflite(self.get_keras_model(), quantized)
+ logger.info("Done\n")
+
+ save_tflite_model(converted_model, tflite_model_path)
+ logger.debug(
+ "Model %s converted and saved to %s", self.model_path, tflite_model_path
+ )
+
+ return TFLiteModel(tflite_model_path)
+
+ def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ """Convert model to Keras format."""
+ return self
+
+
+class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
+ """TFLite model configuration."""
+
+ def input_details(self) -> List[Dict]:
+ """Get model's input details."""
+ interpreter = tf.lite.Interpreter(model_path=self.model_path)
+ return cast(List[Dict], interpreter.get_input_details())
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ return self
+
+
+class TfModel(ModelConfiguration): # pylint: disable=abstract-method
+ """TensorFlow model configuration.
+
+ Supports models supported by TensorFlow API (not Keras)
+ """
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ converted_model = convert_tf_to_tflite(self.model_path, quantized)
+ save_tflite_model(converted_model, tflite_model_path)
+
+ return TFLiteModel(tflite_model_path)
+
+
+def get_model(model: Union[Path, str]) -> "ModelConfiguration":
+ """Return the model object."""
+ if is_tflite_model(model):
+ return TFLiteModel(model)
+
+ if is_keras_model(model):
+ return KerasModel(model)
+
+ if is_tf_model(model):
+ return TfModel(model)
+
+ raise Exception(
+ "The input model format is not supported"
+ "(supported formats: TFLite, Keras, TensorFlow saved model)!"
+ )
+
+
+def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel":
+ """Convert input model to TFLite and returns TFLiteModel object."""
+ tflite_model_path = ctx.get_model_path("converted_model.tflite")
+ converted_model = get_model(model)
+
+ return converted_model.convert_to_tflite(tflite_model_path, True)
+
+
+def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel":
+ """Convert input model to Keras and returns KerasModel object."""
+ keras_model_path = ctx.get_model_path("converted_model.h5")
+ converted_model = get_model(model)
+
+ return converted_model.convert_to_keras(keras_model_path)