diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 6d24133..624c5ed 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -63,8 +63,11 @@ def check_train( output_tensors=["StatefulPartitionedCall:0"], train_params=train_params, ) - assert len(result) == 2 - assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}" + + assert len(result[0][0]) == 2 + assert all( + res >= 0.0 for res in result[0][0] + ), f"Results out of bound: {result}" assert output_file.is_file() if quantized: @@ -229,3 +232,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None: with expected_error: fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore assert len(fn_twins) == 2 + + +def test_train_checkpoint( + test_tflite_model: Path, + test_tfrecord: Path, +) -> None: + """Test the train() function with valid checkpoint parameters.""" + check_train( + tflite_model=test_tflite_model, + tfrecord=test_tfrecord, + train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]), + use_unmodified_model=False, + quantized=True, + ) |