diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 433 |
1 files changed, 249 insertions, 184 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 096daf4..f837964 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,35 +1,41 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 +"""Sequential trainer.""" +# pylint: disable=too-many-arguments, too-many-instance-attributes, +# pylint: disable=too-many-locals, too-many-branches, too-many-statements +from __future__ import annotations + +import logging import math import os import tempfile from collections import defaultdict +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import get_args +from typing import Literal -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import numpy as np import tensorflow as tf +from numpy.random import Generator -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) - -try: - from tensorflow.keras.optimizers.schedules import CosineDecay -except ImportError: - # In TF 2.4 CosineDecay was still experimental - from tensorflow.keras.experimental import CosineDecay - -import numpy as np -from mlia.nn.rewrite.core.utils.numpy_tfrecord import ( - NumpyTFReader, - NumpyTFWriter, - TFLiteModel, - numpytf_count, -) -from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel -from mlia.nn.rewrite.core.graph_edit.record import record_model -from mlia.nn.rewrite.core.utils.utils import load, save from mlia.nn.rewrite.core.extract import extract -from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.graph_edit.diff import diff_stats +from mlia.nn.rewrite.core.graph_edit.join import join_models +from mlia.nn.rewrite.core.graph_edit.record import record_model +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read +from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel +from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.rewrite.core.utils.utils import save +from mlia.utils.logging import log_action +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +logger = logging.getLogger(__name__) augmentation_presets = { "none": (None, None), @@ -40,31 +46,34 @@ augmentation_presets = { "mix_gaussian_small": (1.6, 0.3), } -learning_rate_schedules = {"cosine", "late", "constant"} +LearningRateSchedule = Literal["cosine", "late", "constant"] +LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule) def train( - source_model, - unmodified_model, - output_model, - input_tfrec, - replace_fn, - input_tensors, - output_tensors, - augment, - steps, - lr, - batch_size, - verbose, - show_progress, - learning_rate_schedule="cosine", - checkpoint_at=None, - checkpoint_decay_steps=0, - num_procs=1, - num_threads=0, -): + source_model: str, + unmodified_model: Any, + output_model: str, + input_tfrec: str, + replace_fn: Callable, + input_tensors: list, + output_tensors: list, + augment: tuple[float | None, float | None], + steps: int, + learning_rate: float, + batch_size: int, + verbose: bool, + show_progress: bool, + learning_rate_schedule: LearningRateSchedule = "cosine", + checkpoint_at: list | None = None, + num_procs: int = 1, + num_threads: int = 0, +) -> Any: + """Extract and train a model, and return the results.""" if unmodified_model: - unmodified_model_dir = tempfile.TemporaryDirectory() + unmodified_model_dir = ( + tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + ) unmodified_model_dir_path = unmodified_model_dir.name extract( unmodified_model_dir_path, @@ -79,8 +88,6 @@ def train( results = [] with tempfile.TemporaryDirectory() as train_dir: - p = lambda file: os.path.join(train_dir, file) - extract( train_dir, source_model, @@ -94,14 +101,13 @@ def train( tflite_filenames = train_in_dir( train_dir, unmodified_model_dir_path, - p("new.tflite"), + Path(train_dir, "new.tflite"), replace_fn, augment, steps, - lr, + learning_rate, batch_size, checkpoint_at=checkpoint_at, - checkpoint_decay_steps=checkpoint_decay_steps, verbose=verbose, show_progress=show_progress, num_procs=num_procs, @@ -114,7 +120,8 @@ def train( if output_model: if i + 1 < len(tflite_filenames): - # Append the same _@STEPS.tflite postfix used by intermediate checkpoints for all but the last output + # Append the same _@STEPS.tflite postfix used by intermediate + # checkpoints for all but the last output postfix = filename.split("_@")[-1] output_filename = output_model.split(".tflite")[0] + postfix else: @@ -122,115 +129,130 @@ def train( join_in_dir(train_dir, filename, output_filename) if unmodified_model_dir: - unmodified_model_dir.cleanup() + cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() return ( results if checkpoint_at else results[0] ) # only return a list if multiple checkpoints are asked for -def eval_in_dir(dir, new_part, num_procs=1, num_threads=0): - p = lambda file: os.path.join(dir, file) - input = ( - p("input_orig.tfrec") - if os.path.exists(p("input_orig.tfrec")) - else p("input.tfrec") +def eval_in_dir( + target_dir: str, new_part: str, num_procs: int = 1, num_threads: int = 0 +) -> tuple: + """Evaluate a model in a given directory.""" + model_input_path = Path(target_dir, "input_orig.tfrec") + model_output_path = Path(target_dir, "output_orig.tfrec") + model_input = ( + model_input_path + if model_input_path.exists() + else Path(target_dir, "input.tfrec") ) output = ( - p("output_orig.tfrec") - if os.path.exists(p("output_orig.tfrec")) - else p("output.tfrec") + model_output_path + if model_output_path.exists() + else Path(target_dir, "output.tfrec") ) with tempfile.TemporaryDirectory() as tmp_dir: - predict = os.path.join(tmp_dir, "predict.tfrec") + predict = Path(tmp_dir, "predict.tfrec") record_model( - input, new_part, predict, num_procs=num_procs, num_threads=num_threads + str(model_input), + new_part, + str(predict), + num_procs=num_procs, + num_threads=num_threads, ) - mae, nrmse = diff_stats(output, predict) + mae, nrmse = diff_stats(str(output), str(predict)) return mae, nrmse -def join_in_dir(dir, new_part, output_model): +def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None: + """Join two models in a given directory.""" with tempfile.TemporaryDirectory() as tmp_dir: - d = lambda file: os.path.join(dir, file) - new_end = os.path.join(tmp_dir, "new_end.tflite") - join_models(new_part, d("end.tflite"), new_end) - join_models(d("start.tflite"), new_end, output_model) + new_end = Path(tmp_dir, "new_end.tflite") + join_models(new_part, Path(model_dir, "end.tflite"), new_end) + join_models(Path(model_dir, "start.tflite"), new_end, output_model) def train_in_dir( - train_dir, - baseline_dir, - output_filename, - replace_fn, - augmentations, - steps, - lr=1e-3, - batch_size=32, - checkpoint_at=None, - checkpoint_decay_steps=0, - schedule="cosine", - verbose=False, - show_progress=False, - num_procs=None, - num_threads=1, -): - """Train a replacement for replace.tflite using the input.tfrec and output.tfrec in train_dir. - If baseline_dir is provided, train the replacement to match baseline outputs for train_dir inputs. - Result saved as new.tflite in train_dir. + train_dir: str, + baseline_dir: Any, + output_filename: Path, + replace_fn: Callable, + augmentations: tuple[float | None, float | None], + steps: int, + learning_rate: float = 1e-3, + batch_size: int = 32, + checkpoint_at: list | None = None, + schedule: str = "cosine", + verbose: bool = False, + show_progress: bool = False, + num_procs: int = 0, + num_threads: int = 1, +) -> list: + """Train a replacement for replace.tflite using the input.tfrec \ + and output.tfrec in train_dir. + + If baseline_dir is provided, train the replacement to match baseline + outputs for train_dir inputs. Result saved as new.tflite in train_dir. """ teacher_dir = baseline_dir if baseline_dir else train_dir teacher = ParallelTFLiteModel( - "%s/replace.tflite" % teacher_dir, num_procs, num_threads, batch_size=batch_size - ) - replace = TFLiteModel("%s/replace.tflite" % train_dir) - assert len(teacher.input_tensors()) == 1, ( - "Can only train replacements with a single input tensor right now, found %s" - % teacher.input_tensors() - ) - assert len(teacher.output_tensors()) == 1, ( - "Can only train replacements with a single output tensor right now, found %s" - % teacher.output_tensors() + f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size ) + replace = TFLiteModel(f"{train_dir}/replace.tflite") + assert ( + len(teacher.input_tensors()) == 1 + ), f"Can only train replacements with a single input tensor right now, \ + found {teacher.input_tensors()}" + + assert ( + len(teacher.output_tensors()) == 1 + ), f"Can only train replacements with a single output tensor right now, \ + found {teacher.output_tensors()}" + input_name = teacher.input_tensors()[0] output_name = teacher.output_tensors()[0] assert len(teacher.shape_from_name) == len( replace.shape_from_name - ), "Baseline and train models must have the same number of inputs and outputs. Teacher: {}\nTrain dir: {}".format( - teacher.shape_from_name, replace.shape_from_name - ) + ), f"Baseline and train models must have the same number of inputs and outputs. \ + Teacher: {teacher.shape_from_name}\nTrain dir: {replace.shape_from_name}" + assert all( tn == rn and (ts[1:] == rs[1:]).all() for (tn, ts), (rn, rs) in zip( teacher.shape_from_name.items(), replace.shape_from_name.items() ) - ), "Baseline and train models must have the same input and output shapes for the subgraph being replaced. Teacher: {}\nTrain dir: {}".format( - teacher.shape_from_name, replace.shape_from_name - ) + ), "Baseline and train models must have the same input and output shapes for the \ + subgraph being replaced. Teacher: {teacher.shape_from_name}\n \ + Train dir: {replace.shape_from_name}" - input_filename = os.path.join(train_dir, "input.tfrec") - total = numpytf_count(input_filename) - dict_inputs = NumpyTFReader(input_filename) + input_filename = Path(train_dir, "input.tfrec") + total = numpytf_count(str(input_filename)) + dict_inputs = numpytf_read(str(input_filename)) inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0)) if any(augmentations): - # Map the teacher inputs here because the augmentation stage passes these through a TFLite model to get the outputs - teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "input.tfrec")).map( + # Map the teacher inputs here because the augmentation stage passes these + # through a TFLite model to get the outputs + teacher_outputs = numpytf_read(str(Path(teacher_dir, "input.tfrec"))).map( lambda d: tf.squeeze(d[input_name], axis=0) ) else: - teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "output.tfrec")).map( + teacher_outputs = numpytf_read(str(Path(teacher_dir, "output.tfrec"))).map( lambda d: tf.squeeze(d[output_name], axis=0) ) steps_per_epoch = math.ceil(total / batch_size) epochs = int(math.ceil(steps / steps_per_epoch)) if verbose: - print( - "Training on %d items for %d steps (%d epochs with batch size %d)" - % (total, epochs * steps_per_epoch, epochs, batch_size) + logger.info( + "Training on %d items for %d steps (%d epochs with batch size %d)", + total, + epochs * steps_per_epoch, + epochs, + batch_size, ) dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) @@ -240,13 +262,21 @@ def train_in_dir( if any(augmentations): augment_train, augment_teacher = augment_fn_twins(dict_inputs, augmentations) - augment_fn = lambda train, teach: ( - augment_train({input_name: train})[input_name], - teacher(augment_teacher({input_name: teach}))[output_name], - ) + + def get_augment_results( + train: Any, teach: Any # pylint: disable=redefined-outer-name + ) -> tuple: + """Return results of train and teach based on augmentations.""" + return ( + augment_train({input_name: train})[input_name], + teacher(augment_teacher({input_name: teach}))[output_name], + ) + dataset = dataset.map( - lambda train, teach: tf.py_function( - augment_fn, inp=[train, teach], Tout=[tf.float32, tf.float32] + lambda augment_train, augment_teach: tf.py_function( + get_augment_results, + inp=[augment_train, augment_teach], + Tout=[tf.float32, tf.float32], ) ) @@ -256,7 +286,7 @@ def train_in_dir( output_shape = teacher.shape_from_name[output_name][1:] model = replace_fn(input_shape, output_shape) - optimizer = tf.keras.optimizers.Nadam(learning_rate=lr) + optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate) loss_fn = tf.keras.losses.MeanSquaredError() model.compile(optimizer=optimizer, loss=loss_fn) @@ -265,20 +295,26 @@ def train_in_dir( steps_so_far = 0 - def cosine_decay(epoch_step, logs): - """Cosine decay from lr at start of the run to zero at the end""" + def cosine_decay( + epoch_step: int, logs: Any # pylint: disable=unused-argument + ) -> None: + """Cosine decay from learning rate at start of the run to zero at the end.""" current_step = epoch_step + steps_so_far - learning_rate = lr * (math.cos(math.pi * current_step / steps) + 1) / 2.0 - tf.keras.backend.set_value(optimizer.learning_rate, learning_rate) + cd_learning_rate = ( + learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0 + ) + tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate) - def late_decay(epoch_step, logs): - """Constant until the last 20% of the run, then linear decay to zero""" + def late_decay( + epoch_step: int, logs: Any # pylint: disable=unused-argument + ) -> None: + """Constant until the last 20% of the run, then linear decay to zero.""" current_step = epoch_step + steps_so_far steps_remaining = steps - current_step decay_length = steps // 5 decay_fraction = min(steps_remaining, decay_length) / decay_length - learning_rate = lr * decay_fraction - tf.keras.backend.set_value(optimizer.learning_rate, learning_rate) + ld_learning_rate = learning_rate * decay_fraction + tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate) if schedule == "cosine": callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)] @@ -287,9 +323,10 @@ def train_in_dir( elif schedule == "constant": callbacks = [] else: - assert schedule not in learning_rate_schedules + assert schedule not in LEARNING_RATE_SCHEDULES raise ValueError( - f'LR schedule "{schedule}" not implemented - expected one of {learning_rate_schedules}.' + f'Learning rate schedule "{schedule}" not implemented - ' + f"expected one of {LEARNING_RATE_SCHEDULES}." ) output_filenames = [] @@ -305,53 +342,66 @@ def train_in_dir( verbose=show_progress, ) steps_so_far += steps_to_train - print( - "lr decayed from %f to %f over %d steps" - % (lr_start, optimizer.learning_rate.numpy(), steps_to_train) + logger.info( + "lr decayed from %f to %f over %d steps", + lr_start, + optimizer.learning_rate.numpy(), + steps_to_train, ) if steps_so_far < steps: - filename, ext = os.path.splitext(output_filename) - checkpoint_filename = filename + ("_@%d" % steps_so_far) + ext + filename, ext = Path(output_filename).parts[1:] + checkpoint_filename = filename + (f"_@{steps_so_far}") + ext else: - checkpoint_filename = output_filename - print("%d/%d: Saved as %s" % (steps_so_far, steps, checkpoint_filename)) - save_as_tflite( - model, - checkpoint_filename, - input_name, - replace.shape_from_name[input_name], - output_name, - replace.shape_from_name[output_name], - ) - output_filenames.append(checkpoint_filename) + checkpoint_filename = str(output_filename) + with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"): + save_as_tflite( + model, + checkpoint_filename, + input_name, + replace.shape_from_name[input_name], + output_name, + replace.shape_from_name[output_name], + ) + output_filenames.append(checkpoint_filename) teacher.close() return output_filenames def save_as_tflite( - keras_model, filename, input_name, input_shape, output_name, output_shape -): + keras_model: tf.keras.Model, + filename: str, + input_name: str, + input_shape: list, + output_name: str, + output_shape: list, +) -> None: + """Save Keras model as TFLite file.""" converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) tflite_model = converter.convert() - with open(filename, "wb") as f: - f.write(tflite_model) + with open(filename, "wb") as file: + file.write(tflite_model) # Now fix the shapes and names to match those we expect - fb = load(filename) - i = fb.subgraphs[0].inputs[0] - fb.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32) - fb.subgraphs[0].tensors[i].name = input_name.encode("utf-8") - o = fb.subgraphs[0].outputs[0] - fb.subgraphs[0].tensors[o].shape = np.array(output_shape, dtype=np.int32) - fb.subgraphs[0].tensors[o].name = output_name.encode("utf-8") - save(fb, filename) - - -def augment_fn_twins(inputs, augmentations): - """Return a pair of twinned augmentation functions with the same sequence of random numbers""" + flatbuffer = load(filename) + i = flatbuffer.subgraphs[0].inputs[0] + flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32) + flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8") + output = flatbuffer.subgraphs[0].outputs[0] + flatbuffer.subgraphs[0].tensors[output].shape = np.array( + output_shape, dtype=np.int32 + ) + flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8") + save(flatbuffer, filename) + + +def augment_fn_twins( + inputs: dict, augmentations: tuple[float | None, float | None] +) -> Any: + """Return a pair of twinned augmentation functions with the same sequence \ + of random numbers.""" seed = np.random.randint(2**32 - 1) rng1 = np.random.default_rng(seed) rng2 = np.random.default_rng(seed) @@ -360,52 +410,67 @@ def augment_fn_twins(inputs, augmentations): ) -def augment_fn(inputs, augmentations, rng): +def augment_fn( + inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator +) -> Any: + """Augmentation module.""" mixup_strength, gaussian_strength = augmentations augments = [] if mixup_strength: mixup_range = (0.5 - mixup_strength / 2, 0.5 + mixup_strength / 2) - augment = lambda d: { - k: mixup(rng, v.numpy(), mixup_range) for k, v in d.items() - } - augments.append(augment) + + def mixup_augment(augment_dict: dict) -> dict: + return { + k: mixup(rng, v.numpy(), mixup_range) for k, v in augment_dict.items() + } + + augments.append(mixup_augment) if gaussian_strength: values = defaultdict(list) - for d in inputs.as_numpy_iterator(): - for k, v in d.items(): - values[k].append(v) + for numpy_dict in inputs.as_numpy_iterator(): + for key, value in numpy_dict.items(): + values[key].append(value) noise_scale = { k: np.std(v, axis=0).astype(np.float32) for k, v in values.items() } - augment = lambda d: { - k: v - + rng.standard_normal(v.shape).astype(np.float32) - * gaussian_strength - * noise_scale[k] - for k, v in d.items() - } - augments.append(augment) - if len(augments) == 0: + def gaussian_strength_augment(augment_dict: dict) -> dict: + return { + k: v + + rng.standard_normal(v.shape).astype(np.float32) + * gaussian_strength + * noise_scale[k] + for k, v in augment_dict.items() + } + + augments.append(gaussian_strength_augment) + + if len(augments) == 0: # pylint: disable=no-else-return return lambda x: x elif len(augments) == 1: return augments[0] elif len(augments) == 2: return lambda x: augments[1](augments[0](x)) else: - assert False, "Unexpected number of augmentation functions (%d)" % len(augments) - - -def mixup(rng, batch, beta_range=(0.0, 1.0)): - """Each tensor in the batch becomes a linear combination of it and one other tensor""" - a = batch - b = np.array(batch) - rng.shuffle(b) # randomly pair up tensors in the batch + assert ( + False + ), f"Unexpected number of augmentation \ + functions ({len(augments)})" + + +def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any: + """Each tensor in the batch becomes a linear combination of it \ + and one other tensor.""" + batch_a = batch + batch_b = np.array(batch) + rng.shuffle(batch_b) # randomly pair up tensors in the batch # random mixing coefficient for each pair beta = rng.uniform( low=beta_range[0], high=beta_range[1], size=batch.shape[0] ).astype(np.float32) - return (a.T * beta).T + (b.T * (1.0 - beta)).T # return linear combinations + return (batch_a.T * beta).T + ( + batch_b.T * (1.0 - beta) + ).T # return linear combinations |