aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py95
1 files changed, 85 insertions, 10 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 89de880..4b9821c 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
+# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
from __future__ import annotations
@@ -80,6 +81,7 @@ def train(
rewrite: Callable,
input_tensors: list,
output_tensors: list,
+ is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
) -> Any:
"""Extract and train a model, and return the results."""
@@ -118,6 +120,7 @@ def train(
baseline_dir=unmodified_model_dir_path,
output_filename=Path(train_dir, "new.tflite"),
rewrite=rewrite,
+ is_qat=is_qat,
train_params=train_params,
)
@@ -345,6 +348,7 @@ def train_in_dir(
baseline_dir: Any,
output_filename: Path,
rewrite: Callable,
+ is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
@@ -385,8 +389,9 @@ def train_in_dir(
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = keras.losses.MeanSquaredError()
- model = rewrite.quantize(model, model_is_quantized) # type: ignore[attr-defined]
- model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
+ if model_is_quantized:
+ model = rewrite.quantize(model) # type: ignore[attr-defined]
+ model = model_compile(model, optimizer, loss_fn)
logger.info(model.summary())
@@ -428,11 +433,82 @@ def train_in_dir(
callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined]
- output_filenames = []
+ output_filenames = [] # type: list[str]
checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
train_params.steps
]
+ model, output_filenames = model_fit(
+ model,
+ train_params,
+ checkpoints.copy(),
+ optimizer,
+ dataset,
+ callbacks,
+ output_filename,
+ rewrite,
+ replace,
+ input_name,
+ output_name,
+ model_is_quantized,
+ output_filenames,
+ )
+
+ if model_is_quantized and is_qat:
+ model = rewrite.pruning_preserved_quantization( # type: ignore[attr-defined]
+ model,
+ )
+ optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
+ model = model_compile(model, optimizer, loss_fn)
+
+ callbacks.pop(-1)
+ output_filenames = []
+
+ model, output_filenames = model_fit(
+ model,
+ train_params,
+ checkpoints.copy(),
+ optimizer,
+ dataset,
+ callbacks,
+ output_filename,
+ rewrite,
+ replace,
+ input_name,
+ output_name,
+ model_is_quantized,
+ output_filenames,
+ )
+
+ teacher.close()
+ return output_filenames
+
+
+def model_compile(
+ model: tf.keras.Model, optimizer: tf.keras.optimizers, loss_fn: tf.keras.losses
+) -> tf.keras.Model:
+ """Compiles a tflite model."""
+ model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
+ return model
+
+
+def model_fit(
+ model: tf.keras.Model,
+ train_params: TrainingParameters,
+ checkpoints: list,
+ optimizer: tf.optimizers.Nadam,
+ dataset: tf.data.Dataset,
+ callbacks: list,
+ output_filename: Path,
+ rewrite: Callable,
+ replace: TFLiteModel,
+ input_name: str,
+ output_name: str,
+ model_is_quantized: bool,
+ output_filenames: list,
+) -> tuple[tf.keras.Model, list]:
+ """Train the model."""
+ steps_so_far = 0
while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
lr_start = optimizer.learning_rate.numpy()
@@ -460,15 +536,16 @@ def train_in_dir(
)
else:
checkpoint_filename = str(output_filename)
+
+ if steps_so_far == train_params.steps:
+ model = rewrite.post_process(model) # type: ignore[attr-defined]
+
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):
- if steps_so_far == train_params.steps:
- model = rewrite.post_process(model) # type: ignore[attr-defined]
-
save_as_tflite(
model,
- checkpoint_filename,
+ str(checkpoint_filename),
input_name,
replace.shape_from_name[input_name],
output_name,
@@ -476,9 +553,7 @@ def train_in_dir(
model_is_quantized,
)
output_filenames.append(checkpoint_filename)
-
- teacher.close()
- return output_filenames
+ return model, output_filenames
def save_as_tflite(