aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py94
1 files changed, 48 insertions, 46 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 2a7b432..4fe1c26 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -3,14 +3,14 @@
"""Contains class RewritingOptimizer to replace a subgraph/layer of a model."""
from __future__ import annotations
-import importlib
import logging
import tempfile
+from abc import ABC
+from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import cast
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
@@ -35,8 +35,8 @@ logger = logging.getLogger(__name__)
RewriteCallable = Callable[[Any, Any], keras.Model]
-class Rewrite:
- """Graph rewrite logic to be used by RewritingOptimizer."""
+class Rewrite(ABC):
+ """Abstract class for rewrite logic to be used by RewritingOptimizer."""
def __init__(self, name: str, rewrite_fn: RewriteCallable):
"""Initialize a Rewrite instance with a given name and an optional function."""
@@ -50,10 +50,42 @@ class Rewrite:
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
- def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model:
+ @abstractmethod
+ def quantize(self, model: keras.Model) -> keras.Model:
"""Return a quantized model if required."""
- if model_is_quantized:
- model = tfmot.quantization.keras.quantize_model(model)
+
+ @abstractmethod
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+
+ @abstractmethod
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+
+
+class QATRewrite(Rewrite):
+ """Logic for rewrites requiring quantization-aware training."""
+
+ def pruning_preserved_quantization(
+ self,
+ model: keras.Model,
+ ) -> keras.Model:
+ """Apply pruning-preserved quantization training to a given model."""
+ model = tfmot.quantization.keras.quantize_annotate_model(model)
+ model = tfmot.quantization.keras.quantize_apply(
+ model,
+ tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
+ )
+
+ return model
+
+
+class FullyConnectedRewrite(Rewrite):
+ """Graph rewrite logic for fully-connected rewrite."""
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ model = tfmot.quantization.keras.quantize_model(model)
return model
def training_callbacks(self) -> list:
@@ -65,18 +97,15 @@ class Rewrite:
return model
-class PruningRewrite(Rewrite):
- """Derived Rewrite class with pruning-specific logic."""
+class Sparsity24Rewrite(QATRewrite):
+ """Graph rewrite logic for fully-connected-sparsity24 rewrite."""
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
- def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model:
- """Return a quantized model if required."""
- if model_is_quantized:
- # placeholder for PQAT
- pass
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Skip quantization when using pruning rewrite."""
return model
def training_callbacks(self) -> list:
@@ -88,35 +117,6 @@ class PruningRewrite(Rewrite):
return self.strip_pruning_wrapper(model)
-@dataclass
-class DynamicallyLoadedRewrite(Rewrite):
- """A rewrite which can load logic from a function loaded dynamically."""
-
- def __init__(self, name: str, function_name: str):
- """Initialize."""
-
- def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model:
- """Load the function from a file dynamically."""
- self.load_function(function_name)
- return self.function(input_shape, output_shape)
-
- super().__init__(name, load_and_run)
-
- def load_function(self, function_name: str) -> RewriteCallable:
- """Return the rewrite function. Import using the auto_load attr if necessary."""
- try:
- name_parts = function_name.split(".")
- module_name = ".".join(name_parts[:-1])
- fn_name = name_parts[-1]
- module = importlib.import_module(module_name)
- self.function = cast(RewriteCallable, getattr(module, fn_name))
- return self.function
- except Exception as ex:
- raise RuntimeError(
- f"Unable to load rewrite function '{function_name}' for '{self.name}'."
- ) from ex
-
-
class RewriteRegistry(Registry[Rewrite]):
"""Registry rewrite functions."""
@@ -155,8 +155,8 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
- Rewrite("fully-connected", fc_rewrite),
- PruningRewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
+ FullyConnectedRewrite("fully-connected", fc_rewrite),
+ Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
]
)
@@ -178,6 +178,7 @@ class RewritingOptimizer(Optimizer):
rewrite = RewritingOptimizer.registry.items[
self.optimizer_configuration.optimization_target
]
+ is_qat = isinstance(rewrite, QATRewrite)
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
@@ -190,7 +191,6 @@ class RewritingOptimizer(Optimizer):
"Input and output tensor names need to be set for rewrite."
)
- self.optimizer_configuration.train_params.checkpoint_at = [5000, 10000]
orig_vs_repl_stats, total_stats = train(
source_model=tflite_model,
unmodified_model=tflite_model if use_unmodified_model else None,
@@ -199,6 +199,7 @@ class RewritingOptimizer(Optimizer):
rewrite=rewrite,
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
+ is_qat=is_qat,
train_params=self.optimizer_configuration.train_params,
)
@@ -245,6 +246,7 @@ class RewritingOptimizer(Optimizer):
notes=notes,
)
logger.info(table.to_plain_text(show_title=True))
+ self.model = TFLiteModel(tmp_output)
def get_model(self) -> TFLiteModel:
"""Return optimized model."""