aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
authorMadeleine Dunn <madeleine.dunn@arm.com>2023-11-13 15:40:21 +0000
committerMadeleine Dunn <madeleine.dunn@arm.com>2024-04-03 16:33:39 +0100
commit17813ba5be09f0e11fc0748afa4ccf2da02881b6 (patch)
tree8ec5f3ce3501b86e9398cf5af6f7bd9876685512 /src/mlia/nn/rewrite/core/train.py
parent2a2a910d6d7cc3e7555b0a3c1ba458a4065c41ae (diff)
downloadmlia-17813ba5be09f0e11fc0748afa4ccf2da02881b6.tar.gz
feat: Implement fp32 sparsity 2:4 rewrite
- Update the existing placeholder with code to prune the given model Resolves: MLIA-1002 Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com> Change-Id: I76b0e0bfe81be5e57d518cd7bb588eef76a11641
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index e0b3c75..89de880 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -22,7 +22,6 @@ from typing import Literal
import numpy as np
import tensorflow as tf
-import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from numpy.random import Generator
@@ -78,7 +77,7 @@ def train(
unmodified_model: Any,
output_model: str,
input_tfrec: str,
- replace_fn: Callable,
+ rewrite: Callable,
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
@@ -118,7 +117,7 @@ def train(
train_dir=train_dir,
baseline_dir=unmodified_model_dir_path,
output_filename=Path(train_dir, "new.tflite"),
- replace_fn=replace_fn,
+ rewrite=rewrite,
train_params=train_params,
)
@@ -345,7 +344,7 @@ def train_in_dir(
train_dir: str,
baseline_dir: Any,
output_filename: Path,
- replace_fn: Callable,
+ rewrite: Callable,
train_params: TrainingParameters = TrainingParameters(),
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
@@ -381,13 +380,12 @@ def train_in_dir(
input_shape = teacher.shape_from_name[input_name][1:]
output_shape = teacher.shape_from_name[output_name][1:]
-
- model = replace_fn(input_shape, output_shape)
+ model = rewrite(input_shape, output_shape)
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = keras.losses.MeanSquaredError()
- if model_is_quantized:
- model = tfmot.quantization.keras.quantize_model(model)
+
+ model = rewrite.quantize(model, model_is_quantized) # type: ignore[attr-defined]
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
logger.info(model.summary())
@@ -428,6 +426,8 @@ def train_in_dir(
elif train_params.learning_rate_schedule == "constant":
callbacks = []
+ callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined]
+
output_filenames = []
checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
train_params.steps
@@ -463,6 +463,9 @@ def train_in_dir(
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):
+ if steps_so_far == train_params.steps:
+ model = rewrite.post_process(model) # type: ignore[attr-defined]
+
save_as_tflite(
model,
checkpoint_filename,