aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-07 11:39:37 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-10-07 11:40:21 +0100
commit3083f7ee68ce08147db08fca2474e5f4712fc8d7 (patch)
treec52e668c01a6a1041c08190e52a15944fd65b453 /src/mlia/nn
parentbb7fb49484bb3687041061b2fdbbfae3959be54b (diff)
downloadmlia-3083f7ee68ce08147db08fca2474e5f4712fc8d7.tar.gz
MLIA-607 Update documentation and comments
Use "TensorFlow Lite" instead of "TFLite" in documentation and comments Change-Id: Ie4450d72fb2e5261d152d72ab8bd94c3da914c46
Diffstat (limited to 'src/mlia/nn')
-rw-r--r--src/mlia/nn/tensorflow/config.py16
-rw-r--r--src/mlia/nn/tensorflow/tflite_metrics.py10
-rw-r--r--src/mlia/nn/tensorflow/utils.py10
3 files changed, 18 insertions, 18 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index 6ee32e7..03d1d0f 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -31,7 +31,7 @@ class ModelConfiguration:
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
- """Convert model to TFLite format."""
+ """Convert model to TensorFlow Lite format."""
raise NotImplementedError()
def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel:
@@ -52,8 +52,8 @@ class KerasModel(ModelConfiguration):
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
- """Convert model to TFLite format."""
- logger.info("Converting Keras to TFLite ...")
+ """Convert model to TensorFlow Lite format."""
+ logger.info("Converting Keras to TensorFlow Lite ...")
converted_model = convert_to_tflite(self.get_keras_model(), quantized)
logger.info("Done\n")
@@ -71,7 +71,7 @@ class KerasModel(ModelConfiguration):
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
- """TFLite model configuration."""
+ """TensorFlow Lite model configuration."""
def input_details(self) -> list[dict]:
"""Get model's input details."""
@@ -81,7 +81,7 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
- """Convert model to TFLite format."""
+ """Convert model to TensorFlow Lite format."""
return self
@@ -94,7 +94,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
def convert_to_tflite(
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
- """Convert model to TFLite format."""
+ """Convert model to TensorFlow Lite format."""
converted_model = convert_tf_to_tflite(self.model_path, quantized)
save_tflite_model(converted_model, tflite_model_path)
@@ -114,12 +114,12 @@ def get_model(model: str | Path) -> ModelConfiguration:
raise Exception(
"The input model format is not supported"
- "(supported formats: TFLite, Keras, TensorFlow saved model)!"
+ "(supported formats: TensorFlow Lite, Keras, TensorFlow saved model)!"
)
def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel:
- """Convert input model to TFLite and returns TFLiteModel object."""
+ """Convert input model to TensorFlow Lite and returns TFLiteModel object."""
tflite_model_path = ctx.get_model_path("converted_model.tflite")
converted_model = get_model(model)
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py
index 0af7500..d7ae2a4 100644
--- a/src/mlia/nn/tensorflow/tflite_metrics.py
+++ b/src/mlia/nn/tensorflow/tflite_metrics.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
-Contains class TFLiteMetrics to calculate metrics from a TFLite file.
+Contains class TFLiteMetrics to calculate metrics from a TensorFlow Lite file.
These metrics include:
* Sparsity (per layer and overall)
@@ -102,7 +102,7 @@ class ReportClusterMode(Enum):
class TFLiteMetrics:
- """Helper class to calculate metrics from a TFLite file.
+ """Helper class to calculate metrics from a TensorFlow Lite file.
Metrics include:
* sparsity (per-layer and overall)
@@ -111,12 +111,12 @@ class TFLiteMetrics:
"""
def __init__(self, tflite_file: str, ignore_list: list[str] | None = None) -> None:
- """Load the TFLite file and filter layers."""
+ """Load the TensorFlow Lite file and filter layers."""
self.tflite_file = tflite_file
if ignore_list is None:
ignore_list = DEFAULT_IGNORE_LIST
self.ignore_list = [ignore.casefold() for ignore in ignore_list]
- # Initialize the TFLite interpreter with the model file
+ # Initialize the TensorFlow Lite interpreter with the model file
self.interpreter = tf.lite.Interpreter(
model_path=tflite_file, experimental_preserve_all_tensors=True
)
@@ -218,7 +218,7 @@ class TFLiteMetrics:
"""Print a summary of all the model information."""
print(f"Model file: {self.tflite_file}")
print("#" * 80)
- print(" " * 28 + "### TFLITE SUMMARY ###")
+ print(" " * 28 + "### TENSORFLOW LITE SUMMARY ###")
print(f"File: {os.path.abspath(self.tflite_file)}")
print("Input(s):")
self._print_in_outs(self.interpreter.get_input_details(), verbose)
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index 6250f56..7970329 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -63,7 +63,7 @@ def representative_tf_dataset(model: str) -> Callable:
def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter:
- """Convert Keras model to TFLite."""
+ """Convert Keras model to TensorFlow Lite."""
if not isinstance(model, tf.keras.Model):
raise Exception("Invalid model type")
@@ -83,7 +83,7 @@ def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpr
def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter:
- """Convert TensorFlow model to TFLite."""
+ """Convert TensorFlow model to TensorFlow Lite."""
if not isinstance(model, str):
raise Exception("Invalid model type")
@@ -109,15 +109,15 @@ def save_keras_model(model: tf.keras.Model, save_path: str | Path) -> None:
def save_tflite_model(model: tf.lite.TFLiteConverter, save_path: str | Path) -> None:
- """Save TFLite model at provided path."""
+ """Save TensorFlow Lite model at provided path."""
with open(save_path, "wb") as file:
file.write(model)
def is_tflite_model(model: str | Path) -> bool:
- """Check if model type is supported by TFLite API.
+ """Check if model type is supported by TensorFlow Lite API.
- TFLite model is indicated by the model file extension .tflite
+ TensorFlow Lite model is indicated by the model file extension .tflite
"""
model_path = Path(model)
return model_path.suffix == ".tflite"