diff options
author | Madeleine Dunn <madeleine.dunn@arm.com> | 2023-11-13 15:40:21 +0000 |
---|---|---|
committer | Madeleine Dunn <madeleine.dunn@arm.com> | 2024-04-03 16:33:39 +0100 |
commit | 17813ba5be09f0e11fc0748afa4ccf2da02881b6 (patch) | |
tree | 8ec5f3ce3501b86e9398cf5af6f7bd9876685512 /src/mlia/nn/rewrite/core/train.py | |
parent | 2a2a910d6d7cc3e7555b0a3c1ba458a4065c41ae (diff) | |
download | mlia-17813ba5be09f0e11fc0748afa4ccf2da02881b6.tar.gz |
feat: Implement fp32 sparsity 2:4 rewrite
- Update the existing placeholder with code to prune the given model
Resolves: MLIA-1002
Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com>
Change-Id: I76b0e0bfe81be5e57d518cd7bb588eef76a11641
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index e0b3c75..89de880 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -22,7 +22,6 @@ from typing import Literal import numpy as np import tensorflow as tf -import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from numpy.random import Generator @@ -78,7 +77,7 @@ def train( unmodified_model: Any, output_model: str, input_tfrec: str, - replace_fn: Callable, + rewrite: Callable, input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), @@ -118,7 +117,7 @@ def train( train_dir=train_dir, baseline_dir=unmodified_model_dir_path, output_filename=Path(train_dir, "new.tflite"), - replace_fn=replace_fn, + rewrite=rewrite, train_params=train_params, ) @@ -345,7 +344,7 @@ def train_in_dir( train_dir: str, baseline_dir: Any, output_filename: Path, - replace_fn: Callable, + rewrite: Callable, train_params: TrainingParameters = TrainingParameters(), ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ @@ -381,13 +380,12 @@ def train_in_dir( input_shape = teacher.shape_from_name[input_name][1:] output_shape = teacher.shape_from_name[output_name][1:] - - model = replace_fn(input_shape, output_shape) + model = rewrite(input_shape, output_shape) optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - if model_is_quantized: - model = tfmot.quantization.keras.quantize_model(model) + + model = rewrite.quantize(model, model_is_quantized) # type: ignore[attr-defined] model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) logger.info(model.summary()) @@ -428,6 +426,8 @@ def train_in_dir( elif train_params.learning_rate_schedule == "constant": callbacks = [] + callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] + output_filenames = [] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps @@ -463,6 +463,9 @@ def train_in_dir( with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): + if steps_so_far == train_params.steps: + model = rewrite.post_process(model) # type: ignore[attr-defined] + save_as_tflite( model, checkpoint_filename, |