aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-03-12 13:48:20 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-28 07:18:16 +0000
commit879138ebe9dd4b9387bddbdf4af45ba390f14bf6 (patch)
treee402c948e9ba63b28f3e1f4c7e5f9cebd903f9e0
parentf3f3ab451968350b8f6df2de7c60b2c2b9320b59 (diff)
downloadmlia-main.tar.gz
fix: Check that training checkpoint feature works as expectedHEADmain
Fixes the checkpoint feature in training and also completes unit tests for it Resolves: MLIA-1111 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: Ic2b84b4b045db5ba3cb299fcd137ae9d31df5298
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py24
-rw-r--r--src/mlia/nn/rewrite/core/train.py12
-rw-r--r--tests/test_nn_rewrite_core_train.py21
3 files changed, 46 insertions, 11 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index c7d13ba..f5f5561 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -150,6 +150,7 @@ class RewritingOptimizer(Optimizer):
"Input and output tensor names need to be set for rewrite."
)
+ self.optimizer_configuration.train_params.checkpoint_at = [5000, 10000]
orig_vs_repl_stats, total_stats = train(
source_model=tflite_model,
unmodified_model=tflite_model if use_unmodified_model else None,
@@ -162,9 +163,22 @@ class RewritingOptimizer(Optimizer):
)
if orig_vs_repl_stats:
- orig_vs_repl = ["Replaced sub-graph only"] + [
- f"{stat:.3f}" for stat in orig_vs_repl_stats
- ]
+ model_stats: list = []
+ cp_param = self.optimizer_configuration.train_params.checkpoint_at
+ checkpoints = (
+ [
+ "At checkpoint " + str(checkpoint) + " steps"
+ for checkpoint in cp_param
+ ]
+ if cp_param
+ else []
+ )
+ checkpoints.append("All Steps")
+ for checkpoint, orig_vs_repl_stat in zip(checkpoints, orig_vs_repl_stats):
+ model_stats.append(
+ ["Replaced sub-graph: " + checkpoint]
+ + [f"{stat:.3f}" for stat in orig_vs_repl_stat]
+ )
total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats]
notes = (
"These metrics show the difference between original model\n"
@@ -178,14 +192,14 @@ class RewritingOptimizer(Optimizer):
table = Table(
columns=[
Column(
- "Original vs. optimized",
+ "Original vs. Optimized",
alias="metric",
fmt=Format(wrap_width=40),
),
Column("MAE", alias="value", fmt=Format(wrap_width=15)),
Column("NRMSE", alias="value", fmt=Format(wrap_width=15)),
],
- rows=[orig_vs_repl, total],
+ rows=[*model_stats, total],
name="Rewrite performance metrics",
alias="rewrite_performance_metrics",
notes=notes,
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 60c39ae..e0b3c75 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -159,10 +159,10 @@ def train(
if unmodified_model_dir:
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
- return (results if train_params.checkpoint_at else results[0]), [
+ return results, [
mae,
nrmse,
- ] # only return a list if multiple checkpoints are asked for
+ ]
def eval_in_dir(
@@ -452,8 +452,12 @@ def train_in_dir(
)
if steps_so_far < train_params.steps:
- filename, ext = Path(output_filename).parts[1:]
- checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
+ filename = Path(output_filename).stem
+ filename_dir = Path(output_filename).parent.as_posix()
+ ext = Path(output_filename).suffix
+ checkpoint_filename = (
+ filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext
+ )
else:
checkpoint_filename = str(output_filename)
with log_action(
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 6d24133..624c5ed 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -63,8 +63,11 @@ def check_train(
output_tensors=["StatefulPartitionedCall:0"],
train_params=train_params,
)
- assert len(result) == 2
- assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}"
+
+ assert len(result[0][0]) == 2
+ assert all(
+ res >= 0.0 for res in result[0][0]
+ ), f"Results out of bound: {result}"
assert output_file.is_file()
if quantized:
@@ -229,3 +232,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None:
with expected_error:
fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore
assert len(fn_twins) == 2
+
+
+def test_train_checkpoint(
+ test_tflite_model: Path,
+ test_tfrecord: Path,
+) -> None:
+ """Test the train() function with valid checkpoint parameters."""
+ check_train(
+ tflite_model=test_tflite_model,
+ tfrecord=test_tfrecord,
+ train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]),
+ use_unmodified_model=False,
+ quantized=True,
+ )