aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py201
1 files changed, 164 insertions, 37 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index c7d13ba..e2c097c 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -3,15 +3,17 @@
"""Contains class RewritingOptimizer to replace a subgraph/layer of a model."""
from __future__ import annotations
-import importlib
import logging
import tempfile
+from abc import ABC
+from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import cast
+import numpy as np
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.errors import ConfigurationError
@@ -22,6 +24,13 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
+from mlia.nn.rewrite.library.fc_clustering_layer import (
+ get_keras_model_clus as fc_clustering_rewrite,
+)
+from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
+from mlia.nn.rewrite.library.fc_sparsity24_layer import (
+ get_keras_model as fc_rewrite_sparsity24,
+)
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
@@ -30,8 +39,8 @@ logger = logging.getLogger(__name__)
RewriteCallable = Callable[[Any, Any], keras.Model]
-class Rewrite:
- """Graph rewrite logic to be used by RewritingOptimizer."""
+class Rewrite(ABC):
+ """Abstract class for rewrite logic to be used by RewritingOptimizer."""
def __init__(self, name: str, rewrite_fn: RewriteCallable):
"""Initialize a Rewrite instance with a given name and an optional function."""
@@ -45,34 +54,138 @@ class Rewrite:
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return model
-@dataclass
-class DynamicallyLoadedRewrite(Rewrite):
- """A rewrite which can load logic from a function loaded dynamically."""
+ @abstractmethod
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
- def __init__(self, name: str, function_name: str):
- """Initialize."""
+ @abstractmethod
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
- def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model:
- """Load the function from a file dynamically."""
- self.load_function(function_name)
- return self.function(input_shape, output_shape)
+ @abstractmethod
+ def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ """Check if the optimization has produced the correct result."""
- super().__init__(name, load_and_run)
- def load_function(self, function_name: str) -> RewriteCallable:
- """Return the rewrite function. Import using the auto_load attr if necessary."""
- try:
- name_parts = function_name.split(".")
- module_name = ".".join(name_parts[:-1])
- fn_name = name_parts[-1]
- module = importlib.import_module(module_name)
- self.function = cast(RewriteCallable, getattr(module, fn_name))
- return self.function
- except Exception as ex:
- raise RuntimeError(
- f"Unable to load rewrite function '{function_name}' for '{self.name}'."
- ) from ex
+class GenericRewrite(Rewrite):
+ """Graph rewrite logic for fully-connected rewrite."""
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return tfmot.quantization.keras.quantize_model(model)
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Not needed here."""
+ return True
+
+
+class QuantizeAwareTrainingRewrite(Rewrite, ABC):
+ """Abstract class for rewrites that perform QAT."""
+
+ @abstractmethod
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Apply optimization-aware quantization to a given model."""
+ return model
+
+
+class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
+ """Graph rewrite logic for fully-connected-sparsity24 rewrite."""
+
+ pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
+
+ strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Skip quantization when using pruning rewrite."""
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return pruning-specific rewrite callback."""
+ return [self.pruning_callback()]
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Pruning-specific post-processing rewrite options."""
+ return self.strip_pruning_wrapper(model)
+
+ def preserved_quantize(
+ self,
+ model: keras.Model,
+ ) -> keras.Model:
+ """Apply pruning-preserved quantization training to a given model."""
+ model = tfmot.quantization.keras.quantize_annotate_model(model)
+ model = tfmot.quantization.keras.quantize_apply(
+ model,
+ tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
+ )
+
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Not needed here."""
+ return True
+
+
+class ClusteringRewrite(QuantizeAwareTrainingRewrite):
+ """Graph clustering rewrite logic to be used by RewritingOptimizer."""
+
+ _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
+
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Apply clustering-preserved quantization to a given model."""
+ quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model)
+ cqat_model = tfmot.quantization.keras.quantize_apply(
+ quant_aware_model,
+ tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(),
+ )
+ return cqat_model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Check if clustering has produced the correct result."""
+ number_of_clusters = kwargs.get("number_of_clusters")
+ if not number_of_clusters:
+ raise ValueError(
+ """
+ Expected check_preserved_quantize to have argument number_of_clusters.
+ """
+ )
+
+ for layer in model.layers:
+ for weight in layer.weights:
+ if "kernel" in weight.name:
+ if "kernel_min" in weight.name or "kernel_max" in weight.name:
+ continue
+ number_of_found_clusters = len(np.unique(weight))
+ if number_of_found_clusters != number_of_clusters:
+ logger.warning(
+ "\nWARNING: Expected %d cluster(s), found %d "
+ "cluster(s) in layer %s for weight %s \n",
+ number_of_clusters,
+ number_of_found_clusters,
+ layer.name,
+ weight.name,
+ )
+ return False
+ return True
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return the clustering stripped model."""
+ return self._strip_clustering_wrapper(model)
class RewriteRegistry(Registry[Rewrite]):
@@ -113,9 +226,9 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
- DynamicallyLoadedRewrite(
- "fully-connected", "mlia.nn.rewrite.library.fc_layer.get_keras_model"
- )
+ GenericRewrite("fully-connected", fc_rewrite),
+ Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
+ ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
]
)
@@ -149,22 +262,35 @@ class RewritingOptimizer(Optimizer):
raise ConfigurationError(
"Input and output tensor names need to be set for rewrite."
)
-
orig_vs_repl_stats, total_stats = train(
source_model=tflite_model,
unmodified_model=tflite_model if use_unmodified_model else None,
output_model=str(tmp_output),
input_tfrec=str(tfrecord),
- replace_fn=rewrite,
+ rewrite=rewrite,
+ is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite),
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
train_params=self.optimizer_configuration.train_params,
)
if orig_vs_repl_stats:
- orig_vs_repl = ["Replaced sub-graph only"] + [
- f"{stat:.3f}" for stat in orig_vs_repl_stats
- ]
+ model_stats: list = []
+ cp_param = self.optimizer_configuration.train_params.checkpoint_at
+ checkpoints = (
+ [
+ "At checkpoint " + str(checkpoint) + " steps"
+ for checkpoint in cp_param
+ ]
+ if cp_param
+ else []
+ )
+ checkpoints.append("All Steps")
+ for checkpoint, orig_vs_repl_stat in zip(checkpoints, orig_vs_repl_stats):
+ model_stats.append(
+ ["Replaced sub-graph: " + checkpoint]
+ + [f"{stat:.3f}" for stat in orig_vs_repl_stat]
+ )
total = ["Total model"] + [f"{stat:.3f}" for stat in total_stats]
notes = (
"These metrics show the difference between original model\n"
@@ -178,19 +304,20 @@ class RewritingOptimizer(Optimizer):
table = Table(
columns=[
Column(
- "Original vs. optimized",
+ "Original vs. Optimized",
alias="metric",
fmt=Format(wrap_width=40),
),
Column("MAE", alias="value", fmt=Format(wrap_width=15)),
Column("NRMSE", alias="value", fmt=Format(wrap_width=15)),
],
- rows=[orig_vs_repl, total],
+ rows=[*model_stats, total],
name="Rewrite performance metrics",
alias="rewrite_performance_metrics",
notes=notes,
)
logger.info(table.to_plain_text(show_title=True))
+ self.model = TFLiteModel(tmp_output)
def get_model(self) -> TFLiteModel:
"""Return optimized model."""