diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-12 13:48:20 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-28 07:18:16 +0000 |
commit | 879138ebe9dd4b9387bddbdf4af45ba390f14bf6 (patch) | |
tree | e402c948e9ba63b28f3e1f4c7e5f9cebd903f9e0 /tests/test_nn_rewrite_core_train.py | |
parent | f3f3ab451968350b8f6df2de7c60b2c2b9320b59 (diff) | |
download | mlia-879138ebe9dd4b9387bddbdf4af45ba390f14bf6.tar.gz |
fix: Check that training checkpoint feature works as expected
Fixes the checkpoint feature in training and also completes unit tests for it
Resolves: MLIA-1111
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: Ic2b84b4b045db5ba3cb299fcd137ae9d31df5298
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 6d24133..624c5ed 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -63,8 +63,11 @@ def check_train( output_tensors=["StatefulPartitionedCall:0"], train_params=train_params, ) - assert len(result) == 2 - assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}" + + assert len(result[0][0]) == 2 + assert all( + res >= 0.0 for res in result[0][0] + ), f"Results out of bound: {result}" assert output_file.is_file() if quantized: @@ -229,3 +232,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None: with expected_error: fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore assert len(fn_twins) == 2 + + +def test_train_checkpoint( + test_tflite_model: Path, + test_tfrecord: Path, +) -> None: + """Test the train() function with valid checkpoint parameters.""" + check_train( + tflite_model=test_tflite_model, + tfrecord=test_tfrecord, + train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]), + use_unmodified_model=False, + quantized=True, + ) |