aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/utils.py')
-rw-r--r--src/mlia/nn/tensorflow/utils.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index 6250f56..7970329 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -63,7 +63,7 @@ def representative_tf_dataset(model: str) -> Callable:
def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter:
- """Convert Keras model to TFLite."""
+ """Convert Keras model to TensorFlow Lite."""
if not isinstance(model, tf.keras.Model):
raise Exception("Invalid model type")
@@ -83,7 +83,7 @@ def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpr
def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter:
- """Convert TensorFlow model to TFLite."""
+ """Convert TensorFlow model to TensorFlow Lite."""
if not isinstance(model, str):
raise Exception("Invalid model type")
@@ -109,15 +109,15 @@ def save_keras_model(model: tf.keras.Model, save_path: str | Path) -> None:
def save_tflite_model(model: tf.lite.TFLiteConverter, save_path: str | Path) -> None:
- """Save TFLite model at provided path."""
+ """Save TensorFlow Lite model at provided path."""
with open(save_path, "wb") as file:
file.write(model)
def is_tflite_model(model: str | Path) -> bool:
- """Check if model type is supported by TFLite API.
+ """Check if model type is supported by TensorFlow Lite API.
- TFLite model is indicated by the model file extension .tflite
+ TensorFlow Lite model is indicated by the model file extension .tflite
"""
model_path = Path(model)
return model_path.suffix == ".tflite"