diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/train.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 306 |
1 files changed, 190 insertions, 116 deletions
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index c8497a4..42bf653 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,8 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Sequential trainer.""" -# pylint: disable=too-many-arguments, too-many-instance-attributes, -# pylint: disable=too-many-locals, too-many-branches, too-many-statements +# pylint: disable=too-many-locals +# pylint: disable=too-many-statements from __future__ import annotations import logging @@ -10,10 +10,13 @@ import math import os import tempfile from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path from typing import Any from typing import Callable from typing import cast +from typing import Generator as GeneratorType from typing import get_args from typing import Literal @@ -27,10 +30,10 @@ from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read -from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save +from mlia.nn.tensorflow.config import TFLiteModel +from mlia.nn.tensorflow.tflite_graph import load_fb +from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.utils.logging import log_action @@ -38,7 +41,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) -augmentation_presets = { +AUGMENTATION_PRESETS = { "none": (None, None), "gaussian": (None, 1.0), "mixup": (1.0, None), @@ -51,6 +54,21 @@ LearningRateSchedule = Literal["cosine", "late", "constant"] LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule) +@dataclass +class TrainingParameters: + """Define default parameters for the training.""" + + augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"] + batch_size: int = 32 + steps: int = 48000 + learning_rate: float = 1e-3 + learning_rate_schedule: LearningRateSchedule = "cosine" + num_procs: int = 1 + num_threads: int = 0 + show_progress: bool = True + checkpoint_at: list | None = None + + def train( source_model: str, unmodified_model: Any, @@ -59,16 +77,7 @@ def train( replace_fn: Callable, input_tensors: list, output_tensors: list, - augment: tuple[float | None, float | None], - steps: int, - learning_rate: float, - batch_size: int, - verbose: bool, - show_progress: bool, - learning_rate_schedule: LearningRateSchedule = "cosine", - checkpoint_at: list | None = None, - num_procs: int = 1, - num_threads: int = 0, + train_params: TrainingParameters = TrainingParameters(), ) -> Any: """Extract and train a model, and return the results.""" if unmodified_model: @@ -95,29 +104,27 @@ def train( input_tfrec, input_tensors, output_tensors, - num_procs=num_procs, - num_threads=num_threads, + num_procs=train_params.num_procs, + num_threads=train_params.num_threads, ) tflite_filenames = train_in_dir( - train_dir, - unmodified_model_dir_path, - Path(train_dir, "new.tflite"), - replace_fn, - augment, - steps, - learning_rate, - batch_size, - checkpoint_at=checkpoint_at, - verbose=verbose, - show_progress=show_progress, - num_procs=num_procs, - num_threads=num_threads, - schedule=learning_rate_schedule, + train_dir=train_dir, + baseline_dir=unmodified_model_dir_path, + output_filename=Path(train_dir, "new.tflite"), + replace_fn=replace_fn, + train_params=train_params, ) for i, filename in enumerate(tflite_filenames): - results.append(eval_in_dir(train_dir, filename, num_procs, num_threads)) + results.append( + eval_in_dir( + train_dir, + filename, + train_params.num_procs, + train_params.num_threads, + ) + ) if output_model: if i + 1 < len(tflite_filenames): @@ -133,7 +140,7 @@ def train( cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() return ( - results if checkpoint_at else results[0] + results if train_params.checkpoint_at else results[0] ) # only return a list if multiple checkpoints are asked for @@ -176,46 +183,24 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None: join_models(Path(model_dir, "start.tflite"), new_end, output_model) -def train_in_dir( - train_dir: str, - baseline_dir: Any, - output_filename: Path, - replace_fn: Callable, - augmentations: tuple[float | None, float | None], - steps: int, - learning_rate: float = 1e-3, - batch_size: int = 32, - checkpoint_at: list | None = None, - schedule: str = "cosine", - verbose: bool = False, - show_progress: bool = False, - num_procs: int = 0, - num_threads: int = 1, -) -> list: - """Train a replacement for replace.tflite using the input.tfrec \ - and output.tfrec in train_dir. - - If baseline_dir is provided, train the replacement to match baseline - outputs for train_dir inputs. Result saved as new.tflite in train_dir. - """ - teacher_dir = baseline_dir if baseline_dir else train_dir - teacher = ParallelTFLiteModel( - f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size - ) - replace = TFLiteModel(f"{train_dir}/replace.tflite") +def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]: assert ( - len(teacher.input_tensors()) == 1 + len(model.input_tensors()) == 1 ), f"Can only train replacements with a single input tensor right now, \ - found {teacher.input_tensors()}" + found {model.input_tensors()}" assert ( - len(teacher.output_tensors()) == 1 + len(model.output_tensors()) == 1 ), f"Can only train replacements with a single output tensor right now, \ - found {teacher.output_tensors()}" + found {model.output_tensors()}" + + input_name = model.input_tensors()[0] + output_name = model.output_tensors()[0] + return (input_name, output_name) - input_name = teacher.input_tensors()[0] - output_name = teacher.output_tensors()[0] +def _check_model_compatibility(teacher: TFLiteModel, replace: TFLiteModel) -> None: + """Assert that teacher and replaced sub-graph are compatible.""" assert len(teacher.shape_from_name) == len( replace.shape_from_name ), f"Baseline and train models must have the same number of inputs and outputs. \ @@ -230,10 +215,37 @@ def train_in_dir( subgraph being replaced. Teacher: {teacher.shape_from_name}\n \ Train dir: {replace.shape_from_name}" + +def set_up_data_pipeline( + teacher: TFLiteModel, + replace: TFLiteModel, + train_dir: str, + augmentations: tuple[float | None, float | None], + steps: int, + batch_size: int = 32, +) -> tf.data.Dataset: + """Create a data pipeline for training of the replacement model.""" + _check_model_compatibility(teacher, replace) + + input_name, output_name = _get_io_tensors(teacher) + input_filename = Path(train_dir, "input.tfrec") total = numpytf_count(str(input_filename)) dict_inputs = numpytf_read(str(input_filename)) + inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0)) + + steps_per_epoch = math.ceil(total / batch_size) + epochs = int(math.ceil(steps / steps_per_epoch)) + logger.info( + "Training on %d items for %d steps (%d epochs with batch size %d)", + total, + epochs * steps_per_epoch, + epochs, + batch_size, + ) + + teacher_dir = Path(teacher.model_path).parent if any(augmentations): # Map the teacher inputs here because the augmentation stage passes these # through a TFLite model to get the outputs @@ -245,17 +257,6 @@ def train_in_dir( lambda d: tf.squeeze(d[output_name], axis=0) ) - steps_per_epoch = math.ceil(total / batch_size) - epochs = int(math.ceil(steps / steps_per_epoch)) - if verbose: - logger.info( - "Training on %d items for %d steps (%d epochs with batch size %d)", - total, - epochs * steps_per_epoch, - epochs, - batch_size, - ) - dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) if epochs > 1: dataset = dataset.cache() @@ -268,10 +269,9 @@ def train_in_dir( train: Any, teach: Any # pylint: disable=redefined-outer-name ) -> tuple: """Return results of train and teach based on augmentations.""" - return ( - augment_train({input_name: train})[input_name], - teacher(augment_teacher({input_name: teach}))[output_name], - ) + augmented_train = augment_train({input_name: train})[input_name] + augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name] + return (augmented_train, augmented_teach) dataset = dataset.map( lambda augment_train, augment_teach: tf.py_function( @@ -281,18 +281,67 @@ def train_in_dir( ) ) + # Restore data shapes of the dataset as they are set to unknown per default + # and get lost during augmentation with tf.py_function. + shape_in, shape_out = ( + teacher.shape_from_name[name].tolist() for name in (input_name, output_name) + ) + for shape in (shape_in, shape_out): + shape[0] = None # set dynamic batch size + + def restore_shapes(input_: Any, output: Any) -> tuple[Any, Any]: + input_.set_shape(shape_in) + output.set_shape(shape_out) + return input_, output + + dataset = dataset.map(restore_shapes) dataset = dataset.prefetch(tf.data.AUTOTUNE) + return dataset + + +def train_in_dir( + train_dir: str, + baseline_dir: Any, + output_filename: Path, + replace_fn: Callable, + train_params: TrainingParameters = TrainingParameters(), +) -> list[str]: + """Train a replacement for replace.tflite using the input.tfrec \ + and output.tfrec in train_dir. + + If baseline_dir is provided, train the replacement to match baseline + outputs for train_dir inputs. Result saved as new.tflite in train_dir. + """ + teacher_dir = baseline_dir if baseline_dir else train_dir + teacher = ParallelTFLiteModel( + f"{teacher_dir}/replace.tflite", + train_params.num_procs, + train_params.num_threads, + batch_size=train_params.batch_size, + ) + replace = TFLiteModel(f"{train_dir}/replace.tflite") + + input_name, output_name = _get_io_tensors(teacher) + + dataset = set_up_data_pipeline( + teacher, + replace, + train_dir, + augmentations=train_params.augmentations, + steps=train_params.steps, + batch_size=train_params.batch_size, + ) input_shape = teacher.shape_from_name[input_name][1:] output_shape = teacher.shape_from_name[output_name][1:] + model = replace_fn(input_shape, output_shape) - optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate) + optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = tf.keras.losses.MeanSquaredError() - model.compile(optimizer=optimizer, loss=loss_fn) + model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) - if verbose: - model.summary() + logger.info(model.summary()) steps_so_far = 0 @@ -302,7 +351,9 @@ def train_in_dir( """Cosine decay from learning rate at start of the run to zero at the end.""" current_step = epoch_step + steps_so_far cd_learning_rate = ( - learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0 + train_params.learning_rate + * (math.cos(math.pi * current_step / train_params.steps) + 1) + / 2.0 ) tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate) @@ -311,28 +362,29 @@ def train_in_dir( ) -> None: """Constant until the last 20% of the run, then linear decay to zero.""" current_step = epoch_step + steps_so_far - steps_remaining = steps - current_step - decay_length = steps // 5 + steps_remaining = train_params.steps - current_step + decay_length = train_params.steps // 5 decay_fraction = min(steps_remaining, decay_length) / decay_length - ld_learning_rate = learning_rate * decay_fraction + ld_learning_rate = train_params.learning_rate * decay_fraction tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate) - if schedule == "cosine": + assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, ( + f'Learning rate schedule "{train_params.learning_rate_schedule}" ' + f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}." + ) + if train_params.learning_rate_schedule == "cosine": callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)] - elif schedule == "late": + elif train_params.learning_rate_schedule == "late": callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)] - elif schedule == "constant": + elif train_params.learning_rate_schedule == "constant": callbacks = [] - else: - assert schedule not in LEARNING_RATE_SCHEDULES - raise ValueError( - f'Learning rate schedule "{schedule}" not implemented - ' - f"expected one of {LEARNING_RATE_SCHEDULES}." - ) output_filenames = [] - checkpoints = (checkpoint_at if checkpoint_at else []) + [steps] - while steps_so_far < steps: + checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ + train_params.steps + ] + + while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far lr_start = optimizer.learning_rate.numpy() model.fit( @@ -340,7 +392,7 @@ def train_in_dir( epochs=1, steps_per_epoch=steps_to_train, callbacks=callbacks, - verbose=show_progress, + verbose=train_params.show_progress, ) steps_so_far += steps_to_train logger.info( @@ -350,12 +402,14 @@ def train_in_dir( steps_to_train, ) - if steps_so_far < steps: + if steps_so_far < train_params.steps: filename, ext = Path(output_filename).parts[1:] checkpoint_filename = filename + (f"_@{steps_so_far}") + ext else: checkpoint_filename = str(output_filename) - with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"): + with log_action( + f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" + ): save_as_tflite( model, checkpoint_filename, @@ -379,14 +433,30 @@ def save_as_tflite( output_shape: list, ) -> None: """Save Keras model as TFLite file.""" - converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + + @contextmanager + def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType: + """Fix the input shape of the Keras model temporarily. + + This avoids artifacts during conversion to TensorFlow Lite. + """ + orig_shape = keras_model.input.shape + keras_model.input.set_shape(tf.TensorShape(tmp_shape)) + try: + yield keras_model + finally: + # Restore original shape to not interfere with further training + keras_model.input.set_shape(orig_shape) + + with fixed_input(keras_model, input_shape) as fixed_model: + converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model) tflite_model = converter.convert() with open(filename, "wb") as file: file.write(tflite_model) # Now fix the shapes and names to match those we expect - flatbuffer = load(filename) + flatbuffer = load_fb(filename) i = flatbuffer.subgraphs[0].inputs[0] flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32) flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8") @@ -395,11 +465,11 @@ def save_as_tflite( output_shape, dtype=np.int32 ) flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8") - save(flatbuffer, filename) + save_fb(flatbuffer, filename) def augment_fn_twins( - inputs: dict, augmentations: tuple[float | None, float | None] + inputs: tf.data.Dataset, augmentations: tuple[float | None, float | None] ) -> Any: """Return a pair of twinned augmentation functions with the same sequence \ of random numbers.""" @@ -415,6 +485,11 @@ def augment_fn( inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator ) -> Any: """Augmentation module.""" + assert len(augmentations) == 2, ( + f"Unexpected number of augmentation parameters: {len(augmentations)} " + "(must be 2)" + ) + mixup_strength, gaussian_strength = augmentations augments = [] @@ -449,17 +524,16 @@ def augment_fn( augments.append(gaussian_strength_augment) - if len(augments) == 0: # pylint: disable=no-else-return + if len(augments) == 0: return lambda x: x - elif len(augments) == 1: + if len(augments) == 1: return augments[0] - elif len(augments) == 2: + if len(augments) == 2: return lambda x: augments[1](augments[0](x)) - else: - assert ( - False - ), f"Unexpected number of augmentation \ - functions ({len(augments)})" + + raise RuntimeError( + "Unexpected number of augmentation functions (must be <=2): " f"{len(augments)}" + ) def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any: |