aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py209
1 files changed, 182 insertions, 27 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 60c39ae..4204978 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
+# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
from __future__ import annotations
@@ -22,7 +23,6 @@ from typing import Literal
import numpy as np
import tensorflow as tf
-import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from numpy.random import Generator
@@ -62,7 +62,7 @@ LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
class TrainingParameters:
"""Define default parameters for the training."""
- augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"]
+ augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["none"]
batch_size: int = 32
steps: int = 48000
learning_rate: float = 1e-3
@@ -73,12 +73,13 @@ class TrainingParameters:
checkpoint_at: list | None = None
-def train(
+def train( # pylint: disable=too-many-arguments
source_model: str,
unmodified_model: Any,
output_model: str,
input_tfrec: str,
- replace_fn: Callable,
+ rewrite: Callable,
+ is_qat: bool,
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
@@ -118,7 +119,8 @@ def train(
train_dir=train_dir,
baseline_dir=unmodified_model_dir_path,
output_filename=Path(train_dir, "new.tflite"),
- replace_fn=replace_fn,
+ rewrite=rewrite,
+ is_qat=is_qat,
train_params=train_params,
)
@@ -145,7 +147,8 @@ def train(
# Assess the output diff between the parts after the rewrite subgraph
# in original and optimized model
optimized_end_path = Path(train_dir, "optimized_end.tfrec")
- end_path = Path(train_dir, "end.tfrec")
+ optimized_end_path_dequant = Path(train_dir, "optimized_end_dequant.tfrec")
+ end_path = Path(train_dir, "end_dequant.tfrec")
record_model(
str(input_tfrec),
@@ -153,16 +156,18 @@ def train(
optimized_end_path,
num_procs=train_params.num_procs,
num_threads=train_params.num_threads,
+ dequantize_output=True,
)
- mae, nrmse = diff_stats(end_path, str(optimized_end_path))
+
+ mae, nrmse = diff_stats(end_path, optimized_end_path_dequant)
if unmodified_model_dir:
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
- return (results if train_params.checkpoint_at else results[0]), [
+ return results, [
mae,
nrmse,
- ] # only return a list if multiple checkpoints are asked for
+ ]
def eval_in_dir(
@@ -177,24 +182,27 @@ def eval_in_dir(
model_input = (
model_input_path
if model_input_path.exists()
- else ExtractPaths.tfrec.input(target_dir, False)
+ else ExtractPaths.tfrec.input(target_dir, True)
)
output = (
model_output_path
if model_output_path.exists()
- else ExtractPaths.tfrec.output(target_dir, False)
+ else ExtractPaths.tfrec.output(target_dir, True)
)
with tempfile.TemporaryDirectory() as tmp_dir:
predict = Path(tmp_dir, "predict.tfrec")
+ predict_dequant = Path(tmp_dir, "predict_dequant.tfrec")
record_model(
str(model_input),
new_part,
str(predict),
num_procs=num_procs,
num_threads=num_threads,
+ dequantize_output=True,
+ quantize_input=True,
)
- mae, nrmse = diff_stats(str(output), str(predict))
+ mae, nrmse = diff_stats(str(output), predict_dequant)
return mae, nrmse
@@ -247,7 +255,7 @@ def set_up_data_pipeline(
augmentations: tuple[float | None, float | None],
steps: int,
batch_size: int = 32,
-) -> tf.data.Dataset:
+) -> tuple[tf.data.Dataset, int]:
"""Create a data pipeline for training of the replacement model."""
_check_model_compatibility(teacher, replace)
@@ -338,14 +346,15 @@ def set_up_data_pipeline(
dataset = dataset.map(restore_shapes)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
- return dataset
+ return dataset, steps_per_epoch
def train_in_dir(
train_dir: str,
baseline_dir: Any,
output_filename: Path,
- replace_fn: Callable,
+ rewrite: Callable,
+ is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
@@ -370,7 +379,7 @@ def train_in_dir(
if model_is_quantized:
replace.check_datatypes(np.int8)
- dataset = set_up_data_pipeline(
+ dataset, steps_per_epoch = set_up_data_pipeline(
teacher,
replace,
train_dir,
@@ -380,15 +389,15 @@ def train_in_dir(
)
input_shape = teacher.shape_from_name[input_name][1:]
- output_shape = teacher.shape_from_name[output_name][1:]
- model = replace_fn(input_shape, output_shape)
+ output_shape = teacher.shape_from_name[output_name][1:]
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = keras.losses.MeanSquaredError()
- if model_is_quantized:
- model = tfmot.quantization.keras.quantize_model(model)
- model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
+
+ model = create_model(
+ rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized
+ )
logger.info(model.summary())
@@ -428,11 +437,130 @@ def train_in_dir(
elif train_params.learning_rate_schedule == "constant":
callbacks = []
- output_filenames = []
+ callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined]
+ output_filenames: list = []
checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
train_params.steps
]
+ model, output_filenames = model_fit(
+ model,
+ train_params,
+ checkpoints,
+ optimizer,
+ dataset,
+ callbacks,
+ output_filename,
+ rewrite,
+ replace,
+ input_name,
+ output_name,
+ model_is_quantized,
+ output_filenames,
+ input_shape,
+ output_shape,
+ loss_fn,
+ steps_per_epoch,
+ post_process=True,
+ )
+
+ # Placeholder for now, will be parametrized later (MLIA-1114)
+ # rewrite.check_optimization( # type: ignore[attr-defined]
+ # model, number_of_clusters=32
+ # )
+ if model_is_quantized and is_qat:
+ model = rewrite.preserved_quantize(model) # type: ignore[attr-defined]
+ checkpoints = (
+ train_params.checkpoint_at if train_params.checkpoint_at else []
+ ) + [train_params.steps]
+ output_filenames = []
+
+ if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined]
+ rewrite.training_callbacks() # type: ignore[attr-defined]
+ ).issubset(callbacks):
+ callbacks.pop(-1)
+
+ optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
+ model = model_compile(model, optimizer, loss_fn)
+
+ model, output_filenames = model_fit(
+ model,
+ train_params,
+ checkpoints,
+ optimizer,
+ dataset,
+ callbacks,
+ output_filename,
+ rewrite,
+ replace,
+ input_name,
+ output_name,
+ model_is_quantized,
+ output_filenames,
+ input_shape,
+ output_shape,
+ loss_fn,
+ steps_per_epoch,
+ )
+ # Placeholder for now, will be parametrized later (MLIA-1114)
+ # rewrite.check_optimization( # type: ignore[attr-defined]
+ # model, number_of_clusters=32
+ # )
+
+ teacher.close()
+ return output_filenames
+
+def model_compile(
+ model: keras.Model,
+ optimizer: keras.optimizers.Nadam,
+ loss_fn: keras.losses.Loss,
+) -> keras.Model:
+ """Compiles a tflite model."""
+ model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
+ return model
+
+
+def create_model( # pylint: disable=too-many-arguments
+ rewrite: Callable,
+ input_shape: int,
+ output_shape: int,
+ optimizer: Callable,
+ loss_fn: Callable,
+ model_is_quantized: bool,
+ model_to_load_from: keras.model | None = None,
+) -> keras.Model:
+ """Create a model, optionally from another."""
+ model = rewrite(input_shape, output_shape)
+ if model_is_quantized:
+ model = rewrite.quantize(model) # type: ignore[attr-defined]
+ model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn)
+ if model_to_load_from:
+ model.set_weights(model_to_load_from.get_weights())
+ return model
+
+
+def model_fit( # pylint: disable=too-many-arguments
+ model: keras.Model,
+ train_params: TrainingParameters,
+ checkpoints: list,
+ optimizer: tf.optimizers.Nadam,
+ dataset: tf.data.Dataset,
+ callbacks: list,
+ output_filename: Path,
+ rewrite: Callable,
+ replace: TFLiteModel,
+ input_name: str,
+ output_name: str,
+ model_is_quantized: bool,
+ output_filenames: list,
+ input_shape: int,
+ output_shape: int,
+ loss_fn: Callable,
+ steps_per_epoch: int,
+ post_process: bool = False,
+) -> keras.Model:
+ """Train a tflite model."""
+ steps_so_far = 0
while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
lr_start = optimizer.learning_rate.numpy()
@@ -452,15 +580,43 @@ def train_in_dir(
)
if steps_so_far < train_params.steps:
- filename, ext = Path(output_filename).parts[1:]
- checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
+ filename = Path(output_filename).stem
+ filename_dir = Path(output_filename).parent.as_posix()
+ ext = Path(output_filename).suffix
+ checkpoint_filename = (
+ filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext
+ )
+ # If post processing we are stripping the clustering/pruning layers below
+ # Thus copy the model before saving, so training can continue
+ if post_process:
+ model_to_save = create_model(
+ rewrite,
+ input_shape,
+ output_shape,
+ optimizer,
+ loss_fn,
+ model_is_quantized,
+ model_to_load_from=model,
+ )
+ else:
+ model_to_save = model
else:
checkpoint_filename = str(output_filename)
+ logger.info("Evaluate final Keras Model using %d steps", steps_per_epoch)
+ model.evaluate(
+ dataset,
+ steps=steps_per_epoch,
+ )
+ model_to_save = model
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):
+ if post_process:
+ model_to_save = rewrite.post_process( # type: ignore[attr-defined]
+ model_to_save
+ )
save_as_tflite(
- model,
+ model_to_save,
checkpoint_filename,
input_name,
replace.shape_from_name[input_name],
@@ -470,8 +626,7 @@ def train_in_dir(
)
output_filenames.append(checkpoint_filename)
- teacher.close()
- return output_filenames
+ return model_to_save, output_filenames
def save_as_tflite(