diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 88 |
1 files changed, 6 insertions, 82 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index a929b14..096daf4 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -40,85 +40,7 @@ augmentation_presets = { "mix_gaussian_small": (1.6, 0.3), } - -class SequentialTrainer: - def __init__( - self, - source_model, - output_model, - input_tfrec, - augment="gaussian", - steps=6000, - lr=1e-3, - batch_size=32, - show_progress=True, - eval_fn=None, - num_procs=1, - num_threads=0, - ): - self.source_model = source_model - self.output_model = output_model - self.input_tfrec = input_tfrec - self.default_augment = augment - self.default_steps = steps - self.default_lr = lr - self.default_batch_size = batch_size - self.show_progress = show_progress - self.num_procs = num_procs - self.num_threads = num_threads - self.first_replace = True - self.eval_fn = eval_fn - - def replace( - self, - model_fn, - input_tensors, - output_tensors, - augment=None, - steps=None, - lr=None, - batch_size=None, - ): - augment = self.default_augment if augment is None else augment - steps = self.default_steps if steps is None else steps - lr = self.default_lr if lr is None else lr - batch_size = self.default_batch_size if batch_size is None else batch_size - - if isinstance(augment, str): - augment = augmentation_presets[augment] - - if self.first_replace: - source_model = self.source_model - unmodified_model = None - else: - source_model = self.output_model - unmodified_model = self.source_model - - mae, nrmse = train( - source_model, - unmodified_model, - self.output_model, - self.input_tfrec, - model_fn, - input_tensors, - output_tensors, - augment, - steps, - lr, - batch_size, - False, - self.show_progress, - None, - 0, - self.num_procs, - self.num_threads, - ) - - self.first_replace = False - if self.eval_fn: - return self.eval_fn(mae, nrmse, self.output_model) - else: - return mae, nrmse +learning_rate_schedules = {"cosine", "late", "constant"} def train( @@ -135,6 +57,7 @@ def train( batch_size, verbose, show_progress, + learning_rate_schedule="cosine", checkpoint_at=None, checkpoint_decay_steps=0, num_procs=1, @@ -183,6 +106,7 @@ def train( show_progress=show_progress, num_procs=num_procs, num_threads=num_threads, + schedule=learning_rate_schedule, ) for i, filename in enumerate(tflite_filenames): @@ -363,9 +287,9 @@ def train_in_dir( elif schedule == "constant": callbacks = [] else: - assert False, ( - 'LR schedule "%s" not implemented - expected "cosine", "constant" or "late"' - % schedule + assert schedule not in learning_rate_schedules + raise ValueError( + f'LR schedule "{schedule}" not implemented - expected one of {learning_rate_schedules}.' ) output_filenames = [] |