diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-12 13:48:20 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-28 07:18:16 +0000 |
commit | 879138ebe9dd4b9387bddbdf4af45ba390f14bf6 (patch) | |
tree | e402c948e9ba63b28f3e1f4c7e5f9cebd903f9e0 /src/mlia/nn/rewrite/core/train.py | |
parent | f3f3ab451968350b8f6df2de7c60b2c2b9320b59 (diff) | |
download | mlia-879138ebe9dd4b9387bddbdf4af45ba390f14bf6.tar.gz |
fix: Check that training checkpoint feature works as expected
Fixes the checkpoint feature in training and also completes unit tests for it
Resolves: MLIA-1111
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: Ic2b84b4b045db5ba3cb299fcd137ae9d31df5298
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 60c39ae..e0b3c75 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -159,10 +159,10 @@ def train( if unmodified_model_dir: cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() - return (results if train_params.checkpoint_at else results[0]), [ + return results, [ mae, nrmse, - ] # only return a list if multiple checkpoints are asked for + ] def eval_in_dir( @@ -452,8 +452,12 @@ def train_in_dir( ) if steps_so_far < train_params.steps: - filename, ext = Path(output_filename).parts[1:] - checkpoint_filename = filename + (f"_@{steps_so_far}") + ext + filename = Path(output_filename).stem + filename_dir = Path(output_filename).parent.as_posix() + ext = Path(output_filename).suffix + checkpoint_filename = ( + filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext + ) else: checkpoint_filename = str(output_filename) with log_action( |