aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGergely Nagy <gergely.nagy@arm.com>2023-11-21 12:29:38 +0000
committerGergely Nagy <gergely.nagy@arm.com>2023-12-07 17:09:31 +0000
commit54eec806272b7574a0757c77a913a369a9ecdc70 (patch)
tree2e6484b857b2a68279a2707dbb21e5c26685f4e2
parent7c50f1d6367186c03a282ac7ecb8fca0f905ba30 (diff)
downloadmlia-54eec806272b7574a0757c77a913a369a9ecdc70.tar.gz
MLIA-835 Invalid JSON output
TFLiteConverter was producing log messages in the output that was not possible to capture and redirect to logging. The solution/workaround is to run it as a subprocess. This change required some refactoring around existing invocations of the converter. Change-Id: I394bd0d49d36e6686cfcb9d658e4aad05326cb87 Signed-off-by: Gergely Nagy <gergely.nagy@arm.com>
-rw-r--r--src/mlia/nn/rewrite/core/train.py8
-rw-r--r--src/mlia/nn/tensorflow/config.py20
-rw-r--r--src/mlia/nn/tensorflow/tflite_compat.py7
-rw-r--r--src/mlia/nn/tensorflow/tflite_convert.py167
-rw-r--r--src/mlia/nn/tensorflow/utils.py59
-rw-r--r--tests/conftest.py11
-rw-r--r--tests/test_nn_tensorflow_optimizations_clustering.py6
-rw-r--r--tests/test_nn_tensorflow_optimizations_pruning.py7
-rw-r--r--tests/test_nn_tensorflow_tflite_compat.py4
-rw-r--r--tests/test_nn_tensorflow_tflite_convert.py244
-rw-r--r--tests/test_nn_tensorflow_utils.py44
-rw-r--r--tests/test_target_cortex_a_operators.py4
12 files changed, 445 insertions, 136 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 6345f07..72b8f48 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -34,9 +34,9 @@ from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
-from mlia.nn.tensorflow.utils import get_tflite_converter
from mlia.utils.logging import log_action
@@ -499,11 +499,7 @@ def save_as_tflite(
keras_model.input.set_shape(orig_shape)
with fixed_input(keras_model, input_shape) as fixed_model:
- converter = get_tflite_converter(fixed_model, quantized=model_is_quantized)
- tflite_model = converter.convert()
-
- with open(filename, "wb") as file:
- file.write(tflite_model)
+ convert_to_tflite(fixed_model, model_is_quantized, Path(filename))
# Now fix the shapes and names to match those we expect
flatbuffer = load_fb(filename)
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index b94350a..0a17977 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -21,14 +21,13 @@ from mlia.nn.tensorflow.optimizations.quantization import dequantize
from mlia.nn.tensorflow.optimizations.quantization import is_quantized
from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters
from mlia.nn.tensorflow.optimizations.quantization import quantize
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_graph import load_fb
from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.nn.tensorflow.utils import check_tflite_datatypes
-from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_saved_model
from mlia.nn.tensorflow.utils import is_tflite_model
-from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.utils.logging import log_action
logger = logging.getLogger(__name__)
@@ -67,9 +66,14 @@ class KerasModel(ModelConfiguration):
) -> TFLiteModel:
"""Convert model to TensorFlow Lite format."""
with log_action("Converting Keras to TensorFlow Lite ..."):
- converted_model = convert_to_tflite(self.get_keras_model(), quantized)
+ convert_to_tflite(
+ self.get_keras_model(),
+ quantized,
+ input_path=Path(self.model_path),
+ output_path=Path(tflite_model_path),
+ subprocess=True,
+ )
- save_tflite_model(converted_model, tflite_model_path)
logger.debug(
"Model %s converted and saved to %s", self.model_path, tflite_model_path
)
@@ -270,8 +274,12 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
self, tflite_model_path: str | Path, quantized: bool = False
) -> TFLiteModel:
"""Convert model to TensorFlow Lite format."""
- converted_model = convert_to_tflite(self.model_path, quantized)
- save_tflite_model(converted_model, tflite_model_path)
+ convert_to_tflite(
+ self.model_path,
+ quantized,
+ input_path=Path(self.model_path),
+ output_path=Path(tflite_model_path),
+ )
return TFLiteModel(tflite_model_path)
diff --git a/src/mlia/nn/tensorflow/tflite_compat.py b/src/mlia/nn/tensorflow/tflite_compat.py
index 2b29879..497d5b1 100644
--- a/src/mlia/nn/tensorflow/tflite_compat.py
+++ b/src/mlia/nn/tensorflow/tflite_compat.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Functions for checking TensorFlow Lite compatibility."""
from __future__ import annotations
@@ -14,7 +14,7 @@ from typing import List
import tensorflow as tf
from tensorflow.lite.python import convert
-from mlia.nn.tensorflow.utils import get_tflite_converter
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.utils.logging import redirect_raw_output
TF_VERSION_MAJOR, TF_VERSION_MINOR, _ = (int(s) for s in tf.version.VERSION.split("."))
@@ -115,7 +115,6 @@ class TFLiteChecker:
"""Check TensorFlow Lite compatibility for the provided model."""
try:
logger.debug("Check TensorFlow Lite compatibility for %s", model)
- converter = get_tflite_converter(model, quantized=self.quantized)
# there is an issue with intercepting TensorFlow output
# not all output could be captured, for now just intercept
@@ -123,7 +122,7 @@ class TFLiteChecker:
with redirect_raw_output(
logging.getLogger("tensorflow"), stdout_level=None
):
- converter.convert()
+ convert_to_tflite(model, self.quantized)
except convert.ConverterError as err:
return self._process_convert_error(err)
except Exception as err: # pylint: disable=broad-except
diff --git a/src/mlia/nn/tensorflow/tflite_convert.py b/src/mlia/nn/tensorflow/tflite_convert.py
new file mode 100644
index 0000000..d3a833a
--- /dev/null
+++ b/src/mlia/nn/tensorflow/tflite_convert.py
@@ -0,0 +1,167 @@
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Support module to call TFLiteConverter."""
+from __future__ import annotations
+
+import argparse
+import logging
+import sys
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Iterable
+
+import numpy as np
+import tensorflow as tf
+
+from mlia.nn.tensorflow.utils import get_tf_tensor_shape
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.nn.tensorflow.utils import is_saved_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+from mlia.utils.logging import redirect_output
+from mlia.utils.proc import Command
+from mlia.utils.proc import command_output
+
+logger = logging.getLogger(__name__)
+
+
+def representative_dataset(
+ input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
+) -> Callable:
+ """Sample dataset used for quantization."""
+
+ def dataset() -> Iterable:
+ for _ in range(sample_count):
+ data = np.random.rand(1, *input_shape[1:])
+ yield [data.astype(input_dtype)]
+
+ return dataset
+
+
+def get_tflite_converter(
+ model: tf.keras.Model | str | Path, quantized: bool = False
+) -> tf.lite.TFLiteConverter:
+ """Configure TensorFlow Lite converter for the provided model."""
+ if isinstance(model, (str, Path)):
+ # converter's methods accept string as input parameter
+ model = str(model)
+
+ if isinstance(model, tf.keras.Model):
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ input_shape = model.input_shape
+ elif isinstance(model, str) and is_saved_model(model):
+ converter = tf.lite.TFLiteConverter.from_saved_model(model)
+ input_shape = get_tf_tensor_shape(model)
+ elif isinstance(model, str) and is_keras_model(model):
+ keras_model = tf.keras.models.load_model(model)
+ input_shape = keras_model.input_shape
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+ else:
+ raise ValueError(f"Unable to create TensorFlow Lite converter for {model}")
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset(input_shape)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ return converter
+
+
+def convert_to_tflite_bytes(
+ model: tf.keras.Model | str, quantized: bool = False
+) -> bytes:
+ """Convert Keras model to TensorFlow Lite."""
+ converter = get_tflite_converter(model, quantized)
+
+ with redirect_output(logging.getLogger("tensorflow")):
+ output_bytes = cast(bytes, converter.convert())
+
+ return output_bytes
+
+
+def _convert_to_tflite(
+ model: tf.keras.Model | str,
+ quantized: bool = False,
+ output_path: Path | None = None,
+) -> bytes:
+ """Convert Keras model to TensorFlow Lite."""
+ output_bytes = convert_to_tflite_bytes(model, quantized)
+
+ if output_path:
+ save_tflite_model(output_bytes, output_path)
+
+ return output_bytes
+
+
+def convert_to_tflite(
+ model: tf.keras.Model | str,
+ quantized: bool = False,
+ output_path: Path | None = None,
+ input_path: Path | None = None,
+ subprocess: bool = False,
+) -> None:
+ """Convert Keras model to TensorFlow Lite.
+
+ Optionally runs TFLiteConverter in a subprocess,
+ this is added mainly to work around issues when redirecting
+ Tensorflow's output using SDK calls, didn't make an effect,
+ which would produce unwanted output for MLIA.
+
+ In the subprocess mode, the model should be passed as a
+ file path, or via a dedicated 'input_path' parameter.
+
+ If 'output_path' is provided, the result model be saved under
+ that path.
+ """
+ if not subprocess:
+ _convert_to_tflite(model, quantized, output_path)
+ return
+
+ if input_path is None:
+ if isinstance(model, str):
+ input_path = Path(model)
+ else:
+ raise RuntimeError(
+ f"Input path is required for {model}"
+ " when converter is called in subprocess."
+ )
+
+ args = ["python", __file__, str(input_path)]
+ if output_path:
+ args.append("--output")
+ args.append(str(output_path))
+ if quantized:
+ args.append("--quantize")
+
+ command = Command(args)
+
+ for line in command_output(command):
+ logger.debug("TFLiteConverter: %s", line)
+
+
+def main(argv: list[str] | None = None) -> int:
+ """Entry point to run this module as a standalone executable."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input", type=Path)
+ parser.add_argument("--output", type=Path, default=None)
+ parser.add_argument("--quantize", default=False, action="store_true")
+ args = parser.parse_args(argv)
+
+ if not Path(args.input).exists():
+ raise ValueError(f"Input file doesn't exist: [{args.input}]")
+
+ logger.debug(
+ "Invoking TFLiteConverter on [%s] -> [%s], quantize: [%s]",
+ args.input,
+ args.output,
+ args.quantize,
+ )
+ _convert_to_tflite(args.input, args.quantize, args.output)
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index b8d45c6..1612447 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -4,31 +4,11 @@
"""Collection of useful functions for optimizations."""
from __future__ import annotations
-import logging
from pathlib import Path
from typing import Any
-from typing import Callable
-from typing import cast
-from typing import Iterable
-import numpy as np
import tensorflow as tf
-from mlia.utils.logging import redirect_output
-
-
-def representative_dataset(
- input_shape: Any, sample_count: int = 100, input_dtype: type = np.float32
-) -> Callable:
- """Sample dataset used for quantization."""
-
- def dataset() -> Iterable:
- for _ in range(sample_count):
- data = np.random.rand(1, *input_shape[1:])
- yield [data.astype(input_dtype)]
-
- return dataset
-
def get_tf_tensor_shape(model: str) -> list:
"""Get input shape for the TensorFlow tensor model."""
@@ -49,14 +29,6 @@ def get_tf_tensor_shape(model: str) -> list:
]
-def convert_to_tflite(model: tf.keras.Model | str, quantized: bool = False) -> bytes:
- """Convert Keras model to TensorFlow Lite."""
- converter = get_tflite_converter(model, quantized)
-
- with redirect_output(logging.getLogger("tensorflow")):
- return cast(bytes, converter.convert())
-
-
def save_keras_model(
model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True
) -> None:
@@ -94,37 +66,6 @@ def is_saved_model(model: str | Path) -> bool:
return model_path.is_dir() and not is_keras_model(model)
-def get_tflite_converter(
- model: tf.keras.Model | str | Path, quantized: bool = False
-) -> tf.lite.TFLiteConverter:
- """Configure TensorFlow Lite converter for the provided model."""
- if isinstance(model, (str, Path)):
- # converter's methods accept string as input parameter
- model = str(model)
-
- if isinstance(model, tf.keras.Model):
- converter = tf.lite.TFLiteConverter.from_keras_model(model)
- input_shape = model.input_shape
- elif isinstance(model, str) and is_saved_model(model):
- converter = tf.lite.TFLiteConverter.from_saved_model(model)
- input_shape = get_tf_tensor_shape(model)
- elif isinstance(model, str) and is_keras_model(model):
- keras_model = tf.keras.models.load_model(model)
- input_shape = keras_model.input_shape
- converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
- else:
- raise ValueError(f"Unable to create TensorFlow Lite converter for {model}")
-
- if quantized:
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset(input_shape)
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
-
- return converter
-
-
def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]:
"""Get type map from tflite model."""
model_type_map: dict[str, Any] = {}
diff --git a/tests/conftest.py b/tests/conftest.py
index d700206..345eb8d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -14,9 +14,8 @@ import tensorflow as tf
from mlia.backend.vela.compiler import optimize_model
from mlia.core.context import ExecutionContext
from mlia.nn.rewrite.core.utils.numpy_tfrecord import NumpyTFWriter
-from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.utils import save_keras_model
-from mlia.nn.tensorflow.utils import save_tflite_model
from mlia.target.ethos_u.config import EthosUConfiguration
from tests.utils.rewrite import MockTrainingParameters
@@ -93,15 +92,13 @@ def fixture_test_models_path(
save_keras_model(keras_model, tmp_path / TEST_MODEL_KERAS_FILE)
# Un-quantized TensorFlow Lite model (fp32)
- save_tflite_model(
- convert_to_tflite(keras_model, quantized=False),
- tmp_path / TEST_MODEL_TFLITE_FP32_FILE,
+ convert_to_tflite(
+ keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE
)
# Quantized TensorFlow Lite model (int8)
- tflite_model = convert_to_tflite(keras_model, quantized=True)
tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE
- save_tflite_model(tflite_model, tflite_model_path)
+ convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path)
# Vela-optimized TensorFlow Lite model (int8)
tflite_vela_model = tmp_path / TEST_MODEL_TFLITE_VELA_FILE
diff --git a/tests/test_nn_tensorflow_optimizations_clustering.py b/tests/test_nn_tensorflow_optimizations_clustering.py
index d3c0da6..58ffb3e 100644
--- a/tests/test_nn_tensorflow_optimizations_clustering.py
+++ b/tests/test_nn_tensorflow_optimizations_clustering.py
@@ -14,10 +14,9 @@ from mlia.nn.tensorflow.optimizations.clustering import Clusterer
from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
from mlia.nn.tensorflow.optimizations.pruning import Pruner
from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode
from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
-from mlia.nn.tensorflow.utils import convert_to_tflite
-from mlia.nn.tensorflow.utils import save_tflite_model
from tests.utils.common import get_dataset
from tests.utils.common import train_model
@@ -118,8 +117,7 @@ def test_cluster_simple_model_fully(
clustered_model = clusterer.get_model()
temp_file = tmp_path / "test_cluster_simple_model_fully_after.tflite"
- tflite_clustered_model = convert_to_tflite(clustered_model)
- save_tflite_model(tflite_clustered_model, temp_file)
+ convert_to_tflite(clustered_model, output_path=temp_file)
clustered_tflite_metrics = TFLiteMetrics(str(temp_file))
_test_num_unique_weights(
diff --git a/tests/test_nn_tensorflow_optimizations_pruning.py b/tests/test_nn_tensorflow_optimizations_pruning.py
index d97b3d3..9afc3ff 100644
--- a/tests/test_nn_tensorflow_optimizations_pruning.py
+++ b/tests/test_nn_tensorflow_optimizations_pruning.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Test for module optimizations/pruning."""
from __future__ import annotations
@@ -11,9 +11,8 @@ from numpy.core.numeric import isclose
from mlia.nn.tensorflow.optimizations.pruning import Pruner
from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
-from mlia.nn.tensorflow.utils import convert_to_tflite
-from mlia.nn.tensorflow.utils import save_tflite_model
from tests.utils.common import get_dataset
from tests.utils.common import train_model
@@ -52,7 +51,7 @@ def _get_tflite_metrics(
) -> TFLiteMetrics:
"""Save model as TFLiteModel and return metrics."""
temp_file = path / tflite_fn
- save_tflite_model(convert_to_tflite(model), temp_file)
+ convert_to_tflite(model, output_path=temp_file)
return TFLiteMetrics(str(temp_file))
diff --git a/tests/test_nn_tensorflow_tflite_compat.py b/tests/test_nn_tensorflow_tflite_compat.py
index f203125..4ca387c 100644
--- a/tests/test_nn_tensorflow_tflite_compat.py
+++ b/tests/test_nn_tensorflow_tflite_compat.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for tflite_compat module."""
from __future__ import annotations
@@ -219,7 +219,7 @@ def test_tflite_compatibility(
converter_mock.convert.side_effect = conversion_error
monkeypatch.setattr(
- "mlia.nn.tensorflow.tflite_compat.get_tflite_converter",
+ "mlia.nn.tensorflow.tflite_convert.get_tflite_converter",
lambda *args, **kwargs: converter_mock,
)
diff --git a/tests/test_nn_tensorflow_tflite_convert.py b/tests/test_nn_tensorflow_tflite_convert.py
new file mode 100644
index 0000000..3125c04
--- /dev/null
+++ b/tests/test_nn_tensorflow_tflite_convert.py
@@ -0,0 +1,244 @@
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module utils/test_utils."""
+import os
+from pathlib import Path
+from pathlib import PosixPath
+from unittest.mock import MagicMock
+
+import numpy as np
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow import tflite_convert
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite_bytes
+from mlia.nn.tensorflow.tflite_convert import main
+from mlia.nn.tensorflow.tflite_convert import representative_dataset
+
+
+def test_generate_representative_dataset() -> None:
+ """Test function for generating representative dataset."""
+ dataset = representative_dataset([1, 3, 3], 5)
+ data = list(dataset())
+
+ assert len(data) == 5
+ for elem in data:
+ assert isinstance(elem, list)
+ assert len(elem) == 1
+
+ ndarray = elem[0]
+ assert ndarray.dtype == np.float32
+ assert isinstance(ndarray, np.ndarray)
+
+
+def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None:
+ """Test converting SavedModel to TensorFlow Lite."""
+ result = convert_to_tflite_bytes(test_tf_model.as_posix())
+ assert isinstance(result, bytes)
+
+
+def test_convert_keras_to_tflite(test_keras_model: Path) -> None:
+ """Test converting Keras model to TensorFlow Lite."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+ result = convert_to_tflite_bytes(keras_model)
+ assert isinstance(result, bytes)
+
+
+def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None:
+ """Test saving TensorFlow Lite model."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+
+ temp_file = tmp_path / "test_model_saving.tflite"
+ convert_to_tflite(keras_model, output_path=temp_file)
+
+ interpreter = tf.lite.Interpreter(model_path=str(temp_file))
+ assert interpreter
+
+
+def test_convert_unknown_model_to_tflite() -> None:
+ """Test that unknown model type cannot be converted to TensorFlow Lite."""
+ with pytest.raises(
+ ValueError, match="Unable to create TensorFlow Lite converter for 123"
+ ):
+ convert_to_tflite(123)
+
+
+@pytest.mark.parametrize(
+ "convert_options,expected_args,error",
+ [
+ [
+ {
+ "input_path": PosixPath("/in"),
+ "output_path": PosixPath("/out"),
+ "quantized": True,
+ "subprocess": True,
+ },
+ ["/in", "--output", "/out", "--quantize"],
+ None,
+ ],
+ [
+ {
+ "input_path": None,
+ "output_path": None,
+ "quantized": True,
+ "subprocess": False,
+ },
+ [True, None],
+ None,
+ ],
+ [
+ {
+ "input_path": None,
+ "output_path": PosixPath("/out"),
+ "quantized": False,
+ "subprocess": True,
+ "model": None,
+ },
+ ["/in", "/out"],
+ "Input path is required",
+ ],
+ [
+ {
+ "input_path": PosixPath("/in"),
+ "output_path": PosixPath("/out"),
+ "quantized": False,
+ "subprocess": False,
+ },
+ [False, PosixPath("/out")],
+ None,
+ ],
+ [
+ {
+ "input_path": PosixPath("/in"),
+ "output_path": PosixPath("/out"),
+ "quantized": True,
+ "subprocess": False,
+ },
+ [True, PosixPath("/out")],
+ None,
+ ],
+ [
+ {
+ "input_path": PosixPath("/in"),
+ "output_path": None,
+ "quantized": False,
+ "subprocess": True,
+ },
+ ["/in"],
+ None,
+ ],
+ [
+ {
+ "input_path": PosixPath("/in"),
+ "output_path": PosixPath("/out"),
+ "quantized": False,
+ "subprocess": True,
+ },
+ ["/in", "--output", "/out"],
+ None,
+ ],
+ [
+ {
+ "input_path": PosixPath("/in"),
+ "output_path": PosixPath("/out"),
+ "quantized": True,
+ "subprocess": True,
+ },
+ ["/in", "--output", "/out", "--quantize"],
+ None,
+ ],
+ [
+ {
+ "output_path": PosixPath("/out"),
+ "quantized": True,
+ "subprocess": True,
+ },
+ ["/model_path", "--output", "/out", "--quantize"],
+ None,
+ ],
+ ],
+)
+def test_convert_to_tflite_subprocess(
+ convert_options: dict,
+ expected_args: str,
+ error: str,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test if convert_to_tflite calls the subprocess with the correct args."""
+ command_mock = MagicMock()
+ function_mock = MagicMock()
+ model_path_str = "/model_path"
+ monkeypatch.setattr(
+ "mlia.nn.tensorflow.tflite_convert.command_output", command_mock
+ )
+
+ monkeypatch.setattr(
+ "mlia.nn.tensorflow.tflite_convert._convert_to_tflite", function_mock
+ )
+
+ opts = {"model": model_path_str, **convert_options}
+
+ if error:
+ with pytest.raises(Exception) as exc_info:
+ convert_to_tflite(**opts)
+
+ assert error in str(exc_info.value)
+ command_mock.assert_not_called()
+ function_mock.assert_not_called()
+ return
+
+ convert_to_tflite(**opts)
+
+ if convert_options["subprocess"]:
+ command_mock.assert_called_once()
+ function_mock.assert_not_called()
+ pyfile = os.path.abspath(tflite_convert.__file__)
+ assert command_mock.mock_calls[0].args[0].cmd == [
+ "python",
+ pyfile,
+ *expected_args,
+ ]
+ else:
+ command_mock.assert_not_called()
+ function_mock.assert_called_once()
+ args = function_mock.mock_calls[0].args
+ assert args == (model_path_str, *expected_args)
+
+
+@pytest.mark.parametrize(
+ "args,expected_convert_args",
+ [
+ ["{}", "{},False,None"],
+ ["{} --quantize", "{},True,None"],
+ ["{} --output {}", "{},False,{}"],
+ ["{} --output {} --quantize", "{},True,{}"],
+ ],
+)
+def test_main(
+ args: str,
+ expected_convert_args: str,
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test main function, the entry point to subprocess mode."""
+ mock = MagicMock()
+ monkeypatch.setattr("mlia.nn.tensorflow.tflite_convert._convert_to_tflite", mock)
+
+ input_path = tmp_path
+ output_path = tmp_path / "out"
+ argv = args.format(input_path, output_path).split()
+ main(argv)
+
+ mock.assert_called_once()
+ convert_args = mock.mock_calls[0].args
+ actual = ",".join(str(arg) for arg in convert_args)
+ expected = expected_convert_args.format(input_path, output_path)
+ assert actual == expected
+
+
+def test_main_nonexistent_input() -> None:
+ """Test main with missing input model."""
+ with pytest.raises(ValueError) as excinfo:
+ main(["/missing"])
+ assert "Input file doesn't exist: [/missing]" in str(excinfo.value)
diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py
index dab8b4e..e356a49 100644
--- a/tests/test_nn_tensorflow_utils.py
+++ b/tests/test_nn_tensorflow_utils.py
@@ -8,43 +8,13 @@ import numpy as np
import pytest
import tensorflow as tf
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite
from mlia.nn.tensorflow.utils import check_tflite_datatypes
-from mlia.nn.tensorflow.utils import convert_to_tflite
from mlia.nn.tensorflow.utils import get_tf_tensor_shape
from mlia.nn.tensorflow.utils import get_tflite_model_type_map
from mlia.nn.tensorflow.utils import is_keras_model
from mlia.nn.tensorflow.utils import is_tflite_model
-from mlia.nn.tensorflow.utils import representative_dataset
from mlia.nn.tensorflow.utils import save_keras_model
-from mlia.nn.tensorflow.utils import save_tflite_model
-
-
-def test_generate_representative_dataset() -> None:
- """Test function for generating representative dataset."""
- dataset = representative_dataset([1, 3, 3], 5)
- data = list(dataset())
-
- assert len(data) == 5
- for elem in data:
- assert isinstance(elem, list)
- assert len(elem) == 1
-
- ndarray = elem[0]
- assert ndarray.dtype == np.float32
- assert isinstance(ndarray, np.ndarray)
-
-
-def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None:
- """Test converting SavedModel to TensorFlow Lite."""
- result = convert_to_tflite(test_tf_model.as_posix())
- assert isinstance(result, bytes)
-
-
-def test_convert_keras_to_tflite(test_keras_model: Path) -> None:
- """Test converting Keras model to TensorFlow Lite."""
- keras_model = tf.keras.models.load_model(str(test_keras_model))
- result = convert_to_tflite(keras_model)
- assert isinstance(result, bytes)
def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None:
@@ -62,23 +32,13 @@ def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None:
"""Test saving TensorFlow Lite model."""
keras_model = tf.keras.models.load_model(str(test_keras_model))
- tflite_model = convert_to_tflite(keras_model)
-
temp_file = tmp_path / "test_model_saving.tflite"
- save_tflite_model(tflite_model, temp_file)
+ convert_to_tflite(keras_model, output_path=temp_file)
interpreter = tf.lite.Interpreter(model_path=str(temp_file))
assert interpreter
-def test_convert_unknown_model_to_tflite() -> None:
- """Test that unknown model type cannot be converted to TensorFlow Lite."""
- with pytest.raises(
- ValueError, match="Unable to create TensorFlow Lite converter for 123"
- ):
- convert_to_tflite(123)
-
-
@pytest.mark.parametrize(
"model_path, expected_result",
[
diff --git a/tests/test_target_cortex_a_operators.py b/tests/test_target_cortex_a_operators.py
index 56d6c7b..16cdca5 100644
--- a/tests/test_target_cortex_a_operators.py
+++ b/tests/test_target_cortex_a_operators.py
@@ -6,7 +6,7 @@ from pathlib import Path
import pytest
import tensorflow as tf
-from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.tflite_convert import convert_to_tflite_bytes
from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.target.cortex_a.operators import CortexACompatibilityInfo
from mlia.target.cortex_a.operators import get_cortex_a_compatibility_info
@@ -52,7 +52,7 @@ def test_get_cortex_a_compatibility_info_not_compatible(
]
)
keras_model.compile(optimizer="sgd", loss="mean_squared_error")
- tflite_model = convert_to_tflite(keras_model, quantized=False)
+ tflite_model = convert_to_tflite_bytes(keras_model, quantized=False)
monkeypatch.setattr(
"mlia.nn.tensorflow.tflite_graph.load_tflite", lambda _p: tflite_model