diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-10-24 15:08:08 +0100 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-10-26 17:08:13 +0100 |
commit | 58a65fee574c00329cf92b387a6d2513dcbf6100 (patch) | |
tree | 47e3185f78b4298ab029785ddee68456e44cac10 /src/mlia/nn/tensorflow/config.py | |
parent | 9d34cb72d45a6d0a2ec1063ebf32536c1efdba75 (diff) | |
download | mlia-58a65fee574c00329cf92b387a6d2513dcbf6100.tar.gz |
MLIA-433 Add TensorFlow Lite compatibility check
- Add ability to intercept low level TensorFlow output
- Produce advice for the models that could not be
converted to the TensorFlow Lite format
- Refactor utility functions for TensorFlow Lite
conversion
- Add TensorFlow Lite compatibility checker
Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r-- | src/mlia/nn/tensorflow/config.py | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index 03d1d0f..0c3133a 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -11,12 +11,12 @@ from typing import List import tensorflow as tf from mlia.core.context import Context -from mlia.nn.tensorflow.utils import convert_tf_to_tflite from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model -from mlia.nn.tensorflow.utils import is_tf_model +from mlia.nn.tensorflow.utils import is_saved_model from mlia.nn.tensorflow.utils import is_tflite_model from mlia.nn.tensorflow.utils import save_tflite_model +from mlia.utils.logging import log_action logger = logging.getLogger(__name__) @@ -53,10 +53,8 @@ class KerasModel(ModelConfiguration): self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" - logger.info("Converting Keras to TensorFlow Lite ...") - - converted_model = convert_to_tflite(self.get_keras_model(), quantized) - logger.info("Done\n") + with log_action("Converting Keras to TensorFlow Lite ..."): + converted_model = convert_to_tflite(self.get_keras_model(), quantized) save_tflite_model(converted_model, tflite_model_path) logger.debug( @@ -95,7 +93,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" - converted_model = convert_tf_to_tflite(self.model_path, quantized) + converted_model = convert_to_tflite(self.model_path, quantized) save_tflite_model(converted_model, tflite_model_path) return TFLiteModel(tflite_model_path) @@ -109,7 +107,7 @@ def get_model(model: str | Path) -> ModelConfiguration: if is_keras_model(model): return KerasModel(model) - if is_tf_model(model): + if is_saved_model(model): return TfModel(model) raise Exception( |