aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_rewrite_core_train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_rewrite_core_train.py')
-rw-r--r--tests/test_nn_rewrite_core_train.py21
1 files changed, 19 insertions, 2 deletions
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 6d24133..624c5ed 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -63,8 +63,11 @@ def check_train(
output_tensors=["StatefulPartitionedCall:0"],
train_params=train_params,
)
- assert len(result) == 2
- assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}"
+
+ assert len(result[0][0]) == 2
+ assert all(
+ res >= 0.0 for res in result[0][0]
+ ), f"Results out of bound: {result}"
assert output_file.is_file()
if quantized:
@@ -229,3 +232,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
with expected_error:
fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
assert len(fn_twins) == 2
+
+
+def test_train_checkpoint(
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+) -> None:
+ """Test the train() function with valid checkpoint parameters."""
+ check_train(
+ tflite_model=test_tflite_model,
+ tfrecord=test_tfrecord,
+ train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]),
+ use_unmodified_model=False,
+ quantized=True,
+ )