From 54eec806272b7574a0757c77a913a369a9ecdc70 Mon Sep 17 00:00:00 2001 From: Gergely Nagy Date: Tue, 21 Nov 2023 12:29:38 +0000 Subject: MLIA-835 Invalid JSON output TFLiteConverter was producing log messages in the output that was not possible to capture and redirect to logging. The solution/workaround is to run it as a subprocess. This change required some refactoring around existing invocations of the converter. Change-Id: I394bd0d49d36e6686cfcb9d658e4aad05326cb87 Signed-off-by: Gergely Nagy --- src/mlia/nn/rewrite/core/train.py | 8 +- src/mlia/nn/tensorflow/config.py | 20 ++-- src/mlia/nn/tensorflow/tflite_compat.py | 7 +- src/mlia/nn/tensorflow/tflite_convert.py | 167 +++++++++++++++++++++++++++++++ src/mlia/nn/tensorflow/utils.py | 59 ----------- 5 files changed, 186 insertions(+), 75 deletions(-) create mode 100644 src/mlia/nn/tensorflow/tflite_convert.py (limited to 'src') diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 6345f07..72b8f48 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -34,9 +34,9 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel from mlia.nn.tensorflow.config import TFLiteModel +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb -from mlia.nn.tensorflow.utils import get_tflite_converter from mlia.utils.logging import log_action @@ -499,11 +499,7 @@ def save_as_tflite( keras_model.input.set_shape(orig_shape) with fixed_input(keras_model, input_shape) as fixed_model: - converter = get_tflite_converter(fixed_model, quantized=model_is_quantized) - tflite_model = converter.convert() - - with open(filename, "wb") as file: - file.write(tflite_model) + convert_to_tflite(fixed_model, model_is_quantized, Path(filename)) # Now fix the shapes and names to match those we expect flatbuffer = load_fb(filename) diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index b94350a..0a17977 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -21,14 +21,13 @@ from mlia.nn.tensorflow.optimizations.quantization import dequantize from mlia.nn.tensorflow.optimizations.quantization import is_quantized from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters from mlia.nn.tensorflow.optimizations.quantization import quantize +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.utils import check_tflite_datatypes -from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_saved_model from mlia.nn.tensorflow.utils import is_tflite_model -from mlia.nn.tensorflow.utils import save_tflite_model from mlia.utils.logging import log_action logger = logging.getLogger(__name__) @@ -67,9 +66,14 @@ class KerasModel(ModelConfiguration): ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" with log_action("Converting Keras to TensorFlow Lite ..."): - converted_model = convert_to_tflite(self.get_keras_model(), quantized) + convert_to_tflite( + self.get_keras_model(), + quantized, + input_path=Path(self.model_path), + output_path=Path(tflite_model_path), + subprocess=True, + ) - save_tflite_model(converted_model, tflite_model_path) logger.debug( "Model %s converted and saved to %s", self.model_path, tflite_model_path ) @@ -270,8 +274,12 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" - converted_model = convert_to_tflite(self.model_path, quantized) - save_tflite_model(converted_model, tflite_model_path) + convert_to_tflite( + self.model_path, + quantized, + input_path=Path(self.model_path), + output_path=Path(tflite_model_path), + ) return TFLiteModel(tflite_model_path) diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py index 2b29879..497d5b1 100644 --- a/src/mlia/nn/tensorflow/tflite_compat.py +++ b/src/mlia/nn/tensorflow/tflite_compat.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Functions for checking TensorFlow Lite compatibility.""" from __future__ import annotations @@ -14,7 +14,7 @@ from typing import List import tensorflow as tf from tensorflow.lite.python import convert -from mlia.nn.tensorflow.utils import get_tflite_converter +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.utils.logging import redirect_raw_output TF_VERSION_MAJOR, TF_VERSION_MINOR, _ = (int(s) for s in tf.version.VERSION.split(".")) @@ -115,7 +115,6 @@ class TFLiteChecker: """Check TensorFlow Lite compatibility for the provided model.""" try: logger.debug("Check TensorFlow Lite compatibility for %s", model) - converter = get_tflite_converter(model, quantized=self.quantized) # there is an issue with intercepting TensorFlow output # not all output could be captured, for now just intercept @@ -123,7 +122,7 @@ class TFLiteChecker: with redirect_raw_output( logging.getLogger("tensorflow"), stdout_level=None ): - converter.convert() + convert_to_tflite(model, self.quantized) except convert.ConverterError as err: return self._process_convert_error(err) except Exception as err: # pylint: disable=broad-except diff --git a/src/mlia/nn/tensorflow/tflite_convert.py b/src/mlia/nn/tensorflow/tflite_convert.py new file mode 100644 index 0000000..d3a833a --- /dev/null +++ b/src/mlia/nn/tensorflow/tflite_convert.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Support module to call TFLiteConverter.""" +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import Iterable + +import numpy as np +import tensorflow as tf + +from mlia.nn.tensorflow.utils import get_tf_tensor_shape +from mlia.nn.tensorflow.utils import is_keras_model +from mlia.nn.tensorflow.utils import is_saved_model +from mlia.nn.tensorflow.utils import save_tflite_model +from mlia.utils.logging import redirect_output +from mlia.utils.proc import Command +from mlia.utils.proc import command_output + +logger = logging.getLogger(__name__) + + +def representative_dataset( + input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32 +) -> Callable: + """Sample dataset used for quantization.""" + + def dataset() -> Iterable: + for _ in range(sample_count): + data = np.random.rand(1, *input_shape[1:]) + yield [data.astype(input_dtype)] + + return dataset + + +def get_tflite_converter( + model: tf.keras.Model | str | Path, quantized: bool = False +) -> tf.lite.TFLiteConverter: + """Configure TensorFlow Lite converter for the provided model.""" + if isinstance(model, (str, Path)): + # converter's methods accept string as input parameter + model = str(model) + + if isinstance(model, tf.keras.Model): + converter = tf.lite.TFLiteConverter.from_keras_model(model) + input_shape = model.input_shape + elif isinstance(model, str) and is_saved_model(model): + converter = tf.lite.TFLiteConverter.from_saved_model(model) + input_shape = get_tf_tensor_shape(model) + elif isinstance(model, str) and is_keras_model(model): + keras_model = tf.keras.models.load_model(model) + input_shape = keras_model.input_shape + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + else: + raise ValueError(f"Unable to create TensorFlow Lite converter for {model}") + + if quantized: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset(input_shape) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + + return converter + + +def convert_to_tflite_bytes( + model: tf.keras.Model | str, quantized: bool = False +) -> bytes: + """Convert Keras model to TensorFlow Lite.""" + converter = get_tflite_converter(model, quantized) + + with redirect_output(logging.getLogger("tensorflow")): + output_bytes = cast(bytes, converter.convert()) + + return output_bytes + + +def _convert_to_tflite( + model: tf.keras.Model | str, + quantized: bool = False, + output_path: Path | None = None, +) -> bytes: + """Convert Keras model to TensorFlow Lite.""" + output_bytes = convert_to_tflite_bytes(model, quantized) + + if output_path: + save_tflite_model(output_bytes, output_path) + + return output_bytes + + +def convert_to_tflite( + model: tf.keras.Model | str, + quantized: bool = False, + output_path: Path | None = None, + input_path: Path | None = None, + subprocess: bool = False, +) -> None: + """Convert Keras model to TensorFlow Lite. + + Optionally runs TFLiteConverter in a subprocess, + this is added mainly to work around issues when redirecting + Tensorflow's output using SDK calls, didn't make an effect, + which would produce unwanted output for MLIA. + + In the subprocess mode, the model should be passed as a + file path, or via a dedicated 'input_path' parameter. + + If 'output_path' is provided, the result model be saved under + that path. + """ + if not subprocess: + _convert_to_tflite(model, quantized, output_path) + return + + if input_path is None: + if isinstance(model, str): + input_path = Path(model) + else: + raise RuntimeError( + f"Input path is required for {model}" + " when converter is called in subprocess." + ) + + args = ["python", __file__, str(input_path)] + if output_path: + args.append("--output") + args.append(str(output_path)) + if quantized: + args.append("--quantize") + + command = Command(args) + + for line in command_output(command): + logger.debug("TFLiteConverter: %s", line) + + +def main(argv: list[str] | None = None) -> int: + """Entry point to run this module as a standalone executable.""" + parser = argparse.ArgumentParser() + parser.add_argument("input", type=Path) + parser.add_argument("--output", type=Path, default=None) + parser.add_argument("--quantize", default=False, action="store_true") + args = parser.parse_args(argv) + + if not Path(args.input).exists(): + raise ValueError(f"Input file doesn't exist: [{args.input}]") + + logger.debug( + "Invoking TFLiteConverter on [%s] -> [%s], quantize: [%s]", + args.input, + args.output, + args.quantize, + ) + _convert_to_tflite(args.input, args.quantize, args.output) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index b8d45c6..1612447 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -4,31 +4,11 @@ """Collection of useful functions for optimizations.""" from __future__ import annotations -import logging from pathlib import Path from typing import Any -from typing import Callable -from typing import cast -from typing import Iterable -import numpy as np import tensorflow as tf -from mlia.utils.logging import redirect_output - - -def representative_dataset( - input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32 -) -> Callable: - """Sample dataset used for quantization.""" - - def dataset() -> Iterable: - for _ in range(sample_count): - data = np.random.rand(1, *input_shape[1:]) - yield [data.astype(input_dtype)] - - return dataset - def get_tf_tensor_shape(model: str) -> list: """Get input shape for the TensorFlow tensor model.""" @@ -49,14 +29,6 @@ def get_tf_tensor_shape(model: str) -> list: ] -def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes: - """Convert Keras model to TensorFlow Lite.""" - converter = get_tflite_converter(model, quantized) - - with redirect_output(logging.getLogger("tensorflow")): - return cast(bytes, converter.convert()) - - def save_keras_model( model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True ) -> None: @@ -94,37 +66,6 @@ def is_saved_model(model: str | Path) -> bool: return model_path.is_dir() and not is_keras_model(model) -def get_tflite_converter( - model: tf.keras.Model | str | Path, quantized: bool = False -) -> tf.lite.TFLiteConverter: - """Configure TensorFlow Lite converter for the provided model.""" - if isinstance(model, (str, Path)): - # converter's methods accept string as input parameter - model = str(model) - - if isinstance(model, tf.keras.Model): - converter = tf.lite.TFLiteConverter.from_keras_model(model) - input_shape = model.input_shape - elif isinstance(model, str) and is_saved_model(model): - converter = tf.lite.TFLiteConverter.from_saved_model(model) - input_shape = get_tf_tensor_shape(model) - elif isinstance(model, str) and is_keras_model(model): - keras_model = tf.keras.models.load_model(model) - input_shape = keras_model.input_shape - converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) - else: - raise ValueError(f"Unable to create TensorFlow Lite converter for {model}") - - if quantized: - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = representative_dataset(input_shape) - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 - - return converter - - def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]: """Get type map from tflite model.""" model_type_map: dict[str, Any] = {} -- cgit v1.2.1