diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 94 |
1 files changed, 48 insertions, 46 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 2a7b432..4fe1c26 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -3,14 +3,14 @@ """Contains class RewritingOptimizer to replace a subgraph/layer of a model.""" from __future__ import annotations -import importlib import logging import tempfile +from abc import ABC +from abc import abstractmethod from dataclasses import dataclass from pathlib import Path from typing import Any from typing import Callable -from typing import cast import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 @@ -35,8 +35,8 @@ logger = logging.getLogger(__name__) RewriteCallable = Callable[[Any, Any], keras.Model] -class Rewrite: - """Graph rewrite logic to be used by RewritingOptimizer.""" +class Rewrite(ABC): + """Abstract class for rewrite logic to be used by RewritingOptimizer.""" def __init__(self, name: str, rewrite_fn: RewriteCallable): """Initialize a Rewrite instance with a given name and an optional function.""" @@ -50,10 +50,42 @@ class Rewrite: except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex - def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model: + @abstractmethod + def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" - if model_is_quantized: - model = tfmot.quantization.keras.quantize_model(model) + + @abstractmethod + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + + @abstractmethod + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + + +class QATRewrite(Rewrite): + """Logic for rewrites requiring quantization-aware training.""" + + def pruning_preserved_quantization( + self, + model: keras.Model, + ) -> keras.Model: + """Apply pruning-preserved quantization training to a given model.""" + model = tfmot.quantization.keras.quantize_annotate_model(model) + model = tfmot.quantization.keras.quantize_apply( + model, + tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), + ) + + return model + + +class FullyConnectedRewrite(Rewrite): + """Graph rewrite logic for fully-connected rewrite.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + model = tfmot.quantization.keras.quantize_model(model) return model def training_callbacks(self) -> list: @@ -65,18 +97,15 @@ class Rewrite: return model -class PruningRewrite(Rewrite): - """Derived Rewrite class with pruning-specific logic.""" +class Sparsity24Rewrite(QATRewrite): + """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) - def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model: - """Return a quantized model if required.""" - if model_is_quantized: - # placeholder for PQAT - pass + def quantize(self, model: keras.Model) -> keras.Model: + """Skip quantization when using pruning rewrite.""" return model def training_callbacks(self) -> list: @@ -88,35 +117,6 @@ class PruningRewrite(Rewrite): return self.strip_pruning_wrapper(model) -@dataclass -class DynamicallyLoadedRewrite(Rewrite): - """A rewrite which can load logic from a function loaded dynamically.""" - - def __init__(self, name: str, function_name: str): - """Initialize.""" - - def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model: - """Load the function from a file dynamically.""" - self.load_function(function_name) - return self.function(input_shape, output_shape) - - super().__init__(name, load_and_run) - - def load_function(self, function_name: str) -> RewriteCallable: - """Return the rewrite function. Import using the auto_load attr if necessary.""" - try: - name_parts = function_name.split(".") - module_name = ".".join(name_parts[:-1]) - fn_name = name_parts[-1] - module = importlib.import_module(module_name) - self.function = cast(RewriteCallable, getattr(module, fn_name)) - return self.function - except Exception as ex: - raise RuntimeError( - f"Unable to load rewrite function '{function_name}' for '{self.name}'." - ) from ex - - class RewriteRegistry(Registry[Rewrite]): """Registry rewrite functions.""" @@ -155,8 +155,8 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ - Rewrite("fully-connected", fc_rewrite), - PruningRewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), + FullyConnectedRewrite("fully-connected", fc_rewrite), + Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), ] ) @@ -178,6 +178,7 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] + is_qat = isinstance(rewrite, QATRewrite) use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -190,7 +191,6 @@ class RewritingOptimizer(Optimizer): "Input and output tensor names need to be set for rewrite." ) - self.optimizer_configuration.train_params.checkpoint_at = [5000, 10000] orig_vs_repl_stats, total_stats = train( source_model=tflite_model, unmodified_model=tflite_model if use_unmodified_model else None, @@ -199,6 +199,7 @@ class RewritingOptimizer(Optimizer): rewrite=rewrite, input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], + is_qat=is_qat, train_params=self.optimizer_configuration.train_params, ) @@ -245,6 +246,7 @@ class RewritingOptimizer(Optimizer): notes=notes, ) logger.info(table.to_plain_text(show_title=True)) + self.model = TFLiteModel(tmp_output) def get_model(self) -> TFLiteModel: """Return optimized model.""" |