aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r--tests/test_nn_rewrite_core_train.py76
1 files changed, 42 insertions, 34 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 3c2ef3e..4493671 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -4,6 +4,7 @@
# pylint: disable=too-many-arguments
from __future__ import annotations
+from contextlib import ExitStack as does_not_raise
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
@@ -12,10 +13,13 @@ import numpy as np
import pytest
import tensorflow as tf
-from mlia.nn.rewrite.core.train import augmentation_presets
+from mlia.nn.rewrite.core.train import augment_fn_twins
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
+from mlia.nn.rewrite.core.train import TrainingParameters
+from tests.utils.rewrite import TestTrainingParameters
def replace_fully_connected_with_conv(
@@ -41,15 +45,8 @@ def replace_fully_connected_with_conv(
def check_train(
tflite_model: Path,
tfrecord: Path,
- batch_size: int = 1,
- verbose: bool = False,
- show_progress: bool = False,
- augmentation_preset: tuple[float | None, float | None] = augmentation_presets[
- "none"
- ],
- lr_schedule: LearningRateSchedule = "cosine",
+ train_params: TrainingParameters = TestTrainingParameters(),
use_unmodified_model: bool = False,
- num_procs: int = 1,
) -> None:
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
@@ -62,14 +59,7 @@ def check_train(
replace_fn=replace_fully_connected_with_conv,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
- augment=augmentation_preset,
- steps=32,
- learning_rate=1e-3,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- learning_rate_schedule=lr_schedule,
- num_procs=num_procs,
+ train_params=train_params,
)
assert len(result) == 2
assert all(res >= 0.0 for res in result), f"Results out of bound: {result}"
@@ -79,7 +69,6 @@ def check_train(
@pytest.mark.parametrize(
(
"batch_size",
- "verbose",
"show_progress",
"augmentation_preset",
"lr_schedule",
@@ -87,14 +76,13 @@ def check_train(
"num_procs",
),
(
- (1, False, False, augmentation_presets["none"], "cosine", False, 2),
- (32, True, True, augmentation_presets["gaussian"], "late", True, 1),
- (2, False, False, augmentation_presets["mixup"], "constant", True, 0),
+ (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2),
+ (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1),
+ (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0),
(
1,
False,
- False,
- augmentation_presets["mix_gaussian_large"],
+ AUGMENTATION_PRESETS["mix_gaussian_large"],
"cosine",
False,
2,
@@ -105,7 +93,6 @@ def test_train(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
batch_size: int,
- verbose: bool,
show_progress: bool,
augmentation_preset: tuple[float | None, float | None],
lr_schedule: LearningRateSchedule,
@@ -116,13 +103,14 @@ def test_train(
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- batch_size=batch_size,
- verbose=verbose,
- show_progress=show_progress,
- augmentation_preset=augmentation_preset,
- lr_schedule=lr_schedule,
+ train_params=TestTrainingParameters(
+ batch_size=batch_size,
+ show_progress=show_progress,
+ augmentations=augmentation_preset,
+ learning_rate_schedule=lr_schedule,
+ num_procs=num_procs,
+ ),
use_unmodified_model=use_unmodified_model,
- num_procs=num_procs,
)
@@ -131,11 +119,13 @@ def test_train_invalid_schedule(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid schedule."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- lr_schedule="unknown_schedule", # type: ignore
+ train_params=TestTrainingParameters(
+ learning_rate_schedule="unknown_schedule",
+ ),
)
@@ -144,11 +134,13 @@ def test_train_invalid_augmentation(
test_tfrecord_fp32: Path,
) -> None:
"""Test the train() function with an invalid augmentation."""
- with pytest.raises(ValueError):
+ with pytest.raises(AssertionError):
check_train(
tflite_model=test_tflite_model_fp32,
tfrecord=test_tfrecord_fp32,
- augmentation_preset=(1.0, 2.0, 3.0), # type: ignore
+ train_params=TestTrainingParameters(
+ augmentations=(1.0, 2.0, 3.0),
+ ),
)
@@ -159,3 +151,19 @@ def test_mixup() -> None:
assert src.shape == dst.shape
assert np.all(dst >= 0.0)
assert np.all(dst <= 3.0)
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_error",
+ [
+ (AUGMENTATION_PRESETS["none"], does_not_raise()),
+ (AUGMENTATION_PRESETS["mix_gaussian_large"], does_not_raise()),
+ ((None,) * 3, pytest.raises(AssertionError)),
+ ],
+)
+def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
+ """Test function augment_fn()."""
+ dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2, 3], "b": [4, 5, 6]})
+ with expected_error:
+ fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
+ assert len(fn_twins) == 2