diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 209 |
1 files changed, 182 insertions, 27 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 60c39ae..4204978 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Sequential trainer.""" +# pylint: disable=too-many-arguments # pylint: disable=too-many-locals # pylint: disable=too-many-statements from __future__ import annotations @@ -22,7 +23,6 @@ from typing import Literal import numpy as np import tensorflow as tf -import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from numpy.random import Generator @@ -62,7 +62,7 @@ LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule) class TrainingParameters: """Define default parameters for the training.""" - augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"] + augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["none"] batch_size: int = 32 steps: int = 48000 learning_rate: float = 1e-3 @@ -73,12 +73,13 @@ class TrainingParameters: checkpoint_at: list | None = None -def train( +def train( # pylint: disable=too-many-arguments source_model: str, unmodified_model: Any, output_model: str, input_tfrec: str, - replace_fn: Callable, + rewrite: Callable, + is_qat: bool, input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), @@ -118,7 +119,8 @@ def train( train_dir=train_dir, baseline_dir=unmodified_model_dir_path, output_filename=Path(train_dir, "new.tflite"), - replace_fn=replace_fn, + rewrite=rewrite, + is_qat=is_qat, train_params=train_params, ) @@ -145,7 +147,8 @@ def train( # Assess the output diff between the parts after the rewrite subgraph # in original and optimized model optimized_end_path = Path(train_dir, "optimized_end.tfrec") - end_path = Path(train_dir, "end.tfrec") + optimized_end_path_dequant = Path(train_dir, "optimized_end_dequant.tfrec") + end_path = Path(train_dir, "end_dequant.tfrec") record_model( str(input_tfrec), @@ -153,16 +156,18 @@ def train( optimized_end_path, num_procs=train_params.num_procs, num_threads=train_params.num_threads, + dequantize_output=True, ) - mae, nrmse = diff_stats(end_path, str(optimized_end_path)) + + mae, nrmse = diff_stats(end_path, optimized_end_path_dequant) if unmodified_model_dir: cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() - return (results if train_params.checkpoint_at else results[0]), [ + return results, [ mae, nrmse, - ] # only return a list if multiple checkpoints are asked for + ] def eval_in_dir( @@ -177,24 +182,27 @@ def eval_in_dir( model_input = ( model_input_path if model_input_path.exists() - else ExtractPaths.tfrec.input(target_dir, False) + else ExtractPaths.tfrec.input(target_dir, True) ) output = ( model_output_path if model_output_path.exists() - else ExtractPaths.tfrec.output(target_dir, False) + else ExtractPaths.tfrec.output(target_dir, True) ) with tempfile.TemporaryDirectory() as tmp_dir: predict = Path(tmp_dir, "predict.tfrec") + predict_dequant = Path(tmp_dir, "predict_dequant.tfrec") record_model( str(model_input), new_part, str(predict), num_procs=num_procs, num_threads=num_threads, + dequantize_output=True, + quantize_input=True, ) - mae, nrmse = diff_stats(str(output), str(predict)) + mae, nrmse = diff_stats(str(output), predict_dequant) return mae, nrmse @@ -247,7 +255,7 @@ def set_up_data_pipeline( augmentations: tuple[float | None, float | None], steps: int, batch_size: int = 32, -) -> tf.data.Dataset: +) -> tuple[tf.data.Dataset, int]: """Create a data pipeline for training of the replacement model.""" _check_model_compatibility(teacher, replace) @@ -338,14 +346,15 @@ def set_up_data_pipeline( dataset = dataset.map(restore_shapes) dataset = dataset.prefetch(tf.data.AUTOTUNE) - return dataset + return dataset, steps_per_epoch def train_in_dir( train_dir: str, baseline_dir: Any, output_filename: Path, - replace_fn: Callable, + rewrite: Callable, + is_qat: bool, train_params: TrainingParameters = TrainingParameters(), ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ @@ -370,7 +379,7 @@ def train_in_dir( if model_is_quantized: replace.check_datatypes(np.int8) - dataset = set_up_data_pipeline( + dataset, steps_per_epoch = set_up_data_pipeline( teacher, replace, train_dir, @@ -380,15 +389,15 @@ def train_in_dir( ) input_shape = teacher.shape_from_name[input_name][1:] - output_shape = teacher.shape_from_name[output_name][1:] - model = replace_fn(input_shape, output_shape) + output_shape = teacher.shape_from_name[output_name][1:] optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - if model_is_quantized: - model = tfmot.quantization.keras.quantize_model(model) - model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) + + model = create_model( + rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + ) logger.info(model.summary()) @@ -428,11 +437,130 @@ def train_in_dir( elif train_params.learning_rate_schedule == "constant": callbacks = [] - output_filenames = [] + callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] + output_filenames: list = [] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps ] + model, output_filenames = model_fit( + model, + train_params, + checkpoints, + optimizer, + dataset, + callbacks, + output_filename, + rewrite, + replace, + input_name, + output_name, + model_is_quantized, + output_filenames, + input_shape, + output_shape, + loss_fn, + steps_per_epoch, + post_process=True, + ) + + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) + if model_is_quantized and is_qat: + model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] + checkpoints = ( + train_params.checkpoint_at if train_params.checkpoint_at else [] + ) + [train_params.steps] + output_filenames = [] + + if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined] + rewrite.training_callbacks() # type: ignore[attr-defined] + ).issubset(callbacks): + callbacks.pop(-1) + + optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) + model = model_compile(model, optimizer, loss_fn) + + model, output_filenames = model_fit( + model, + train_params, + checkpoints, + optimizer, + dataset, + callbacks, + output_filename, + rewrite, + replace, + input_name, + output_name, + model_is_quantized, + output_filenames, + input_shape, + output_shape, + loss_fn, + steps_per_epoch, + ) + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) + + teacher.close() + return output_filenames + +def model_compile( + model: keras.Model, + optimizer: keras.optimizers.Nadam, + loss_fn: keras.losses.Loss, +) -> keras.Model: + """Compiles a tflite model.""" + model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) + return model + + +def create_model( # pylint: disable=too-many-arguments + rewrite: Callable, + input_shape: int, + output_shape: int, + optimizer: Callable, + loss_fn: Callable, + model_is_quantized: bool, + model_to_load_from: keras.model | None = None, +) -> keras.Model: + """Create a model, optionally from another.""" + model = rewrite(input_shape, output_shape) + if model_is_quantized: + model = rewrite.quantize(model) # type: ignore[attr-defined] + model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) + if model_to_load_from: + model.set_weights(model_to_load_from.get_weights()) + return model + + +def model_fit( # pylint: disable=too-many-arguments + model: keras.Model, + train_params: TrainingParameters, + checkpoints: list, + optimizer: tf.optimizers.Nadam, + dataset: tf.data.Dataset, + callbacks: list, + output_filename: Path, + rewrite: Callable, + replace: TFLiteModel, + input_name: str, + output_name: str, + model_is_quantized: bool, + output_filenames: list, + input_shape: int, + output_shape: int, + loss_fn: Callable, + steps_per_epoch: int, + post_process: bool = False, +) -> keras.Model: + """Train a tflite model.""" + steps_so_far = 0 while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far lr_start = optimizer.learning_rate.numpy() @@ -452,15 +580,43 @@ def train_in_dir( ) if steps_so_far < train_params.steps: - filename, ext = Path(output_filename).parts[1:] - checkpoint_filename = filename + (f"_@{steps_so_far}") + ext + filename = Path(output_filename).stem + filename_dir = Path(output_filename).parent.as_posix() + ext = Path(output_filename).suffix + checkpoint_filename = ( + filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext + ) + # If post processing we are stripping the clustering/pruning layers below + # Thus copy the model before saving, so training can continue + if post_process: + model_to_save = create_model( + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + model_to_load_from=model, + ) + else: + model_to_save = model else: checkpoint_filename = str(output_filename) + logger.info("Evaluate final Keras Model using %d steps", steps_per_epoch) + model.evaluate( + dataset, + steps=steps_per_epoch, + ) + model_to_save = model with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): + if post_process: + model_to_save = rewrite.post_process( # type: ignore[attr-defined] + model_to_save + ) save_as_tflite( - model, + model_to_save, checkpoint_filename, input_name, replace.shape_from_name[input_name], @@ -470,8 +626,7 @@ def train_in_dir( ) output_filenames.append(checkpoint_filename) - teacher.close() - return output_filenames + return model_to_save, output_filenames def save_as_tflite( |