diff options
author | Gergely Nagy <gergely.nagy@arm.com> | 2023-11-21 12:29:38 +0000 |
---|---|---|
committer | Gergely Nagy <gergely.nagy@arm.com> | 2023-12-07 17:09:31 +0000 |
commit | 54eec806272b7574a0757c77a913a369a9ecdc70 (patch) | |
tree | 2e6484b857b2a68279a2707dbb21e5c26685f4e2 /tests/test_nn_tensorflow_utils.py | |
parent | 7c50f1d6367186c03a282ac7ecb8fca0f905ba30 (diff) | |
download | mlia-54eec806272b7574a0757c77a913a369a9ecdc70.tar.gz |
MLIA-835 Invalid JSON output
TFLiteConverter was producing log messages in the output that was not
possible to capture and redirect to logging.
The solution/workaround is to run it as a subprocess.
This change required some refactoring around existing invocations of
the converter.
Change-Id: I394bd0d49d36e6686cfcb9d658e4aad05326cb87
Signed-off-by: Gergely Nagy <gergely.nagy@arm.com>
Diffstat (limited to 'tests/test_nn_tensorflow_utils.py')
-rw-r--r-- | tests/test_nn_tensorflow_utils.py | 44 |
1 files changed, 2 insertions, 42 deletions
diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py index dab8b4e..e356a49 100644 --- a/tests/test_nn_tensorflow_utils.py +++ b/tests/test_nn_tensorflow_utils.py @@ -8,43 +8,13 @@ import numpy as np import pytest import tensorflow as tf +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.utils import check_tflite_datatypes -from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import get_tf_tensor_shape from mlia.nn.tensorflow.utils import get_tflite_model_type_map from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_tflite_model -from mlia.nn.tensorflow.utils import representative_dataset from mlia.nn.tensorflow.utils import save_keras_model -from mlia.nn.tensorflow.utils import save_tflite_model - - -def test_generate_representative_dataset() -> None: - """Test function for generating representative dataset.""" - dataset = representative_dataset([1, 3, 3], 5) - data = list(dataset()) - - assert len(data) == 5 - for elem in data: - assert isinstance(elem, list) - assert len(elem) == 1 - - ndarray = elem[0] - assert ndarray.dtype == np.float32 - assert isinstance(ndarray, np.ndarray) - - -def test_convert_saved_model_to_tflite(test_tf_model: Path) -> None: - """Test converting SavedModel to TensorFlow Lite.""" - result = convert_to_tflite(test_tf_model.as_posix()) - assert isinstance(result, bytes) - - -def test_convert_keras_to_tflite(test_keras_model: Path) -> None: - """Test converting Keras model to TensorFlow Lite.""" - keras_model = tf.keras.models.load_model(str(test_keras_model)) - result = convert_to_tflite(keras_model) - assert isinstance(result, bytes) def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None: @@ -62,23 +32,13 @@ def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None: """Test saving TensorFlow Lite model.""" keras_model = tf.keras.models.load_model(str(test_keras_model)) - tflite_model = convert_to_tflite(keras_model) - temp_file = tmp_path / "test_model_saving.tflite" - save_tflite_model(tflite_model, temp_file) + convert_to_tflite(keras_model, output_path=temp_file) interpreter = tf.lite.Interpreter(model_path=str(temp_file)) assert interpreter -def test_convert_unknown_model_to_tflite() -> None: - """Test that unknown model type cannot be converted to TensorFlow Lite.""" - with pytest.raises( - ValueError, match="Unable to create TensorFlow Lite converter for 123" - ): - convert_to_tflite(123) - - @pytest.mark.parametrize( "model_path, expected_result", [ |