diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/conftest.py | 39 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_graph_edit_join.py | 6 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_graph_edit_record.py | 4 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 6 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 76 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_utils.py | 33 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_utils_numpy_tfrecord.py | 25 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_config.py | 40 | ||||
-rw-r--r-- | tests/test_nn_tensorflow_tflite_graph.py | 30 | ||||
-rw-r--r-- | tests/utils/rewrite.py | 18 |
10 files changed, 189 insertions, 88 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index c42b8cb..bb2423f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import shutil from pathlib import Path from typing import Callable from typing import Generator +from unittest.mock import MagicMock import numpy as np import pytest @@ -17,6 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import save_keras_model from mlia.nn.tensorflow.utils import save_tflite_model from mlia.target.ethos_u.config import EthosUConfiguration +from tests.utils.rewrite import TestTrainingParameters @pytest.fixture(scope="session", name="test_resources_path") @@ -168,16 +170,12 @@ def _write_tfrecord( writer.write({input_name: data_generator()}) -@pytest.fixture(scope="session", name="test_tfrecord") -def fixture_test_tfrecord( - tmp_path_factory: pytest.TempPathFactory, +def create_tfrecord( + tmp_path_factory: pytest.TempPathFactory, random_data: Callable ) -> Generator[Path, None, None]: """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" tmp_path = tmp_path_factory.mktemp("tfrecords") - tfrecord_file = tmp_path / "test_int8.tfrecord" - - def random_data() -> np.ndarray: - return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) + tfrecord_file = tmp_path / "test.tfrecord" _write_tfrecord(tfrecord_file, random_data) @@ -186,19 +184,36 @@ def fixture_test_tfrecord( shutil.rmtree(tmp_path) +@pytest.fixture(scope="session", name="test_tfrecord") +def fixture_test_tfrecord( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" + + def random_data() -> np.ndarray: + return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) + + yield from create_tfrecord(tmp_path_factory, random_data) + + @pytest.fixture(scope="session", name="test_tfrecord_fp32") def fixture_test_tfrecord_fp32( tmp_path_factory: pytest.TempPathFactory, ) -> Generator[Path, None, None]: """Create tfrecord with random data matching fixture 'test_tflite_model_fp32'.""" - tmp_path = tmp_path_factory.mktemp("tfrecords") - tfrecord_file = tmp_path / "test_fp32.tfrecord" def random_data() -> np.ndarray: return np.random.rand(1, 28, 28, 1).astype(np.float32) - _write_tfrecord(tfrecord_file, random_data) + yield from create_tfrecord(tmp_path_factory, random_data) - yield tfrecord_file - shutil.rmtree(tmp_path) +@pytest.fixture(scope="session", autouse=True) +def set_training_steps() -> Generator[None, None, None]: + """Speed up tests by using TestTrainingParameters.""" + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "mlia.nn.select._get_rewrite_train_params", + MagicMock(return_value=TestTrainingParameters()), + ) + yield diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py index cb3e4e2..0cb121e 100644 --- a/tests/test_nn_rewrite_core_graph_edit_join.py +++ b/tests/test_nn_rewrite_core_graph_edit_join.py @@ -10,7 +10,7 @@ import pytest from mlia.nn.rewrite.core.graph_edit.cut import cut_model from mlia.nn.rewrite.core.graph_edit.join import append_relabel from mlia.nn.rewrite.core.graph_edit.join import join_models -from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.tensorflow.tflite_graph import load_fb from tests.utils.rewrite import models_are_equal @@ -49,8 +49,8 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None: ) assert joined_file.is_file() - orig_model = load(str(test_tflite_model)) - joined_model = load(str(joined_file)) + orig_model = load_fb(str(test_tflite_model)) + joined_model = load_fb(str(joined_file)) assert models_are_equal(orig_model, joined_model) diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py index cd728af..41b9c50 100644 --- a/tests/test_nn_rewrite_core_graph_edit_record.py +++ b/tests/test_nn_rewrite_core_graph_edit_record.py @@ -3,15 +3,13 @@ """Tests for module mlia.nn.rewrite.graph_edit.record.""" from pathlib import Path -import pytest import tensorflow as tf from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read -@pytest.mark.parametrize("batch_size", (None, 1, 2)) -def test_record_model( +def check_record_model( test_tflite_model: Path, tmp_path: Path, test_tfrecord: Path, diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index b98971e..2542db2 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -12,6 +12,7 @@ import pytest from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import Rewriter from mlia.nn.tensorflow.config import TFLiteModel +from tests.utils.rewrite import TestTrainingParameters @pytest.mark.parametrize( @@ -32,12 +33,14 @@ def test_rewrite_configuration( None, ) + assert config_obj.optimization_target in str(config_obj) + rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj) assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name assert isinstance(rewriter_obj, Rewriter) -def test_rewriter( +def test_rewriting_optimizer( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, ) -> None: @@ -46,6 +49,7 @@ def test_rewriter( "fully_connected", ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], test_tfrecord_fp32, + train_params=TestTrainingParameters(), ) test_obj = Rewriter(test_tflite_model_fp32, config_obj) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 3c2ef3e..4493671 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -4,6 +4,7 @@ # pylint: disable=too-many-arguments from __future__ import annotations +from contextlib import ExitStack as does_not_raise from pathlib import Path from tempfile import TemporaryDirectory from typing import Any @@ -12,10 +13,13 @@ import numpy as np import pytest import tensorflow as tf -from mlia.nn.rewrite.core.train import augmentation_presets +from mlia.nn.rewrite.core.train import augment_fn_twins +from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train +from mlia.nn.rewrite.core.train import TrainingParameters +from tests.utils.rewrite import TestTrainingParameters def replace_fully_connected_with_conv( @@ -41,15 +45,8 @@ def replace_fully_connected_with_conv( def check_train( tflite_model: Path, tfrecord: Path, - batch_size: int = 1, - verbose: bool = False, - show_progress: bool = False, - augmentation_preset: tuple[float | None, float | None] = augmentation_presets[ - "none" - ], - lr_schedule: LearningRateSchedule = "cosine", + train_params: TrainingParameters = TestTrainingParameters(), use_unmodified_model: bool = False, - num_procs: int = 1, ) -> None: """Test the train() function.""" with TemporaryDirectory() as tmp_dir: @@ -62,14 +59,7 @@ def check_train( replace_fn=replace_fully_connected_with_conv, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], - augment=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=batch_size, - verbose=verbose, - show_progress=show_progress, - learning_rate_schedule=lr_schedule, - num_procs=num_procs, + train_params=train_params, ) assert len(result) == 2 assert all(res >= 0.0 for res in result), f"Results out of bound: {result}" @@ -79,7 +69,6 @@ def check_train( @pytest.mark.parametrize( ( "batch_size", - "verbose", "show_progress", "augmentation_preset", "lr_schedule", @@ -87,14 +76,13 @@ def check_train( "num_procs", ), ( - (1, False, False, augmentation_presets["none"], "cosine", False, 2), - (32, True, True, augmentation_presets["gaussian"], "late", True, 1), - (2, False, False, augmentation_presets["mixup"], "constant", True, 0), + (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2), + (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1), + (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0), ( 1, False, - False, - augmentation_presets["mix_gaussian_large"], + AUGMENTATION_PRESETS["mix_gaussian_large"], "cosine", False, 2, @@ -105,7 +93,6 @@ def test_train( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, batch_size: int, - verbose: bool, show_progress: bool, augmentation_preset: tuple[float | None, float | None], lr_schedule: LearningRateSchedule, @@ -116,13 +103,14 @@ def test_train( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - batch_size=batch_size, - verbose=verbose, - show_progress=show_progress, - augmentation_preset=augmentation_preset, - lr_schedule=lr_schedule, + train_params=TestTrainingParameters( + batch_size=batch_size, + show_progress=show_progress, + augmentations=augmentation_preset, + learning_rate_schedule=lr_schedule, + num_procs=num_procs, + ), use_unmodified_model=use_unmodified_model, - num_procs=num_procs, ) @@ -131,11 +119,13 @@ def test_train_invalid_schedule( test_tfrecord_fp32: Path, ) -> None: """Test the train() function with an invalid schedule.""" - with pytest.raises(ValueError): + with pytest.raises(AssertionError): check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - lr_schedule="unknown_schedule", # type: ignore + train_params=TestTrainingParameters( + learning_rate_schedule="unknown_schedule", + ), ) @@ -144,11 +134,13 @@ def test_train_invalid_augmentation( test_tfrecord_fp32: Path, ) -> None: """Test the train() function with an invalid augmentation.""" - with pytest.raises(ValueError): + with pytest.raises(AssertionError): check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - augmentation_preset=(1.0, 2.0, 3.0), # type: ignore + train_params=TestTrainingParameters( + augmentations=(1.0, 2.0, 3.0), + ), ) @@ -159,3 +151,19 @@ def test_mixup() -> None: assert src.shape == dst.shape assert np.all(dst >= 0.0) assert np.all(dst <= 3.0) + + +@pytest.mark.parametrize( + "augmentations, expected_error", + [ + (AUGMENTATION_PRESETS["none"], does_not_raise()), + (AUGMENTATION_PRESETS["mix_gaussian_large"], does_not_raise()), + ((None,) * 3, pytest.raises(AssertionError)), + ], +) +def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None: + """Test function augment_fn().""" + dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2, 3], "b": [4, 5, 6]}) + with expected_error: + fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore + assert len(fn_twins) == 2 diff --git a/tests/test_nn_rewrite_core_utils.py b/tests/test_nn_rewrite_core_utils.py deleted file mode 100644 index d806a7b..0000000 --- a/tests/test_nn_rewrite_core_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for module mlia.nn.rewrite.utils.""" -from pathlib import Path - -import pytest -import tensorflow as tf -from tensorflow.lite.python.schema_py_generated import ModelT - -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save -from tests.utils.rewrite import models_are_equal - - -def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None: - """Test the load/save functions for TensorFlow Lite models.""" - with pytest.raises(FileNotFoundError): - load("THIS_IS_NOT_A_REAL_FILE") - - model = load(test_tflite_model) - assert isinstance(model, ModelT) - assert model.subgraphs - - output_file = tmp_path / "test.tflite" - assert not output_file.is_file() - save(model, output_file) - assert output_file.is_file() - - model_copy = load(str(output_file)) - assert models_are_equal(model, model_copy) - - # Double check that the TensorFlow Lite Interpreter can still load the file. - tf.lite.Interpreter(model_path=str(output_file)) diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py index 7fc8048..d030350 100644 --- a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py +++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py @@ -5,6 +5,10 @@ from __future__ import annotations from pathlib import Path +import pytest +import tensorflow as tf + +from mlia.nn.rewrite.core.utils.numpy_tfrecord import make_decode_fn from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec @@ -16,3 +20,24 @@ def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None: sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file)) assert output_file.is_file() assert numpytf_count(str(output_file)) == 1 + + +def test_make_decode_fn(test_tfrecord: Path) -> None: + """Test function make_decode_fn().""" + decode = make_decode_fn(str(test_tfrecord)) + dataset = tf.data.TFRecordDataset(str(test_tfrecord)) + features = decode(next(iter(dataset))) + assert isinstance(features, dict) + assert len(features) == 1 + key, val = next(iter(features.items())) + assert isinstance(key, str) + assert isinstance(val, tf.Tensor) + assert val.dtype == tf.int8 + + with pytest.raises(FileNotFoundError): + make_decode_fn(str(test_tfrecord) + "_") + + +def test_numpytf_count(test_tfrecord: Path) -> None: + """Test function numpytf_count().""" + assert numpytf_count(test_tfrecord) == 3 diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py index 656619d..48aec0a 100644 --- a/tests/test_nn_tensorflow_config.py +++ b/tests/test_nn_tensorflow_config.py @@ -4,13 +4,28 @@ from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any +from typing import Generator +import numpy as np import pytest +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.tensorflow.config import get_model from mlia.nn.tensorflow.config import KerasModel +from mlia.nn.tensorflow.config import ModelConfiguration from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.config import TfModel +from tests.conftest import create_tfrecord + + +def test_model_configuration(test_keras_model: Path) -> None: + """Test ModelConfiguration class.""" + model = ModelConfiguration(model_path=test_keras_model) + assert test_keras_model.match(model.model_path) + with pytest.raises(NotImplementedError): + model.convert_to_keras("keras_model.h5") + with pytest.raises(NotImplementedError): + model.convert_to_tflite("model.tflite") def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None: @@ -38,7 +53,7 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None: @pytest.mark.parametrize( "model_path, expected_type, expected_error", [ - ("test.tflite", TFLiteModel, does_not_raise()), + ("test.tflite", TFLiteModel, pytest.raises(ValueError)), ("test.h5", KerasModel, does_not_raise()), ("test.hdf5", KerasModel, does_not_raise()), ( @@ -73,3 +88,26 @@ def test_get_model_dir( """Test TensorFlow Lite model type.""" model = get_model(str(test_models_path / model_path)) assert isinstance(model, expected_type) + + +@pytest.fixture(scope="session", name="test_tfrecord_fp32_batch_3") +def fixture_test_tfrecord_fp32_batch_3( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create tfrecord (same as test_tfrecord_fp32) but with batch size 3.""" + + def random_data() -> np.ndarray: + return np.random.rand(3, 28, 28, 1).astype(np.float32) + + yield from create_tfrecord(tmp_path_factory, random_data) + + +def test_tflite_model_call( + test_tflite_model_fp32: Path, test_tfrecord_fp32_batch_3: Path +) -> None: + """Test inference function of class TFLiteModel.""" + model = TFLiteModel(test_tflite_model_fp32, batch_size=2) + data = numpytf_read(test_tfrecord_fp32_batch_3) + for named_input in data.as_numpy_iterator(): + res = model(named_input) + assert res diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py index cd1fad6..3512cdd 100644 --- a/tests/test_nn_tensorflow_tflite_graph.py +++ b/tests/test_nn_tensorflow_tflite_graph.py @@ -1,15 +1,22 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the tflite_graph module.""" import json from pathlib import Path +import pytest +import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT + +from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import Op from mlia.nn.tensorflow.tflite_graph import parse_subgraphs +from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.tflite_graph import TensorInfo from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION from mlia.nn.tensorflow.tflite_graph import TFL_OP from mlia.nn.tensorflow.tflite_graph import TFL_TYPE +from tests.utils.rewrite import models_are_equal def test_tensor_info() -> None: @@ -79,3 +86,24 @@ def test_parse_subgraphs(test_tflite_model: Path) -> None: assert TFL_OP[oper.type] in TFL_OP assert len(oper.inputs) > 0 assert len(oper.outputs) > 0 + + +def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None: + """Test the load/save functions for TensorFlow Lite models.""" + with pytest.raises(FileNotFoundError): + load_fb("THIS_IS_NOT_A_REAL_FILE") + + model = load_fb(test_tflite_model) + assert isinstance(model, ModelT) + assert model.subgraphs + + output_file = tmp_path / "test.tflite" + assert not output_file.is_file() + save_fb(model, output_file) + assert output_file.is_file() + + model_copy = load_fb(str(output_file)) + assert models_are_equal(model, model_copy) + + # Double check that the TensorFlow Lite Interpreter can still load the file. + tf.lite.Interpreter(model_path=str(output_file)) diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py index 4264b4b..739bb11 100644 --- a/tests/utils/rewrite.py +++ b/tests/utils/rewrite.py @@ -3,8 +3,12 @@ """Common test utils for the rewrite tests.""" from __future__ import annotations +from typing import Any + from tensorflow.lite.python.schema_py_generated import ModelT +from mlia.nn.rewrite.core.train import TrainingParameters + def models_are_equal(model1: ModelT, model2: ModelT) -> bool: """Check that the two models are equal.""" @@ -25,3 +29,17 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool: return False # Tensor from graph1 not found in other graph.") return True + + +class TestTrainingParameters( + TrainingParameters +): # pylint: disable=too-few-public-methods + """ + TrainingParameter class for rewrites with different default values. + + To speed things up for the unit tests. + """ + + def __init__(self, *args: Any, steps: int = 32, **kwargs: Any) -> None: + """Initialize TrainingParameters with different defaults.""" + super().__init__(*args, steps=steps, **kwargs) # type: ignore |