diff options
author | Annie Tallund <annie.tallund@arm.com> | 2023-03-13 17:00:31 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-10-11 15:41:48 +0100 |
commit | f0b8ed75fed9dc69ab1f6313339f9f7e38bfc725 (patch) | |
tree | bc353fad664040b44915b5cf7ae807894b0b87e8 /src/mlia/nn/rewrite/core/train.py | |
parent | b236127b9a18ec2668271c6b5baafa6a7c1dde51 (diff) | |
download | mlia-f0b8ed75fed9dc69ab1f6313339f9f7e38bfc725.tar.gz |
MLIA-845 Migrate rewrite code
- Add required files for rewriting of TensorFlow Lite graphs
- Adapt rewrite dependency paths and project name
- Add license headers
Change-Id: I19c5f63215fe2af2fa7d7d44af08144c6c5f911c
Signed-off-by: Benjamin Klimczak <benjamin.klimczak@arm.com>
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 487 |
1 files changed, 487 insertions, 0 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py new file mode 100644 index 0000000..a929b14 --- /dev/null +++ b/src/mlia/nn/rewrite/core/train.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +import math +import os +import tempfile +from collections import defaultdict + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +import tensorflow as tf + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +try: + from tensorflow.keras.optimizers.schedules import CosineDecay +except ImportError: + # In TF 2.4 CosineDecay was still experimental + from tensorflow.keras.experimental import CosineDecay + +import numpy as np +from mlia.nn.rewrite.core.utils.numpy_tfrecord import ( + NumpyTFReader, + NumpyTFWriter, + TFLiteModel, + numpytf_count, +) +from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel +from mlia.nn.rewrite.core.graph_edit.record import record_model +from mlia.nn.rewrite.core.utils.utils import load, save +from mlia.nn.rewrite.core.extract import extract +from mlia.nn.rewrite.core.graph_edit.join import join_models +from mlia.nn.rewrite.core.graph_edit.diff import diff_stats + + +augmentation_presets = { + "none": (None, None), + "gaussian": (None, 1.0), + "mixup": (1.0, None), + "mixout": (1.6, None), + "mix_gaussian_large": (2.0, 1.0), + "mix_gaussian_small": (1.6, 0.3), +} + + +class SequentialTrainer: + def __init__( + self, + source_model, + output_model, + input_tfrec, + augment="gaussian", + steps=6000, + lr=1e-3, + batch_size=32, + show_progress=True, + eval_fn=None, + num_procs=1, + num_threads=0, + ): + self.source_model = source_model + self.output_model = output_model + self.input_tfrec = input_tfrec + self.default_augment = augment + self.default_steps = steps + self.default_lr = lr + self.default_batch_size = batch_size + self.show_progress = show_progress + self.num_procs = num_procs + self.num_threads = num_threads + self.first_replace = True + self.eval_fn = eval_fn + + def replace( + self, + model_fn, + input_tensors, + output_tensors, + augment=None, + steps=None, + lr=None, + batch_size=None, + ): + augment = self.default_augment if augment is None else augment + steps = self.default_steps if steps is None else steps + lr = self.default_lr if lr is None else lr + batch_size = self.default_batch_size if batch_size is None else batch_size + + if isinstance(augment, str): + augment = augmentation_presets[augment] + + if self.first_replace: + source_model = self.source_model + unmodified_model = None + else: + source_model = self.output_model + unmodified_model = self.source_model + + mae, nrmse = train( + source_model, + unmodified_model, + self.output_model, + self.input_tfrec, + model_fn, + input_tensors, + output_tensors, + augment, + steps, + lr, + batch_size, + False, + self.show_progress, + None, + 0, + self.num_procs, + self.num_threads, + ) + + self.first_replace = False + if self.eval_fn: + return self.eval_fn(mae, nrmse, self.output_model) + else: + return mae, nrmse + + +def train( + source_model, + unmodified_model, + output_model, + input_tfrec, + replace_fn, + input_tensors, + output_tensors, + augment, + steps, + lr, + batch_size, + verbose, + show_progress, + checkpoint_at=None, + checkpoint_decay_steps=0, + num_procs=1, + num_threads=0, +): + if unmodified_model: + unmodified_model_dir = tempfile.TemporaryDirectory() + unmodified_model_dir_path = unmodified_model_dir.name + extract( + unmodified_model_dir_path, + source_model, + input_tfrec, + input_tensors, + output_tensors, + ) + else: + unmodified_model_dir = None + unmodified_model_dir_path = None + + results = [] + with tempfile.TemporaryDirectory() as train_dir: + p = lambda file: os.path.join(train_dir, file) + + extract( + train_dir, + source_model, + input_tfrec, + input_tensors, + output_tensors, + num_procs=num_procs, + num_threads=num_threads, + ) + + tflite_filenames = train_in_dir( + train_dir, + unmodified_model_dir_path, + p("new.tflite"), + replace_fn, + augment, + steps, + lr, + batch_size, + checkpoint_at=checkpoint_at, + checkpoint_decay_steps=checkpoint_decay_steps, + verbose=verbose, + show_progress=show_progress, + num_procs=num_procs, + num_threads=num_threads, + ) + + for i, filename in enumerate(tflite_filenames): + results.append(eval_in_dir(train_dir, filename, num_procs, num_threads)) + + if output_model: + if i + 1 < len(tflite_filenames): + # Append the same _@STEPS.tflite postfix used by intermediate checkpoints for all but the last output + postfix = filename.split("_@")[-1] + output_filename = output_model.split(".tflite")[0] + postfix + else: + output_filename = output_model + join_in_dir(train_dir, filename, output_filename) + + if unmodified_model_dir: + unmodified_model_dir.cleanup() + + return ( + results if checkpoint_at else results[0] + ) # only return a list if multiple checkpoints are asked for + + +def eval_in_dir(dir, new_part, num_procs=1, num_threads=0): + p = lambda file: os.path.join(dir, file) + input = ( + p("input_orig.tfrec") + if os.path.exists(p("input_orig.tfrec")) + else p("input.tfrec") + ) + output = ( + p("output_orig.tfrec") + if os.path.exists(p("output_orig.tfrec")) + else p("output.tfrec") + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + predict = os.path.join(tmp_dir, "predict.tfrec") + record_model( + input, new_part, predict, num_procs=num_procs, num_threads=num_threads + ) + mae, nrmse = diff_stats(output, predict) + + return mae, nrmse + + +def join_in_dir(dir, new_part, output_model): + with tempfile.TemporaryDirectory() as tmp_dir: + d = lambda file: os.path.join(dir, file) + new_end = os.path.join(tmp_dir, "new_end.tflite") + join_models(new_part, d("end.tflite"), new_end) + join_models(d("start.tflite"), new_end, output_model) + + +def train_in_dir( + train_dir, + baseline_dir, + output_filename, + replace_fn, + augmentations, + steps, + lr=1e-3, + batch_size=32, + checkpoint_at=None, + checkpoint_decay_steps=0, + schedule="cosine", + verbose=False, + show_progress=False, + num_procs=None, + num_threads=1, +): + """Train a replacement for replace.tflite using the input.tfrec and output.tfrec in train_dir. + If baseline_dir is provided, train the replacement to match baseline outputs for train_dir inputs. + Result saved as new.tflite in train_dir. + """ + teacher_dir = baseline_dir if baseline_dir else train_dir + teacher = ParallelTFLiteModel( + "%s/replace.tflite" % teacher_dir, num_procs, num_threads, batch_size=batch_size + ) + replace = TFLiteModel("%s/replace.tflite" % train_dir) + assert len(teacher.input_tensors()) == 1, ( + "Can only train replacements with a single input tensor right now, found %s" + % teacher.input_tensors() + ) + assert len(teacher.output_tensors()) == 1, ( + "Can only train replacements with a single output tensor right now, found %s" + % teacher.output_tensors() + ) + input_name = teacher.input_tensors()[0] + output_name = teacher.output_tensors()[0] + + assert len(teacher.shape_from_name) == len( + replace.shape_from_name + ), "Baseline and train models must have the same number of inputs and outputs. Teacher: {}\nTrain dir: {}".format( + teacher.shape_from_name, replace.shape_from_name + ) + assert all( + tn == rn and (ts[1:] == rs[1:]).all() + for (tn, ts), (rn, rs) in zip( + teacher.shape_from_name.items(), replace.shape_from_name.items() + ) + ), "Baseline and train models must have the same input and output shapes for the subgraph being replaced. Teacher: {}\nTrain dir: {}".format( + teacher.shape_from_name, replace.shape_from_name + ) + + input_filename = os.path.join(train_dir, "input.tfrec") + total = numpytf_count(input_filename) + dict_inputs = NumpyTFReader(input_filename) + inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0)) + if any(augmentations): + # Map the teacher inputs here because the augmentation stage passes these through a TFLite model to get the outputs + teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "input.tfrec")).map( + lambda d: tf.squeeze(d[input_name], axis=0) + ) + else: + teacher_outputs = NumpyTFReader(os.path.join(teacher_dir, "output.tfrec")).map( + lambda d: tf.squeeze(d[output_name], axis=0) + ) + + steps_per_epoch = math.ceil(total / batch_size) + epochs = int(math.ceil(steps / steps_per_epoch)) + if verbose: + print( + "Training on %d items for %d steps (%d epochs with batch size %d)" + % (total, epochs * steps_per_epoch, epochs, batch_size) + ) + + dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) + if epochs > 1: + dataset = dataset.cache() + dataset = dataset.shuffle(total).repeat().batch(batch_size) + + if any(augmentations): + augment_train, augment_teacher = augment_fn_twins(dict_inputs, augmentations) + augment_fn = lambda train, teach: ( + augment_train({input_name: train})[input_name], + teacher(augment_teacher({input_name: teach}))[output_name], + ) + dataset = dataset.map( + lambda train, teach: tf.py_function( + augment_fn, inp=[train, teach], Tout=[tf.float32, tf.float32] + ) + ) + + dataset = dataset.prefetch(tf.data.AUTOTUNE) + + input_shape = teacher.shape_from_name[input_name][1:] + output_shape = teacher.shape_from_name[output_name][1:] + model = replace_fn(input_shape, output_shape) + + optimizer = tf.keras.optimizers.Nadam(learning_rate=lr) + loss_fn = tf.keras.losses.MeanSquaredError() + model.compile(optimizer=optimizer, loss=loss_fn) + + if verbose: + model.summary() + + steps_so_far = 0 + + def cosine_decay(epoch_step, logs): + """Cosine decay from lr at start of the run to zero at the end""" + current_step = epoch_step + steps_so_far + learning_rate = lr * (math.cos(math.pi * current_step / steps) + 1) / 2.0 + tf.keras.backend.set_value(optimizer.learning_rate, learning_rate) + + def late_decay(epoch_step, logs): + """Constant until the last 20% of the run, then linear decay to zero""" + current_step = epoch_step + steps_so_far + steps_remaining = steps - current_step + decay_length = steps // 5 + decay_fraction = min(steps_remaining, decay_length) / decay_length + learning_rate = lr * decay_fraction + tf.keras.backend.set_value(optimizer.learning_rate, learning_rate) + + if schedule == "cosine": + callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)] + elif schedule == "late": + callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)] + elif schedule == "constant": + callbacks = [] + else: + assert False, ( + 'LR schedule "%s" not implemented - expected "cosine", "constant" or "late"' + % schedule + ) + + output_filenames = [] + checkpoints = (checkpoint_at if checkpoint_at else []) + [steps] + while steps_so_far < steps: + steps_to_train = checkpoints.pop(0) - steps_so_far + lr_start = optimizer.learning_rate.numpy() + model.fit( + dataset, + epochs=1, + steps_per_epoch=steps_to_train, + callbacks=callbacks, + verbose=show_progress, + ) + steps_so_far += steps_to_train + print( + "lr decayed from %f to %f over %d steps" + % (lr_start, optimizer.learning_rate.numpy(), steps_to_train) + ) + + if steps_so_far < steps: + filename, ext = os.path.splitext(output_filename) + checkpoint_filename = filename + ("_@%d" % steps_so_far) + ext + else: + checkpoint_filename = output_filename + print("%d/%d: Saved as %s" % (steps_so_far, steps, checkpoint_filename)) + save_as_tflite( + model, + checkpoint_filename, + input_name, + replace.shape_from_name[input_name], + output_name, + replace.shape_from_name[output_name], + ) + output_filenames.append(checkpoint_filename) + + teacher.close() + return output_filenames + + +def save_as_tflite( + keras_model, filename, input_name, input_shape, output_name, output_shape +): + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + tflite_model = converter.convert() + + with open(filename, "wb") as f: + f.write(tflite_model) + + # Now fix the shapes and names to match those we expect + fb = load(filename) + i = fb.subgraphs[0].inputs[0] + fb.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32) + fb.subgraphs[0].tensors[i].name = input_name.encode("utf-8") + o = fb.subgraphs[0].outputs[0] + fb.subgraphs[0].tensors[o].shape = np.array(output_shape, dtype=np.int32) + fb.subgraphs[0].tensors[o].name = output_name.encode("utf-8") + save(fb, filename) + + +def augment_fn_twins(inputs, augmentations): + """Return a pair of twinned augmentation functions with the same sequence of random numbers""" + seed = np.random.randint(2**32 - 1) + rng1 = np.random.default_rng(seed) + rng2 = np.random.default_rng(seed) + return augment_fn(inputs, augmentations, rng1), augment_fn( + inputs, augmentations, rng2 + ) + + +def augment_fn(inputs, augmentations, rng): + mixup_strength, gaussian_strength = augmentations + + augments = [] + + if mixup_strength: + mixup_range = (0.5 - mixup_strength / 2, 0.5 + mixup_strength / 2) + augment = lambda d: { + k: mixup(rng, v.numpy(), mixup_range) for k, v in d.items() + } + augments.append(augment) + + if gaussian_strength: + values = defaultdict(list) + for d in inputs.as_numpy_iterator(): + for k, v in d.items(): + values[k].append(v) + noise_scale = { + k: np.std(v, axis=0).astype(np.float32) for k, v in values.items() + } + augment = lambda d: { + k: v + + rng.standard_normal(v.shape).astype(np.float32) + * gaussian_strength + * noise_scale[k] + for k, v in d.items() + } + augments.append(augment) + + if len(augments) == 0: + return lambda x: x + elif len(augments) == 1: + return augments[0] + elif len(augments) == 2: + return lambda x: augments[1](augments[0](x)) + else: + assert False, "Unexpected number of augmentation functions (%d)" % len(augments) + + +def mixup(rng, batch, beta_range=(0.0, 1.0)): + """Each tensor in the batch becomes a linear combination of it and one other tensor""" + a = batch + b = np.array(batch) + rng.shuffle(b) # randomly pair up tensors in the batch + # random mixing coefficient for each pair + beta = rng.uniform( + low=beta_range[0], high=beta_range[1], size=batch.shape[0] + ).astype(np.float32) + return (a.T * beta).T + (b.T * (1.0 - beta)).T # return linear combinations |