aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2023-07-19 16:35:57 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2023-10-11 16:06:17 +0100
commit3cd84481fa25e64c29e57396d4bf32d7a3ca490a (patch)
treead81fb520a965bd3a3c7c983833b7cd48f9b8dea /src/mlia/nn/rewrite/core/train.py
parentf3e6597dd50ec70f043d692b773f2d9fd31519ae (diff)
downloadmlia-3cd84481fa25e64c29e57396d4bf32d7a3ca490a.tar.gz
Bug-fixes and re-factoring for the rewrite module
- Fix input shape of rewrite replacement: During and after training of the replacement model for a rewrite the Keras model is converted and saved in TensorFlow Lite format. If the input shape does not match the teacher model exactly, e.g. if the batch size is undefined, the TFLiteConverter adds extra operators during conversion. - Fix rewritten model output - Save the model output with the rewritten operator in the output dir - Log MAE and NRMSE of the rewrite - Remove 'verbose' flag from rewrite module and rely on the logging mechanism to control verbose output. - Re-factor utility classes for rewrites - Merge the two TFLiteModel classes - Move functionality to load/save TensorFlow Lite flatbuffers to nn/tensorflow/tflite_graph - Fix issue with unknown shape in datasets After upgrading to TensorFlow 2.12 the unknown shape of the TFRecordDataset is causing problems when training the replacement models for rewrites. By explicitly setting the right shape of the tensors we can work around the issue. - Adapt default parameters for rewrites. The training steps especially had to be increased significantly to be effective. Resolves: MLIA-895, MLIA-907, MLIA-946, MLIA-979 Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com> Change-Id: I887ad165aed0f2c6e5a0041f64cec5e6c5ab5c5c
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py306
1 files changed, 190 insertions, 116 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index c8497a4..42bf653 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,8 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
-# pylint: disable=too-many-arguments, too-many-instance-attributes,
-# pylint: disable=too-many-locals, too-many-branches, too-many-statements
+# pylint: disable=too-many-locals
+# pylint: disable=too-many-statements
from __future__ import annotations
import logging
@@ -10,10 +10,13 @@ import math
import os
import tempfile
from collections import defaultdict
+from contextlib import contextmanager
+from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
from typing import cast
+from typing import Generator as GeneratorType
from typing import get_args
from typing import Literal
@@ -27,10 +30,10 @@ from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.record import record_model
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
-from mlia.nn.rewrite.core.utils.utils import load
-from mlia.nn.rewrite.core.utils.utils import save
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.tflite_graph import load_fb
+from mlia.nn.tensorflow.tflite_graph import save_fb
from mlia.utils.logging import log_action
@@ -38,7 +41,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger(__name__)
-augmentation_presets = {
+AUGMENTATION_PRESETS = {
"none": (None, None),
"gaussian": (None, 1.0),
"mixup": (1.0, None),
@@ -51,6 +54,21 @@ LearningRateSchedule = Literal["cosine", "late", "constant"]
LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
+@dataclass
+class TrainingParameters:
+ """Define default parameters for the training."""
+
+ augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"]
+ batch_size: int = 32
+ steps: int = 48000
+ learning_rate: float = 1e-3
+ learning_rate_schedule: LearningRateSchedule = "cosine"
+ num_procs: int = 1
+ num_threads: int = 0
+ show_progress: bool = True
+ checkpoint_at: list | None = None
+
+
def train(
source_model: str,
unmodified_model: Any,
@@ -59,16 +77,7 @@ def train(
replace_fn: Callable,
input_tensors: list,
output_tensors: list,
- augment: tuple[float | None, float | None],
- steps: int,
- learning_rate: float,
- batch_size: int,
- verbose: bool,
- show_progress: bool,
- learning_rate_schedule: LearningRateSchedule = "cosine",
- checkpoint_at: list | None = None,
- num_procs: int = 1,
- num_threads: int = 0,
+ train_params: TrainingParameters = TrainingParameters(),
) -> Any:
"""Extract and train a model, and return the results."""
if unmodified_model:
@@ -95,29 +104,27 @@ def train(
input_tfrec,
input_tensors,
output_tensors,
- num_procs=num_procs,
- num_threads=num_threads,
+ num_procs=train_params.num_procs,
+ num_threads=train_params.num_threads,
)
tflite_filenames = train_in_dir(
- train_dir,
- unmodified_model_dir_path,
- Path(train_dir, "new.tflite"),
- replace_fn,
- augment,
- steps,
- learning_rate,
- batch_size,
- checkpoint_at=checkpoint_at,
- verbose=verbose,
- show_progress=show_progress,
- num_procs=num_procs,
- num_threads=num_threads,
- schedule=learning_rate_schedule,
+ train_dir=train_dir,
+ baseline_dir=unmodified_model_dir_path,
+ output_filename=Path(train_dir, "new.tflite"),
+ replace_fn=replace_fn,
+ train_params=train_params,
)
for i, filename in enumerate(tflite_filenames):
- results.append(eval_in_dir(train_dir, filename, num_procs, num_threads))
+ results.append(
+ eval_in_dir(
+ train_dir,
+ filename,
+ train_params.num_procs,
+ train_params.num_threads,
+ )
+ )
if output_model:
if i + 1 < len(tflite_filenames):
@@ -133,7 +140,7 @@ def train(
cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
return (
- results if checkpoint_at else results[0]
+ results if train_params.checkpoint_at else results[0]
) # only return a list if multiple checkpoints are asked for
@@ -176,46 +183,24 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None:
join_models(Path(model_dir, "start.tflite"), new_end, output_model)
-def train_in_dir(
- train_dir: str,
- baseline_dir: Any,
- output_filename: Path,
- replace_fn: Callable,
- augmentations: tuple[float | None, float | None],
- steps: int,
- learning_rate: float = 1e-3,
- batch_size: int = 32,
- checkpoint_at: list | None = None,
- schedule: str = "cosine",
- verbose: bool = False,
- show_progress: bool = False,
- num_procs: int = 0,
- num_threads: int = 1,
-) -> list:
- """Train a replacement for replace.tflite using the input.tfrec \
- and output.tfrec in train_dir.
-
- If baseline_dir is provided, train the replacement to match baseline
- outputs for train_dir inputs. Result saved as new.tflite in train_dir.
- """
- teacher_dir = baseline_dir if baseline_dir else train_dir
- teacher = ParallelTFLiteModel(
- f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size
- )
- replace = TFLiteModel(f"{train_dir}/replace.tflite")
+def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]:
assert (
- len(teacher.input_tensors()) == 1
+ len(model.input_tensors()) == 1
), f"Can only train replacements with a single input tensor right now, \
- found {teacher.input_tensors()}"
+ found {model.input_tensors()}"
assert (
- len(teacher.output_tensors()) == 1
+ len(model.output_tensors()) == 1
), f"Can only train replacements with a single output tensor right now, \
- found {teacher.output_tensors()}"
+ found {model.output_tensors()}"
+
+ input_name = model.input_tensors()[0]
+ output_name = model.output_tensors()[0]
+ return (input_name, output_name)
- input_name = teacher.input_tensors()[0]
- output_name = teacher.output_tensors()[0]
+def _check_model_compatibility(teacher: TFLiteModel, replace: TFLiteModel) -> None:
+ """Assert that teacher and replaced sub-graph are compatible."""
assert len(teacher.shape_from_name) == len(
replace.shape_from_name
), f"Baseline and train models must have the same number of inputs and outputs. \
@@ -230,10 +215,37 @@ def train_in_dir(
subgraph being replaced. Teacher: {teacher.shape_from_name}\n \
Train dir: {replace.shape_from_name}"
+
+def set_up_data_pipeline(
+ teacher: TFLiteModel,
+ replace: TFLiteModel,
+ train_dir: str,
+ augmentations: tuple[float | None, float | None],
+ steps: int,
+ batch_size: int = 32,
+) -> tf.data.Dataset:
+ """Create a data pipeline for training of the replacement model."""
+ _check_model_compatibility(teacher, replace)
+
+ input_name, output_name = _get_io_tensors(teacher)
+
input_filename = Path(train_dir, "input.tfrec")
total = numpytf_count(str(input_filename))
dict_inputs = numpytf_read(str(input_filename))
+
inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0))
+
+ steps_per_epoch = math.ceil(total / batch_size)
+ epochs = int(math.ceil(steps / steps_per_epoch))
+ logger.info(
+ "Training on %d items for %d steps (%d epochs with batch size %d)",
+ total,
+ epochs * steps_per_epoch,
+ epochs,
+ batch_size,
+ )
+
+ teacher_dir = Path(teacher.model_path).parent
if any(augmentations):
# Map the teacher inputs here because the augmentation stage passes these
# through a TFLite model to get the outputs
@@ -245,17 +257,6 @@ def train_in_dir(
lambda d: tf.squeeze(d[output_name], axis=0)
)
- steps_per_epoch = math.ceil(total / batch_size)
- epochs = int(math.ceil(steps / steps_per_epoch))
- if verbose:
- logger.info(
- "Training on %d items for %d steps (%d epochs with batch size %d)",
- total,
- epochs * steps_per_epoch,
- epochs,
- batch_size,
- )
-
dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
if epochs > 1:
dataset = dataset.cache()
@@ -268,10 +269,9 @@ def train_in_dir(
train: Any, teach: Any # pylint: disable=redefined-outer-name
) -> tuple:
"""Return results of train and teach based on augmentations."""
- return (
- augment_train({input_name: train})[input_name],
- teacher(augment_teacher({input_name: teach}))[output_name],
- )
+ augmented_train = augment_train({input_name: train})[input_name]
+ augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name]
+ return (augmented_train, augmented_teach)
dataset = dataset.map(
lambda augment_train, augment_teach: tf.py_function(
@@ -281,18 +281,67 @@ def train_in_dir(
)
)
+ # Restore data shapes of the dataset as they are set to unknown per default
+ # and get lost during augmentation with tf.py_function.
+ shape_in, shape_out = (
+ teacher.shape_from_name[name].tolist() for name in (input_name, output_name)
+ )
+ for shape in (shape_in, shape_out):
+ shape[0] = None # set dynamic batch size
+
+ def restore_shapes(input_: Any, output: Any) -> tuple[Any, Any]:
+ input_.set_shape(shape_in)
+ output.set_shape(shape_out)
+ return input_, output
+
+ dataset = dataset.map(restore_shapes)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
+ return dataset
+
+
+def train_in_dir(
+ train_dir: str,
+ baseline_dir: Any,
+ output_filename: Path,
+ replace_fn: Callable,
+ train_params: TrainingParameters = TrainingParameters(),
+) -> list[str]:
+ """Train a replacement for replace.tflite using the input.tfrec \
+ and output.tfrec in train_dir.
+
+ If baseline_dir is provided, train the replacement to match baseline
+ outputs for train_dir inputs. Result saved as new.tflite in train_dir.
+ """
+ teacher_dir = baseline_dir if baseline_dir else train_dir
+ teacher = ParallelTFLiteModel(
+ f"{teacher_dir}/replace.tflite",
+ train_params.num_procs,
+ train_params.num_threads,
+ batch_size=train_params.batch_size,
+ )
+ replace = TFLiteModel(f"{train_dir}/replace.tflite")
+
+ input_name, output_name = _get_io_tensors(teacher)
+
+ dataset = set_up_data_pipeline(
+ teacher,
+ replace,
+ train_dir,
+ augmentations=train_params.augmentations,
+ steps=train_params.steps,
+ batch_size=train_params.batch_size,
+ )
input_shape = teacher.shape_from_name[input_name][1:]
output_shape = teacher.shape_from_name[output_name][1:]
+
model = replace_fn(input_shape, output_shape)
- optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate)
+ optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = tf.keras.losses.MeanSquaredError()
- model.compile(optimizer=optimizer, loss=loss_fn)
+ model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
- if verbose:
- model.summary()
+ logger.info(model.summary())
steps_so_far = 0
@@ -302,7 +351,9 @@ def train_in_dir(
"""Cosine decay from learning rate at start of the run to zero at the end."""
current_step = epoch_step + steps_so_far
cd_learning_rate = (
- learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0
+ train_params.learning_rate
+ * (math.cos(math.pi * current_step / train_params.steps) + 1)
+ / 2.0
)
tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
@@ -311,28 +362,29 @@ def train_in_dir(
) -> None:
"""Constant until the last 20% of the run, then linear decay to zero."""
current_step = epoch_step + steps_so_far
- steps_remaining = steps - current_step
- decay_length = steps // 5
+ steps_remaining = train_params.steps - current_step
+ decay_length = train_params.steps // 5
decay_fraction = min(steps_remaining, decay_length) / decay_length
- ld_learning_rate = learning_rate * decay_fraction
+ ld_learning_rate = train_params.learning_rate * decay_fraction
tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
- if schedule == "cosine":
+ assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, (
+ f'Learning rate schedule "{train_params.learning_rate_schedule}" '
+ f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}."
+ )
+ if train_params.learning_rate_schedule == "cosine":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
- elif schedule == "late":
+ elif train_params.learning_rate_schedule == "late":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)]
- elif schedule == "constant":
+ elif train_params.learning_rate_schedule == "constant":
callbacks = []
- else:
- assert schedule not in LEARNING_RATE_SCHEDULES
- raise ValueError(
- f'Learning rate schedule "{schedule}" not implemented - '
- f"expected one of {LEARNING_RATE_SCHEDULES}."
- )
output_filenames = []
- checkpoints = (checkpoint_at if checkpoint_at else []) + [steps]
- while steps_so_far < steps:
+ checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
+ train_params.steps
+ ]
+
+ while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
lr_start = optimizer.learning_rate.numpy()
model.fit(
@@ -340,7 +392,7 @@ def train_in_dir(
epochs=1,
steps_per_epoch=steps_to_train,
callbacks=callbacks,
- verbose=show_progress,
+ verbose=train_params.show_progress,
)
steps_so_far += steps_to_train
logger.info(
@@ -350,12 +402,14 @@ def train_in_dir(
steps_to_train,
)
- if steps_so_far < steps:
+ if steps_so_far < train_params.steps:
filename, ext = Path(output_filename).parts[1:]
checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
else:
checkpoint_filename = str(output_filename)
- with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"):
+ with log_action(
+ f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
+ ):
save_as_tflite(
model,
checkpoint_filename,
@@ -379,14 +433,30 @@ def save_as_tflite(
output_shape: list,
) -> None:
"""Save Keras model as TFLite file."""
- converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+
+ @contextmanager
+ def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType:
+ """Fix the input shape of the Keras model temporarily.
+
+ This avoids artifacts during conversion to TensorFlow Lite.
+ """
+ orig_shape = keras_model.input.shape
+ keras_model.input.set_shape(tf.TensorShape(tmp_shape))
+ try:
+ yield keras_model
+ finally:
+ # Restore original shape to not interfere with further training
+ keras_model.input.set_shape(orig_shape)
+
+ with fixed_input(keras_model, input_shape) as fixed_model:
+ converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model)
tflite_model = converter.convert()
with open(filename, "wb") as file:
file.write(tflite_model)
# Now fix the shapes and names to match those we expect
- flatbuffer = load(filename)
+ flatbuffer = load_fb(filename)
i = flatbuffer.subgraphs[0].inputs[0]
flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32)
flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8")
@@ -395,11 +465,11 @@ def save_as_tflite(
output_shape, dtype=np.int32
)
flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8")
- save(flatbuffer, filename)
+ save_fb(flatbuffer, filename)
def augment_fn_twins(
- inputs: dict, augmentations: tuple[float | None, float | None]
+ inputs: tf.data.Dataset, augmentations: tuple[float | None, float | None]
) -> Any:
"""Return a pair of twinned augmentation functions with the same sequence \
of random numbers."""
@@ -415,6 +485,11 @@ def augment_fn(
inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator
) -> Any:
"""Augmentation module."""
+ assert len(augmentations) == 2, (
+ f"Unexpected number of augmentation parameters: {len(augmentations)} "
+ "(must be 2)"
+ )
+
mixup_strength, gaussian_strength = augmentations
augments = []
@@ -449,17 +524,16 @@ def augment_fn(
augments.append(gaussian_strength_augment)
- if len(augments) == 0: # pylint: disable=no-else-return
+ if len(augments) == 0:
return lambda x: x
- elif len(augments) == 1:
+ if len(augments) == 1:
return augments[0]
- elif len(augments) == 2:
+ if len(augments) == 2:
return lambda x: augments[1](augments[0](x))
- else:
- assert (
- False
- ), f"Unexpected number of augmentation \
- functions ({len(augments)})"
+
+ raise RuntimeError(
+ "Unexpected number of augmentation functions (must be <=2): " f"{len(augments)}"
+ )
def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any: