aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r--src/mlia/nn/tensorflow/config.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index 6ee32e7..03d1d0f 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -31,7 +31,7 @@ class ModelConfiguration:
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
- """Convert model to TFLite format."""
+ """Convert model to TensorFlow Lite format."""
raise NotImplementedError()
def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel:
@@ -52,8 +52,8 @@ class KerasModel(ModelConfiguration):
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
- """Convert model to TFLite format."""
- logger.info("Converting Keras to TFLite ...")
+ """Convert model to TensorFlow Lite format."""
+ logger.info("Converting Keras to TensorFlow Lite ...")
converted_model = convert_to_tflite(self.get_keras_model(), quantized)
logger.info("Done\n")
@@ -71,7 +71,7 @@ class KerasModel(ModelConfiguration):
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
- """TFLite model configuration."""
+ """TensorFlow Lite model configuration."""
def input_details(self) -> list[dict]:
"""Get model's input details."""
@@ -81,7 +81,7 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
- """Convert model to TFLite format."""
+ """Convert model to TensorFlow Lite format."""
return self
@@ -94,7 +94,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
- """Convert model to TFLite format."""
+ """Convert model to TensorFlow Lite format."""
converted_model = convert_tf_to_tflite(self.model_path, quantized)
save_tflite_model(converted_model, tflite_model_path)
@@ -114,12 +114,12 @@ def get_model(model: str | Path) -> ModelConfiguration:
raise Exception(
"The input model format is not supported"
- "(supported formats: TFLite, Keras, TensorFlow saved model)!"
+ "(supported formats: TensorFlow Lite, Keras, TensorFlow saved model)!"
)
def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel:
- """Convert input model to TFLite and returns TFLiteModel object."""
+ """Convert input model to TensorFlow Lite and returns TFLiteModel object."""
tflite_model_path = ctx.get_model_path("converted_model.tflite")
converted_model = get_model(model)