aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/tflite_metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/tflite_metrics.py')
-rw-r--r--src/mlia/nn/tensorflow/tflite_metrics.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py
index 0af7500..d7ae2a4 100644
--- a/src/mlia/nn/tensorflow/tflite_metrics.py
+++ b/src/mlia/nn/tensorflow/tflite_metrics.py
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""
-Contains class TFLiteMetrics to calculate metrics from a TFLite file.
+Contains class TFLiteMetrics to calculate metrics from a TensorFlow Lite file.
These metrics include:
* Sparsity (per layer and overall)
@@ -102,7 +102,7 @@ class ReportClusterMode(Enum):
class TFLiteMetrics:
- """Helper class to calculate metrics from a TFLite file.
+ """Helper class to calculate metrics from a TensorFlow Lite file.
Metrics include:
* sparsity (per-layer and overall)
@@ -111,12 +111,12 @@ class TFLiteMetrics:
"""
def __init__(self, tflite_file: str, ignore_list: list[str] | None = None) -> None:
- """Load the TFLite file and filter layers."""
+ """Load the TensorFlow Lite file and filter layers."""
self.tflite_file = tflite_file
if ignore_list is None:
ignore_list = DEFAULT_IGNORE_LIST
self.ignore_list = [ignore.casefold() for ignore in ignore_list]
- # Initialize the TFLite interpreter with the model file
+ # Initialize the TensorFlow Lite interpreter with the model file
self.interpreter = tf.lite.Interpreter(
model_path=tflite_file, experimental_preserve_all_tensors=True
)
@@ -218,7 +218,7 @@ class TFLiteMetrics:
"""Print a summary of all the model information."""
print(f"Model file: {self.tflite_file}")
print("#" * 80)
- print(" " * 28 + "### TFLITE SUMMARY ###")
+ print(" " * 28 + "### TENSORFLOW LITE SUMMARY ###")
print(f"File: {os.path.abspath(self.tflite_file)}")
print("Input(s):")
self._print_in_outs(self.interpreter.get_input_details(), verbose)