diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 60c39ae..e0b3c75 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -159,10 +159,10 @@ def train( if unmodified_model_dir: cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() - return (results if train_params.checkpoint_at else results[0]), [ + return results, [ mae, nrmse, - ] # only return a list if multiple checkpoints are asked for + ] def eval_in_dir( @@ -452,8 +452,12 @@ def train_in_dir( ) if steps_so_far < train_params.steps: - filename, ext = Path(output_filename).parts[1:] - checkpoint_filename = filename + (f"_@{steps_so_far}") + ext + filename = Path(output_filename).stem + filename_dir = Path(output_filename).parent.as_posix() + ext = Path(output_filename).suffix + checkpoint_filename = ( + filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext + ) else: checkpoint_filename = str(output_filename) with log_action( |