diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 118 |
1 files changed, 89 insertions, 29 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 4b9821c..88efa23 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -73,15 +73,15 @@ class TrainingParameters: checkpoint_at: list | None = None -def train( +def train( # pylint: disable=too-many-arguments source_model: str, unmodified_model: Any, output_model: str, input_tfrec: str, rewrite: Callable, + is_qat: bool, input_tensors: list, output_tensors: list, - is_qat: bool, train_params: TrainingParameters = TrainingParameters(), ) -> Any: """Extract and train a model, and return the results.""" @@ -383,15 +383,15 @@ def train_in_dir( ) input_shape = teacher.shape_from_name[input_name][1:] + output_shape = teacher.shape_from_name[output_name][1:] - model = rewrite(input_shape, output_shape) optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - if model_is_quantized: - model = rewrite.quantize(model) # type: ignore[attr-defined] - model = model_compile(model, optimizer, loss_fn) + model = create_model( + rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + ) logger.info(model.summary()) @@ -432,16 +432,14 @@ def train_in_dir( callbacks = [] callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] - - output_filenames = [] # type: list[str] + output_filenames: list = [] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps ] - model, output_filenames = model_fit( model, train_params, - checkpoints.copy(), + checkpoints, optimizer, dataset, callbacks, @@ -452,22 +450,35 @@ def train_in_dir( output_name, model_is_quantized, output_filenames, + input_shape, + output_shape, + loss_fn, + post_process=True, ) + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) if model_is_quantized and is_qat: - model = rewrite.pruning_preserved_quantization( # type: ignore[attr-defined] - model, - ) + model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] + checkpoints = ( + train_params.checkpoint_at if train_params.checkpoint_at else [] + ) + [train_params.steps] + output_filenames = [] + + if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined] + rewrite.training_callbacks() # type: ignore[attr-defined] + ).issubset(callbacks): + callbacks.pop(-1) + optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) model = model_compile(model, optimizer, loss_fn) - callbacks.pop(-1) - output_filenames = [] - model, output_filenames = model_fit( model, train_params, - checkpoints.copy(), + checkpoints, optimizer, dataset, callbacks, @@ -478,22 +489,50 @@ def train_in_dir( output_name, model_is_quantized, output_filenames, + input_shape, + output_shape, + loss_fn, ) + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) teacher.close() return output_filenames def model_compile( - model: tf.keras.Model, optimizer: tf.keras.optimizers, loss_fn: tf.keras.losses -) -> tf.keras.Model: + model: keras.Model, + optimizer: keras.optimizers.Nadam, + loss_fn: keras.losses.Loss, +) -> keras.Model: """Compiles a tflite model.""" model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) return model -def model_fit( - model: tf.keras.Model, +def create_model( # pylint: disable=too-many-arguments + rewrite: Callable, + input_shape: int, + output_shape: int, + optimizer: Callable, + loss_fn: Callable, + model_is_quantized: bool, + model_to_load_from: keras.model | None = None, +) -> keras.Model: + """Create a model, optionally from another.""" + model = rewrite(input_shape, output_shape) + if model_is_quantized: + model = rewrite.quantize(model) # type: ignore[attr-defined] + model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) + if model_to_load_from: + model.set_weights(model_to_load_from.get_weights()) + return model + + +def model_fit( # pylint: disable=too-many-arguments + model: keras.Model, train_params: TrainingParameters, checkpoints: list, optimizer: tf.optimizers.Nadam, @@ -506,8 +545,12 @@ def model_fit( output_name: str, model_is_quantized: bool, output_filenames: list, -) -> tuple[tf.keras.Model, list]: - """Train the model.""" + input_shape: int, + output_shape: int, + loss_fn: Callable, + post_process: bool = False, +) -> keras.Model: + """Train a tflite model.""" steps_so_far = 0 while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far @@ -534,18 +577,34 @@ def model_fit( checkpoint_filename = ( filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext ) + # If post processing we are stripping the clustering/pruning layers below + # Thus copy the model before saving, so training can continue + if post_process: + model_to_save = create_model( + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + model_to_load_from=model, + ) + else: + model_to_save = model else: checkpoint_filename = str(output_filename) - - if steps_so_far == train_params.steps: - model = rewrite.post_process(model) # type: ignore[attr-defined] + model_to_save = model with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): + if post_process: + model_to_save = rewrite.post_process( # type: ignore[attr-defined] + model_to_save + ) save_as_tflite( - model, - str(checkpoint_filename), + model_to_save, + checkpoint_filename, input_name, replace.shape_from_name[input_name], output_name, @@ -553,7 +612,8 @@ def model_fit( model_is_quantized, ) output_filenames.append(checkpoint_filename) - return model, output_filenames + + return model_to_save, output_filenames def save_as_tflite( |