aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_train.py
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-03-12 13:48:20 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-28 07:18:16 +0000
commit879138ebe9dd4b9387bddbdf4af45ba390f14bf6 (patch)
treee402c948e9ba63b28f3e1f4c7e5f9cebd903f9e0 /tests/test_nn_rewrite_core_train.py
parentf3f3ab451968350b8f6df2de7c60b2c2b9320b59 (diff)
downloadmlia-879138ebe9dd4b9387bddbdf4af45ba390f14bf6.tar.gz
fix: Check that training checkpoint feature works as expected
Fixes the checkpoint feature in training and also completes unit tests for it Resolves: MLIA-1111 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: Ic2b84b4b045db5ba3cb299fcd137ae9d31df5298
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r--tests/test_nn_rewrite_core_train.py21
1 files changed, 19 insertions, 2 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 6d24133..624c5ed 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -63,8 +63,11 @@ def check_train(
output_tensors=["StatefulPartitionedCall:0"],
train_params=train_params,
)
- assert len(result) == 2
- assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}"
+
+ assert len(result[0][0]) == 2
+ assert all(
+ res >= 0.0 for res in result[0][0]
+ ), f"Results out of bound: {result}"
assert output_file.is_file()
if quantized:
@@ -229,3 +232,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
with expected_error:
fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
assert len(fn_twins) == 2
+
+
+def test_train_checkpoint(
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+) -> None:
+ """Test the train() function with valid checkpoint parameters."""
+ check_train(
+ tflite_model=test_tflite_model,
+ tfrecord=test_tfrecord,
+ train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]),
+ use_unmodified_model=False,
+ quantized=True,
+ )