aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index e0b3c75..89de880 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -22,7 +22,6 @@ from typing import Literal
import numpy as np
import tensorflow as tf
-import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from numpy.random import Generator
@@ -78,7 +77,7 @@ def train(
unmodified_model: Any,
output_model: str,
input_tfrec: str,
- replace_fn: Callable,
+ rewrite: Callable,
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
@@ -118,7 +117,7 @@ def train(
train_dir=train_dir,
baseline_dir=unmodified_model_dir_path,
output_filename=Path(train_dir, "new.tflite"),
- replace_fn=replace_fn,
+ rewrite=rewrite,
train_params=train_params,
)
@@ -345,7 +344,7 @@ def train_in_dir(
train_dir: str,
baseline_dir: Any,
output_filename: Path,
- replace_fn: Callable,
+ rewrite: Callable,
train_params: TrainingParameters = TrainingParameters(),
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
@@ -381,13 +380,12 @@ def train_in_dir(
input_shape = teacher.shape_from_name[input_name][1:]
output_shape = teacher.shape_from_name[output_name][1:]
-
- model = replace_fn(input_shape, output_shape)
+ model = rewrite(input_shape, output_shape)
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = keras.losses.MeanSquaredError()
- if model_is_quantized:
- model = tfmot.quantization.keras.quantize_model(model)
+
+ model = rewrite.quantize(model, model_is_quantized) # type: ignore[attr-defined]
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
logger.info(model.summary())
@@ -428,6 +426,8 @@ def train_in_dir(
elif train_params.learning_rate_schedule == "constant":
callbacks = []
+ callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined]
+
output_filenames = []
checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
train_params.steps
@@ -463,6 +463,9 @@ def train_in_dir(
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):
+ if steps_so_far == train_params.steps:
+ model = rewrite.post_process(model) # type: ignore[attr-defined]
+
save_as_tflite(
model,
checkpoint_filename,