diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r-- | src/mlia/nn/tensorflow/config.py | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index 03d1d0f..0c3133a 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -11,12 +11,12 @@ from typing import List import tensorflow as tf from mlia.core.context import Context -from mlia.nn.tensorflow.utils import convert_tf_to_tflite from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model -from mlia.nn.tensorflow.utils import is_tf_model +from mlia.nn.tensorflow.utils import is_saved_model from mlia.nn.tensorflow.utils import is_tflite_model from mlia.nn.tensorflow.utils import save_tflite_model +from mlia.utils.logging import log_action logger = logging.getLogger(__name__) @@ -53,10 +53,8 @@ class KerasModel(ModelConfiguration): self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" - logger.info("Converting Keras to TensorFlow Lite ...") - - converted_model = convert_to_tflite(self.get_keras_model(), quantized) - logger.info("Done\n") + with log_action("Converting Keras to TensorFlow Lite ..."): + converted_model = convert_to_tflite(self.get_keras_model(), quantized) save_tflite_model(converted_model, tflite_model_path) logger.debug( @@ -95,7 +93,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" - converted_model = convert_tf_to_tflite(self.model_path, quantized) + converted_model = convert_to_tflite(self.model_path, quantized) save_tflite_model(converted_model, tflite_model_path) return TFLiteModel(tflite_model_path) @@ -109,7 +107,7 @@ def get_model(model: str | Path) -> ModelConfiguration: if is_keras_model(model): return KerasModel(model) - if is_tf_model(model): + if is_saved_model(model): return TfModel(model) raise Exception( |