diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-12 13:48:20 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-28 07:18:16 +0000 |
commit | 879138ebe9dd4b9387bddbdf4af45ba390f14bf6 (patch) | |
tree | e402c948e9ba63b28f3e1f4c7e5f9cebd903f9e0 /src/mlia/nn/rewrite/core/rewrite.py | |
parent | f3f3ab451968350b8f6df2de7c60b2c2b9320b59 (diff) | |
download | mlia-879138ebe9dd4b9387bddbdf4af45ba390f14bf6.tar.gz |
fix: Check that training checkpoint feature works as expected
Fixes the checkpoint feature in training and also completes unit tests for it
Resolves: MLIA-1111
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: Ic2b84b4b045db5ba3cb299fcd137ae9d31df5298
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index c7d13ba..f5f5561 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -150,6 +150,7 @@ class RewritingOptimizer(Optimizer): "Input and output tensor names need to be set for rewrite." ) + self.optimizer_configuration.train_params.checkpoint_at = [5000, 10000] orig_vs_repl_stats, total_stats = train( source_model=tflite_model, unmodified_model=tflite_model if use_unmodified_model else None, @@ -162,9 +163,22 @@ class RewritingOptimizer(Optimizer): ) if orig_vs_repl_stats: - orig_vs_repl = ["Replaced sub-graph only"] + [ - f"{stat:.3f}" for stat in orig_vs_repl_stats - ] + model_stats: list = [] + cp_param = self.optimizer_configuration.train_params.checkpoint_at + checkpoints = ( + [ + "At checkpoint " + str(checkpoint) + " steps" + for checkpoint in cp_param + ] + if cp_param + else [] + ) + checkpoints.append("All Steps") + for checkpoint, orig_vs_repl_stat in zip(checkpoints, orig_vs_repl_stats): + model_stats.append( + ["Replaced sub-graph: " + checkpoint] + + [f"{stat:.3f}" for stat in orig_vs_repl_stat] + ) total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats] notes = ( "These metrics show the difference between original model\n" @@ -178,14 +192,14 @@ class RewritingOptimizer(Optimizer): table = Table( columns=[ Column( - "Original vs. optimized", + "Original vs. Optimized", alias="metric", fmt=Format(wrap_width=40), ), Column("MAE", alias="value", fmt=Format(wrap_width=15)), Column("NRMSE", alias="value", fmt=Format(wrap_width=15)), ], - rows=[orig_vs_repl, total], + rows=[*model_stats, total], name="Rewrite performance metrics", alias="rewrite_performance_metrics", notes=notes, |