diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 76 |
1 files changed, 42 insertions, 34 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 3c2ef3e..4493671 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -4,6 +4,7 @@ # pylint: disable=too-many-arguments from __future__ import annotations +from contextlib import ExitStack as does_not_raise from pathlib import Path from tempfile import TemporaryDirectory from typing import Any @@ -12,10 +13,13 @@ import numpy as np import pytest import tensorflow as tf -from mlia.nn.rewrite.core.train import augmentation_presets +from mlia.nn.rewrite.core.train import augment_fn_twins +from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train +from mlia.nn.rewrite.core.train import TrainingParameters +from tests.utils.rewrite import TestTrainingParameters def replace_fully_connected_with_conv( @@ -41,15 +45,8 @@ def replace_fully_connected_with_conv( def check_train( tflite_model: Path, tfrecord: Path, - batch_size: int = 1, - verbose: bool = False, - show_progress: bool = False, - augmentation_preset: tuple[float | None, float | None] = augmentation_presets[ - "none" - ], - lr_schedule: LearningRateSchedule = "cosine", + train_params: TrainingParameters = TestTrainingParameters(), use_unmodified_model: bool = False, - num_procs: int = 1, ) -> None: """Test the train() function.""" with TemporaryDirectory() as tmp_dir: @@ -62,14 +59,7 @@ def check_train( replace_fn=replace_fully_connected_with_conv, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], - augment=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=batch_size, - verbose=verbose, - show_progress=show_progress, - learning_rate_schedule=lr_schedule, - num_procs=num_procs, + train_params=train_params, ) assert len(result) == 2 assert all(res >= 0.0 for res in result), f"Results out of bound: {result}" @@ -79,7 +69,6 @@ def check_train( @pytest.mark.parametrize( ( "batch_size", - "verbose", "show_progress", "augmentation_preset", "lr_schedule", @@ -87,14 +76,13 @@ def check_train( "num_procs", ), ( - (1, False, False, augmentation_presets["none"], "cosine", False, 2), - (32, True, True, augmentation_presets["gaussian"], "late", True, 1), - (2, False, False, augmentation_presets["mixup"], "constant", True, 0), + (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2), + (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1), + (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0), ( 1, False, - False, - augmentation_presets["mix_gaussian_large"], + AUGMENTATION_PRESETS["mix_gaussian_large"], "cosine", False, 2, @@ -105,7 +93,6 @@ def test_train( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, batch_size: int, - verbose: bool, show_progress: bool, augmentation_preset: tuple[float | None, float | None], lr_schedule: LearningRateSchedule, @@ -116,13 +103,14 @@ def test_train( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - batch_size=batch_size, - verbose=verbose, - show_progress=show_progress, - augmentation_preset=augmentation_preset, - lr_schedule=lr_schedule, + train_params=TestTrainingParameters( + batch_size=batch_size, + show_progress=show_progress, + augmentations=augmentation_preset, + learning_rate_schedule=lr_schedule, + num_procs=num_procs, + ), use_unmodified_model=use_unmodified_model, - num_procs=num_procs, ) @@ -131,11 +119,13 @@ def test_train_invalid_schedule( test_tfrecord_fp32: Path, ) -> None: """Test the train() function with an invalid schedule.""" - with pytest.raises(ValueError): + with pytest.raises(AssertionError): check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - lr_schedule="unknown_schedule", # type: ignore + train_params=TestTrainingParameters( + learning_rate_schedule="unknown_schedule", + ), ) @@ -144,11 +134,13 @@ def test_train_invalid_augmentation( test_tfrecord_fp32: Path, ) -> None: """Test the train() function with an invalid augmentation.""" - with pytest.raises(ValueError): + with pytest.raises(AssertionError): check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - augmentation_preset=(1.0, 2.0, 3.0), # type: ignore + train_params=TestTrainingParameters( + augmentations=(1.0, 2.0, 3.0), + ), ) @@ -159,3 +151,19 @@ def test_mixup() -> None: assert src.shape == dst.shape assert np.all(dst >= 0.0) assert np.all(dst <= 3.0) + + +@pytest.mark.parametrize( + "augmentations, expected_error", + [ + (AUGMENTATION_PRESETS["none"], does_not_raise()), + (AUGMENTATION_PRESETS["mix_gaussian_large"], does_not_raise()), + ((None,) * 3, pytest.raises(AssertionError)), + ], +) +def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None: + """Test function augment_fn().""" + dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2, 3], "b": [4, 5, 6]}) + with expected_error: + fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore + assert len(fn_twins) == 2 |