diff options
Diffstat (limited to 'src/mlia/nn/rewrite')
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/cut.py | 8 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/graph_edit/join.py | 10 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 77 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 306 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py | 138 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/parallel.py | 4 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/utils/utils.py | 32 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/fc_layer.py | 14 |
8 files changed, 267 insertions, 322 deletions
diff --git a/src/mlia/nn/rewrite/core/graph_edit/cut.py b/src/mlia/nn/rewrite/core/graph_edit/cut.py index 2707eb1..13a5268 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/cut.py +++ b/src/mlia/nn/rewrite/core/graph_edit/cut.py @@ -9,8 +9,8 @@ import tensorflow as tf from tensorflow.lite.python.schema_py_generated import ModelT from tensorflow.lite.python.schema_py_generated import SubGraphT -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save +from mlia.nn.tensorflow.tflite_graph import load_fb +from mlia.nn.tensorflow.tflite_graph import save_fb os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) @@ -138,8 +138,8 @@ def cut_model( output_file: str, ) -> None: """Cut subgraphs and simplify a given model.""" - model = load(model_file) + model = load_fb(model_file) subgraph = model.subgraphs[subgraph_index] cut_subgraph(subgraph, input_names, output_names) simplify(model) - save(model, output_file) + save_fb(model, output_file) diff --git a/src/mlia/nn/rewrite/core/graph_edit/join.py b/src/mlia/nn/rewrite/core/graph_edit/join.py index 2530ec8..70109d8 100644 --- a/src/mlia/nn/rewrite/core/graph_edit/join.py +++ b/src/mlia/nn/rewrite/core/graph_edit/join.py @@ -11,8 +11,8 @@ from tensorflow.lite.python.schema_py_generated import ModelT from tensorflow.lite.python.schema_py_generated import OperatorCodeT from tensorflow.lite.python.schema_py_generated import SubGraphT -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save +from mlia.nn.tensorflow.tflite_graph import load_fb +from mlia.nn.tensorflow.tflite_graph import save_fb os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) @@ -26,12 +26,12 @@ def join_models( subgraph_dst: int = 0, ) -> None: """Join two models and save the result into a given model file path.""" - src_model = load(input_src) - dst_model = load(input_dst) + src_model = load_fb(input_src) + dst_model = load_fb(input_dst) src_subgraph = src_model.subgraphs[subgraph_src] dst_subgraph = dst_model.subgraphs[subgraph_dst] join_subgraphs(src_model, src_subgraph, dst_model, dst_subgraph) - save(dst_model, output_file) + save_fb(dst_model, output_file) def join_subgraphs( diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 0d182df..6b27984 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -4,6 +4,7 @@ from __future__ import annotations import importlib +import logging import tempfile from dataclasses import dataclass from pathlib import Path @@ -12,13 +13,14 @@ from typing import Any from mlia.core.errors import ConfigurationError from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration -from mlia.nn.rewrite.core.train import eval_in_dir -from mlia.nn.rewrite.core.train import join_in_dir from mlia.nn.rewrite.core.train import train -from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.core.train import TrainingParameters from mlia.nn.tensorflow.config import TFLiteModel +logger = logging.getLogger(__name__) + + @dataclass class RewriteConfiguration(OptimizerConfiguration): """Rewrite configuration.""" @@ -26,6 +28,7 @@ class RewriteConfiguration(OptimizerConfiguration): optimization_target: str layers_to_optimize: list[str] | None = None dataset: Path | None = None + train_params: TrainingParameters = TrainingParameters() def __str__(self) -> str: """Return string representation of the configuration.""" @@ -40,8 +43,8 @@ class Rewriter(Optimizer): ): """Init Rewriter instance.""" self.model = TFLiteModel(tflite_model_path) + self.model_path = tflite_model_path self.optimizer_configuration = optimizer_configuration - self.train_dir = "" def apply_optimization(self) -> None: """Apply the rewrite flow.""" @@ -61,50 +64,36 @@ class Rewriter(Optimizer): replace_fn = get_function(replace_function) - augmentation_preset = (None, None) use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_output = Path(tmp_dir, "output.tflite") - - if self.train_dir: - tmp_new = Path(tmp_dir, "new.tflite") - new_part = train_in_dir( - train_dir=self.train_dir, - baseline_dir=None, - output_filename=tmp_new, - replace_fn=replace_fn, - augmentations=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=1, - verbose=True, - show_progress=True, - ) - eval_in_dir(self.train_dir, new_part[0]) - join_in_dir(self.train_dir, new_part[0], str(tmp_output)) - else: - if not self.optimizer_configuration.layers_to_optimize: - raise ConfigurationError( - "Input and output tensor names need to be set for rewrite." - ) - train( - source_model=tflite_model, - unmodified_model=tflite_model if use_unmodified_model else None, - output_model=str(tmp_output), - input_tfrec=str(tfrecord), - replace_fn=replace_fn, - input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], - output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], - augment=augmentation_preset, - steps=32, - learning_rate=1e-3, - batch_size=1, - verbose=True, - show_progress=True, - ) + tmp_dir = tempfile.mkdtemp() + tmp_output = Path(tmp_dir, "output.tflite") + + if not self.optimizer_configuration.layers_to_optimize: + raise ConfigurationError( + "Input and output tensor names need to be set for rewrite." + ) + result = train( + source_model=tflite_model, + unmodified_model=tflite_model if use_unmodified_model else None, + output_model=str(tmp_output), + input_tfrec=str(tfrecord), + replace_fn=replace_fn, + input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], + output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], + train_params=self.optimizer_configuration.train_params, + ) + + self.model = TFLiteModel(tmp_output) + + if result: + stats_as_str = ", ".join(str(stats) for stats in result) + logger.info( + "The MAE and NRMSE between original and replacement [%s]", + stats_as_str, + ) def get_model(self) -> TFLiteModel: """Return optimized model.""" diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index c8497a4..42bf653 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,8 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Sequential trainer.""" -# pylint: disable=too-many-arguments, too-many-instance-attributes, -# pylint: disable=too-many-locals, too-many-branches, too-many-statements +# pylint: disable=too-many-locals +# pylint: disable=too-many-statements from __future__ import annotations import logging @@ -10,10 +10,13 @@ import math import os import tempfile from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path from typing import Any from typing import Callable from typing import cast +from typing import Generator as GeneratorType from typing import get_args from typing import Literal @@ -27,10 +30,10 @@ from mlia.nn.rewrite.core.graph_edit.join import join_models from mlia.nn.rewrite.core.graph_edit.record import record_model from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_count from mlia.nn.rewrite.core.utils.numpy_tfrecord import numpytf_read -from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel from mlia.nn.rewrite.core.utils.parallel import ParallelTFLiteModel -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save +from mlia.nn.tensorflow.config import TFLiteModel +from mlia.nn.tensorflow.tflite_graph import load_fb +from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.utils.logging import log_action @@ -38,7 +41,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) -augmentation_presets = { +AUGMENTATION_PRESETS = { "none": (None, None), "gaussian": (None, 1.0), "mixup": (1.0, None), @@ -51,6 +54,21 @@ LearningRateSchedule = Literal["cosine", "late", "constant"] LEARNING_RATE_SCHEDULES = get_args(LearningRateSchedule) +@dataclass +class TrainingParameters: + """Define default parameters for the training.""" + + augmentations: tuple[float | None, float | None] = AUGMENTATION_PRESETS["gaussian"] + batch_size: int = 32 + steps: int = 48000 + learning_rate: float = 1e-3 + learning_rate_schedule: LearningRateSchedule = "cosine" + num_procs: int = 1 + num_threads: int = 0 + show_progress: bool = True + checkpoint_at: list | None = None + + def train( source_model: str, unmodified_model: Any, @@ -59,16 +77,7 @@ def train( replace_fn: Callable, input_tensors: list, output_tensors: list, - augment: tuple[float | None, float | None], - steps: int, - learning_rate: float, - batch_size: int, - verbose: bool, - show_progress: bool, - learning_rate_schedule: LearningRateSchedule = "cosine", - checkpoint_at: list | None = None, - num_procs: int = 1, - num_threads: int = 0, + train_params: TrainingParameters = TrainingParameters(), ) -> Any: """Extract and train a model, and return the results.""" if unmodified_model: @@ -95,29 +104,27 @@ def train( input_tfrec, input_tensors, output_tensors, - num_procs=num_procs, - num_threads=num_threads, + num_procs=train_params.num_procs, + num_threads=train_params.num_threads, ) tflite_filenames = train_in_dir( - train_dir, - unmodified_model_dir_path, - Path(train_dir, "new.tflite"), - replace_fn, - augment, - steps, - learning_rate, - batch_size, - checkpoint_at=checkpoint_at, - verbose=verbose, - show_progress=show_progress, - num_procs=num_procs, - num_threads=num_threads, - schedule=learning_rate_schedule, + train_dir=train_dir, + baseline_dir=unmodified_model_dir_path, + output_filename=Path(train_dir, "new.tflite"), + replace_fn=replace_fn, + train_params=train_params, ) for i, filename in enumerate(tflite_filenames): - results.append(eval_in_dir(train_dir, filename, num_procs, num_threads)) + results.append( + eval_in_dir( + train_dir, + filename, + train_params.num_procs, + train_params.num_threads, + ) + ) if output_model: if i + 1 < len(tflite_filenames): @@ -133,7 +140,7 @@ def train( cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() return ( - results if checkpoint_at else results[0] + results if train_params.checkpoint_at else results[0] ) # only return a list if multiple checkpoints are asked for @@ -176,46 +183,24 @@ def join_in_dir(model_dir: str, new_part: str, output_model: str) -> None: join_models(Path(model_dir, "start.tflite"), new_end, output_model) -def train_in_dir( - train_dir: str, - baseline_dir: Any, - output_filename: Path, - replace_fn: Callable, - augmentations: tuple[float | None, float | None], - steps: int, - learning_rate: float = 1e-3, - batch_size: int = 32, - checkpoint_at: list | None = None, - schedule: str = "cosine", - verbose: bool = False, - show_progress: bool = False, - num_procs: int = 0, - num_threads: int = 1, -) -> list: - """Train a replacement for replace.tflite using the input.tfrec \ - and output.tfrec in train_dir. - - If baseline_dir is provided, train the replacement to match baseline - outputs for train_dir inputs. Result saved as new.tflite in train_dir. - """ - teacher_dir = baseline_dir if baseline_dir else train_dir - teacher = ParallelTFLiteModel( - f"{teacher_dir}/replace.tflite", num_procs, num_threads, batch_size=batch_size - ) - replace = TFLiteModel(f"{train_dir}/replace.tflite") +def _get_io_tensors(model: TFLiteModel) -> tuple[str, str]: assert ( - len(teacher.input_tensors()) == 1 + len(model.input_tensors()) == 1 ), f"Can only train replacements with a single input tensor right now, \ - found {teacher.input_tensors()}" + found {model.input_tensors()}" assert ( - len(teacher.output_tensors()) == 1 + len(model.output_tensors()) == 1 ), f"Can only train replacements with a single output tensor right now, \ - found {teacher.output_tensors()}" + found {model.output_tensors()}" + + input_name = model.input_tensors()[0] + output_name = model.output_tensors()[0] + return (input_name, output_name) - input_name = teacher.input_tensors()[0] - output_name = teacher.output_tensors()[0] +def _check_model_compatibility(teacher: TFLiteModel, replace: TFLiteModel) -> None: + """Assert that teacher and replaced sub-graph are compatible.""" assert len(teacher.shape_from_name) == len( replace.shape_from_name ), f"Baseline and train models must have the same number of inputs and outputs. \ @@ -230,10 +215,37 @@ def train_in_dir( subgraph being replaced. Teacher: {teacher.shape_from_name}\n \ Train dir: {replace.shape_from_name}" + +def set_up_data_pipeline( + teacher: TFLiteModel, + replace: TFLiteModel, + train_dir: str, + augmentations: tuple[float | None, float | None], + steps: int, + batch_size: int = 32, +) -> tf.data.Dataset: + """Create a data pipeline for training of the replacement model.""" + _check_model_compatibility(teacher, replace) + + input_name, output_name = _get_io_tensors(teacher) + input_filename = Path(train_dir, "input.tfrec") total = numpytf_count(str(input_filename)) dict_inputs = numpytf_read(str(input_filename)) + inputs = dict_inputs.map(lambda d: tf.squeeze(d[input_name], axis=0)) + + steps_per_epoch = math.ceil(total / batch_size) + epochs = int(math.ceil(steps / steps_per_epoch)) + logger.info( + "Training on %d items for %d steps (%d epochs with batch size %d)", + total, + epochs * steps_per_epoch, + epochs, + batch_size, + ) + + teacher_dir = Path(teacher.model_path).parent if any(augmentations): # Map the teacher inputs here because the augmentation stage passes these # through a TFLite model to get the outputs @@ -245,17 +257,6 @@ def train_in_dir( lambda d: tf.squeeze(d[output_name], axis=0) ) - steps_per_epoch = math.ceil(total / batch_size) - epochs = int(math.ceil(steps / steps_per_epoch)) - if verbose: - logger.info( - "Training on %d items for %d steps (%d epochs with batch size %d)", - total, - epochs * steps_per_epoch, - epochs, - batch_size, - ) - dataset = tf.data.Dataset.zip((inputs, teacher_outputs)) if epochs > 1: dataset = dataset.cache() @@ -268,10 +269,9 @@ def train_in_dir( train: Any, teach: Any # pylint: disable=redefined-outer-name ) -> tuple: """Return results of train and teach based on augmentations.""" - return ( - augment_train({input_name: train})[input_name], - teacher(augment_teacher({input_name: teach}))[output_name], - ) + augmented_train = augment_train({input_name: train})[input_name] + augmented_teach = teacher(augment_teacher({input_name: teach}))[output_name] + return (augmented_train, augmented_teach) dataset = dataset.map( lambda augment_train, augment_teach: tf.py_function( @@ -281,18 +281,67 @@ def train_in_dir( ) ) + # Restore data shapes of the dataset as they are set to unknown per default + # and get lost during augmentation with tf.py_function. + shape_in, shape_out = ( + teacher.shape_from_name[name].tolist() for name in (input_name, output_name) + ) + for shape in (shape_in, shape_out): + shape[0] = None # set dynamic batch size + + def restore_shapes(input_: Any, output: Any) -> tuple[Any, Any]: + input_.set_shape(shape_in) + output.set_shape(shape_out) + return input_, output + + dataset = dataset.map(restore_shapes) dataset = dataset.prefetch(tf.data.AUTOTUNE) + return dataset + + +def train_in_dir( + train_dir: str, + baseline_dir: Any, + output_filename: Path, + replace_fn: Callable, + train_params: TrainingParameters = TrainingParameters(), +) -> list[str]: + """Train a replacement for replace.tflite using the input.tfrec \ + and output.tfrec in train_dir. + + If baseline_dir is provided, train the replacement to match baseline + outputs for train_dir inputs. Result saved as new.tflite in train_dir. + """ + teacher_dir = baseline_dir if baseline_dir else train_dir + teacher = ParallelTFLiteModel( + f"{teacher_dir}/replace.tflite", + train_params.num_procs, + train_params.num_threads, + batch_size=train_params.batch_size, + ) + replace = TFLiteModel(f"{train_dir}/replace.tflite") + + input_name, output_name = _get_io_tensors(teacher) + + dataset = set_up_data_pipeline( + teacher, + replace, + train_dir, + augmentations=train_params.augmentations, + steps=train_params.steps, + batch_size=train_params.batch_size, + ) input_shape = teacher.shape_from_name[input_name][1:] output_shape = teacher.shape_from_name[output_name][1:] + model = replace_fn(input_shape, output_shape) - optimizer = tf.keras.optimizers.Nadam(learning_rate=learning_rate) + optimizer = tf.keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = tf.keras.losses.MeanSquaredError() - model.compile(optimizer=optimizer, loss=loss_fn) + model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) - if verbose: - model.summary() + logger.info(model.summary()) steps_so_far = 0 @@ -302,7 +351,9 @@ def train_in_dir( """Cosine decay from learning rate at start of the run to zero at the end.""" current_step = epoch_step + steps_so_far cd_learning_rate = ( - learning_rate * (math.cos(math.pi * current_step / steps) + 1) / 2.0 + train_params.learning_rate + * (math.cos(math.pi * current_step / train_params.steps) + 1) + / 2.0 ) tf.keras.backend.set_value(optimizer.learning_rate, cd_learning_rate) @@ -311,28 +362,29 @@ def train_in_dir( ) -> None: """Constant until the last 20% of the run, then linear decay to zero.""" current_step = epoch_step + steps_so_far - steps_remaining = steps - current_step - decay_length = steps // 5 + steps_remaining = train_params.steps - current_step + decay_length = train_params.steps // 5 decay_fraction = min(steps_remaining, decay_length) / decay_length - ld_learning_rate = learning_rate * decay_fraction + ld_learning_rate = train_params.learning_rate * decay_fraction tf.keras.backend.set_value(optimizer.learning_rate, ld_learning_rate) - if schedule == "cosine": + assert train_params.learning_rate_schedule in LEARNING_RATE_SCHEDULES, ( + f'Learning rate schedule "{train_params.learning_rate_schedule}" ' + f"not implemented - expected one of {LEARNING_RATE_SCHEDULES}." + ) + if train_params.learning_rate_schedule == "cosine": callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=cosine_decay)] - elif schedule == "late": + elif train_params.learning_rate_schedule == "late": callbacks = [tf.keras.callbacks.LambdaCallback(on_batch_begin=late_decay)] - elif schedule == "constant": + elif train_params.learning_rate_schedule == "constant": callbacks = [] - else: - assert schedule not in LEARNING_RATE_SCHEDULES - raise ValueError( - f'Learning rate schedule "{schedule}" not implemented - ' - f"expected one of {LEARNING_RATE_SCHEDULES}." - ) output_filenames = [] - checkpoints = (checkpoint_at if checkpoint_at else []) + [steps] - while steps_so_far < steps: + checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ + train_params.steps + ] + + while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far lr_start = optimizer.learning_rate.numpy() model.fit( @@ -340,7 +392,7 @@ def train_in_dir( epochs=1, steps_per_epoch=steps_to_train, callbacks=callbacks, - verbose=show_progress, + verbose=train_params.show_progress, ) steps_so_far += steps_to_train logger.info( @@ -350,12 +402,14 @@ def train_in_dir( steps_to_train, ) - if steps_so_far < steps: + if steps_so_far < train_params.steps: filename, ext = Path(output_filename).parts[1:] checkpoint_filename = filename + (f"_@{steps_so_far}") + ext else: checkpoint_filename = str(output_filename) - with log_action(f"{steps_so_far}/{steps}: Saved as {checkpoint_filename}"): + with log_action( + f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" + ): save_as_tflite( model, checkpoint_filename, @@ -379,14 +433,30 @@ def save_as_tflite( output_shape: list, ) -> None: """Save Keras model as TFLite file.""" - converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + + @contextmanager + def fixed_input(keras_model: tf.keras.Model, tmp_shape: list) -> GeneratorType: + """Fix the input shape of the Keras model temporarily. + + This avoids artifacts during conversion to TensorFlow Lite. + """ + orig_shape = keras_model.input.shape + keras_model.input.set_shape(tf.TensorShape(tmp_shape)) + try: + yield keras_model + finally: + # Restore original shape to not interfere with further training + keras_model.input.set_shape(orig_shape) + + with fixed_input(keras_model, input_shape) as fixed_model: + converter = tf.lite.TFLiteConverter.from_keras_model(fixed_model) tflite_model = converter.convert() with open(filename, "wb") as file: file.write(tflite_model) # Now fix the shapes and names to match those we expect - flatbuffer = load(filename) + flatbuffer = load_fb(filename) i = flatbuffer.subgraphs[0].inputs[0] flatbuffer.subgraphs[0].tensors[i].shape = np.array(input_shape, dtype=np.int32) flatbuffer.subgraphs[0].tensors[i].name = input_name.encode("utf-8") @@ -395,11 +465,11 @@ def save_as_tflite( output_shape, dtype=np.int32 ) flatbuffer.subgraphs[0].tensors[output].name = output_name.encode("utf-8") - save(flatbuffer, filename) + save_fb(flatbuffer, filename) def augment_fn_twins( - inputs: dict, augmentations: tuple[float | None, float | None] + inputs: tf.data.Dataset, augmentations: tuple[float | None, float | None] ) -> Any: """Return a pair of twinned augmentation functions with the same sequence \ of random numbers.""" @@ -415,6 +485,11 @@ def augment_fn( inputs: Any, augmentations: tuple[float | None, float | None], rng: Generator ) -> Any: """Augmentation module.""" + assert len(augmentations) == 2, ( + f"Unexpected number of augmentation parameters: {len(augmentations)} " + "(must be 2)" + ) + mixup_strength, gaussian_strength = augmentations augments = [] @@ -449,17 +524,16 @@ def augment_fn( augments.append(gaussian_strength_augment) - if len(augments) == 0: # pylint: disable=no-else-return + if len(augments) == 0: return lambda x: x - elif len(augments) == 1: + if len(augments) == 1: return augments[0] - elif len(augments) == 2: + if len(augments) == 2: return lambda x: augments[1](augments[0](x)) - else: - assert ( - False - ), f"Unexpected number of augmentation \ - functions ({len(augments)})" + + raise RuntimeError( + "Unexpected number of augmentation functions (must be <=2): " f"{len(augments)}" + ) def mixup(rng: Generator, batch: Any, beta_range: tuple = (0.0, 1.0)) -> Any: diff --git a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py index 9229810..38ac1ed 100644 --- a/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py +++ b/src/mlia/nn/rewrite/core/utils/numpy_tfrecord.py @@ -6,55 +6,56 @@ from __future__ import annotations import json import os import random -import tempfile -from collections import defaultdict +from functools import lru_cache from pathlib import Path from typing import Any from typing import Callable -import numpy as np import tensorflow as tf -from tensorflow.lite.python import interpreter as interpreter_wrapper -from mlia.nn.rewrite.core.utils.utils import load -from mlia.nn.rewrite.core.utils.utils import save os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -def make_decode_fn(filename: str) -> Callable: - """Make decode filename.""" +def decode_fn(record_bytes: Any, type_map: dict) -> dict: + """Decode the given bytes into a name-tensor dict assuming the given type.""" + parse_dict = { + name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() + } + example = tf.io.parse_single_example(record_bytes, parse_dict) + features = { + n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) for n, t in type_map.items() + } + return features - def decode_fn(record_bytes: Any, type_map: dict) -> dict: - parse_dict = { - name: tf.io.FixedLenFeature([], tf.string) for name in type_map.keys() - } - example = tf.io.parse_single_example(record_bytes, parse_dict) - features = { - n: tf.io.parse_tensor(example[n], tf.as_dtype(t)) - for n, t in type_map.items() - } - return features +def make_decode_fn(filename: str, model_filename: str | Path | None = None) -> Callable: + """Make decode filename.""" meta_filename = filename + ".meta" - with open(meta_filename, encoding="utf-8") as file: - type_map = json.load(file)["type_map"] + try: + with open(meta_filename, encoding="utf-8") as file: + type_map = json.load(file)["type_map"] return lambda record_bytes: decode_fn(record_bytes, type_map) def numpytf_read(filename: str | Path) -> Any: """Read TFRecord dataset.""" - decode_fn = make_decode_fn(str(filename)) + decode = make_decode_fn(str(filename)) dataset = tf.data.TFRecordDataset(str(filename)) - return dataset.map(decode_fn) + return dataset.map(decode) -def numpytf_count(filename: str | Path) -> Any: +@lru_cache +def numpytf_count(filename: str | Path) -> int: """Return count from TFRecord file.""" meta_filename = f"{filename}.meta" - with open(meta_filename, encoding="utf-8") as file: - return json.load(file)["count"] + try: + with open(meta_filename, encoding="utf-8") as file: + return int(json.load(file)["count"]) + except FileNotFoundError: + raw_dataset = tf.data.TFRecordDataset(filename) + return sum(1 for _ in raw_dataset) class NumpyTFWriter: @@ -101,93 +102,6 @@ class NumpyTFWriter: self.writer.close() -class TFLiteModel: - """A representation of a TFLite Model.""" - - def __init__( - self, - filename: str, - batch_size: int | None = None, - num_threads: int | None = None, - ) -> None: - """Initiate a TFLite Model.""" - if not num_threads: - num_threads = None - if not batch_size: - self.interpreter = interpreter_wrapper.Interpreter( - model_path=filename, num_threads=num_threads - ) - else: # if a batch size is specified, modify the TFLite model to use this size - with tempfile.TemporaryDirectory() as tmp: - flatbuffer = load(filename) - for subgraph in flatbuffer.subgraphs: - for tensor in list(subgraph.inputs) + list(subgraph.outputs): - subgraph.tensors[tensor].shape = np.array( - [batch_size] + list(subgraph.tensors[tensor].shape[1:]), - dtype=np.int32, - ) - tempname = os.path.join(tmp, "rewrite_tmp.tflite") - save(flatbuffer, tempname) - self.interpreter = interpreter_wrapper.Interpreter( - model_path=tempname, num_threads=num_threads - ) - - try: - self.interpreter.allocate_tensors() - except RuntimeError: - self.interpreter = interpreter_wrapper.Interpreter( - model_path=filename, num_threads=num_threads - ) - self.interpreter.allocate_tensors() - - # Get input and output tensors. - self.input_details = self.interpreter.get_input_details() - self.output_details = self.interpreter.get_output_details() - details = list(self.input_details) + list(self.output_details) - self.handle_from_name = {d["name"]: d["index"] for d in details} - self.shape_from_name = {d["name"]: d["shape"] for d in details} - self.batch_size = next(iter(self.shape_from_name.values()))[0] - - def __call__(self, named_input: dict) -> dict: - """Execute the model on one or a batch of named inputs \ - (a dict of name: numpy array).""" - input_len = next(iter(named_input.values())).shape[0] - full_steps = input_len // self.batch_size - remainder = input_len % self.batch_size - - named_ys = defaultdict(list) - for i in range(full_steps): - for name, x_batch in named_input.items(): - x_tensor = x_batch[i : i + self.batch_size] # noqa: E203 - self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) - self.interpreter.invoke() - for output_detail in self.output_details: - named_ys[output_detail["name"]].append( - self.interpreter.get_tensor(output_detail["index"]) - ) - if remainder: - for name, x_batch in named_input.items(): - x_tensor = np.zeros( # pylint: disable=invalid-name - self.shape_from_name[name] - ).astype(x_batch.dtype) - x_tensor[:remainder] = x_batch[-remainder:] - self.interpreter.set_tensor(self.handle_from_name[name], x_tensor) - self.interpreter.invoke() - for output_detail in self.output_details: - named_ys[output_detail["name"]].append( - self.interpreter.get_tensor(output_detail["index"])[:remainder] - ) - return {k: np.concatenate(v) for k, v in named_ys.items()} - - def input_tensors(self) -> list: - """Return name from input details.""" - return [d["name"] for d in self.input_details] - - def output_tensors(self) -> list: - """Return name from output details.""" - return [d["name"] for d in self.output_details] - - def sample_tfrec(input_file: str, k: int, output_file: str) -> None: """Count, read and write TFRecord input and output data.""" total = numpytf_count(input_file) diff --git a/src/mlia/nn/rewrite/core/utils/parallel.py b/src/mlia/nn/rewrite/core/utils/parallel.py index d930a1e..b7b390d 100644 --- a/src/mlia/nn/rewrite/core/utils/parallel.py +++ b/src/mlia/nn/rewrite/core/utils/parallel.py @@ -15,14 +15,14 @@ from typing import Any import numpy as np import tensorflow as tf -from mlia.nn.rewrite.core.utils.numpy_tfrecord import TFLiteModel +from mlia.nn.tensorflow.config import TFLiteModel os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) logger = logging.getLogger(__name__) -class ParallelTFLiteModel(TFLiteModel): +class ParallelTFLiteModel(TFLiteModel): # pylint: disable=abstract-method """A parallel version of a TFLiteModel. num_procs: 0 => detect real cores on system diff --git a/src/mlia/nn/rewrite/core/utils/utils.py b/src/mlia/nn/rewrite/core/utils/utils.py deleted file mode 100644 index ddf0cc2..0000000 --- a/src/mlia/nn/rewrite/core/utils/utils.py +++ /dev/null @@ -1,32 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Model and file system utilites.""" -from __future__ import annotations - -from pathlib import Path - -import flatbuffers -from tensorflow.lite.python.schema_py_generated import Model -from tensorflow.lite.python.schema_py_generated import ModelT - - -def load(input_tflite_file: str | Path) -> ModelT: - """Load a flatbuffer model from file.""" - if not Path(input_tflite_file).exists(): - raise FileNotFoundError(f"TFLite file not found at {input_tflite_file}\n") - with open(input_tflite_file, "rb") as file_handle: - file_data = bytearray(file_handle.read()) - model_obj = Model.GetRootAsModel(file_data, 0) - model = ModelT.InitFromObj(model_obj) - return model - - -def save(model: ModelT, output_tflite_file: str | Path) -> None: - """Save a flatbuffer model to a given file.""" - builder = flatbuffers.Builder(1024) # Initial size of the buffer, which - # will grow automatically if needed - model_offset = model.Pack(builder) - builder.Finish(model_offset, file_identifier=b"TFL3") - model_data = builder.Output() - with open(output_tflite_file, "wb") as out_file: - out_file.write(model_data) diff --git a/src/mlia/nn/rewrite/library/fc_layer.py b/src/mlia/nn/rewrite/library/fc_layer.py index 8704154..2480500 100644 --- a/src/mlia/nn/rewrite/library/fc_layer.py +++ b/src/mlia/nn/rewrite/library/fc_layer.py @@ -7,12 +7,12 @@ import tensorflow as tf def get_keras_model(input_shape: Any, output_shape: Any) -> tf.keras.Model: - """Generate tflite model for rewrite.""" - input_tensor = tf.keras.layers.Input( - shape=input_shape, name="MbileNet/avg_pool/AvgPool" + """Generate TensorFlow Lite model for rewrite.""" + model = tf.keras.Sequential( + ( + tf.keras.layers.InputLayer(input_shape=input_shape), + tf.keras.layers.Reshape([-1]), + tf.keras.layers.Dense(output_shape), + ) ) - output_tensor = tf.keras.layers.Dense(output_shape, name="MobileNet/fc1/BiasAdd")( - input_tensor - ) - model = tf.keras.Model(input_tensor, output_tensor) return model |