diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 6d24133..94c99ff 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -20,6 +20,7 @@ from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters +from tests.test_nn_rewrite_core_rewrite import TestRewrite from tests.utils.rewrite import MockTrainingParameters @@ -53,18 +54,23 @@ def check_train( """Test the train() function.""" with TemporaryDirectory() as tmp_dir: output_file = Path(tmp_dir, "out.tflite") + mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv) result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, output_model=str(output_file), input_tfrec=str(tfrecord), - replace_fn=replace_fully_connected_with_conv, + rewrite=mock_rewrite, + is_qat=False, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], train_params=train_params, ) - assert len(result) == 2 - assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}" + + assert len(result[0][0]) == 2 + assert all( + res >= 0.0 for res in result[0][0] + ), f"Results out of bound: {result}" assert output_file.is_file() if quantized: @@ -229,3 +235,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None: with expected_error: fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore assert len(fn_twins) == 2 + + +def test_train_checkpoint( + test_tflite_model: Path, + test_tfrecord: Path, +) -> None: + """Test the train() function with valid checkpoint parameters.""" + check_train( + tflite_model=test_tflite_model, + tfrecord=test_tfrecord, + train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]), + use_unmodified_model=False, + quantized=True, + ) |