From 1ebb335cba516bcf973b041efa6a9878d1022b93 Mon Sep 17 00:00:00 2001 From: Madeleine Dunn Date: Wed, 21 Feb 2024 17:10:07 +0000 Subject: feat: Implement int8 sparsity 2:4 rewrite - Implement pruning-preserving quantisation aware training - Rework the training logic to avoid duplication - Remove the DynamicallyLoadedRewrite class as it is now unused Resolves: MLIA-1003 Signed-off-by: Madeleine Dunn Change-Id: Ia7a4acf5f477a27963cffa88180cca085b32ffe4 --- src/mlia/nn/rewrite/core/train.py | 95 ++++++++++++++++++++++++++++++++++----- 1 file changed, 85 insertions(+), 10 deletions(-) (limited to 'src/mlia/nn/rewrite/core/train.py') diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 89de880..4b9821c 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Sequential trainer.""" +# pylint: disable=too-many-arguments # pylint: disable=too-many-locals # pylint: disable=too-many-statements from __future__ import annotations @@ -80,6 +81,7 @@ def train( rewrite: Callable, input_tensors: list, output_tensors: list, + is_qat: bool, train_params: TrainingParameters = TrainingParameters(), ) -> Any: """Extract and train a model, and return the results.""" @@ -118,6 +120,7 @@ def train( baseline_dir=unmodified_model_dir_path, output_filename=Path(train_dir, "new.tflite"), rewrite=rewrite, + is_qat=is_qat, train_params=train_params, ) @@ -345,6 +348,7 @@ def train_in_dir( baseline_dir: Any, output_filename: Path, rewrite: Callable, + is_qat: bool, train_params: TrainingParameters = TrainingParameters(), ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ @@ -385,8 +389,9 @@ def train_in_dir( optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - model = rewrite.quantize(model, model_is_quantized) # type: ignore[attr-defined] - model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) + if model_is_quantized: + model = rewrite.quantize(model) # type: ignore[attr-defined] + model = model_compile(model, optimizer, loss_fn) logger.info(model.summary()) @@ -428,11 +433,82 @@ def train_in_dir( callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] - output_filenames = [] + output_filenames = [] # type: list[str] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps ] + model, output_filenames = model_fit( + model, + train_params, + checkpoints.copy(), + optimizer, + dataset, + callbacks, + output_filename, + rewrite, + replace, + input_name, + output_name, + model_is_quantized, + output_filenames, + ) + + if model_is_quantized and is_qat: + model = rewrite.pruning_preserved_quantization( # type: ignore[attr-defined] + model, + ) + optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) + model = model_compile(model, optimizer, loss_fn) + + callbacks.pop(-1) + output_filenames = [] + + model, output_filenames = model_fit( + model, + train_params, + checkpoints.copy(), + optimizer, + dataset, + callbacks, + output_filename, + rewrite, + replace, + input_name, + output_name, + model_is_quantized, + output_filenames, + ) + + teacher.close() + return output_filenames + + +def model_compile( + model: tf.keras.Model, optimizer: tf.keras.optimizers, loss_fn: tf.keras.losses +) -> tf.keras.Model: + """Compiles a tflite model.""" + model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) + return model + + +def model_fit( + model: tf.keras.Model, + train_params: TrainingParameters, + checkpoints: list, + optimizer: tf.optimizers.Nadam, + dataset: tf.data.Dataset, + callbacks: list, + output_filename: Path, + rewrite: Callable, + replace: TFLiteModel, + input_name: str, + output_name: str, + model_is_quantized: bool, + output_filenames: list, +) -> tuple[tf.keras.Model, list]: + """Train the model.""" + steps_so_far = 0 while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far lr_start = optimizer.learning_rate.numpy() @@ -460,15 +536,16 @@ def train_in_dir( ) else: checkpoint_filename = str(output_filename) + + if steps_so_far == train_params.steps: + model = rewrite.post_process(model) # type: ignore[attr-defined] + with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): - if steps_so_far == train_params.steps: - model = rewrite.post_process(model) # type: ignore[attr-defined] - save_as_tflite( model, - checkpoint_filename, + str(checkpoint_filename), input_name, replace.shape_from_name[input_name], output_name, @@ -476,9 +553,7 @@ def train_in_dir( model_is_quantized, ) output_filenames.append(checkpoint_filename) - - teacher.close() - return output_filenames + return model, output_filenames def save_as_tflite( -- cgit v1.2.1