aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-03-12 13:48:20 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-28 07:18:16 +0000
commit879138ebe9dd4b9387bddbdf4af45ba390f14bf6 (patch)
treee402c948e9ba63b28f3e1f4c7e5f9cebd903f9e0 /src/mlia/nn/rewrite/core/train.py
parentf3f3ab451968350b8f6df2de7c60b2c2b9320b59 (diff)
downloadmlia-879138ebe9dd4b9387bddbdf4af45ba390f14bf6.tar.gz
fix: Check that training checkpoint feature works as expected
Fixes the checkpoint feature in training and also completes unit tests for it Resolves: MLIA-1111 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: Ic2b84b4b045db5ba3cb299fcd137ae9d31df5298
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 60c39ae..e0b3c75 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -159,10 +159,10 @@ def train(
if unmodified_model_dir:
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
- return (results if train_params.checkpoint_at else results[0]), [
+ return results, [
mae,
nrmse,
- ] # only return a list if multiple checkpoints are asked for
+ ]
def eval_in_dir(
@@ -452,8 +452,12 @@ def train_in_dir(
)
if steps_so_far < train_params.steps:
- filename, ext = Path(output_filename).parts[1:]
- checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
+ filename = Path(output_filename).stem
+ filename_dir = Path(output_filename).parent.as_posix()
+ ext = Path(output_filename).suffix
+ checkpoint_filename = (
+ filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext
+ )
else:
checkpoint_filename = str(output_filename)
with log_action(