diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 201 |
1 files changed, 164 insertions, 37 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index c7d13ba..e2c097c 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -3,15 +3,17 @@ """Contains class RewritingOptimizer to replace a subgraph/layer of a model.""" from __future__ import annotations -import importlib import logging import tempfile +from abc import ABC +from abc import abstractmethod from dataclasses import dataclass from pathlib import Path from typing import Any from typing import Callable -from typing import cast +import numpy as np +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.core.errors import ConfigurationError @@ -22,6 +24,13 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters +from mlia.nn.rewrite.library.fc_clustering_layer import ( + get_keras_model_clus as fc_clustering_rewrite, +) +from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite +from mlia.nn.rewrite.library.fc_sparsity24_layer import ( + get_keras_model as fc_rewrite_sparsity24, +) from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry @@ -30,8 +39,8 @@ logger = logging.getLogger(__name__) RewriteCallable = Callable[[Any, Any], keras.Model] -class Rewrite: - """Graph rewrite logic to be used by RewritingOptimizer.""" +class Rewrite(ABC): + """Abstract class for rewrite logic to be used by RewritingOptimizer.""" def __init__(self, name: str, rewrite_fn: RewriteCallable): """Initialize a Rewrite instance with a given name and an optional function.""" @@ -45,34 +54,138 @@ class Rewrite: except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return model -@dataclass -class DynamicallyLoadedRewrite(Rewrite): - """A rewrite which can load logic from a function loaded dynamically.""" + @abstractmethod + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" - def __init__(self, name: str, function_name: str): - """Initialize.""" + @abstractmethod + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" - def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model: - """Load the function from a file dynamically.""" - self.load_function(function_name) - return self.function(input_shape, output_shape) + @abstractmethod + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Check if the optimization has produced the correct result.""" - super().__init__(name, load_and_run) - def load_function(self, function_name: str) -> RewriteCallable: - """Return the rewrite function. Import using the auto_load attr if necessary.""" - try: - name_parts = function_name.split(".") - module_name = ".".join(name_parts[:-1]) - fn_name = name_parts[-1] - module = importlib.import_module(module_name) - self.function = cast(RewriteCallable, getattr(module, fn_name)) - return self.function - except Exception as ex: - raise RuntimeError( - f"Unable to load rewrite function '{function_name}' for '{self.name}'." - ) from ex +class GenericRewrite(Rewrite): + """Graph rewrite logic for fully-connected rewrite.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + return model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True + + +class QuantizeAwareTrainingRewrite(Rewrite, ABC): + """Abstract class for rewrites that perform QAT.""" + + @abstractmethod + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply optimization-aware quantization to a given model.""" + return model + + +class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): + """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" + + pruning_callback = tfmot.sparsity.keras.UpdatePruningStep + + strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) + + def quantize(self, model: keras.Model) -> keras.Model: + """Skip quantization when using pruning rewrite.""" + return model + + def training_callbacks(self) -> list: + """Return pruning-specific rewrite callback.""" + return [self.pruning_callback()] + + def post_process(self, model: keras.Model) -> keras.Model: + """Pruning-specific post-processing rewrite options.""" + return self.strip_pruning_wrapper(model) + + def preserved_quantize( + self, + model: keras.Model, + ) -> keras.Model: + """Apply pruning-preserved quantization training to a given model.""" + model = tfmot.quantization.keras.quantize_annotate_model(model) + model = tfmot.quantization.keras.quantize_apply( + model, + tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), + ) + + return model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True + + +class ClusteringRewrite(QuantizeAwareTrainingRewrite): + """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + + _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply clustering-preserved quantization to a given model.""" + quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model) + cqat_model = tfmot.quantization.keras.quantize_apply( + quant_aware_model, + tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(), + ) + return cqat_model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Check if clustering has produced the correct result.""" + number_of_clusters = kwargs.get("number_of_clusters") + if not number_of_clusters: + raise ValueError( + """ + Expected check_preserved_quantize to have argument number_of_clusters. + """ + ) + + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + number_of_found_clusters = len(np.unique(weight)) + if number_of_found_clusters != number_of_clusters: + logger.warning( + "\nWARNING: Expected %d cluster(s), found %d " + "cluster(s) in layer %s for weight %s \n", + number_of_clusters, + number_of_found_clusters, + layer.name, + weight.name, + ) + return False + return True + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return the clustering stripped model.""" + return self._strip_clustering_wrapper(model) class RewriteRegistry(Registry[Rewrite]): @@ -113,9 +226,9 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ - DynamicallyLoadedRewrite( - "fully-connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model" - ) + GenericRewrite("fully-connected", fc_rewrite), + Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), + ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), ] ) @@ -149,22 +262,35 @@ class RewritingOptimizer(Optimizer): raise ConfigurationError( "Input and output tensor names need to be set for rewrite." ) - orig_vs_repl_stats, total_stats = train( source_model=tflite_model, unmodified_model=tflite_model if use_unmodified_model else None, output_model=str(tmp_output), input_tfrec=str(tfrecord), - replace_fn=rewrite, + rewrite=rewrite, + is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite), input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, ) if orig_vs_repl_stats: - orig_vs_repl = ["Replaced sub-graph only"] + [ - f"{stat:.3f}" for stat in orig_vs_repl_stats - ] + model_stats: list = [] + cp_param = self.optimizer_configuration.train_params.checkpoint_at + checkpoints = ( + [ + "At checkpoint " + str(checkpoint) + " steps" + for checkpoint in cp_param + ] + if cp_param + else [] + ) + checkpoints.append("All Steps") + for checkpoint, orig_vs_repl_stat in zip(checkpoints, orig_vs_repl_stats): + model_stats.append( + ["Replaced sub-graph: " + checkpoint] + + [f"{stat:.3f}" for stat in orig_vs_repl_stat] + ) total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats] notes = ( "These metrics show the difference between original model\n" @@ -178,19 +304,20 @@ class RewritingOptimizer(Optimizer): table = Table( columns=[ Column( - "Original vs. optimized", + "Original vs. Optimized", alias="metric", fmt=Format(wrap_width=40), ), Column("MAE", alias="value", fmt=Format(wrap_width=15)), Column("NRMSE", alias="value", fmt=Format(wrap_width=15)), ], - rows=[orig_vs_repl, total], + rows=[*model_stats, total], name="Rewrite performance metrics", alias="rewrite_performance_metrics", notes=notes, ) logger.info(table.to_plain_text(show_title=True)) + self.model = TFLiteModel(tmp_output) def get_model(self) -> TFLiteModel: """Return optimized model.""" |