From 879138ebe9dd4b9387bddbdf4af45ba390f14bf6 Mon Sep 17 00:00:00 2001 From: Nathan Bailey Date: Tue, 12 Mar 2024 13:48:20 +0000 Subject: fix: Check that training checkpoint feature works as expected Fixes the checkpoint feature in training and also completes unit tests for it Resolves: MLIA-1111 Signed-off-by: Nathan Bailey Change-Id: Ic2b84b4b045db5ba3cb299fcd137ae9d31df5298 --- src/mlia/nn/rewrite/core/rewrite.py | 24 +++++++++++++++++++----- src/mlia/nn/rewrite/core/train.py | 12 ++++++++---- tests/test_nn_rewrite_core_train.py | 21 +++++++++++++++++++-- 3 files changed, 46 insertions(+), 11 deletions(-) diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index c7d13ba..f5f5561 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -150,6 +150,7 @@ class RewritingOptimizer(Optimizer): "Input and output tensor names need to be set for rewrite." ) + self.optimizer_configuration.train_params.checkpoint_at = [5000, 10000] orig_vs_repl_stats, total_stats = train( source_model=tflite_model, unmodified_model=tflite_model if use_unmodified_model else None, @@ -162,9 +163,22 @@ class RewritingOptimizer(Optimizer): ) if orig_vs_repl_stats: - orig_vs_repl = ["Replaced sub-graph only"] + [ - f"{stat:.3f}" for stat in orig_vs_repl_stats - ] + model_stats: list = [] + cp_param = self.optimizer_configuration.train_params.checkpoint_at + checkpoints = ( + [ + "At checkpoint " + str(checkpoint) + " steps" + for checkpoint in cp_param + ] + if cp_param + else [] + ) + checkpoints.append("All Steps") + for checkpoint, orig_vs_repl_stat in zip(checkpoints, orig_vs_repl_stats): + model_stats.append( + ["Replaced sub-graph: " + checkpoint] + + [f"{stat:.3f}" for stat in orig_vs_repl_stat] + ) total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats] notes = ( "These metrics show the difference between original model\n" @@ -178,14 +192,14 @@ class RewritingOptimizer(Optimizer): table = Table( columns=[ Column( - "Original vs. optimized", + "Original vs. Optimized", alias="metric", fmt=Format(wrap_width=40), ), Column("MAE", alias="value", fmt=Format(wrap_width=15)), Column("NRMSE", alias="value", fmt=Format(wrap_width=15)), ], - rows=[orig_vs_repl, total], + rows=[*model_stats, total], name="Rewrite performance metrics", alias="rewrite_performance_metrics", notes=notes, diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 60c39ae..e0b3c75 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -159,10 +159,10 @@ def train( if unmodified_model_dir: cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() - return (results if train_params.checkpoint_at else results[0]), [ + return results, [ mae, nrmse, - ] # only return a list if multiple checkpoints are asked for + ] def eval_in_dir( @@ -452,8 +452,12 @@ def train_in_dir( ) if steps_so_far < train_params.steps: - filename, ext = Path(output_filename).parts[1:] - checkpoint_filename = filename + (f"_@{steps_so_far}") + ext + filename = Path(output_filename).stem + filename_dir = Path(output_filename).parent.as_posix() + ext = Path(output_filename).suffix + checkpoint_filename = ( + filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext + ) else: checkpoint_filename = str(output_filename) with log_action( diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 6d24133..624c5ed 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -63,8 +63,11 @@ def check_train( output_tensors=["StatefulPartitionedCall:0"], train_params=train_params, ) - assert len(result) == 2 - assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}" + + assert len(result[0][0]) == 2 + assert all( + res >= 0.0 for res in result[0][0] + ), f"Results out of bound: {result}" assert output_file.is_file() if quantized: @@ -229,3 +232,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None: with expected_error: fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore assert len(fn_twins) == 2 + + +def test_train_checkpoint( + test_tflite_model: Path, + test_tfrecord: Path, +) -> None: + """Test the train() function with valid checkpoint parameters.""" + check_train( + tflite_model=test_tflite_model, + tfrecord=test_tfrecord, + train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]), + use_unmodified_model=False, + quantized=True, + ) -- cgit v1.2.1