aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r--src/mlia/nn/rewrite/core/train.py433
1 files changed, 249 insertions, 184 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 096daf4..f837964 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,35 +1,41 @@
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
+"""Sequential trainer."""
+# pylint: disable=too-many-arguments, too-many-instance-attributes,
+# pylint: disable=too-many-locals, too-many-branches, too-many-statements
+from __future__ import annotations
+
+import logging
import math
import os
import tempfile
from collections import defaultdict
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import get_args
+from typing import Literal
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+import numpy as np
import tensorflow as tf
+from numpy.random import Generator
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-
-try:
- from tensorflow.keras.optimizers.schedules import CosineDecay
-except ImportError:
- # In TF 2.4 CosineDecay was still experimental
- from tensorflow.keras.experimental import CosineDecay
-
-import numpy as np
-from mlia.nn.rewrite.core.utils.numpy_tfrecord import (
- NumpyTFReader,
- NumpyTFWriter,
- TFLiteModel,
- numpytf_count,
-)
-from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
-from mlia.nn.rewrite.core.graph_edit.record import record_model
-from mlia.nn.rewrite.core.utils.utils import load, save
from mlia.nn.rewrite.core.extract import extract
-from mlia.nn.rewrite.core.graph_edit.join import join_models
from mlia.nn.rewrite.core.graph_edit.diff import diff_stats
+from mlia.nn.rewrite.core.graph_edit.join import join_models
+from mlia.nn.rewrite.core.graph_edit.record import record_model
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read
+from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel
+from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel
+from mlia.nn.rewrite.core.utils.utils import load
+from mlia.nn.rewrite.core.utils.utils import save
+from mlia.utils.logging import log_action
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+logger = logging.getLogger(__name__)
augmentation_presets = {
"none": (None, None),
@@ -40,31 +46,34 @@ augmentation_presets = {
"mix_gaussian_small": (1.6, 0.3),
}
-learning_rate_schedules = {"cosine", "late", "constant"}
+LearningRateSchedule = Literal["cosine", "late", "constant"]
+LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule)
def train(
- source_model,
- unmodified_model,
- output_model,
- input_tfrec,
- replace_fn,
- input_tensors,
- output_tensors,
- augment,
- steps,
- lr,
- batch_size,
- verbose,
- show_progress,
- learning_rate_schedule="cosine",
- checkpoint_at=None,
- checkpoint_decay_steps=0,
- num_procs=1,
- num_threads=0,
-):
+ source_model: str,
+ unmodified_model: Any,
+ output_model: str,
+ input_tfrec: str,
+ replace_fn: Callable,
+ input_tensors: list,
+ output_tensors: list,
+ augment: tuple[float | None, float | None],
+ steps: int,
+ learning_rate: float,
+ batch_size: int,
+ verbose: bool,
+ show_progress: bool,
+ learning_rate_schedule: LearningRateSchedule = "cosine",
+ checkpoint_at: list | None = None,
+ num_procs: int = 1,
+ num_threads: int = 0,
+) -> Any:
+ """Extract and train a model, and return the results."""
if unmodified_model:
- unmodified_model_dir = tempfile.TemporaryDirectory()
+ unmodified_model_dir = (
+ tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
+ )
unmodified_model_dir_path = unmodified_model_dir.name
extract(
unmodified_model_dir_path,
@@ -79,8 +88,6 @@ def train(
results = []
with tempfile.TemporaryDirectory() as train_dir:
- p = lambda file: os.path.join(train_dir, file)
-
extract(
train_dir,
source_model,
@@ -94,14 +101,13 @@ def train(
tflite_filenames = train_in_dir(
train_dir,
unmodified_model_dir_path,
- p("new.tflite"),
+ Path(train_dir, "new.tflite"),
replace_fn,
augment,
steps,
- lr,
+ learning_rate,
batch_size,
checkpoint_at=checkpoint_at,
- checkpoint_decay_steps=checkpoint_decay_steps,
verbose=verbose,
show_progress=show_progress,
num_procs=num_procs,
@@ -114,7 +120,8 @@ def train(
if output_model:
if i + 1 < len(tflite_filenames):
- # Append the same _@STEPS.tflite postfix used by intermediate checkpoints for all but the last output
+ # Append the same _@STEPS.tflite postfix used by intermediate
+ # checkpoints for all but the last output
postfix = filename.split("_@")[-1]
output_filename = output_model.split(".tflite")[0] + postfix
else:
@@ -122,115 +129,130 @@ def train(
join_in_dir(train_dir, filename, output_filename)
if unmodified_model_dir:
- unmodified_model_dir.cleanup()
+ cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup()
return (
results if checkpoint_at else results[0]
) # only return a list if multiple checkpoints are asked for
-def eval_in_dir(dir, new_part, num_procs=1, num_threads=0):
- p = lambda file: os.path.join(dir, file)
- input = (
- p("input_orig.tfrec")
- if os.path.exists(p("input_orig.tfrec"))
- else p("input.tfrec")
+def eval_in_dir(
+ target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0
+) -> tuple:
+ """Evaluate a model in a given directory."""
+ model_input_path = Path(target_dir, "input_orig.tfrec")
+ model_output_path = Path(target_dir, "output_orig.tfrec")
+ model_input = (
+ model_input_path
+ if model_input_path.exists()
+ else Path(target_dir, "input.tfrec")
)
output = (
- p("output_orig.tfrec")
- if os.path.exists(p("output_orig.tfrec"))
- else p("output.tfrec")
+ model_output_path
+ if model_output_path.exists()
+ else Path(target_dir, "output.tfrec")
)
with tempfile.TemporaryDirectory() as tmp_dir:
- predict = os.path.join(tmp_dir, "predict.tfrec")
+ predict = Path(tmp_dir, "predict.tfrec")
record_model(
- input, new_part, predict, num_procs=num_procs, num_threads=num_threads
+ str(model_input),
+ new_part,
+ str(predict),
+ num_procs=num_procs,
+ num_threads=num_threads,
)
- mae, nrmse = diff_stats(output, predict)
+ mae, nrmse = diff_stats(str(output), str(predict))
return mae, nrmse
-def join_in_dir(dir, new_part, output_model):
+def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None:
+ """Join two models in a given directory."""
with tempfile.TemporaryDirectory() as tmp_dir:
- d = lambda file: os.path.join(dir, file)
- new_end = os.path.join(tmp_dir, "new_end.tflite")
- join_models(new_part, d("end.tflite"), new_end)
- join_models(d("start.tflite"), new_end, output_model)
+ new_end = Path(tmp_dir, "new_end.tflite")
+ join_models(new_part, Path(model_dir, "end.tflite"), new_end)
+ join_models(Path(model_dir, "start.tflite"), new_end, output_model)
def train_in_dir(
- train_dir,
- baseline_dir,
- output_filename,
- replace_fn,
- augmentations,
- steps,
- lr=1e-3,
- batch_size=32,
- checkpoint_at=None,
- checkpoint_decay_steps=0,
- schedule="cosine",
- verbose=False,
- show_progress=False,
- num_procs=None,
- num_threads=1,
-):
- """Train a replacement for replace.tflite using the input.tfrec and output.tfrec in train_dir.
- If baseline_dir is provided, train the replacement to match baseline outputs for train_dir inputs.
- Result saved as new.tflite in train_dir.
+ train_dir: str,
+ baseline_dir: Any,
+ output_filename: Path,
+ replace_fn: Callable,
+ augmentations: tuple[float | None, float | None],
+ steps: int,
+ learning_rate: float = 1e-3,
+ batch_size: int = 32,
+ checkpoint_at: list | None = None,
+ schedule: str = "cosine",
+ verbose: bool = False,
+ show_progress: bool = False,
+ num_procs: int = 0,
+ num_threads: int = 1,
+) -> list:
+ """Train a replacement for replace.tflite using the input.tfrec \
+ and output.tfrec in train_dir.
+
+ If baseline_dir is provided, train the replacement to match baseline
+ outputs for train_dir inputs. Result saved as new.tflite in train_dir.
"""
teacher_dir = baseline_dir if baseline_dir else train_dir
teacher = ParallelTFLiteModel(
- "%s/replace.tflite" % teacher_dir, num_procs, num_threads, batch_size=batch_size
- )
- replace = TFLiteModel("%s/replace.tflite" % train_dir)
- assert len(teacher.input_tensors()) == 1, (
- "Can only train replacements with a single input tensor right now, found %s"
- % teacher.input_tensors()
- )
- assert len(teacher.output_tensors()) == 1, (
- "Can only train replacements with a single output tensor right now, found %s"
- % teacher.output_tensors()
+ f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size
)
+ replace = TFLiteModel(f"{train_dir}/replace.tflite")
+ assert (
+ len(teacher.input_tensors()) == 1
+ ), f"Can only train replacements with a single input tensor right now, \
+ found {teacher.input_tensors()}"
+
+ assert (
+ len(teacher.output_tensors()) == 1
+ ), f"Can only train replacements with a single output tensor right now, \
+ found {teacher.output_tensors()}"
+
input_name = teacher.input_tensors()[0]
output_name = teacher.output_tensors()[0]
assert len(teacher.shape_from_name) == len(
replace.shape_from_name
- ), "Baseline and train models must have the same number of inputs and outputs. Teacher: {}\nTrain dir: {}".format(
- teacher.shape_from_name, replace.shape_from_name
- )
+ ), f"Baseline and train models must have the same number of inputs and outputs. \
+ Teacher: {teacher.shape_from_name}\nTrain dir: {replace.shape_from_name}"
+
assert all(
tn == rn and (ts[1:] == rs[1:]).all()
for (tn, ts), (rn, rs) in zip(
teacher.shape_from_name.items(), replace.shape_from_name.items()
)
- ), "Baseline and train models must have the same input and output shapes for the subgraph being replaced. Teacher: {}\nTrain dir: {}".format(
- teacher.shape_from_name, replace.shape_from_name
- )
+ ), "Baseline and train models must have the same input and output shapes for the \
+ subgraph being replaced. Teacher: {teacher.shape_from_name}\n \
+ Train dir: {replace.shape_from_name}"
- input_filename = os.path.join(train_dir, "input.tfrec")
- total = numpytf_count(input_filename)
- dict_inputs = NumpyTFReader(input_filename)
+ input_filename = Path(train_dir, "input.tfrec")
+ total = numpytf_count(str(input_filename))
+ dict_inputs = numpytf_read(str(input_filename))
inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0))
if any(augmentations):
- # Map the teacher inputs here because the augmentation stage passes these through a TFLite model to get the outputs
- teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "input.tfrec")).map(
+ # Map the teacher inputs here because the augmentation stage passes these
+ # through a TFLite model to get the outputs
+ teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map(
lambda d: tf.squeeze(d[input_name], axis=0)
)
else:
- teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "output.tfrec")).map(
+ teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map(
lambda d: tf.squeeze(d[output_name], axis=0)
)
steps_per_epoch = math.ceil(total / batch_size)
epochs = int(math.ceil(steps / steps_per_epoch))
if verbose:
- print(
- "Training on %d items for %d steps (%d epochs with batch size %d)"
- % (total, epochs * steps_per_epoch, epochs, batch_size)
+ logger.info(
+ "Training on %d items for %d steps (%d epochs with batch size %d)",
+ total,
+ epochs * steps_per_epoch,
+ epochs,
+ batch_size,
)
dataset = tf.data.Dataset.zip((inputs, teacher_outputs))
@@ -240,13 +262,21 @@ def train_in_dir(
if any(augmentations):
augment_train, augment_teacher = augment_fn_twins(dict_inputs, augmentations)
- augment_fn = lambda train, teach: (
- augment_train({input_name: train})[input_name],
- teacher(augment_teacher({input_name: teach}))[output_name],
- )
+
+ def get_augment_results(
+ train: Any, teach: Any # pylint: disable=redefined-outer-name
+ ) -> tuple:
+ """Return results of train and teach based on augmentations."""
+ return (
+ augment_train({input_name: train})[input_name],
+ teacher(augment_teacher({input_name: teach}))[output_name],
+ )
+
dataset = dataset.map(
- lambda train, teach: tf.py_function(
- augment_fn, inp=[train, teach], Tout=[tf.float32, tf.float32]
+ lambda augment_train, augment_teach: tf.py_function(
+ get_augment_results,
+ inp=[augment_train, augment_teach],
+ Tout=[tf.float32, tf.float32],
)
)
@@ -256,7 +286,7 @@ def train_in_dir(
output_shape = teacher.shape_from_name[output_name][1:]
model = replace_fn(input_shape, output_shape)
- optimizer = tf.keras.optimizers.Nadam(learning_rate=lr)
+ optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate)
loss_fn = tf.keras.losses.MeanSquaredError()
model.compile(optimizer=optimizer, loss=loss_fn)
@@ -265,20 +295,26 @@ def train_in_dir(
steps_so_far = 0
- def cosine_decay(epoch_step, logs):
- """Cosine decay from lr at start of the run to zero at the end"""
+ def cosine_decay(
+ epoch_step: int, logs: Any # pylint: disable=unused-argument
+ ) -> None:
+ """Cosine decay from learning rate at start of the run to zero at the end."""
current_step = epoch_step + steps_so_far
- learning_rate = lr * (math.cos(math.pi * current_step / steps) + 1) / 2.0
- tf.keras.backend.set_value(optimizer.learning_rate, learning_rate)
+ cd_learning_rate = (
+ learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0
+ )
+ tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate)
- def late_decay(epoch_step, logs):
- """Constant until the last 20% of the run, then linear decay to zero"""
+ def late_decay(
+ epoch_step: int, logs: Any # pylint: disable=unused-argument
+ ) -> None:
+ """Constant until the last 20% of the run, then linear decay to zero."""
current_step = epoch_step + steps_so_far
steps_remaining = steps - current_step
decay_length = steps // 5
decay_fraction = min(steps_remaining, decay_length) / decay_length
- learning_rate = lr * decay_fraction
- tf.keras.backend.set_value(optimizer.learning_rate, learning_rate)
+ ld_learning_rate = learning_rate * decay_fraction
+ tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate)
if schedule == "cosine":
callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)]
@@ -287,9 +323,10 @@ def train_in_dir(
elif schedule == "constant":
callbacks = []
else:
- assert schedule not in learning_rate_schedules
+ assert schedule not in LEARNING_RATE_SCHEDULES
raise ValueError(
- f'LR schedule "{schedule}" not implemented - expected one of {learning_rate_schedules}.'
+ f'Learning rate schedule "{schedule}" not implemented - '
+ f"expected one of {LEARNING_RATE_SCHEDULES}."
)
output_filenames = []
@@ -305,53 +342,66 @@ def train_in_dir(
verbose=show_progress,
)
steps_so_far += steps_to_train
- print(
- "lr decayed from %f to %f over %d steps"
- % (lr_start, optimizer.learning_rate.numpy(), steps_to_train)
+ logger.info(
+ "lr decayed from %f to %f over %d steps",
+ lr_start,
+ optimizer.learning_rate.numpy(),
+ steps_to_train,
)
if steps_so_far < steps:
- filename, ext = os.path.splitext(output_filename)
- checkpoint_filename = filename + ("_@%d" % steps_so_far) + ext
+ filename, ext = Path(output_filename).parts[1:]
+ checkpoint_filename = filename + (f"_@{steps_so_far}") + ext
else:
- checkpoint_filename = output_filename
- print("%d/%d: Saved as %s" % (steps_so_far, steps, checkpoint_filename))
- save_as_tflite(
- model,
- checkpoint_filename,
- input_name,
- replace.shape_from_name[input_name],
- output_name,
- replace.shape_from_name[output_name],
- )
- output_filenames.append(checkpoint_filename)
+ checkpoint_filename = str(output_filename)
+ with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"):
+ save_as_tflite(
+ model,
+ checkpoint_filename,
+ input_name,
+ replace.shape_from_name[input_name],
+ output_name,
+ replace.shape_from_name[output_name],
+ )
+ output_filenames.append(checkpoint_filename)
teacher.close()
return output_filenames
def save_as_tflite(
- keras_model, filename, input_name, input_shape, output_name, output_shape
-):
+ keras_model: tf.keras.Model,
+ filename: str,
+ input_name: str,
+ input_shape: list,
+ output_name: str,
+ output_shape: list,
+) -> None:
+ """Save Keras model as TFLite file."""
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()
- with open(filename, "wb") as f:
- f.write(tflite_model)
+ with open(filename, "wb") as file:
+ file.write(tflite_model)
# Now fix the shapes and names to match those we expect
- fb = load(filename)
- i = fb.subgraphs[0].inputs[0]
- fb.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32)
- fb.subgraphs[0].tensors[i].name = input_name.encode("utf-8")
- o = fb.subgraphs[0].outputs[0]
- fb.subgraphs[0].tensors[o].shape = np.array(output_shape, dtype=np.int32)
- fb.subgraphs[0].tensors[o].name = output_name.encode("utf-8")
- save(fb, filename)
-
-
-def augment_fn_twins(inputs, augmentations):
- """Return a pair of twinned augmentation functions with the same sequence of random numbers"""
+ flatbuffer = load(filename)
+ i = flatbuffer.subgraphs[0].inputs[0]
+ flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32)
+ flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8")
+ output = flatbuffer.subgraphs[0].outputs[0]
+ flatbuffer.subgraphs[0].tensors[output].shape = np.array(
+ output_shape, dtype=np.int32
+ )
+ flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8")
+ save(flatbuffer, filename)
+
+
+def augment_fn_twins(
+ inputs: dict, augmentations: tuple[float | None, float | None]
+) -> Any:
+ """Return a pair of twinned augmentation functions with the same sequence \
+ of random numbers."""
seed = np.random.randint(2**32 - 1)
rng1 = np.random.default_rng(seed)
rng2 = np.random.default_rng(seed)
@@ -360,52 +410,67 @@ def augment_fn_twins(inputs, augmentations):
)
-def augment_fn(inputs, augmentations, rng):
+def augment_fn(
+ inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator
+) -> Any:
+ """Augmentation module."""
mixup_strength, gaussian_strength = augmentations
augments = []
if mixup_strength:
mixup_range = (0.5 - mixup_strength / 2, 0.5 + mixup_strength / 2)
- augment = lambda d: {
- k: mixup(rng, v.numpy(), mixup_range) for k, v in d.items()
- }
- augments.append(augment)
+
+ def mixup_augment(augment_dict: dict) -> dict:
+ return {
+ k: mixup(rng, v.numpy(), mixup_range) for k, v in augment_dict.items()
+ }
+
+ augments.append(mixup_augment)
if gaussian_strength:
values = defaultdict(list)
- for d in inputs.as_numpy_iterator():
- for k, v in d.items():
- values[k].append(v)
+ for numpy_dict in inputs.as_numpy_iterator():
+ for key, value in numpy_dict.items():
+ values[key].append(value)
noise_scale = {
k: np.std(v, axis=0).astype(np.float32) for k, v in values.items()
}
- augment = lambda d: {
- k: v
- + rng.standard_normal(v.shape).astype(np.float32)
- * gaussian_strength
- * noise_scale[k]
- for k, v in d.items()
- }
- augments.append(augment)
- if len(augments) == 0:
+ def gaussian_strength_augment(augment_dict: dict) -> dict:
+ return {
+ k: v
+ + rng.standard_normal(v.shape).astype(np.float32)
+ * gaussian_strength
+ * noise_scale[k]
+ for k, v in augment_dict.items()
+ }
+
+ augments.append(gaussian_strength_augment)
+
+ if len(augments) == 0: # pylint: disable=no-else-return
return lambda x: x
elif len(augments) == 1:
return augments[0]
elif len(augments) == 2:
return lambda x: augments[1](augments[0](x))
else:
- assert False, "Unexpected number of augmentation functions (%d)" % len(augments)
-
-
-def mixup(rng, batch, beta_range=(0.0, 1.0)):
- """Each tensor in the batch becomes a linear combination of it and one other tensor"""
- a = batch
- b = np.array(batch)
- rng.shuffle(b) # randomly pair up tensors in the batch
+ assert (
+ False
+ ), f"Unexpected number of augmentation \
+ functions ({len(augments)})"
+
+
+def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any:
+ """Each tensor in the batch becomes a linear combination of it \
+ and one other tensor."""
+ batch_a = batch
+ batch_b = np.array(batch)
+ rng.shuffle(batch_b) # randomly pair up tensors in the batch
# random mixing coefficient for each pair
beta = rng.uniform(
low=beta_range[0], high=beta_range[1], size=batch.shape[0]
).astype(np.float32)
- return (a.T * beta).T + (b.T * (1.0 - beta)).T # return linear combinations
+ return (batch_a.T * beta).T + (
+ batch_b.T * (1.0 - beta)
+ ).T # return linear combinations