From 3083f7ee68ce08147db08fca2474e5f4712fc8d7 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Fri, 7 Oct 2022 11:39:37 +0100 Subject: MLIA-607 Update documentation and comments Use "TensorFlow Lite" instead of "TFLite" in documentation and comments Change-Id: Ie4450d72fb2e5261d152d72ab8bd94c3da914c46 --- src/mlia/nn/tensorflow/config.py | 16 ++++++++-------- src/mlia/nn/tensorflow/tflite_metrics.py | 10 +++++----- src/mlia/nn/tensorflow/utils.py | 10 +++++----- 3 files changed, 18 insertions(+), 18 deletions(-) (limited to 'src/mlia/nn') diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index 6ee32e7..03d1d0f 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -31,7 +31,7 @@ class ModelConfiguration: def convert_to_tflite( self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: - """Convert model to TFLite format.""" + """Convert model to TensorFlow Lite format.""" raise NotImplementedError() def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel: @@ -52,8 +52,8 @@ class KerasModel(ModelConfiguration): def convert_to_tflite( self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: - """Convert model to TFLite format.""" - logger.info("Converting Keras to TFLite ...") + """Convert model to TensorFlow Lite format.""" + logger.info("Converting Keras to TensorFlow Lite ...") converted_model = convert_to_tflite(self.get_keras_model(), quantized) logger.info("Done\n") @@ -71,7 +71,7 @@ class KerasModel(ModelConfiguration): class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method - """TFLite model configuration.""" + """TensorFlow Lite model configuration.""" def input_details(self) -> list[dict]: """Get model's input details.""" @@ -81,7 +81,7 @@ class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method def convert_to_tflite( self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: - """Convert model to TFLite format.""" + """Convert model to TensorFlow Lite format.""" return self @@ -94,7 +94,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method def convert_to_tflite( self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: - """Convert model to TFLite format.""" + """Convert model to TensorFlow Lite format.""" converted_model = convert_tf_to_tflite(self.model_path, quantized) save_tflite_model(converted_model, tflite_model_path) @@ -114,12 +114,12 @@ def get_model(model: str | Path) -> ModelConfiguration: raise Exception( "The input model format is not supported" - "(supported formats: TFLite, Keras, TensorFlow saved model)!" + "(supported formats: TensorFlow Lite, Keras, TensorFlow saved model)!" ) def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel: - """Convert input model to TFLite and returns TFLiteModel object.""" + """Convert input model to TensorFlow Lite and returns TFLiteModel object.""" tflite_model_path = ctx.get_model_path("converted_model.tflite") converted_model = get_model(model) diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py index 0af7500..d7ae2a4 100644 --- a/src/mlia/nn/tensorflow/tflite_metrics.py +++ b/src/mlia/nn/tensorflow/tflite_metrics.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """ -Contains class TFLiteMetrics to calculate metrics from a TFLite file. +Contains class TFLiteMetrics to calculate metrics from a TensorFlow Lite file. These metrics include: * Sparsity (per layer and overall) @@ -102,7 +102,7 @@ class ReportClusterMode(Enum): class TFLiteMetrics: - """Helper class to calculate metrics from a TFLite file. + """Helper class to calculate metrics from a TensorFlow Lite file. Metrics include: * sparsity (per-layer and overall) @@ -111,12 +111,12 @@ class TFLiteMetrics: """ def __init__(self, tflite_file: str, ignore_list: list[str] | None = None) -> None: - """Load the TFLite file and filter layers.""" + """Load the TensorFlow Lite file and filter layers.""" self.tflite_file = tflite_file if ignore_list is None: ignore_list = DEFAULT_IGNORE_LIST self.ignore_list = [ignore.casefold() for ignore in ignore_list] - # Initialize the TFLite interpreter with the model file + # Initialize the TensorFlow Lite interpreter with the model file self.interpreter = tf.lite.Interpreter( model_path=tflite_file, experimental_preserve_all_tensors=True ) @@ -218,7 +218,7 @@ class TFLiteMetrics: """Print a summary of all the model information.""" print(f"Model file: {self.tflite_file}") print("#" * 80) - print(" " * 28 + "### TFLITE SUMMARY ###") + print(" " * 28 + "### TENSORFLOW LITE SUMMARY ###") print(f"File: {os.path.abspath(self.tflite_file)}") print("Input(s):") self._print_in_outs(self.interpreter.get_input_details(), verbose) diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index 6250f56..7970329 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -63,7 +63,7 @@ def representative_tf_dataset(model: str) -> Callable: def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter: - """Convert Keras model to TFLite.""" + """Convert Keras model to TensorFlow Lite.""" if not isinstance(model, tf.keras.Model): raise Exception("Invalid model type") @@ -83,7 +83,7 @@ def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpr def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter: - """Convert TensorFlow model to TFLite.""" + """Convert TensorFlow model to TensorFlow Lite.""" if not isinstance(model, str): raise Exception("Invalid model type") @@ -109,15 +109,15 @@ def save_keras_model(model: tf.keras.Model, save_path: str | Path) -> None: def save_tflite_model(model: tf.lite.TFLiteConverter, save_path: str | Path) -> None: - """Save TFLite model at provided path.""" + """Save TensorFlow Lite model at provided path.""" with open(save_path, "wb") as file: file.write(model) def is_tflite_model(model: str | Path) -> bool: - """Check if model type is supported by TFLite API. + """Check if model type is supported by TensorFlow Lite API. - TFLite model is indicated by the model file extension .tflite + TensorFlow Lite model is indicated by the model file extension .tflite """ model_path = Path(model) return model_path.suffix == ".tflite" -- cgit v1.2.1