aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGergely Nagy <gergely.nagy@arm.com>2023-11-21 12:29:38 +0000
committerGergely Nagy <gergely.nagy@arm.com>2023-12-07 17:09:31 +0000
commit54eec806272b7574a0757c77a913a369a9ecdc70 (patch)
tree2e6484b857b2a68279a2707dbb21e5c26685f4e2 /src
parent7c50f1d6367186c03a282ac7ecb8fca0f905ba30 (diff)
downloadmlia-54eec806272b7574a0757c77a913a369a9ecdc70.tar.gz
MLIA-835 Invalid JSON output
TFLiteConverter was producing log messages in the output that was not possible to capture and redirect to logging. The solution/workaround is to run it as a subprocess. This change required some refactoring around existing invocations of the converter. Change-Id: I394bd0d49d36e6686cfcb9d658e4aad05326cb87 Signed-off-by: Gergely Nagy <gergely.nagy@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/mlia/nn/rewrite/core/train.py8
-rw-r--r--src/mlia/nn/tensorflow/config.py20
-rw-r--r--src/mlia/nn/tensorflow/tflite_compat.py7
-rw-r--r--src/mlia/nn/tensorflow/tflite_convert.py167
-rw-r--r--src/mlia/nn/tensorflow/utils.py59
5 files changed, 186 insertions, 75 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 6345f07..72b8f48 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -34,9 +34,9 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
-from mlia.nn.tensorflow.utils import get_tflite_converter
from mlia.utils.logging import log_action
@@ -499,11 +499,7 @@ def save_as_tflite(
keras_model.input.set_shape(orig_shape)
with fixed_input(keras_model, input_shape) as fixed_model:
- converter = get_tflite_converter(fixed_model, quantized=model_is_quantized)
- tflite_model = converter.convert()
-
- with open(filename, "wb") as file:
- file.write(tflite_model)
+ convert_to_tflite(fixed_model, model_is_quantized, Path(filename))
# Now fix the shapes and names to match those we expect
flatbuffer = load_fb(filename)
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index b94350a..0a17977 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -21,14 +21,13 @@ from mlia.nn.tensorflow.optimizations.quantization import dequantize
from mlia.nn.tensorflow.optimizations.quantization import is_quantized
from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters
from mlia.nn.tensorflow.optimizations.quantization import quantize
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.nn.tensorflow.utils import check_tflite_datatypes
-from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
from mlia.nn.tensorflow.utils import is_tflite_model
-from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.utils.logging import log_action
logger = logging.getLogger(__name__)
@@ -67,9 +66,14 @@ class KerasModel(ModelConfiguration):
) -> TFLiteModel:
"""Convert model to TensorFlow Lite format."""
with log_action("Converting Keras to TensorFlow Lite ..."):
- converted_model = convert_to_tflite(self.get_keras_model(), quantized)
+ convert_to_tflite(
+ self.get_keras_model(),
+ quantized,
+ input_path=Path(self.model_path),
+ output_path=Path(tflite_model_path),
+ subprocess=True,
+ )
- save_tflite_model(converted_model, tflite_model_path)
logger.debug(
"Model %s converted and saved to %s", self.model_path, tflite_model_path
)
@@ -270,8 +274,12 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
"""Convert model to TensorFlow Lite format."""
- converted_model = convert_to_tflite(self.model_path, quantized)
- save_tflite_model(converted_model, tflite_model_path)
+ convert_to_tflite(
+ self.model_path,
+ quantized,
+ input_path=Path(self.model_path),
+ output_path=Path(tflite_model_path),
+ )
return TFLiteModel(tflite_model_path)
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py
index 2b29879..497d5b1 100644
--- a/src/mlia/nn/tensorflow/tflite_compat.py
+++ b/src/mlia/nn/tensorflow/tflite_compat.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Functions for checking TensorFlow Lite compatibility."""
from __future__ import annotations
@@ -14,7 +14,7 @@ from typing import List
import tensorflow as tf
from tensorflow.lite.python import convert
-from mlia.nn.tensorflow.utils import get_tflite_converter
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.utils.logging import redirect_raw_output
TF_VERSION_MAJOR, TF_VERSION_MINOR, _ = (int(s) for s in tf.version.VERSION.split("."))
@@ -115,7 +115,6 @@ class TFLiteChecker:
"""Check TensorFlow Lite compatibility for the provided model."""
try:
logger.debug("Check TensorFlow Lite compatibility for %s", model)
- converter = get_tflite_converter(model, quantized=self.quantized)
# there is an issue with intercepting TensorFlow output
# not all output could be captured, for now just intercept
@@ -123,7 +122,7 @@ class TFLiteChecker:
with redirect_raw_output(
logging.getLogger("tensorflow"), stdout_level=None
):
- converter.convert()
+ convert_to_tflite(model, self.quantized)
except convert.ConverterError as err:
return self._process_convert_error(err)
except Exception as err: # pylint: disable=broad-except
diff --git a/src/mlia/nn/tensorflow/tflite_convert.py b/src/mlia/nn/tensorflow/tflite_convert.py
new file mode 100644
index 0000000..d3a833a
--- /dev/null
+++ b/src/mlia/nn/tensorflow/tflite_convert.py
@@ -0,0 +1,167 @@
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Support module to call TFLiteConverter."""
+from __future__ import annotations
+
+import argparse
+import logging
+import sys
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Iterable
+
+import numpy as np
+import tensorflow as tf
+
+from mlia.nn.tensorflow.utils import get_tf_tensor_shape
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.nn.tensorflow.utils import is_saved_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+from mlia.utils.logging import redirect_output
+from mlia.utils.proc import Command
+from mlia.utils.proc import command_output
+
+logger = logging.getLogger(__name__)
+
+
+def representative_dataset(
+ input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
+) -> Callable:
+ """Sample dataset used for quantization."""
+
+ def dataset() -> Iterable:
+ for _ in range(sample_count):
+ data = np.random.rand(1, *input_shape[1:])
+ yield [data.astype(input_dtype)]
+
+ return dataset
+
+
+def get_tflite_converter(
+ model: tf.keras.Model | str | Path, quantized: bool = False
+) -> tf.lite.TFLiteConverter:
+ """Configure TensorFlow Lite converter for the provided model."""
+ if isinstance(model, (str, Path)):
+ # converter's methods accept string as input parameter
+ model = str(model)
+
+ if isinstance(model, tf.keras.Model):
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ input_shape = model.input_shape
+ elif isinstance(model, str) and is_saved_model(model):
+ converter = tf.lite.TFLiteConverter.from_saved_model(model)
+ input_shape = get_tf_tensor_shape(model)
+ elif isinstance(model, str) and is_keras_model(model):
+ keras_model = tf.keras.models.load_model(model)
+ input_shape = keras_model.input_shape
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+ else:
+ raise ValueError(f"Unable to create TensorFlow Lite converter for {model}")
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset(input_shape)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ return converter
+
+
+def convert_to_tflite_bytes(
+ model: tf.keras.Model | str, quantized: bool = False
+) -> bytes:
+ """Convert Keras model to TensorFlow Lite."""
+ converter = get_tflite_converter(model, quantized)
+
+ with redirect_output(logging.getLogger("tensorflow")):
+ output_bytes = cast(bytes, converter.convert())
+
+ return output_bytes
+
+
+def _convert_to_tflite(
+ model: tf.keras.Model | str,
+ quantized: bool = False,
+ output_path: Path | None = None,
+) -> bytes:
+ """Convert Keras model to TensorFlow Lite."""
+ output_bytes = convert_to_tflite_bytes(model, quantized)
+
+ if output_path:
+ save_tflite_model(output_bytes, output_path)
+
+ return output_bytes
+
+
+def convert_to_tflite(
+ model: tf.keras.Model | str,
+ quantized: bool = False,
+ output_path: Path | None = None,
+ input_path: Path | None = None,
+ subprocess: bool = False,
+) -> None:
+ """Convert Keras model to TensorFlow Lite.
+
+ Optionally runs TFLiteConverter in a subprocess,
+ this is added mainly to work around issues when redirecting
+ Tensorflow's output using SDK calls, didn't make an effect,
+ which would produce unwanted output for MLIA.
+
+ In the subprocess mode, the model should be passed as a
+ file path, or via a dedicated 'input_path' parameter.
+
+ If 'output_path' is provided, the result model be saved under
+ that path.
+ """
+ if not subprocess:
+ _convert_to_tflite(model, quantized, output_path)
+ return
+
+ if input_path is None:
+ if isinstance(model, str):
+ input_path = Path(model)
+ else:
+ raise RuntimeError(
+ f"Input path is required for {model}"
+ " when converter is called in subprocess."
+ )
+
+ args = ["python", __file__, str(input_path)]
+ if output_path:
+ args.append("--output")
+ args.append(str(output_path))
+ if quantized:
+ args.append("--quantize")
+
+ command = Command(args)
+
+ for line in command_output(command):
+ logger.debug("TFLiteConverter: %s", line)
+
+
+def main(argv: list[str] | None = None) -> int:
+ """Entry point to run this module as a standalone executable."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input", type=Path)
+ parser.add_argument("--output", type=Path, default=None)
+ parser.add_argument("--quantize", default=False, action="store_true")
+ args = parser.parse_args(argv)
+
+ if not Path(args.input).exists():
+ raise ValueError(f"Input file doesn't exist: [{args.input}]")
+
+ logger.debug(
+ "Invoking TFLiteConverter on [%s] -> [%s], quantize: [%s]",
+ args.input,
+ args.output,
+ args.quantize,
+ )
+ _convert_to_tflite(args.input, args.quantize, args.output)
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index b8d45c6..1612447 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -4,31 +4,11 @@
"""Collection of useful functions for optimizations."""
from __future__ import annotations
-import logging
from pathlib import Path
from typing import Any
-from typing import Callable
-from typing import cast
-from typing import Iterable
-import numpy as np
import tensorflow as tf
-from mlia.utils.logging import redirect_output
-
-
-def representative_dataset(
- input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
-) -> Callable:
- """Sample dataset used for quantization."""
-
- def dataset() -> Iterable:
- for _ in range(sample_count):
- data = np.random.rand(1, *input_shape[1:])
- yield [data.astype(input_dtype)]
-
- return dataset
-
def get_tf_tensor_shape(model: str) -> list:
"""Get input shape for the TensorFlow tensor model."""
@@ -49,14 +29,6 @@ def get_tf_tensor_shape(model: str) -> list:
]
-def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes:
- """Convert Keras model to TensorFlow Lite."""
- converter = get_tflite_converter(model, quantized)
-
- with redirect_output(logging.getLogger("tensorflow")):
- return cast(bytes, converter.convert())
-
-
def save_keras_model(
model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True
) -> None:
@@ -94,37 +66,6 @@ def is_saved_model(model: str | Path) -> bool:
return model_path.is_dir() and not is_keras_model(model)
-def get_tflite_converter(
- model: tf.keras.Model | str | Path, quantized: bool = False
-) -> tf.lite.TFLiteConverter:
- """Configure TensorFlow Lite converter for the provided model."""
- if isinstance(model, (str, Path)):
- # converter's methods accept string as input parameter
- model = str(model)
-
- if isinstance(model, tf.keras.Model):
- converter = tf.lite.TFLiteConverter.from_keras_model(model)
- input_shape = model.input_shape
- elif isinstance(model, str) and is_saved_model(model):
- converter = tf.lite.TFLiteConverter.from_saved_model(model)
- input_shape = get_tf_tensor_shape(model)
- elif isinstance(model, str) and is_keras_model(model):
- keras_model = tf.keras.models.load_model(model)
- input_shape = keras_model.input_shape
- converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
- else:
- raise ValueError(f"Unable to create TensorFlow Lite converter for {model}")
-
- if quantized:
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset(input_shape)
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
-
- return converter
-
-
def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]:
"""Get type map from tflite model."""
model_type_map: dict[str, Any] = {}