diff options
22 files changed, 591 insertions, 425 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py index 2707eb1..13a5268 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/cut.py +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -9,8 +9,8 @@ import tensorflow as tf from tensorflow.lite.python.schema_py_generated import ModelT from tensorflow.lite.python.schema_py_generated import SubGraphT -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save +from mlia.nn.tensorflow.tflite_graph import load_fb +from mlia.nn.tensorflow.tflite_graph import save_fb os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) @@ -138,8 +138,8 @@ def cut_model( output_file: str, ) -> None: """Cut subgraphs and simplify a given model.""" - model = load(model_file) + model = load_fb(model_file) subgraph = model.subgraphs[subgraph_index] cut_subgraph(subgraph, input_names, output_names) simplify(model) - save(model, output_file) + save_fb(model, output_file) diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py index 2530ec8..70109d8 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/join.py +++ b/src/mlia/nn/rewrite/core/graph_edit/join.py @@ -11,8 +11,8 @@ from tensorflow.lite.python.schema_py_generated import ModelT from tensorflow.lite.python.schema_py_generated import OperatorCodeT from tensorflow.lite.python.schema_py_generated import SubGraphT -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save +from mlia.nn.tensorflow.tflite_graph import load_fb +from mlia.nn.tensorflow.tflite_graph import save_fb os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) @@ -26,12 +26,12 @@ def join_models( subgraph_dst: int = 0, ) -> None: """Join two models and save the result into a given model file path.""" - src_model = load(input_src) - dst_model = load(input_dst) + src_model = load_fb(input_src) + dst_model = load_fb(input_dst) src_subgraph = src_model.subgraphs[subgraph_src] dst_subgraph = dst_model.subgraphs[subgraph_dst] join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph) - save(dst_model, output_file) + save_fb(dst_model, output_file) def join_subgraphs( diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 0d182df..6b27984 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -4,6 +4,7 @@ from __future__ import annotations import importlib +import logging import tempfile from dataclasses import dataclass from pathlib import Path @@ -12,13 +13,14 @@ from typing import Any from mlia.core.errors import ConfigurationError from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration -from mlia.nn.rewrite.core.train import eval_in_dir -from mlia.nn.rewrite.core.train import join_in_dir from mlia.nn.rewrite.core.train import train -from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.core.train import TrainingParameters from mlia.nn.tensorflow.config import TFLiteModel +logger = logging.getLogger(__name__) + + @dataclass class RewriteConfiguration(OptimizerConfiguration): """Rewrite configuration.""" @@ -26,6 +28,7 @@ class RewriteConfiguration(OptimizerConfiguration): optimization_target: str layers_to_optimize: list[str] | None = None dataset: Path | None = None + train_params: TrainingParameters = TrainingParameters() def __str__(self) -> str: """Return string representation of the configuration.""" @@ -40,8 +43,8 @@ class Rewriter(Optimizer): ): """Init Rewriter instance.""" self.model = TFLiteModel(tflite_model_path) + self.model_path = tflite_model_path self.optimizer_configuration = optimizer_configuration - self.train_dir = "" def apply_optimization(self) -> None: """Apply the rewrite flow.""" @@ -61,50 +64,36 @@ class Rewriter(Optimizer): replace_fn = get_function(replace_function) - augmentation_preset = (None, None) use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_output = Path(tmp_dir, "output.tflite") - - if self.train_dir: - tmp_new = Path(tmp_dir, "new.tflite") - new_part = train_in_dir( - train_dir=self.train_dir, - baseline_dir=None, - output_filename=tmp_new, - replace_fn=replace_fn, - augmentations=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=1, - verbose=True, - show_progress=True, - ) - eval_in_dir(self.train_dir, new_part[0]) - join_in_dir(self.train_dir, new_part[0], str(tmp_output)) - else: - if not self.optimizer_configuration.layers_to_optimize: - raise ConfigurationError( - "Input and output tensor names need to be set for rewrite." - ) - train( - source_model=tflite_model, - unmodified_model=tflite_model if use_unmodified_model else None, - output_model=str(tmp_output), - input_tfrec=str(tfrecord), - replace_fn=replace_fn, - input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], - output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], - augment=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=1, - verbose=True, - show_progress=True, - ) + tmp_dir = tempfile.mkdtemp() + tmp_output = Path(tmp_dir, "output.tflite") + + if not self.optimizer_configuration.layers_to_optimize: + raise ConfigurationError( + "Input and output tensor names need to be set for rewrite." + ) + result = train( + source_model=tflite_model, + unmodified_model=tflite_model if use_unmodified_model else None, + output_model=str(tmp_output), + input_tfrec=str(tfrecord), + replace_fn=replace_fn, + input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], + output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], + train_params=self.optimizer_configuration.train_params, + ) + + self.model = TFLiteModel(tmp_output) + + if result: + stats_as_str = ", ".join(str(stats) for stats in result) + logger.info( + "The MAE and NRMSE between original and replacement [%s]", + stats_as_str, + ) def get_model(self) -> TFLiteModel: """Return optimized model.""" diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index c8497a4..42bf653 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,8 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Sequential trainer.""" -# pylint: disable=too-many-arguments, too-many-instance-attributes, -# pylint: disable=too-many-locals, too-many-branches, too-many-statements +# pylint: disable=too-many-locals +# pylint: disable=too-many-statements from __future__ import annotations import logging @@ -10,10 +10,13 @@ import math import os import tempfile from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path from typing import Any from typing import Callable from typing import cast +from typing import Generator as GeneratorType from typing import get_args from typing import Literal @@ -27,10 +30,10 @@ from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read -from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save +from mlia.nn.tensorflow.config import TFLiteModel +from mlia.nn.tensorflow.tflite_graph import load_fb +from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.utils.logging import log_action @@ -38,7 +41,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) -augmentation_presets = { +AUGMENTATION_PRESETS = { "none": (None, None), "gaussian": (None, 1.0), "mixup": (1.0, None), @@ -51,6 +54,21 @@ LearningRateSchedule = Literal["cosine", "late", "constant"] LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule) +@dataclass +class TrainingParameters: + """Define default parameters for the training.""" + + augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"] + batch_size: int = 32 + steps: int = 48000 + learning_rate: float = 1e-3 + learning_rate_schedule: LearningRateSchedule = "cosine" + num_procs: int = 1 + num_threads: int = 0 + show_progress: bool = True + checkpoint_at: list | None = None + + def train( source_model: str, unmodified_model: Any, @@ -59,16 +77,7 @@ def train( replace_fn: Callable, input_tensors: list, output_tensors: list, - augment: tuple[float | None, float | None], - steps: int, - learning_rate: float, - batch_size: int, - verbose: bool, - show_progress: bool, - learning_rate_schedule: LearningRateSchedule = "cosine", - checkpoint_at: list | None = None, - num_procs: int = 1, - num_threads: int = 0, + train_params: TrainingParameters = TrainingParameters(), ) -> Any: """Extract and train a model, and return the results.""" if unmodified_model: @@ -95,29 +104,27 @@ def train( input_tfrec, input_tensors, output_tensors, - num_procs=num_procs, - num_threads=num_threads, + num_procs=train_params.num_procs, + num_threads=train_params.num_threads, ) tflite_filenames = train_in_dir( - train_dir, - unmodified_model_dir_path, - Path(train_dir, "new.tflite"), - replace_fn, - augment, - steps, - learning_rate, - batch_size, - checkpoint_at=checkpoint_at, - verbose=verbose, - show_progress=show_progress, - num_procs=num_procs, - num_threads=num_threads, - schedule=learning_rate_schedule, + train_dir=train_dir, + baseline_dir=unmodified_model_dir_path, + output_filename=Path(train_dir, "new.tflite"), + replace_fn=replace_fn, + train_params=train_params, ) for i, filename in enumerate(tflite_filenames): - results.append(eval_in_dir(train_dir, filename, num_procs, num_threads)) + results.append( + eval_in_dir( + train_dir, + filename, + train_params.num_procs, + train_params.num_threads, + ) + ) if output_model: if i + 1 < len(tflite_filenames): @@ -133,7 +140,7 @@ def train( cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() return ( - results if checkpoint_at else results[0] + results if train_params.checkpoint_at else results[0] ) # only return a list if multiple checkpoints are asked for @@ -176,46 +183,24 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None: join_models(Path(model_dir, "start.tflite"), new_end, output_model) -def train_in_dir( - train_dir: str, - baseline_dir: Any, - output_filename: Path, - replace_fn: Callable, - augmentations: tuple[float | None, float | None], - steps: int, - learning_rate: float = 1e-3, - batch_size: int = 32, - checkpoint_at: list | None = None, - schedule: str = "cosine", - verbose: bool = False, - show_progress: bool = False, - num_procs: int = 0, - num_threads: int = 1, -) -> list: - """Train a replacement for replace.tflite using the input.tfrec \ - and output.tfrec in train_dir. - - If baseline_dir is provided, train the replacement to match baseline - outputs for train_dir inputs. Result saved as new.tflite in train_dir. - """ - teacher_dir = baseline_dir if baseline_dir else train_dir - teacher = ParallelTFLiteModel( - f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size - ) - replace = TFLiteModel(f"{train_dir}/replace.tflite") +def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]: assert ( - len(teacher.input_tensors()) == 1 + len(model.input_tensors()) == 1 ), f"Can only train replacements with a single input tensor right now, \ - found {teacher.input_tensors()}" + found {model.input_tensors()}" assert ( - len(teacher.output_tensors()) == 1 + len(model.output_tensors()) == 1 ), f"Can only train replacements with a single output tensor right now, \ - found {teacher.output_tensors()}" + found {model.output_tensors()}" + + input_name = model.input_tensors()[0] + output_name = model.output_tensors()[0] + return (input_name, output_name) - input_name = teacher.input_tensors()[0] - output_name = teacher.output_tensors()[0] +def _check_model_compatibility(teacher: TFLiteModel, replace: TFLiteModel) -> None: + """Assert that teacher and replaced sub-graph are compatible.""" assert len(teacher.shape_from_name) == len( replace.shape_from_name ), f"Baseline and train models must have the same number of inputs and outputs. \ @@ -230,10 +215,37 @@ def train_in_dir( subgraph being replaced. Teacher: {teacher.shape_from_name}\n \ Train dir: {replace.shape_from_name}" + +def set_up_data_pipeline( + teacher: TFLiteModel, + replace: TFLiteModel, + train_dir: str, + augmentations: tuple[float | None, float | None], + steps: int, + batch_size: int = 32, +) -> tf.data.Dataset: + """Create a data pipeline for training of the replacement model.""" + _check_model_compatibility(teacher, replace) + + input_name, output_name = _get_io_tensors(teacher) + input_filename = Path(train_dir, "input.tfrec") total = numpytf_count(str(input_filename)) dict_inputs = numpytf_read(str(input_filename)) + inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0)) + + steps_per_epoch = math.ceil(total / batch_size) + epochs = int(math.ceil(steps / steps_per_epoch)) + logger.info( + "Training on %d items for %d steps (%d epochs with batch size %d)", + total, + epochs * steps_per_epoch, + epochs, + batch_size, + ) + + teacher_dir = Path(teacher.model_path).parent if any(augmentations): # Map the teacher inputs here because the augmentation stage passes these # through a TFLite model to get the outputs @@ -245,17 +257,6 @@ def train_in_dir( lambda d: tf.squeeze(d[output_name], axis=0) ) - steps_per_epoch = math.ceil(total / batch_size) - epochs = int(math.ceil(steps / steps_per_epoch)) - if verbose: - logger.info( - "Training on %d items for %d steps (%d epochs with batch size %d)", - total, - epochs * steps_per_epoch, - epochs, - batch_size, - ) - dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) if epochs > 1: dataset = dataset.cache() @@ -268,10 +269,9 @@ def train_in_dir( train: Any, teach: Any # pylint: disable=redefined-outer-name ) -> tuple: """Return results of train and teach based on augmentations.""" - return ( - augment_train({input_name: train})[input_name], - teacher(augment_teacher({input_name: teach}))[output_name], - ) + augmented_train = augment_train({input_name: train})[input_name] + augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name] + return (augmented_train, augmented_teach) dataset = dataset.map( lambda augment_train, augment_teach: tf.py_function( @@ -281,18 +281,67 @@ def train_in_dir( ) ) + # Restore data shapes of the dataset as they are set to unknown per default + # and get lost during augmentation with tf.py_function. + shape_in, shape_out = ( + teacher.shape_from_name[name].tolist() for name in (input_name, output_name) + ) + for shape in (shape_in, shape_out): + shape[0] = None # set dynamic batch size + + def restore_shapes(input_: Any, output: Any) -> tuple[Any, Any]: + input_.set_shape(shape_in) + output.set_shape(shape_out) + return input_, output + + dataset = dataset.map(restore_shapes) dataset = dataset.prefetch(tf.data.AUTOTUNE) + return dataset + + +def train_in_dir( + train_dir: str, + baseline_dir: Any, + output_filename: Path, + replace_fn: Callable, + train_params: TrainingParameters = TrainingParameters(), +) -> list[str]: + """Train a replacement for replace.tflite using the input.tfrec \ + and output.tfrec in train_dir. + + If baseline_dir is provided, train the replacement to match baseline + outputs for train_dir inputs. Result saved as new.tflite in train_dir. + """ + teacher_dir = baseline_dir if baseline_dir else train_dir + teacher = ParallelTFLiteModel( + f"{teacher_dir}/replace.tflite", + train_params.num_procs, + train_params.num_threads, + batch_size=train_params.batch_size, + ) + replace = TFLiteModel(f"{train_dir}/replace.tflite") + + input_name, output_name = _get_io_tensors(teacher) + + dataset = set_up_data_pipeline( + teacher, + replace, + train_dir, + augmentations=train_params.augmentations, + steps=train_params.steps, + batch_size=train_params.batch_size, + ) input_shape = teacher.shape_from_name[input_name][1:] output_shape = teacher.shape_from_name[output_name][1:] + model = replace_fn(input_shape, output_shape) - optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate) + optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = tf.keras.losses.MeanSquaredError() - model.compile(optimizer=optimizer, loss=loss_fn) + model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) - if verbose: - model.summary() + logger.info(model.summary()) steps_so_far = 0 @@ -302,7 +351,9 @@ def train_in_dir( """Cosine decay from learning rate at start of the run to zero at the end.""" current_step = epoch_step + steps_so_far cd_learning_rate = ( - learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0 + train_params.learning_rate + * (math.cos(math.pi * current_step / train_params.steps) + 1) + / 2.0 ) tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate) @@ -311,28 +362,29 @@ def train_in_dir( ) -> None: """Constant until the last 20% of the run, then linear decay to zero.""" current_step = epoch_step + steps_so_far - steps_remaining = steps - current_step - decay_length = steps // 5 + steps_remaining = train_params.steps - current_step + decay_length = train_params.steps // 5 decay_fraction = min(steps_remaining, decay_length) / decay_length - ld_learning_rate = learning_rate * decay_fraction + ld_learning_rate = train_params.learning_rate * decay_fraction tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate) - if schedule == "cosine": + assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, ( + f'Learning rate schedule "{train_params.learning_rate_schedule}" ' + f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}." + ) + if train_params.learning_rate_schedule == "cosine": callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)] - elif schedule == "late": + elif train_params.learning_rate_schedule == "late": callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)] - elif schedule == "constant": + elif train_params.learning_rate_schedule == "constant": callbacks = [] - else: - assert schedule not in LEARNING_RATE_SCHEDULES - raise ValueError( - f'Learning rate schedule "{schedule}" not implemented - ' - f"expected one of {LEARNING_RATE_SCHEDULES}." - ) output_filenames = [] - checkpoints = (checkpoint_at if checkpoint_at else []) + [steps] - while steps_so_far < steps: + checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ + train_params.steps + ] + + while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far lr_start = optimizer.learning_rate.numpy() model.fit( @@ -340,7 +392,7 @@ def train_in_dir( epochs=1, steps_per_epoch=steps_to_train, callbacks=callbacks, - verbose=show_progress, + verbose=train_params.show_progress, ) steps_so_far += steps_to_train logger.info( @@ -350,12 +402,14 @@ def train_in_dir( steps_to_train, ) - if steps_so_far < steps: + if steps_so_far < train_params.steps: filename, ext = Path(output_filename).parts[1:] checkpoint_filename = filename + (f"_@{steps_so_far}") + ext else: checkpoint_filename = str(output_filename) - with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"): + with log_action( + f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" + ): save_as_tflite( model, checkpoint_filename, @@ -379,14 +433,30 @@ def save_as_tflite( output_shape: list, ) -> None: """Save Keras model as TFLite file.""" - converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + + @contextmanager + def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType: + """Fix the input shape of the Keras model temporarily. + + This avoids artifacts during conversion to TensorFlow Lite. + """ + orig_shape = keras_model.input.shape + keras_model.input.set_shape(tf.TensorShape(tmp_shape)) + try: + yield keras_model + finally: + # Restore original shape to not interfere with further training + keras_model.input.set_shape(orig_shape) + + with fixed_input(keras_model, input_shape) as fixed_model: + converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model) tflite_model = converter.convert() with open(filename, "wb") as file: file.write(tflite_model) # Now fix the shapes and names to match those we expect - flatbuffer = load(filename) + flatbuffer = load_fb(filename) i = flatbuffer.subgraphs[0].inputs[0] flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32) flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8") @@ -395,11 +465,11 @@ def save_as_tflite( output_shape, dtype=np.int32 ) flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8") - save(flatbuffer, filename) + save_fb(flatbuffer, filename) def augment_fn_twins( - inputs: dict, augmentations: tuple[float | None, float | None] + inputs: tf.data.Dataset, augmentations: tuple[float | None, float | None] ) -> Any: """Return a pair of twinned augmentation functions with the same sequence \ of random numbers.""" @@ -415,6 +485,11 @@ def augment_fn( inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator ) -> Any: """Augmentation module.""" + assert len(augmentations) == 2, ( + f"Unexpected number of augmentation parameters: {len(augmentations)} " + "(must be 2)" + ) + mixup_strength, gaussian_strength = augmentations augments = [] @@ -449,17 +524,16 @@ def augment_fn( augments.append(gaussian_strength_augment) - if len(augments) == 0: # pylint: disable=no-else-return + if len(augments) == 0: return lambda x: x - elif len(augments) == 1: + if len(augments) == 1: return augments[0] - elif len(augments) == 2: + if len(augments) == 2: return lambda x: augments[1](augments[0](x)) - else: - assert ( - False - ), f"Unexpected number of augmentation \ - functions ({len(augments)})" + + raise RuntimeError( + "Unexpected number of augmentation functions (must be <=2): " f"{len(augments)}" + ) def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any: diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py index 9229810..38ac1ed 100644 --- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py +++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py @@ -6,55 +6,56 @@ from __future__ import annotations import json import os import random -import tempfile -from collections import defaultdict +from functools import lru_cache from pathlib import Path from typing import Any from typing import Callable -import numpy as np import tensorflow as tf -from tensorflow.lite.python import interpreter as interpreter_wrapper -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -def make_decode_fn(filename: str) -> Callable: - """Make decode filename.""" +def decode_fn(record_bytes: Any, type_map: dict) -> dict: + """Decode the given bytes into a name-tensor dict assuming the given type.""" + parse_dict = { + name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() + } + example = tf.io.parse_single_example(record_bytes, parse_dict) + features = { + n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) for n, t in type_map.items() + } + return features - def decode_fn(record_bytes: Any, type_map: dict) -> dict: - parse_dict = { - name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() - } - example = tf.io.parse_single_example(record_bytes, parse_dict) - features = { - n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) - for n, t in type_map.items() - } - return features +def make_decode_fn(filename: str, model_filename: str | Path | None = None) -> Callable: + """Make decode filename.""" meta_filename = filename + ".meta" - with open(meta_filename, encoding="utf-8") as file: - type_map = json.load(file)["type_map"] + try: + with open(meta_filename, encoding="utf-8") as file: + type_map = json.load(file)["type_map"] return lambda record_bytes: decode_fn(record_bytes, type_map) def numpytf_read(filename: str | Path) -> Any: """Read TFRecord dataset.""" - decode_fn = make_decode_fn(str(filename)) + decode = make_decode_fn(str(filename)) dataset = tf.data.TFRecordDataset(str(filename)) - return dataset.map(decode_fn) + return dataset.map(decode) -def numpytf_count(filename: str | Path) -> Any: +@lru_cache +def numpytf_count(filename: str | Path) -> int: """Return count from TFRecord file.""" meta_filename = f"{filename}.meta" - with open(meta_filename, encoding="utf-8") as file: - return json.load(file)["count"] + try: + with open(meta_filename, encoding="utf-8") as file: + return int(json.load(file)["count"]) + except FileNotFoundError: + raw_dataset = tf.data.TFRecordDataset(filename) + return sum(1 for _ in raw_dataset) class NumpyTFWriter: @@ -101,93 +102,6 @@ class NumpyTFWriter: self.writer.close() -class TFLiteModel: - """A representation of a TFLite Model.""" - - def __init__( - self, - filename: str, - batch_size: int | None = None, - num_threads: int | None = None, - ) -> None: - """Initiate a TFLite Model.""" - if not num_threads: - num_threads = None - if not batch_size: - self.interpreter = interpreter_wrapper.Interpreter( - model_path=filename, num_threads=num_threads - ) - else: # if a batch size is specified, modify the TFLite model to use this size - with tempfile.TemporaryDirectory() as tmp: - flatbuffer = load(filename) - for subgraph in flatbuffer.subgraphs: - for tensor in list(subgraph.inputs) + list(subgraph.outputs): - subgraph.tensors[tensor].shape = np.array( - [batch_size] + list(subgraph.tensors[tensor].shape[1:]), - dtype=np.int32, - ) - tempname = os.path.join(tmp, "rewrite_tmp.tflite") - save(flatbuffer, tempname) - self.interpreter = interpreter_wrapper.Interpreter( - model_path=tempname, num_threads=num_threads - ) - - try: - self.interpreter.allocate_tensors() - except RuntimeError: - self.interpreter = interpreter_wrapper.Interpreter( - model_path=filename, num_threads=num_threads - ) - self.interpreter.allocate_tensors() - - # Get input and output tensors. - self.input_details = self.interpreter.get_input_details() - self.output_details = self.interpreter.get_output_details() - details = list(self.input_details) + list(self.output_details) - self.handle_from_name = {d["name"]: d["index"] for d in details} - self.shape_from_name = {d["name"]: d["shape"] for d in details} - self.batch_size = next(iter(self.shape_from_name.values()))[0] - - def __call__(self, named_input: dict) -> dict: - """Execute the model on one or a batch of named inputs \ - (a dict of name: numpy array).""" - input_len = next(iter(named_input.values())).shape[0] - full_steps = input_len // self.batch_size - remainder = input_len % self.batch_size - - named_ys = defaultdict(list) - for i in range(full_steps): - for name, x_batch in named_input.items(): - x_tensor = x_batch[i : i + self.batch_size] # noqa: E203 - self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) - self.interpreter.invoke() - for output_detail in self.output_details: - named_ys[output_detail["name"]].append( - self.interpreter.get_tensor(output_detail["index"]) - ) - if remainder: - for name, x_batch in named_input.items(): - x_tensor = np.zeros( # pylint: disable=invalid-name - self.shape_from_name[name] - ).astype(x_batch.dtype) - x_tensor[:remainder] = x_batch[-remainder:] - self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) - self.interpreter.invoke() - for output_detail in self.output_details: - named_ys[output_detail["name"]].append( - self.interpreter.get_tensor(output_detail["index"])[:remainder] - ) - return {k: np.concatenate(v) for k, v in named_ys.items()} - - def input_tensors(self) -> list: - """Return name from input details.""" - return [d["name"] for d in self.input_details] - - def output_tensors(self) -> list: - """Return name from output details.""" - return [d["name"] for d in self.output_details] - - def sample_tfrec(input_file: str, k: int, output_file: str) -> None: """Count, read and write TFRecord input and output data.""" total = numpytf_count(input_file) diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py index d930a1e..b7b390d 100644 --- a/src/mlia/nn/rewrite/core/utils/parallel.py +++ b/src/mlia/nn/rewrite/core/utils/parallel.py @@ -15,14 +15,14 @@ from typing import Any import numpy as np import tensorflow as tf -from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel +from mlia.nn.tensorflow.config import TFLiteModel os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) -class ParallelTFLiteModel(TFLiteModel): +class ParallelTFLiteModel(TFLiteModel): # pylint: disable=abstract-method """A parallel version of a TFLiteModel. num_procs: 0 => detect real cores on system diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py deleted file mode 100644 index ddf0cc2..0000000 --- a/src/mlia/nn/rewrite/core/utils/utils.py +++ /dev/null @@ -1,32 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Model and file system utilites.""" -from __future__ import annotations - -from pathlib import Path - -import flatbuffers -from tensorflow.lite.python.schema_py_generated import Model -from tensorflow.lite.python.schema_py_generated import ModelT - - -def load(input_tflite_file: str | Path) -> ModelT: - """Load a flatbuffer model from file.""" - if not Path(input_tflite_file).exists(): - raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n") - with open(input_tflite_file, "rb") as file_handle: - file_data = bytearray(file_handle.read()) - model_obj = Model.GetRootAsModel(file_data, 0) - model = ModelT.InitFromObj(model_obj) - return model - - -def save(model: ModelT, output_tflite_file: str | Path) -> None: - """Save a flatbuffer model to a given file.""" - builder = flatbuffers.Builder(1024) # Initial size of the buffer, which - # will grow automatically if needed - model_offset = model.Pack(builder) - builder.Finish(model_offset, file_identifier=b"TFL3") - model_data = builder.Output() - with open(output_tflite_file, "wb") as out_file: - out_file.write(model_data) diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py index 8704154..2480500 100644 --- a/src/mlia/nn/rewrite/library/fc_layer.py +++ b/src/mlia/nn/rewrite/library/fc_layer.py @@ -7,12 +7,12 @@ import tensorflow as tf def get_keras_model(input_shape: Any, output_shape: Any) -> tf.keras.Model: - """Generate tflite model for rewrite.""" - input_tensor = tf.keras.layers.Input( - shape=input_shape, name="MbileNet/avg_pool/AvgPool" + """Generate TensorFlow Lite model for rewrite.""" + model = tf.keras.Sequential( + ( + tf.keras.layers.InputLayer(input_shape=input_shape), + tf.keras.layers.Reshape([-1]), + tf.keras.layers.Dense(output_shape), + ) ) - output_tensor = tf.keras.layers.Dense(output_shape, name="MobileNet/fc1/BiasAdd")( - input_tensor - ) - model = tf.keras.Model(input_tensor, output_tensor) return model diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 5a7f289..983426b 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -17,6 +17,7 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import Rewriter +from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.tensorflow.config import KerasModel from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.optimizations.clustering import Clusterer @@ -164,6 +165,15 @@ def _get_optimizer( return MultiStageOptimizer(model, optimizer_configs) +def _get_rewrite_train_params() -> TrainingParameters: + """Get the rewrite TrainingParameters. + + Return the default constructed TrainingParameters() per default, but can be + overwritten in the unit tests. + """ + return TrainingParameters() + + def _get_optimizer_configuration( optimization_type: str, optimization_target: int | float | str, @@ -190,7 +200,10 @@ def _get_optimizer_configuration( if opt_type == "rewrite": if isinstance(optimization_target, str): return RewriteConfiguration( - str(optimization_target), layers_to_optimize, dataset + optimization_target=str(optimization_target), + layers_to_optimize=layers_to_optimize, + dataset=dataset, + train_params=_get_rewrite_train_params(), ) raise ConfigurationError( diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index d7d430f..c6a7c88 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -4,13 +4,16 @@ from __future__ import annotations import logging +import tempfile +from collections import defaultdict from pathlib import Path -from typing import cast -from typing import List +import numpy as np import tensorflow as tf from mlia.core.context import Context +from mlia.nn.tensorflow.tflite_graph import load_fb +from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_saved_model @@ -71,10 +74,89 @@ class KerasModel(ModelConfiguration): class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method """TensorFlow Lite model configuration.""" - def input_details(self) -> list[dict]: - """Get model's input details.""" - interpreter = tf.lite.Interpreter(model_path=self.model_path) - return cast(List[dict], interpreter.get_input_details()) + def __init__( + self, + model_path: str | Path, + batch_size: int | None = None, + num_threads: int | None = None, + ) -> None: + """Initiate a TFLite Model.""" + super().__init__(model_path) + if not num_threads: + num_threads = None + if not batch_size: + self.interpreter = tf.lite.Interpreter( + model_path=self.model_path, num_threads=num_threads + ) + else: # if a batch size is specified, modify the TFLite model to use this size + with tempfile.TemporaryDirectory() as tmp: + flatbuffer = load_fb(self.model_path) + for subgraph in flatbuffer.subgraphs: + for tensor in list(subgraph.inputs) + list(subgraph.outputs): + subgraph.tensors[tensor].shape = np.array( + [batch_size] + list(subgraph.tensors[tensor].shape[1:]), + dtype=np.int32, + ) + tempname = Path(tmp, "rewrite_tmp.tflite") + save_fb(flatbuffer, tempname) + self.interpreter = tf.lite.Interpreter( + model_path=str(tempname), num_threads=num_threads + ) + + try: + self.interpreter.allocate_tensors() + except RuntimeError: + self.interpreter = tf.lite.Interpreter( + model_path=self.model_path, num_threads=num_threads + ) + self.interpreter.allocate_tensors() + + # Get input and output tensors. + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + details = list(self.input_details) + list(self.output_details) + self.handle_from_name = {d["name"]: d["index"] for d in details} + self.shape_from_name = {d["name"]: d["shape"] for d in details} + self.batch_size = next(iter(self.shape_from_name.values()))[0] + + def __call__(self, named_input: dict) -> dict: + """Execute the model on one or a batch of named inputs \ + (a dict of name: numpy array).""" + input_len = next(iter(named_input.values())).shape[0] + full_steps = input_len // self.batch_size + remainder = input_len % self.batch_size + + named_ys = defaultdict(list) + for i in range(full_steps): + for name, x_batch in named_input.items(): + x_tensor = x_batch[i : i + self.batch_size] # noqa: E203 + self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) + self.interpreter.invoke() + for output_detail in self.output_details: + named_ys[output_detail["name"]].append( + self.interpreter.get_tensor(output_detail["index"]) + ) + if remainder: + for name, x_batch in named_input.items(): + x_tensor = np.zeros( # pylint: disable=invalid-name + self.shape_from_name[name] + ).astype(x_batch.dtype) + x_tensor[:remainder] = x_batch[-remainder:] + self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) + self.interpreter.invoke() + for output_detail in self.output_details: + named_ys[output_detail["name"]].append( + self.interpreter.get_tensor(output_detail["index"])[:remainder] + ) + return {k: np.concatenate(v) for k, v in named_ys.items()} + + def input_tensors(self) -> list: + """Return name from input details.""" + return [d["name"] for d in self.input_details] + + def output_tensors(self) -> list: + """Return name from output details.""" + return [d["name"] for d in self.output_details] def convert_to_tflite( self, tflite_model_path: str | Path, quantized: bool = False @@ -118,10 +200,10 @@ def get_model(model: str | Path) -> ModelConfiguration: def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel: """Convert input model to TensorFlow Lite and returns TFLiteModel object.""" - tflite_model_path = ctx.get_model_path("converted_model.tflite") - converted_model = get_model(model) + dst_model_path = ctx.get_model_path("converted_model.tflite") + src_model = get_model(model) - return converted_model.convert_to_tflite(tflite_model_path, True) + return src_model.convert_to_tflite(dst_model_path, quantized=True) def get_keras_model(model: str | Path, ctx: Context) -> KerasModel: diff --git a/src/mlia/nn/tensorflow/tflite_graph.py b/src/mlia/nn/tensorflow/tflite_graph.py index 4f5e85f..7ca9337 100644 --- a/src/mlia/nn/tensorflow/tflite_graph.py +++ b/src/mlia/nn/tensorflow/tflite_graph.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Utilities for TensorFlow Lite graphs.""" from __future__ import annotations @@ -10,7 +10,10 @@ from pathlib import Path from typing import Any from typing import cast +import flatbuffers from tensorflow.lite.python import schema_py_generated as schema_fb +from tensorflow.lite.python.schema_py_generated import Model +from tensorflow.lite.python.schema_py_generated import ModelT from tensorflow.lite.tools import visualize @@ -137,3 +140,25 @@ def parse_subgraphs(tflite_file: Path) -> list[list[Op]]: ] return graphs + + +def load_fb(input_tflite_file: str | Path) -> ModelT: + """Load a flatbuffer model from file.""" + if not Path(input_tflite_file).exists(): + raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n") + with open(input_tflite_file, "rb") as file_handle: + file_data = bytearray(file_handle.read()) + model_obj = Model.GetRootAsModel(file_data, 0) + model = ModelT.InitFromObj(model_obj) + return model + + +def save_fb(model: ModelT, output_tflite_file: str | Path) -> None: + """Save a flatbuffer model to a given file.""" + builder = flatbuffers.Builder(1024) # Initial size of the buffer, which + # will grow automatically if needed + model_offset = model.Pack(builder) + builder.Finish(model_offset, file_identifier=b"TFL3") + model_data = builder.Output() + with open(output_tflite_file, "wb") as out_file: + out_file.write(model_data) diff --git a/src/mlia/target/ethos_u/data_collection.py b/src/mlia/target/ethos_u/data_collection.py index ba8b0fe..4ea6120 100644 --- a/src/mlia/target/ethos_u/data_collection.py +++ b/src/mlia/target/ethos_u/data_collection.py @@ -106,15 +106,14 @@ class OptimizeModel: self.context = context self.opt_settings = opt_settings - def __call__(self, keras_model: KerasModel) -> Any: + def __call__(self, model: KerasModel | TFLiteModel) -> Any: """Run optimization.""" - optimizer = get_optimizer(keras_model, self.opt_settings) + optimizer = get_optimizer(model, self.opt_settings) opts_as_str = ", ".join(str(opt) for opt in self.opt_settings) logger.info("Applying model optimizations - [%s]", opts_as_str) optimizer.apply_optimization() - - model = optimizer.get_model() + model = optimizer.get_model() # type: ignore if isinstance(model, Path): return model @@ -178,6 +177,7 @@ class EthosUOptimizationPerformance(ContextAwareDataCollector): self.target, self.backends, ) + original_metrics, *optimized_metrics = estimate_performance( model, estimator, optimizers # type: ignore ) diff --git a/tests/conftest.py b/tests/conftest.py index c42b8cb..bb2423f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import shutil from pathlib import Path from typing import Callable from typing import Generator +from unittest.mock import MagicMock import numpy as np import pytest @@ -17,6 +18,7 @@ from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import save_keras_model from mlia.nn.tensorflow.utils import save_tflite_model from mlia.target.ethos_u.config import EthosUConfiguration +from tests.utils.rewrite import TestTrainingParameters @pytest.fixture(scope="session", name="test_resources_path") @@ -168,16 +170,12 @@ def _write_tfrecord( writer.write({input_name: data_generator()}) -@pytest.fixture(scope="session", name="test_tfrecord") -def fixture_test_tfrecord( - tmp_path_factory: pytest.TempPathFactory, +def create_tfrecord( + tmp_path_factory: pytest.TempPathFactory, random_data: Callable ) -> Generator[Path, None, None]: """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" tmp_path = tmp_path_factory.mktemp("tfrecords") - tfrecord_file = tmp_path / "test_int8.tfrecord" - - def random_data() -> np.ndarray: - return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) + tfrecord_file = tmp_path / "test.tfrecord" _write_tfrecord(tfrecord_file, random_data) @@ -186,19 +184,36 @@ def fixture_test_tfrecord( shutil.rmtree(tmp_path) +@pytest.fixture(scope="session", name="test_tfrecord") +def fixture_test_tfrecord( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create a tfrecord with random data matching fixture 'test_tflite_model'.""" + + def random_data() -> np.ndarray: + return np.random.randint(low=-127, high=128, size=(1, 28, 28, 1), dtype=np.int8) + + yield from create_tfrecord(tmp_path_factory, random_data) + + @pytest.fixture(scope="session", name="test_tfrecord_fp32") def fixture_test_tfrecord_fp32( tmp_path_factory: pytest.TempPathFactory, ) -> Generator[Path, None, None]: """Create tfrecord with random data matching fixture 'test_tflite_model_fp32'.""" - tmp_path = tmp_path_factory.mktemp("tfrecords") - tfrecord_file = tmp_path / "test_fp32.tfrecord" def random_data() -> np.ndarray: return np.random.rand(1, 28, 28, 1).astype(np.float32) - _write_tfrecord(tfrecord_file, random_data) + yield from create_tfrecord(tmp_path_factory, random_data) - yield tfrecord_file - shutil.rmtree(tmp_path) +@pytest.fixture(scope="session", autouse=True) +def set_training_steps() -> Generator[None, None, None]: + """Speed up tests by using TestTrainingParameters.""" + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "mlia.nn.select._get_rewrite_train_params", + MagicMock(return_value=TestTrainingParameters()), + ) + yield diff --git a/tests/test_nn_rewrite_core_graph_edit_join.py b/tests/test_nn_rewrite_core_graph_edit_join.py index cb3e4e2..0cb121e 100644 --- a/tests/test_nn_rewrite_core_graph_edit_join.py +++ b/tests/test_nn_rewrite_core_graph_edit_join.py @@ -10,7 +10,7 @@ import pytest from mlia.nn.rewrite.core.graph_edit.cut import cut_model from mlia.nn.rewrite.core.graph_edit.join import append_relabel from mlia.nn.rewrite.core.graph_edit.join import join_models -from mlia.nn.rewrite.core.utils.utils import load +from mlia.nn.tensorflow.tflite_graph import load_fb from tests.utils.rewrite import models_are_equal @@ -49,8 +49,8 @@ def test_join_model(test_tflite_model: Path, tmp_path: Path) -> None: ) assert joined_file.is_file() - orig_model = load(str(test_tflite_model)) - joined_model = load(str(joined_file)) + orig_model = load_fb(str(test_tflite_model)) + joined_model = load_fb(str(joined_file)) assert models_are_equal(orig_model, joined_model) diff --git a/tests/test_nn_rewrite_core_graph_edit_record.py b/tests/test_nn_rewrite_core_graph_edit_record.py index cd728af..41b9c50 100644 --- a/tests/test_nn_rewrite_core_graph_edit_record.py +++ b/tests/test_nn_rewrite_core_graph_edit_record.py @@ -3,15 +3,13 @@ """Tests for module mlia.nn.rewrite.graph_edit.record.""" from pathlib import Path -import pytest import tensorflow as tf from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read -@pytest.mark.parametrize("batch_size", (None, 1, 2)) -def test_record_model( +def check_record_model( test_tflite_model: Path, tmp_path: Path, test_tfrecord: Path, diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index b98971e..2542db2 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -12,6 +12,7 @@ import pytest from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import Rewriter from mlia.nn.tensorflow.config import TFLiteModel +from tests.utils.rewrite import TestTrainingParameters @pytest.mark.parametrize( @@ -32,12 +33,14 @@ def test_rewrite_configuration( None, ) + assert config_obj.optimization_target in str(config_obj) + rewriter_obj = Rewriter(test_tflite_model_fp32, config_obj) assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name assert isinstance(rewriter_obj, Rewriter) -def test_rewriter( +def test_rewriting_optimizer( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, ) -> None: @@ -46,6 +49,7 @@ def test_rewriter( "fully_connected", ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], test_tfrecord_fp32, + train_params=TestTrainingParameters(), ) test_obj = Rewriter(test_tflite_model_fp32, config_obj) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 3c2ef3e..4493671 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -4,6 +4,7 @@ # pylint: disable=too-many-arguments from __future__ import annotations +from contextlib import ExitStack as does_not_raise from pathlib import Path from tempfile import TemporaryDirectory from typing import Any @@ -12,10 +13,13 @@ import numpy as np import pytest import tensorflow as tf -from mlia.nn.rewrite.core.train import augmentation_presets +from mlia.nn.rewrite.core.train import augment_fn_twins +from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train +from mlia.nn.rewrite.core.train import TrainingParameters +from tests.utils.rewrite import TestTrainingParameters def replace_fully_connected_with_conv( @@ -41,15 +45,8 @@ def replace_fully_connected_with_conv( def check_train( tflite_model: Path, tfrecord: Path, - batch_size: int = 1, - verbose: bool = False, - show_progress: bool = False, - augmentation_preset: tuple[float | None, float | None] = augmentation_presets[ - "none" - ], - lr_schedule: LearningRateSchedule = "cosine", + train_params: TrainingParameters = TestTrainingParameters(), use_unmodified_model: bool = False, - num_procs: int = 1, ) -> None: """Test the train() function.""" with TemporaryDirectory() as tmp_dir: @@ -62,14 +59,7 @@ def check_train( replace_fn=replace_fully_connected_with_conv, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], - augment=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=batch_size, - verbose=verbose, - show_progress=show_progress, - learning_rate_schedule=lr_schedule, - num_procs=num_procs, + train_params=train_params, ) assert len(result) == 2 assert all(res >= 0.0 for res in result), f"Results out of bound: {result}" @@ -79,7 +69,6 @@ def check_train( @pytest.mark.parametrize( ( "batch_size", - "verbose", "show_progress", "augmentation_preset", "lr_schedule", @@ -87,14 +76,13 @@ def check_train( "num_procs", ), ( - (1, False, False, augmentation_presets["none"], "cosine", False, 2), - (32, True, True, augmentation_presets["gaussian"], "late", True, 1), - (2, False, False, augmentation_presets["mixup"], "constant", True, 0), + (1, False, AUGMENTATION_PRESETS["none"], "cosine", False, 2), + (32, True, AUGMENTATION_PRESETS["gaussian"], "late", True, 1), + (2, False, AUGMENTATION_PRESETS["mixup"], "constant", True, 0), ( 1, False, - False, - augmentation_presets["mix_gaussian_large"], + AUGMENTATION_PRESETS["mix_gaussian_large"], "cosine", False, 2, @@ -105,7 +93,6 @@ def test_train( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, batch_size: int, - verbose: bool, show_progress: bool, augmentation_preset: tuple[float | None, float | None], lr_schedule: LearningRateSchedule, @@ -116,13 +103,14 @@ def test_train( check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - batch_size=batch_size, - verbose=verbose, - show_progress=show_progress, - augmentation_preset=augmentation_preset, - lr_schedule=lr_schedule, + train_params=TestTrainingParameters( + batch_size=batch_size, + show_progress=show_progress, + augmentations=augmentation_preset, + learning_rate_schedule=lr_schedule, + num_procs=num_procs, + ), use_unmodified_model=use_unmodified_model, - num_procs=num_procs, ) @@ -131,11 +119,13 @@ def test_train_invalid_schedule( test_tfrecord_fp32: Path, ) -> None: """Test the train() function with an invalid schedule.""" - with pytest.raises(ValueError): + with pytest.raises(AssertionError): check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - lr_schedule="unknown_schedule", # type: ignore + train_params=TestTrainingParameters( + learning_rate_schedule="unknown_schedule", + ), ) @@ -144,11 +134,13 @@ def test_train_invalid_augmentation( test_tfrecord_fp32: Path, ) -> None: """Test the train() function with an invalid augmentation.""" - with pytest.raises(ValueError): + with pytest.raises(AssertionError): check_train( tflite_model=test_tflite_model_fp32, tfrecord=test_tfrecord_fp32, - augmentation_preset=(1.0, 2.0, 3.0), # type: ignore + train_params=TestTrainingParameters( + augmentations=(1.0, 2.0, 3.0), + ), ) @@ -159,3 +151,19 @@ def test_mixup() -> None: assert src.shape == dst.shape assert np.all(dst >= 0.0) assert np.all(dst <= 3.0) + + +@pytest.mark.parametrize( + "augmentations, expected_error", + [ + (AUGMENTATION_PRESETS["none"], does_not_raise()), + (AUGMENTATION_PRESETS["mix_gaussian_large"], does_not_raise()), + ((None,) * 3, pytest.raises(AssertionError)), + ], +) +def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None: + """Test function augment_fn().""" + dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2, 3], "b": [4, 5, 6]}) + with expected_error: + fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore + assert len(fn_twins) == 2 diff --git a/tests/test_nn_rewrite_core_utils.py b/tests/test_nn_rewrite_core_utils.py deleted file mode 100644 index d806a7b..0000000 --- a/tests/test_nn_rewrite_core_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for module mlia.nn.rewrite.utils.""" -from pathlib import Path - -import pytest -import tensorflow as tf -from tensorflow.lite.python.schema_py_generated import ModelT - -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save -from tests.utils.rewrite import models_are_equal - - -def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None: - """Test the load/save functions for TensorFlow Lite models.""" - with pytest.raises(FileNotFoundError): - load("THIS_IS_NOT_A_REAL_FILE") - - model = load(test_tflite_model) - assert isinstance(model, ModelT) - assert model.subgraphs - - output_file = tmp_path / "test.tflite" - assert not output_file.is_file() - save(model, output_file) - assert output_file.is_file() - - model_copy = load(str(output_file)) - assert models_are_equal(model, model_copy) - - # Double check that the TensorFlow Lite Interpreter can still load the file. - tf.lite.Interpreter(model_path=str(output_file)) diff --git a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py index 7fc8048..d030350 100644 --- a/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py +++ b/tests/test_nn_rewrite_core_utils_numpy_tfrecord.py @@ -5,6 +5,10 @@ from __future__ import annotations from pathlib import Path +import pytest +import tensorflow as tf + +from mlia.nn.rewrite.core.utils.numpy_tfrecord import make_decode_fn from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import sample_tfrec @@ -16,3 +20,24 @@ def test_sample_tfrec(test_tfrecord: Path, tmp_path: Path) -> None: sample_tfrec(input_file=str(test_tfrecord), k=1, output_file=str(output_file)) assert output_file.is_file() assert numpytf_count(str(output_file)) == 1 + + +def test_make_decode_fn(test_tfrecord: Path) -> None: + """Test function make_decode_fn().""" + decode = make_decode_fn(str(test_tfrecord)) + dataset = tf.data.TFRecordDataset(str(test_tfrecord)) + features = decode(next(iter(dataset))) + assert isinstance(features, dict) + assert len(features) == 1 + key, val = next(iter(features.items())) + assert isinstance(key, str) + assert isinstance(val, tf.Tensor) + assert val.dtype == tf.int8 + + with pytest.raises(FileNotFoundError): + make_decode_fn(str(test_tfrecord) + "_") + + +def test_numpytf_count(test_tfrecord: Path) -> None: + """Test function numpytf_count().""" + assert numpytf_count(test_tfrecord) == 3 diff --git a/tests/test_nn_tensorflow_config.py b/tests/test_nn_tensorflow_config.py index 656619d..48aec0a 100644 --- a/tests/test_nn_tensorflow_config.py +++ b/tests/test_nn_tensorflow_config.py @@ -4,13 +4,28 @@ from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any +from typing import Generator +import numpy as np import pytest +from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read from mlia.nn.tensorflow.config import get_model from mlia.nn.tensorflow.config import KerasModel +from mlia.nn.tensorflow.config import ModelConfiguration from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.config import TfModel +from tests.conftest import create_tfrecord + + +def test_model_configuration(test_keras_model: Path) -> None: + """Test ModelConfiguration class.""" + model = ModelConfiguration(model_path=test_keras_model) + assert test_keras_model.match(model.model_path) + with pytest.raises(NotImplementedError): + model.convert_to_keras("keras_model.h5") + with pytest.raises(NotImplementedError): + model.convert_to_tflite("model.tflite") def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None: @@ -38,7 +53,7 @@ def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None: @pytest.mark.parametrize( "model_path, expected_type, expected_error", [ - ("test.tflite", TFLiteModel, does_not_raise()), + ("test.tflite", TFLiteModel, pytest.raises(ValueError)), ("test.h5", KerasModel, does_not_raise()), ("test.hdf5", KerasModel, does_not_raise()), ( @@ -73,3 +88,26 @@ def test_get_model_dir( """Test TensorFlow Lite model type.""" model = get_model(str(test_models_path / model_path)) assert isinstance(model, expected_type) + + +@pytest.fixture(scope="session", name="test_tfrecord_fp32_batch_3") +def fixture_test_tfrecord_fp32_batch_3( + tmp_path_factory: pytest.TempPathFactory, +) -> Generator[Path, None, None]: + """Create tfrecord (same as test_tfrecord_fp32) but with batch size 3.""" + + def random_data() -> np.ndarray: + return np.random.rand(3, 28, 28, 1).astype(np.float32) + + yield from create_tfrecord(tmp_path_factory, random_data) + + +def test_tflite_model_call( + test_tflite_model_fp32: Path, test_tfrecord_fp32_batch_3: Path +) -> None: + """Test inference function of class TFLiteModel.""" + model = TFLiteModel(test_tflite_model_fp32, batch_size=2) + data = numpytf_read(test_tfrecord_fp32_batch_3) + for named_input in data.as_numpy_iterator(): + res = model(named_input) + assert res diff --git a/tests/test_nn_tensorflow_tflite_graph.py b/tests/test_nn_tensorflow_tflite_graph.py index cd1fad6..3512cdd 100644 --- a/tests/test_nn_tensorflow_tflite_graph.py +++ b/tests/test_nn_tensorflow_tflite_graph.py @@ -1,15 +1,22 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the tflite_graph module.""" import json from pathlib import Path +import pytest +import tensorflow as tf +from tensorflow.lite.python.schema_py_generated import ModelT + +from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import Op from mlia.nn.tensorflow.tflite_graph import parse_subgraphs +from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.tflite_graph import TensorInfo from mlia.nn.tensorflow.tflite_graph import TFL_ACTIVATION_FUNCTION from mlia.nn.tensorflow.tflite_graph import TFL_OP from mlia.nn.tensorflow.tflite_graph import TFL_TYPE +from tests.utils.rewrite import models_are_equal def test_tensor_info() -> None: @@ -79,3 +86,24 @@ def test_parse_subgraphs(test_tflite_model: Path) -> None: assert TFL_OP[oper.type] in TFL_OP assert len(oper.inputs) > 0 assert len(oper.outputs) > 0 + + +def test_load_save(test_tflite_model: Path, tmp_path: Path) -> None: + """Test the load/save functions for TensorFlow Lite models.""" + with pytest.raises(FileNotFoundError): + load_fb("THIS_IS_NOT_A_REAL_FILE") + + model = load_fb(test_tflite_model) + assert isinstance(model, ModelT) + assert model.subgraphs + + output_file = tmp_path / "test.tflite" + assert not output_file.is_file() + save_fb(model, output_file) + assert output_file.is_file() + + model_copy = load_fb(str(output_file)) + assert models_are_equal(model, model_copy) + + # Double check that the TensorFlow Lite Interpreter can still load the file. + tf.lite.Interpreter(model_path=str(output_file)) diff --git a/tests/utils/rewrite.py b/tests/utils/rewrite.py index 4264b4b..739bb11 100644 --- a/tests/utils/rewrite.py +++ b/tests/utils/rewrite.py @@ -3,8 +3,12 @@ """Common test utils for the rewrite tests.""" from __future__ import annotations +from typing import Any + from tensorflow.lite.python.schema_py_generated import ModelT +from mlia.nn.rewrite.core.train import TrainingParameters + def models_are_equal(model1: ModelT, model2: ModelT) -> bool: """Check that the two models are equal.""" @@ -25,3 +29,17 @@ def models_are_equal(model1: ModelT, model2: ModelT) -> bool: return False # Tensor from graph1 not found in other graph.") return True + + +class TestTrainingParameters( + TrainingParameters +): # pylint: disable=too-few-public-methods + """ + TrainingParameter class for rewrites with different default values. + + To speed things up for the unit tests. + """ + + def __init__(self, *args: Any, steps: int = 32, **kwargs: Any) -> None: + """Initialize TrainingParameters with different defaults.""" + super().__init__(*args, steps=steps, **kwargs) # type: ignore |