From baaf4de286762c1955c874f78cd802d4703a8ba5 Mon Sep 17 00:00:00 2001 From: Gergely Nagy Date: Thu, 22 Jun 2023 14:35:21 +0100 Subject: Re-factoring of rewrite management & added metrics - List available rewrites - Refactor/rename 'Rewrite' class to 'RewritingOptimizer' - Introduce a registry for rewrite functions - Refactor 'Rewriter' to use the registry to look up rewrite functions - Remove mentions of hardcoded "fully_connected" from CLI help and error messages, using the registry instead - Add unit tests - Enable rewrites for all targets: Extract optimization (including rewrite specific code) from the Ethos-U-specific data collector into OptimizingDataCollector. This is reused in other targets' collectors, such as TOSA and Cortex-A. - Add more logging for rewrite - add display of MAE and NRMSE values for the trained result - add total model MAE and NRMSE metric Resolves: MLIA-891, MLIA-899, MLIA-906 Change-Id: Ie798749e1ed60cab14fdb6d9c2271c833960e93f Signed-off-by: Benjamin Klimczak --- src/mlia/nn/rewrite/core/rewrite.py | 154 +++++++++++++++++++++++++++++------- src/mlia/nn/rewrite/core/train.py | 21 ++++- 2 files changed, 143 insertions(+), 32 deletions(-) (limited to 'src/mlia/nn/rewrite') diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 6b27984..fdfd35c 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""Contains class Rewriter to replace a subgraph/layer of a model.""" +"""Contains class RewritingOptimizer to replace a subgraph/layer of a model.""" from __future__ import annotations import importlib @@ -9,16 +9,88 @@ import tempfile from dataclasses import dataclass from pathlib import Path from typing import Any +from typing import Callable +from typing import cast + +import tensorflow as tf from mlia.core.errors import ConfigurationError +from mlia.core.reporting import Column +from mlia.core.reporting import Format +from mlia.core.reporting import Table from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters from mlia.nn.tensorflow.config import TFLiteModel - +from mlia.utils.registry import Registry logger = logging.getLogger(__name__) +RewriteCallable = Callable[[Any, Any], tf.keras.Model] + + +class Rewrite: + """Graph rewrite logic to be used by RewritingOptimizer.""" + + def __init__(self, name: str, rewrite_fn: RewriteCallable): + """Initialize a Rewrite instance with a given name and an optional function.""" + self.name = name + self.function = rewrite_fn + + def __call__(self, input_shape: Any, output_shape: Any) -> tf.keras.Model: + """Perform the rewrite operation using the configured function.""" + try: + return self.function(input_shape, output_shape) + except Exception as ex: + raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex + + +@dataclass +class DynamicallyLoadedRewrite(Rewrite): + """A rewrite which can load logic from a function loaded dynamically.""" + + def __init__(self, name: str, function_name: str): + """Initialize.""" + + def load_and_run(input_shape: Any, output_shape: Any) -> tf.keras.Model: + """Load the function from a file dynamically.""" + self.load_function(function_name) + return self.function(input_shape, output_shape) + + super().__init__(name, load_and_run) + + def load_function(self, function_name: str) -> RewriteCallable: + """Return the rewrite function. Import using the auto_load attr if necessary.""" + try: + name_parts = function_name.split(".") + module_name = ".".join(name_parts[:-1]) + fn_name = name_parts[-1] + module = importlib.import_module(module_name) + self.function = cast(RewriteCallable, getattr(module, fn_name)) + return self.function + except Exception as ex: + raise RuntimeError( + f"Unable to load rewrite function '{function_name}' for '{self.name}'." + ) from ex + + +class RewriteRegistry(Registry[Rewrite]): + """Registry rewrite functions.""" + + def __init__(self, rewrites: list[Rewrite] | None = None): + """Set up a rewrite registry. + + Can optionally initialise with name->function pairs + to be automatically loaded on demand + """ + super().__init__() + if rewrites: + for rewrite in rewrites: + self.register_rewrite(rewrite) + + def register_rewrite(self, rewrite: Rewrite) -> bool: + """Register a rewrite.""" + return super().register(rewrite.name, rewrite) @dataclass @@ -35,34 +107,35 @@ class RewriteConfiguration(OptimizerConfiguration): return f"rewrite: {self.optimization_target}" -class Rewriter(Optimizer): - """Rewriter class for basic rewrite flow.""" +class RewritingOptimizer(Optimizer): + """RewritingOptimizer class for basic rewrite flow.""" + + registry = RewriteRegistry( + [ + DynamicallyLoadedRewrite( + "fully_connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model" + ) + ] + ) def __init__( self, tflite_model_path: Path, optimizer_configuration: RewriteConfiguration ): - """Init Rewriter instance.""" + """Init RewritingOptimizer instance.""" self.model = TFLiteModel(tflite_model_path) self.model_path = tflite_model_path self.optimizer_configuration = optimizer_configuration - def apply_optimization(self) -> None: - """Apply the rewrite flow.""" + @classmethod + def builtin_rewrite_names(cls) -> list: + """Return all registered rewrite names.""" + return cls.registry.names() - def get_function(arg: str) -> Any: - module_name = ".".join(arg.split(".")[:-1]) - fn_name = arg.split(".")[-1] - module = importlib.import_module(module_name) - return getattr(module, fn_name) - - if self.optimizer_configuration.optimization_target == "fully_connected": - replace_function = "mlia.nn.rewrite.library.fc_layer.get_keras_model" - else: - raise ConfigurationError( - "Only fully_connected replacement is supported in rewrite module." - ) - - replace_fn = get_function(replace_function) + def apply_optimization(self) -> None: # pylint: disable=too-many-locals + """Apply the rewrite flow.""" + rewrite = RewritingOptimizer.registry.items[ + self.optimizer_configuration.optimization_target + ] use_unmodified_model = True tflite_model = self.model.model_path @@ -75,25 +148,48 @@ class Rewriter(Optimizer): raise ConfigurationError( "Input and output tensor names need to be set for rewrite." ) - result = train( + + orig_vs_repl_stats, total_stats = train( source_model=tflite_model, unmodified_model=tflite_model if use_unmodified_model else None, output_model=str(tmp_output), input_tfrec=str(tfrecord), - replace_fn=replace_fn, + replace_fn=rewrite, input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, ) - self.model = TFLiteModel(tmp_output) + if orig_vs_repl_stats: + orig_vs_repl = ["Replaced sub-graph only"] + [ + f"{stat:.3f}" for stat in orig_vs_repl_stats + ] + total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats] + notes = ( + "These metrics show the difference between original model\n" + "and the model optimized by the rewrite. The models are\n" + "compared at two positions: directly after the replaced\n" + "sub-graph and at the model output.\n" + "MAE = Mean Absolute Error\n" + "NRMSE = Normalized Root Mean Square Error" + ) - if result: - stats_as_str = ", ".join(str(stats) for stats in result) - logger.info( - "The MAE and NRMSE between original and replacement [%s]", - stats_as_str, + table = Table( + columns=[ + Column( + "Original vs. optimized", + alias="metric", + fmt=Format(wrap_width=40), + ), + Column("MAE", alias="value", fmt=Format(wrap_width=15)), + Column("NRMSE", alias="value", fmt=Format(wrap_width=15)), + ], + rows=[orig_vs_repl, total], + name="Rewrite performance metrics", + alias="rewrite_performance_metrics", + notes=notes, ) + logger.info(table.to_plain_text(show_title=True)) def get_model(self) -> TFLiteModel: """Return optimized model.""" diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 42bf653..82af747 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -136,12 +136,27 @@ def train( output_filename = output_model join_in_dir(train_dir, filename, output_filename) + # Assess the output diff between the parts after the rewrite subgraph + # in original and optimized model + optimized_end_path = Path(train_dir, "optimized_end.tfrec") + end_path = Path(train_dir, "end.tfrec") + + record_model( + str(input_tfrec), + output_filename, + optimized_end_path, + num_procs=train_params.num_procs, + num_threads=train_params.num_threads, + ) + mae, nrmse = diff_stats(end_path, str(optimized_end_path)) + if unmodified_model_dir: cast(tempfile.TemporaryDirectory, unmodified_model_dir).cleanup() - return ( - results if train_params.checkpoint_at else results[0] - ) # only return a list if multiple checkpoints are asked for + return (results if train_params.checkpoint_at else results[0]), [ + mae, + nrmse, + ] # only return a list if multiple checkpoints are asked for def eval_in_dir( -- cgit v1.2.1