aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py306
1 files changed, 190 insertions, 116 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index c8497a4..42bf653 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,8 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
-# pylint: disable=too-many-arguments, too-many-instance-attributes,
-# pylint: disable=too-many-locals, too-many-branches, too-many-statements
+# pylint: disable=too-many-locals
+# pylint: disable=too-many-statements
from __future__ import annotations
import logging
@@ -10,10 +10,13 @@ import math
import os
import tempfile
from collections import defaultdict
+from contextlib import contextmanager
+from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
from typing import cast
+from typing import Generator as GeneratorType
from typing import get_args
from typing import Literal
@@ -27,10 +30,10 @@ from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.utils.logging import log_action
@@ -38,7 +41,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
-augmentation_presets = {
+AUGMENTATION_PRESETS = {
"none": (None, None),
"gaussian": (None, 1.0),
"mixup": (1.0, None),
@@ -51,6 +54,21 @@ LearningRateSchedule = Literal["cosine", "late", "constant"]
LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
+@dataclass
+class TrainingParameters:
+ """Define default parameters for the training."""
+
+ augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"]
+ batch_size: int = 32
+ steps: int = 48000
+ learning_rate: float = 1e-3
+ learning_rate_schedule: LearningRateSchedule = "cosine"
+ num_procs: int = 1
+ num_threads: int = 0
+ show_progress: bool = True
+ checkpoint_at: list | None = None
+
+
def train(
source_model: str,
unmodified_model: Any,
@@ -59,16 +77,7 @@ def train(
replace_fn: Callable,
input_tensors: list,
output_tensors: list,
- augment: tuple[float | None, float | None],
- steps: int,
- learning_rate: float,
- batch_size: int,
- verbose: bool,
- show_progress: bool,
- learning_rate_schedule: LearningRateSchedule = "cosine",
- checkpoint_at: list | None = None,
- num_procs: int = 1,
- num_threads: int = 0,
+ train_params: TrainingParameters = TrainingParameters(),
) -> Any:
"""Extract and train a model, and return the results."""
if unmodified_model:
@@ -95,29 +104,27 @@ def train(
input_tfrec,
input_tensors,
output_tensors,
- num_procs=num_procs,
- num_threads=num_threads,
+ num_procs=train_params.num_procs,
+ num_threads=train_params.num_threads,
)
tflite_filenames = train_in_dir(
- train_dir,
- unmodified_model_dir_path,
- Path(train_dir, "new.tflite"),
- replace_fn,
- augment,
- steps,
- learning_rate,
- batch_size,
- checkpoint_at=checkpoint_at,
- verbose=verbose,
- show_progress=show_progress,
- num_procs=num_procs,
- num_threads=num_threads,
- schedule=learning_rate_schedule,
+ train_dir=train_dir,
+ baseline_dir=unmodified_model_dir_path,
+ output_filename=Path(train_dir, "new.tflite"),
+ replace_fn=replace_fn,
+ train_params=train_params,
)
for i, filename in enumerate(tflite_filenames):
- results.append(eval_in_dir(train_dir, filename, num_procs, num_threads))
+ results.append(
+ eval_in_dir(
+ train_dir,
+ filename,
+ train_params.num_procs,
+ train_params.num_threads,
+ )
+ )
if output_model:
if i + 1 < len(tflite_filenames):
@@ -133,7 +140,7 @@ def train(
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
return (
- results if checkpoint_at else results[0]
+ results if train_params.checkpoint_at else results[0]
) # only return a list if multiple checkpoints are asked for
@@ -176,46 +183,24 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None:
join_models(Path(model_dir, "start.tflite"), new_end, output_model)
-def train_in_dir(
- train_dir: str,
- baseline_dir: Any,
- output_filename: Path,
- replace_fn: Callable,
- augmentations: tuple[float | None, float | None],
- steps: int,
- learning_rate: float = 1e-3,
- batch_size: int = 32,
- checkpoint_at: list | None = None,
- schedule: str = "cosine",
- verbose: bool = False,
- show_progress: bool = False,
- num_procs: int = 0,
- num_threads: int = 1,
-) -> list:
- """Train a replacement for replace.tflite using the input.tfrec \
- and output.tfrec in train_dir.
-
- If baseline_dir is provided, train the replacement to match baseline
- outputs for train_dir inputs. Result saved as new.tflite in train_dir.
- """
- teacher_dir = baseline_dir if baseline_dir else train_dir
- teacher = ParallelTFLiteModel(
- f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size
- )
- replace = TFLiteModel(f"{train_dir}/replace.tflite")
+def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]:
assert (
- len(teacher.input_tensors()) == 1
+ len(model.input_tensors()) == 1
), f"Can only train replacements with a single input tensor right now, \
- found {teacher.input_tensors()}"
+ found {model.input_tensors()}"
assert (
- len(teacher.output_tensors()) == 1
+ len(model.output_tensors()) == 1
), f"Can only train replacements with a single output tensor right now, \
- found {teacher.output_tensors()}"
+ found {model.output_tensors()}"
+
+ input_name = model.input_tensors()[0]
+ output_name = model.output_tensors()[0]
+ return (input_name, output_name)
- input_name = teacher.input_tensors()[0]
- output_name = teacher.output_tensors()[0]
+def _check_model_compatibility(teacher: TFLiteModel, replace: TFLiteModel) -> None:
+ """Assert that teacher and replaced sub-graph are compatible."""
assert len(teacher.shape_from_name) == len(
replace.shape_from_name
), f"Baseline and train models must have the same number of inputs and outputs. \
@@ -230,10 +215,37 @@ def train_in_dir(
subgraph being replaced. Teacher: {teacher.shape_from_name}\n \
Train dir: {replace.shape_from_name}"
+
+def set_up_data_pipeline(
+ teacher: TFLiteModel,
+ replace: TFLiteModel,
+ train_dir: str,
+ augmentations: tuple[float | None, float | None],
+ steps: int,
+ batch_size: int = 32,
+) -> tf.data.Dataset:
+ """Create a data pipeline for training of the replacement model."""
+ _check_model_compatibility(teacher, replace)
+
+ input_name, output_name = _get_io_tensors(teacher)
+
input_filename = Path(train_dir, "input.tfrec")
total = numpytf_count(str(input_filename))
dict_inputs = numpytf_read(str(input_filename))
+
inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0))
+
+ steps_per_epoch = math.ceil(total / batch_size)
+ epochs = int(math.ceil(steps / steps_per_epoch))
+ logger.info(
+ "Training on %d items for %d steps (%d epochs with batch size %d)",
+ total,
+ epochs * steps_per_epoch,
+ epochs,
+ batch_size,
+ )
+
+ teacher_dir = Path(teacher.model_path).parent
if any(augmentations):
# Map the teacher inputs here because the augmentation stage passes these
# through a TFLite model to get the outputs
@@ -245,17 +257,6 @@ def train_in_dir(
lambda d: tf.squeeze(d[output_name], axis=0)
)
- steps_per_epoch = math.ceil(total / batch_size)
- epochs = int(math.ceil(steps / steps_per_epoch))
- if verbose:
- logger.info(
- "Training on %d items for %d steps (%d epochs with batch size %d)",
- total,
- epochs * steps_per_epoch,
- epochs,
- batch_size,
- )
-
dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
if epochs > 1:
dataset = dataset.cache()
@@ -268,10 +269,9 @@ def train_in_dir(
train: Any, teach: Any # pylint: disable=redefined-outer-name
) -> tuple:
"""Return results of train and teach based on augmentations."""
- return (
- augment_train({input_name: train})[input_name],
- teacher(augment_teacher({input_name: teach}))[output_name],
- )
+ augmented_train = augment_train({input_name: train})[input_name]
+ augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name]
+ return (augmented_train, augmented_teach)
dataset = dataset.map(
lambda augment_train, augment_teach: tf.py_function(
@@ -281,18 +281,67 @@ def train_in_dir(
)
)
+ # Restore data shapes of the dataset as they are set to unknown per default
+ # and get lost during augmentation with tf.py_function.
+ shape_in, shape_out = (
+ teacher.shape_from_name[name].tolist() for name in (input_name, output_name)
+ )
+ for shape in (shape_in, shape_out):
+ shape[0] = None # set dynamic batch size
+
+ def restore_shapes(input_: Any, output: Any) -> tuple[Any, Any]:
+ input_.set_shape(shape_in)
+ output.set_shape(shape_out)
+ return input_, output
+
+ dataset = dataset.map(restore_shapes)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
+ return dataset
+
+
+def train_in_dir(
+ train_dir: str,
+ baseline_dir: Any,
+ output_filename: Path,
+ replace_fn: Callable,
+ train_params: TrainingParameters = TrainingParameters(),
+) -> list[str]:
+ """Train a replacement for replace.tflite using the input.tfrec \
+ and output.tfrec in train_dir.
+
+ If baseline_dir is provided, train the replacement to match baseline
+ outputs for train_dir inputs. Result saved as new.tflite in train_dir.
+ """
+ teacher_dir = baseline_dir if baseline_dir else train_dir
+ teacher = ParallelTFLiteModel(
+ f"{teacher_dir}/replace.tflite",
+ train_params.num_procs,
+ train_params.num_threads,
+ batch_size=train_params.batch_size,
+ )
+ replace = TFLiteModel(f"{train_dir}/replace.tflite")
+
+ input_name, output_name = _get_io_tensors(teacher)
+
+ dataset = set_up_data_pipeline(
+ teacher,
+ replace,
+ train_dir,
+ augmentations=train_params.augmentations,
+ steps=train_params.steps,
+ batch_size=train_params.batch_size,
+ )
input_shape = teacher.shape_from_name[input_name][1:]
output_shape = teacher.shape_from_name[output_name][1:]
+
model = replace_fn(input_shape, output_shape)
- optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate)
+ optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = tf.keras.losses.MeanSquaredError()
- model.compile(optimizer=optimizer, loss=loss_fn)
+ model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
- if verbose:
- model.summary()
+ logger.info(model.summary())
steps_so_far = 0
@@ -302,7 +351,9 @@ def train_in_dir(
"""Cosine decay from learning rate at start of the run to zero at the end."""
current_step = epoch_step + steps_so_far
cd_learning_rate = (
- learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0
+ train_params.learning_rate
+ * (math.cos(math.pi * current_step / train_params.steps) + 1)
+ / 2.0
)
tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
@@ -311,28 +362,29 @@ def train_in_dir(
) -> None:
"""Constant until the last 20% of the run, then linear decay to zero."""
current_step = epoch_step + steps_so_far
- steps_remaining = steps - current_step
- decay_length = steps // 5
+ steps_remaining = train_params.steps - current_step
+ decay_length = train_params.steps // 5
decay_fraction = min(steps_remaining, decay_length) / decay_length
- ld_learning_rate = learning_rate * decay_fraction
+ ld_learning_rate = train_params.learning_rate * decay_fraction
tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
- if schedule == "cosine":
+ assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, (
+ f'Learning rate schedule "{train_params.learning_rate_schedule}" '
+ f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}."
+ )
+ if train_params.learning_rate_schedule == "cosine":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
- elif schedule == "late":
+ elif train_params.learning_rate_schedule == "late":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)]
- elif schedule == "constant":
+ elif train_params.learning_rate_schedule == "constant":
callbacks = []
- else:
- assert schedule not in LEARNING_RATE_SCHEDULES
- raise ValueError(
- f'Learning rate schedule "{schedule}" not implemented - '
- f"expected one of {LEARNING_RATE_SCHEDULES}."
- )
output_filenames = []
- checkpoints = (checkpoint_at if checkpoint_at else []) + [steps]
- while steps_so_far < steps:
+ checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
+ train_params.steps
+ ]
+
+ while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
lr_start = optimizer.learning_rate.numpy()
model.fit(
@@ -340,7 +392,7 @@ def train_in_dir(
epochs=1,
steps_per_epoch=steps_to_train,
callbacks=callbacks,
- verbose=show_progress,
+ verbose=train_params.show_progress,
)
steps_so_far += steps_to_train
logger.info(
@@ -350,12 +402,14 @@ def train_in_dir(
steps_to_train,
)
- if steps_so_far < steps:
+ if steps_so_far < train_params.steps:
filename, ext = Path(output_filename).parts[1:]
checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
else:
checkpoint_filename = str(output_filename)
- with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"):
+ with log_action(
+ f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
+ ):
save_as_tflite(
model,
checkpoint_filename,
@@ -379,14 +433,30 @@ def save_as_tflite(
output_shape: list,
) -> None:
"""Save Keras model as TFLite file."""
- converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+
+ @contextmanager
+ def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType:
+ """Fix the input shape of the Keras model temporarily.
+
+ This avoids artifacts during conversion to TensorFlow Lite.
+ """
+ orig_shape = keras_model.input.shape
+ keras_model.input.set_shape(tf.TensorShape(tmp_shape))
+ try:
+ yield keras_model
+ finally:
+ # Restore original shape to not interfere with further training
+ keras_model.input.set_shape(orig_shape)
+
+ with fixed_input(keras_model, input_shape) as fixed_model:
+ converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model)
tflite_model = converter.convert()
with open(filename, "wb") as file:
file.write(tflite_model)
# Now fix the shapes and names to match those we expect
- flatbuffer = load(filename)
+ flatbuffer = load_fb(filename)
i = flatbuffer.subgraphs[0].inputs[0]
flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32)
flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8")
@@ -395,11 +465,11 @@ def save_as_tflite(
output_shape, dtype=np.int32
)
flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8")
- save(flatbuffer, filename)
+ save_fb(flatbuffer, filename)
def augment_fn_twins(
- inputs: dict, augmentations: tuple[float | None, float | None]
+ inputs: tf.data.Dataset, augmentations: tuple[float | None, float | None]
) -> Any:
"""Return a pair of twinned augmentation functions with the same sequence \
of random numbers."""
@@ -415,6 +485,11 @@ def augment_fn(
inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator
) -> Any:
"""Augmentation module."""
+ assert len(augmentations) == 2, (
+ f"Unexpected number of augmentation parameters: {len(augmentations)} "
+ "(must be 2)"
+ )
+
mixup_strength, gaussian_strength = augmentations
augments = []
@@ -449,17 +524,16 @@ def augment_fn(
augments.append(gaussian_strength_augment)
- if len(augments) == 0: # pylint: disable=no-else-return
+ if len(augments) == 0:
return lambda x: x
- elif len(augments) == 1:
+ if len(augments) == 1:
return augments[0]
- elif len(augments) == 2:
+ if len(augments) == 2:
return lambda x: augments[1](augments[0](x))
- else:
- assert (
- False
- ), f"Unexpected number of augmentation \
- functions ({len(augments)})"
+
+ raise RuntimeError(
+ "Unexpected number of augmentation functions (must be <=2): " f"{len(augments)}"
+ )
def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any: