aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py88
1 files changed, 6 insertions, 82 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index a929b14..096daf4 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -40,85 +40,7 @@ augmentation_presets = {
"mix_gaussian_small": (1.6, 0.3),
}
-
-class SequentialTrainer:
- def __init__(
- self,
- source_model,
- output_model,
- input_tfrec,
- augment="gaussian",
- steps=6000,
- lr=1e-3,
- batch_size=32,
- show_progress=True,
- eval_fn=None,
- num_procs=1,
- num_threads=0,
- ):
- self.source_model = source_model
- self.output_model = output_model
- self.input_tfrec = input_tfrec
- self.default_augment = augment
- self.default_steps = steps
- self.default_lr = lr
- self.default_batch_size = batch_size
- self.show_progress = show_progress
- self.num_procs = num_procs
- self.num_threads = num_threads
- self.first_replace = True
- self.eval_fn = eval_fn
-
- def replace(
- self,
- model_fn,
- input_tensors,
- output_tensors,
- augment=None,
- steps=None,
- lr=None,
- batch_size=None,
- ):
- augment = self.default_augment if augment is None else augment
- steps = self.default_steps if steps is None else steps
- lr = self.default_lr if lr is None else lr
- batch_size = self.default_batch_size if batch_size is None else batch_size
-
- if isinstance(augment, str):
- augment = augmentation_presets[augment]
-
- if self.first_replace:
- source_model = self.source_model
- unmodified_model = None
- else:
- source_model = self.output_model
- unmodified_model = self.source_model
-
- mae, nrmse = train(
- source_model,
- unmodified_model,
- self.output_model,
- self.input_tfrec,
- model_fn,
- input_tensors,
- output_tensors,
- augment,
- steps,
- lr,
- batch_size,
- False,
- self.show_progress,
- None,
- 0,
- self.num_procs,
- self.num_threads,
- )
-
- self.first_replace = False
- if self.eval_fn:
- return self.eval_fn(mae, nrmse, self.output_model)
- else:
- return mae, nrmse
+learning_rate_schedules = {"cosine", "late", "constant"}
def train(
@@ -135,6 +57,7 @@ def train(
batch_size,
verbose,
show_progress,
+ learning_rate_schedule="cosine",
checkpoint_at=None,
checkpoint_decay_steps=0,
num_procs=1,
@@ -183,6 +106,7 @@ def train(
show_progress=show_progress,
num_procs=num_procs,
num_threads=num_threads,
+ schedule=learning_rate_schedule,
)
for i, filename in enumerate(tflite_filenames):
@@ -363,9 +287,9 @@ def train_in_dir(
elif schedule == "constant":
callbacks = []
else:
- assert False, (
- 'LR schedule "%s" not implemented - expected "cosine", "constant" or "late"'
- % schedule
+ assert schedule not in learning_rate_schedules
+ raise ValueError(
+ f'LR schedule "{schedule}" not implemented - expected one of {learning_rate_schedules}.'
)
output_filenames = []