From 54eec806272b7574a0757c77a913a369a9ecdc70 Mon Sep 17 00:00:00 2001 From: Gergely Nagy Date: Tue, 21 Nov 2023 12:29:38 +0000 Subject: MLIA-835 Invalid JSON output TFLiteConverter was producing log messages in the output that was not possible to capture and redirect to logging. The solution/workaround is to run it as a subprocess. This change required some refactoring around existing invocations of the converter. Change-Id: I394bd0d49d36e6686cfcb9d658e4aad05326cb87 Signed-off-by: Gergely Nagy --- src/mlia/nn/rewrite/core/train.py | 8 +- src/mlia/nn/tensorflow/config.py | 20 +- src/mlia/nn/tensorflow/tflite_compat.py | 7 +- src/mlia/nn/tensorflow/tflite_convert.py | 167 ++++++++++++++ src/mlia/nn/tensorflow/utils.py | 59 ----- tests/conftest.py | 11 +- .../test_nn_tensorflow_optimizations_clustering.py | 6 +- tests/test_nn_tensorflow_optimizations_pruning.py | 7 +- tests/test_nn_tensorflow_tflite_compat.py | 4 +- tests/test_nn_tensorflow_tflite_convert.py | 244 +++++++++++++++++++++ tests/test_nn_tensorflow_utils.py | 44 +--- tests/test_target_cortex_a_operators.py | 4 +- 12 files changed, 445 insertions(+), 136 deletions(-) create mode 100644 src/mlia/nn/tensorflow/tflite_convert.py create mode 100644 tests/test_nn_tensorflow_tflite_convert.py diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 6345f07..72b8f48 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -34,9 +34,9 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel from mlia.nn.tensorflow.config import TFLiteModel +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb -from mlia.nn.tensorflow.utils import get_tflite_converter from mlia.utils.logging import log_action @@ -499,11 +499,7 @@ def save_as_tflite( keras_model.input.set_shape(orig_shape) with fixed_input(keras_model, input_shape) as fixed_model: - converter = get_tflite_converter(fixed_model, quantized=model_is_quantized) - tflite_model = converter.convert() - - with open(filename, "wb") as file: - file.write(tflite_model) + convert_to_tflite(fixed_model, model_is_quantized, Path(filename)) # Now fix the shapes and names to match those we expect flatbuffer = load_fb(filename) diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index b94350a..0a17977 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -21,14 +21,13 @@ from mlia.nn.tensorflow.optimizations.quantization import dequantize from mlia.nn.tensorflow.optimizations.quantization import is_quantized from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters from mlia.nn.tensorflow.optimizations.quantization import quantize +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.utils import check_tflite_datatypes -from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_saved_model from mlia.nn.tensorflow.utils import is_tflite_model -from mlia.nn.tensorflow.utils import save_tflite_model from mlia.utils.logging import log_action logger = logging.getLogger(__name__) @@ -67,9 +66,14 @@ class KerasModel(ModelConfiguration): ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" with log_action("Converting Keras to TensorFlow Lite ..."): - converted_model = convert_to_tflite(self.get_keras_model(), quantized) + convert_to_tflite( + self.get_keras_model(), + quantized, + input_path=Path(self.model_path), + output_path=Path(tflite_model_path), + subprocess=True, + ) - save_tflite_model(converted_model, tflite_model_path) logger.debug( "Model %s converted and saved to %s", self.model_path, tflite_model_path ) @@ -270,8 +274,12 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" - converted_model = convert_to_tflite(self.model_path, quantized) - save_tflite_model(converted_model, tflite_model_path) + convert_to_tflite( + self.model_path, + quantized, + input_path=Path(self.model_path), + output_path=Path(tflite_model_path), + ) return TFLiteModel(tflite_model_path) diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py index 2b29879..497d5b1 100644 --- a/src/mlia/nn/tensorflow/tflite_compat.py +++ b/src/mlia/nn/tensorflow/tflite_compat.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Functions for checking TensorFlow Lite compatibility.""" from __future__ import annotations @@ -14,7 +14,7 @@ from typing import List import tensorflow as tf from tensorflow.lite.python import convert -from mlia.nn.tensorflow.utils import get_tflite_converter +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.utils.logging import redirect_raw_output TF_VERSION_MAJOR, TF_VERSION_MINOR, _ = (int(s) for s in tf.version.VERSION.split(".")) @@ -115,7 +115,6 @@ class TFLiteChecker: """Check TensorFlow Lite compatibility for the provided model.""" try: logger.debug("Check TensorFlow Lite compatibility for %s", model) - converter = get_tflite_converter(model, quantized=self.quantized) # there is an issue with intercepting TensorFlow output # not all output could be captured, for now just intercept @@ -123,7 +122,7 @@ class TFLiteChecker: with redirect_raw_output( logging.getLogger("tensorflow"), stdout_level=None ): - converter.convert() + convert_to_tflite(model, self.quantized) except convert.ConverterError as err: return self._process_convert_error(err) except Exception as err: # pylint: disable=broad-except diff --git a/src/mlia/nn/tensorflow/tflite_convert.py b/src/mlia/nn/tensorflow/tflite_convert.py new file mode 100644 index 0000000..d3a833a --- /dev/null +++ b/src/mlia/nn/tensorflow/tflite_convert.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Support module to call TFLiteConverter.""" +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import Iterable + +import numpy as np +import tensorflow as tf + +from mlia.nn.tensorflow.utils import get_tf_tensor_shape +from mlia.nn.tensorflow.utils import is_keras_model +from mlia.nn.tensorflow.utils import is_saved_model +from mlia.nn.tensorflow.utils import save_tflite_model +from mlia.utils.logging import redirect_output +from mlia.utils.proc import Command +from mlia.utils.proc import command_output + +logger = logging.getLogger(__name__) + + +def representative_dataset( + input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32 +) -> Callable: + """Sample dataset used for quantization.""" + + def dataset() -> Iterable: + for _ in range(sample_count): + data = np.random.rand(1, *input_shape[1:]) + yield [data.astype(input_dtype)] + + return dataset + + +def get_tflite_converter( + model: tf.keras.Model | str | Path, quantized: bool = False +) -> tf.lite.TFLiteConverter: + """Configure TensorFlow Lite converter for the provided model.""" + if isinstance(model, (str, Path)): + # converter's methods accept string as input parameter + model = str(model) + + if isinstance(model, tf.keras.Model): + converter = tf.lite.TFLiteConverter.from_keras_model(model) + input_shape = model.input_shape + elif isinstance(model, str) and is_saved_model(model): + converter = tf.lite.TFLiteConverter.from_saved_model(model) + input_shape = get_tf_tensor_shape(model) + elif isinstance(model, str) and is_keras_model(model): + keras_model = tf.keras.models.load_model(model) + input_shape = keras_model.input_shape + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + else: + raise ValueError(f"Unable to create TensorFlow Lite converter for {model}") + + if quantized: + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset(input_shape) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + + return converter + + +def convert_to_tflite_bytes( + model: tf.keras.Model | str, quantized: bool = False +) -> bytes: + """Convert Keras model to TensorFlow Lite.""" + converter = get_tflite_converter(model, quantized) + + with redirect_output(logging.getLogger("tensorflow")): + output_bytes = cast(bytes, converter.convert()) + + return output_bytes + + +def _convert_to_tflite( + model: tf.keras.Model | str, + quantized: bool = False, + output_path: Path | None = None, +) -> bytes: + """Convert Keras model to TensorFlow Lite.""" + output_bytes = convert_to_tflite_bytes(model, quantized) + + if output_path: + save_tflite_model(output_bytes, output_path) + + return output_bytes + + +def convert_to_tflite( + model: tf.keras.Model | str, + quantized: bool = False, + output_path: Path | None = None, + input_path: Path | None = None, + subprocess: bool = False, +) -> None: + """Convert Keras model to TensorFlow Lite. + + Optionally runs TFLiteConverter in a subprocess, + this is added mainly to work around issues when redirecting + Tensorflow's output using SDK calls, didn't make an effect, + which would produce unwanted output for MLIA. + + In the subprocess mode, the model should be passed as a + file path, or via a dedicated 'input_path' parameter. + + If 'output_path' is provided, the result model be saved under + that path. + """ + if not subprocess: + _convert_to_tflite(model, quantized, output_path) + return + + if input_path is None: + if isinstance(model, str): + input_path = Path(model) + else: + raise RuntimeError( + f"Input path is required for {model}" + " when converter is called in subprocess." + ) + + args = ["python", __file__, str(input_path)] + if output_path: + args.append("--output") + args.append(str(output_path)) + if quantized: + args.append("--quantize") + + command = Command(args) + + for line in command_output(command): + logger.debug("TFLiteConverter: %s", line) + + +def main(argv: list[str] | None = None) -> int: + """Entry point to run this module as a standalone executable.""" + parser = argparse.ArgumentParser() + parser.add_argument("input", type=Path) + parser.add_argument("--output", type=Path, default=None) + parser.add_argument("--quantize", default=False, action="store_true") + args = parser.parse_args(argv) + + if not Path(args.input).exists(): + raise ValueError(f"Input file doesn't exist: [{args.input}]") + + logger.debug( + "Invoking TFLiteConverter on [%s] -> [%s], quantize: [%s]", + args.input, + args.output, + args.quantize, + ) + _convert_to_tflite(args.input, args.quantize, args.output) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index b8d45c6..1612447 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -4,31 +4,11 @@ """Collection of useful functions for optimizations.""" from __future__ import annotations -import logging from pathlib import Path from typing import Any -from typing import Callable -from typing import cast -from typing import Iterable -import numpy as np import tensorflow as tf -from mlia.utils.logging import redirect_output - - -def representative_dataset( - input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32 -) -> Callable: - """Sample dataset used for quantization.""" - - def dataset() -> Iterable: - for _ in range(sample_count): - data = np.random.rand(1, *input_shape[1:]) - yield [data.astype(input_dtype)] - - return dataset - def get_tf_tensor_shape(model: str) -> list: """Get input shape for the TensorFlow tensor model.""" @@ -49,14 +29,6 @@ def get_tf_tensor_shape(model: str) -> list: ] -def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes: - """Convert Keras model to TensorFlow Lite.""" - converter = get_tflite_converter(model, quantized) - - with redirect_output(logging.getLogger("tensorflow")): - return cast(bytes, converter.convert()) - - def save_keras_model( model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True ) -> None: @@ -94,37 +66,6 @@ def is_saved_model(model: str | Path) -> bool: return model_path.is_dir() and not is_keras_model(model) -def get_tflite_converter( - model: tf.keras.Model | str | Path, quantized: bool = False -) -> tf.lite.TFLiteConverter: - """Configure TensorFlow Lite converter for the provided model.""" - if isinstance(model, (str, Path)): - # converter's methods accept string as input parameter - model = str(model) - - if isinstance(model, tf.keras.Model): - converter = tf.lite.TFLiteConverter.from_keras_model(model) - input_shape = model.input_shape - elif isinstance(model, str) and is_saved_model(model): - converter = tf.lite.TFLiteConverter.from_saved_model(model) - input_shape = get_tf_tensor_shape(model) - elif isinstance(model, str) and is_keras_model(model): - keras_model = tf.keras.models.load_model(model) - input_shape = keras_model.input_shape - converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) - else: - raise ValueError(f"Unable to create TensorFlow Lite converter for {model}") - - if quantized: - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = representative_dataset(input_shape) - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] - converter.inference_input_type = tf.int8 - converter.inference_output_type = tf.int8 - - return converter - - def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]: """Get type map from tflite model.""" model_type_map: dict[str, Any] = {} diff --git a/tests/conftest.py b/tests/conftest.py index d700206..345eb8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,8 @@ import tensorflow as tf from mlia.backend.vela.compiler import optimize_model from mlia.core.context import ExecutionContext from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter -from mlia.nn.tensorflow.utils import convert_to_tflite +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.utils import save_keras_model -from mlia.nn.tensorflow.utils import save_tflite_model from mlia.target.ethos_u.config import EthosUConfiguration from tests.utils.rewrite import MockTrainingParameters @@ -93,15 +92,13 @@ def fixture_test_models_path( save_keras_model(keras_model, tmp_path / TEST_MODEL_KERAS_FILE) # Un-quantized TensorFlow Lite model (fp32) - save_tflite_model( - convert_to_tflite(keras_model, quantized=False), - tmp_path / TEST_MODEL_TFLITE_FP32_FILE, + convert_to_tflite( + keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE ) # Quantized TensorFlow Lite model (int8) - tflite_model = convert_to_tflite(keras_model, quantized=True) tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE - save_tflite_model(tflite_model, tflite_model_path) + convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path) # Vela-optimized TensorFlow Lite model (int8) tflite_vela_model = tmp_path / TEST_MODEL_TFLITE_VELA_FILE diff --git a/tests/test_nn_tensorflow_optimizations_clustering.py b/tests/test_nn_tensorflow_optimizations_clustering.py index d3c0da6..58ffb3e 100644 --- a/tests/test_nn_tensorflow_optimizations_clustering.py +++ b/tests/test_nn_tensorflow_optimizations_clustering.py @@ -14,10 +14,9 @@ from mlia.nn.tensorflow.optimizations.clustering import Clusterer from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration from mlia.nn.tensorflow.optimizations.pruning import Pruner from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics -from mlia.nn.tensorflow.utils import convert_to_tflite -from mlia.nn.tensorflow.utils import save_tflite_model from tests.utils.common import get_dataset from tests.utils.common import train_model @@ -118,8 +117,7 @@ def test_cluster_simple_model_fully( clustered_model = clusterer.get_model() temp_file = tmp_path / "test_cluster_simple_model_fully_after.tflite" - tflite_clustered_model = convert_to_tflite(clustered_model) - save_tflite_model(tflite_clustered_model, temp_file) + convert_to_tflite(clustered_model, output_path=temp_file) clustered_tflite_metrics = TFLiteMetrics(str(temp_file)) _test_num_unique_weights( diff --git a/tests/test_nn_tensorflow_optimizations_pruning.py b/tests/test_nn_tensorflow_optimizations_pruning.py index d97b3d3..9afc3ff 100644 --- a/tests/test_nn_tensorflow_optimizations_pruning.py +++ b/tests/test_nn_tensorflow_optimizations_pruning.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Test for module optimizations/pruning.""" from __future__ import annotations @@ -11,9 +11,8 @@ from numpy.core.numeric import isclose from mlia.nn.tensorflow.optimizations.pruning import Pruner from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics -from mlia.nn.tensorflow.utils import convert_to_tflite -from mlia.nn.tensorflow.utils import save_tflite_model from tests.utils.common import get_dataset from tests.utils.common import train_model @@ -52,7 +51,7 @@ def _get_tflite_metrics( ) -> TFLiteMetrics: """Save model as TFLiteModel and return metrics.""" temp_file = path / tflite_fn - save_tflite_model(convert_to_tflite(model), temp_file) + convert_to_tflite(model, output_path=temp_file) return TFLiteMetrics(str(temp_file)) diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py index f203125..4ca387c 100644 --- a/tests/test_nn_tensorflow_tflite_compat.py +++ b/tests/test_nn_tensorflow_tflite_compat.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for tflite_compat module.""" from __future__ import annotations @@ -219,7 +219,7 @@ def test_tflite_compatibility( converter_mock.convert.side_effect = conversion_error monkeypatch.setattr( - "mlia.nn.tensorflow.tflite_compat.get_tflite_converter", + "mlia.nn.tensorflow.tflite_convert.get_tflite_converter", lambda *args, **kwargs: converter_mock, ) diff --git a/tests/test_nn_tensorflow_tflite_convert.py b/tests/test_nn_tensorflow_tflite_convert.py new file mode 100644 index 0000000..3125c04 --- /dev/null +++ b/tests/test_nn_tensorflow_tflite_convert.py @@ -0,0 +1,244 @@ +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Test for module utils/test_utils.""" +import os +from pathlib import Path +from pathlib import PosixPath +from unittest.mock import MagicMock + +import numpy as np +import pytest +import tensorflow as tf + +from mlia.nn.tensorflow import tflite_convert +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite_bytes +from mlia.nn.tensorflow.tflite_convert import main +from mlia.nn.tensorflow.tflite_convert import representative_dataset + + +def test_generate_representative_dataset() -> None: + """Test function for generating representative dataset.""" + dataset = representative_dataset([1, 3, 3], 5) + data = list(dataset()) + + assert len(data) == 5 + for elem in data: + assert isinstance(elem, list) + assert len(elem) == 1 + + ndarray = elem[0] + assert ndarray.dtype == np.float32 + assert isinstance(ndarray, np.ndarray) + + +def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None: + """Test converting SavedModel to TensorFlow Lite.""" + result = convert_to_tflite_bytes(test_tf_model.as_posix()) + assert isinstance(result, bytes) + + +def test_convert_keras_to_tflite(test_keras_model: Path) -> None: + """Test converting Keras model to TensorFlow Lite.""" + keras_model = tf.keras.models.load_model(str(test_keras_model)) + result = convert_to_tflite_bytes(keras_model) + assert isinstance(result, bytes) + + +def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None: + """Test saving TensorFlow Lite model.""" + keras_model = tf.keras.models.load_model(str(test_keras_model)) + + temp_file = tmp_path / "test_model_saving.tflite" + convert_to_tflite(keras_model, output_path=temp_file) + + interpreter = tf.lite.Interpreter(model_path=str(temp_file)) + assert interpreter + + +def test_convert_unknown_model_to_tflite() -> None: + """Test that unknown model type cannot be converted to TensorFlow Lite.""" + with pytest.raises( + ValueError, match="Unable to create TensorFlow Lite converter for 123" + ): + convert_to_tflite(123) + + +@pytest.mark.parametrize( + "convert_options,expected_args,error", + [ + [ + { + "input_path": PosixPath("/in"), + "output_path": PosixPath("/out"), + "quantized": True, + "subprocess": True, + }, + ["/in", "--output", "/out", "--quantize"], + None, + ], + [ + { + "input_path": None, + "output_path": None, + "quantized": True, + "subprocess": False, + }, + [True, None], + None, + ], + [ + { + "input_path": None, + "output_path": PosixPath("/out"), + "quantized": False, + "subprocess": True, + "model": None, + }, + ["/in", "/out"], + "Input path is required", + ], + [ + { + "input_path": PosixPath("/in"), + "output_path": PosixPath("/out"), + "quantized": False, + "subprocess": False, + }, + [False, PosixPath("/out")], + None, + ], + [ + { + "input_path": PosixPath("/in"), + "output_path": PosixPath("/out"), + "quantized": True, + "subprocess": False, + }, + [True, PosixPath("/out")], + None, + ], + [ + { + "input_path": PosixPath("/in"), + "output_path": None, + "quantized": False, + "subprocess": True, + }, + ["/in"], + None, + ], + [ + { + "input_path": PosixPath("/in"), + "output_path": PosixPath("/out"), + "quantized": False, + "subprocess": True, + }, + ["/in", "--output", "/out"], + None, + ], + [ + { + "input_path": PosixPath("/in"), + "output_path": PosixPath("/out"), + "quantized": True, + "subprocess": True, + }, + ["/in", "--output", "/out", "--quantize"], + None, + ], + [ + { + "output_path": PosixPath("/out"), + "quantized": True, + "subprocess": True, + }, + ["/model_path", "--output", "/out", "--quantize"], + None, + ], + ], +) +def test_convert_to_tflite_subprocess( + convert_options: dict, + expected_args: str, + error: str, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test if convert_to_tflite calls the subprocess with the correct args.""" + command_mock = MagicMock() + function_mock = MagicMock() + model_path_str = "/model_path" + monkeypatch.setattr( + "mlia.nn.tensorflow.tflite_convert.command_output", command_mock + ) + + monkeypatch.setattr( + "mlia.nn.tensorflow.tflite_convert._convert_to_tflite", function_mock + ) + + opts = {"model": model_path_str, **convert_options} + + if error: + with pytest.raises(Exception) as exc_info: + convert_to_tflite(**opts) + + assert error in str(exc_info.value) + command_mock.assert_not_called() + function_mock.assert_not_called() + return + + convert_to_tflite(**opts) + + if convert_options["subprocess"]: + command_mock.assert_called_once() + function_mock.assert_not_called() + pyfile = os.path.abspath(tflite_convert.__file__) + assert command_mock.mock_calls[0].args[0].cmd == [ + "python", + pyfile, + *expected_args, + ] + else: + command_mock.assert_not_called() + function_mock.assert_called_once() + args = function_mock.mock_calls[0].args + assert args == (model_path_str, *expected_args) + + +@pytest.mark.parametrize( + "args,expected_convert_args", + [ + ["{}", "{},False,None"], + ["{} --quantize", "{},True,None"], + ["{} --output {}", "{},False,{}"], + ["{} --output {} --quantize", "{},True,{}"], + ], +) +def test_main( + args: str, + expected_convert_args: str, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test main function, the entry point to subprocess mode.""" + mock = MagicMock() + monkeypatch.setattr("mlia.nn.tensorflow.tflite_convert._convert_to_tflite", mock) + + input_path = tmp_path + output_path = tmp_path / "out" + argv = args.format(input_path, output_path).split() + main(argv) + + mock.assert_called_once() + convert_args = mock.mock_calls[0].args + actual = ",".join(str(arg) for arg in convert_args) + expected = expected_convert_args.format(input_path, output_path) + assert actual == expected + + +def test_main_nonexistent_input() -> None: + """Test main with missing input model.""" + with pytest.raises(ValueError) as excinfo: + main(["/missing"]) + assert "Input file doesn't exist: [/missing]" in str(excinfo.value) diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py index dab8b4e..e356a49 100644 --- a/tests/test_nn_tensorflow_utils.py +++ b/tests/test_nn_tensorflow_utils.py @@ -8,43 +8,13 @@ import numpy as np import pytest import tensorflow as tf +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.utils import check_tflite_datatypes -from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import get_tf_tensor_shape from mlia.nn.tensorflow.utils import get_tflite_model_type_map from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_tflite_model -from mlia.nn.tensorflow.utils import representative_dataset from mlia.nn.tensorflow.utils import save_keras_model -from mlia.nn.tensorflow.utils import save_tflite_model - - -def test_generate_representative_dataset() -> None: - """Test function for generating representative dataset.""" - dataset = representative_dataset([1, 3, 3], 5) - data = list(dataset()) - - assert len(data) == 5 - for elem in data: - assert isinstance(elem, list) - assert len(elem) == 1 - - ndarray = elem[0] - assert ndarray.dtype == np.float32 - assert isinstance(ndarray, np.ndarray) - - -def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None: - """Test converting SavedModel to TensorFlow Lite.""" - result = convert_to_tflite(test_tf_model.as_posix()) - assert isinstance(result, bytes) - - -def test_convert_keras_to_tflite(test_keras_model: Path) -> None: - """Test converting Keras model to TensorFlow Lite.""" - keras_model = tf.keras.models.load_model(str(test_keras_model)) - result = convert_to_tflite(keras_model) - assert isinstance(result, bytes) def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None: @@ -62,23 +32,13 @@ def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None: """Test saving TensorFlow Lite model.""" keras_model = tf.keras.models.load_model(str(test_keras_model)) - tflite_model = convert_to_tflite(keras_model) - temp_file = tmp_path / "test_model_saving.tflite" - save_tflite_model(tflite_model, temp_file) + convert_to_tflite(keras_model, output_path=temp_file) interpreter = tf.lite.Interpreter(model_path=str(temp_file)) assert interpreter -def test_convert_unknown_model_to_tflite() -> None: - """Test that unknown model type cannot be converted to TensorFlow Lite.""" - with pytest.raises( - ValueError, match="Unable to create TensorFlow Lite converter for 123" - ): - convert_to_tflite(123) - - @pytest.mark.parametrize( "model_path, expected_result", [ diff --git a/tests/test_target_cortex_a_operators.py b/tests/test_target_cortex_a_operators.py index 56d6c7b..16cdca5 100644 --- a/tests/test_target_cortex_a_operators.py +++ b/tests/test_target_cortex_a_operators.py @@ -6,7 +6,7 @@ from pathlib import Path import pytest import tensorflow as tf -from mlia.nn.tensorflow.utils import convert_to_tflite +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite_bytes from mlia.target.cortex_a.config import CortexAConfiguration from mlia.target.cortex_a.operators import CortexACompatibilityInfo from mlia.target.cortex_a.operators import get_cortex_a_compatibility_info @@ -52,7 +52,7 @@ def test_get_cortex_a_compatibility_info_not_compatible( ] ) keras_model.compile(optimizer="sgd", loss="mean_squared_error") - tflite_model = convert_to_tflite(keras_model, quantized=False) + tflite_model = convert_to_tflite_bytes(keras_model, quantized=False) monkeypatch.setattr( "mlia.nn.tensorflow.tflite_graph.load_tflite", lambda _p: tflite_model -- cgit v1.2.1