diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index e0b3c75..89de880 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -22,7 +22,6 @@ from typing import Literal import numpy as np import tensorflow as tf -import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from numpy.random import Generator @@ -78,7 +77,7 @@ def train( unmodified_model: Any, output_model: str, input_tfrec: str, - replace_fn: Callable, + rewrite: Callable, input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), @@ -118,7 +117,7 @@ def train( train_dir=train_dir, baseline_dir=unmodified_model_dir_path, output_filename=Path(train_dir, "new.tflite"), - replace_fn=replace_fn, + rewrite=rewrite, train_params=train_params, ) @@ -345,7 +344,7 @@ def train_in_dir( train_dir: str, baseline_dir: Any, output_filename: Path, - replace_fn: Callable, + rewrite: Callable, train_params: TrainingParameters = TrainingParameters(), ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ @@ -381,13 +380,12 @@ def train_in_dir( input_shape = teacher.shape_from_name[input_name][1:] output_shape = teacher.shape_from_name[output_name][1:] - - model = replace_fn(input_shape, output_shape) + model = rewrite(input_shape, output_shape) optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - if model_is_quantized: - model = tfmot.quantization.keras.quantize_model(model) + + model = rewrite.quantize(model, model_is_quantized) # type: ignore[attr-defined] model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) logger.info(model.summary()) @@ -428,6 +426,8 @@ def train_in_dir( elif train_params.learning_rate_schedule == "constant": callbacks = [] + callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] + output_filenames = [] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps @@ -463,6 +463,9 @@ def train_in_dir( with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): + if steps_so_far == train_params.steps: + model = rewrite.post_process(model) # type: ignore[attr-defined] + save_as_tflite( model, checkpoint_filename, |