aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-24 15:08:08 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-26 17:08:13 +0100
commit58a65fee574c00329cf92b387a6d2513dcbf6100 (patch)
tree47e3185f78b4298ab029785ddee68456e44cac10 /src/mlia/nn/tensorflow/config.py
parent9d34cb72d45a6d0a2ec1063ebf32536c1efdba75 (diff)
downloadmlia-58a65fee574c00329cf92b387a6d2513dcbf6100.tar.gz
MLIA-433 Add TensorFlow Lite compatibility check
- Add ability to intercept low level TensorFlow output - Produce advice for the models that could not be converted to the TensorFlow Lite format - Refactor utility functions for TensorFlow Lite conversion - Add TensorFlow Lite compatibility checker Change-Id: I47d120d2619ced7b143bc92c5184515b81c0220d
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r--src/mlia/nn/tensorflow/config.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index 03d1d0f..0c3133a 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -11,12 +11,12 @@ from typing import List
import tensorflow as tf
from mlia.core.context import Context
-from mlia.nn.tensorflow.utils import convert_tf_to_tflite
from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
-from mlia.nn.tensorflow.utils import is_tf_model
+from mlia.nn.tensorflow.utils import is_saved_model
from mlia.nn.tensorflow.utils import is_tflite_model
from mlia.nn.tensorflow.utils import save_tflite_model
+from mlia.utils.logging import log_action
logger = logging.getLogger(__name__)
@@ -53,10 +53,8 @@ class KerasModel(ModelConfiguration):
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
"""Convert model to TensorFlow Lite format."""
- logger.info("Converting Keras to TensorFlow Lite ...")
-
- converted_model = convert_to_tflite(self.get_keras_model(), quantized)
- logger.info("Done\n")
+ with log_action("Converting Keras to TensorFlow Lite ..."):
+ converted_model = convert_to_tflite(self.get_keras_model(), quantized)
save_tflite_model(converted_model, tflite_model_path)
logger.debug(
@@ -95,7 +93,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
"""Convert model to TensorFlow Lite format."""
- converted_model = convert_tf_to_tflite(self.model_path, quantized)
+ converted_model = convert_to_tflite(self.model_path, quantized)
save_tflite_model(converted_model, tflite_model_path)
return TFLiteModel(tflite_model_path)
@@ -109,7 +107,7 @@ def get_model(model: str | Path) -> ModelConfiguration:
if is_keras_model(model):
return KerasModel(model)
- if is_tf_model(model):
+ if is_saved_model(model):
return TfModel(model)
raise Exception(