aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/rewrite/core/rewrite.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py132
1 files changed, 91 insertions, 41 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 6a3695a..e2c097c 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -12,6 +12,7 @@ from pathlib import Path
from typing import Any
from typing import Callable
+import numpy as np
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
@@ -53,9 +54,9 @@ class Rewrite(ABC):
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
- @abstractmethod
def quantize(self, model: keras.Model) -> keras.Model:
"""Return a quantized model if required."""
+ return model
@abstractmethod
def training_callbacks(self) -> list:
@@ -65,60 +66,41 @@ class Rewrite(ABC):
def post_process(self, model: keras.Model) -> keras.Model:
"""Return default post-processing rewrite options."""
+ @abstractmethod
+ def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ """Check if the optimization has produced the correct result."""
-class ClusteringRewrite(Rewrite):
- """Graph clustering rewrite logic to be used by RewritingOptimizer."""
- strip_pruning_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
+class GenericRewrite(Rewrite):
+ """Graph rewrite logic for fully-connected rewrite."""
def quantize(self, model: keras.Model) -> keras.Model:
- """Return a quantized model."""
- return model
-
- def post_process(self, model: keras.Model) -> keras.Model:
- """Return the clustering stripped model."""
- return self.strip_pruning_wrapper(model)
+ """Return a quantized model if required."""
+ return tfmot.quantization.keras.quantize_model(model)
def training_callbacks(self) -> list:
"""Return default rewrite callbacks."""
return []
-
-class QATRewrite(Rewrite):
- """Logic for rewrites requiring quantization-aware training."""
-
- def pruning_preserved_quantization(
- self,
- model: keras.Model,
- ) -> keras.Model:
- """Apply pruning-preserved quantization training to a given model."""
- model = tfmot.quantization.keras.quantize_annotate_model(model)
- model = tfmot.quantization.keras.quantize_apply(
- model,
- tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
- )
-
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
return model
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Not needed here."""
+ return True
-class FullyConnectedRewrite(Rewrite):
- """Graph rewrite logic for fully-connected rewrite."""
-
- def quantize(self, model: keras.Model) -> keras.Model:
- """Return a quantized model if required."""
- model = tfmot.quantization.keras.quantize_model(model)
- return model
- def training_callbacks(self) -> list:
- """Return default rewrite callbacks."""
- return []
+class QuantizeAwareTrainingRewrite(Rewrite, ABC):
+ """Abstract class for rewrites that perform QAT."""
- def post_process(self, model: keras.Model) -> keras.Model:
- """Return default post-processing rewrite options."""
+ @abstractmethod
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Apply optimization-aware quantization to a given model."""
return model
-class Sparsity24Rewrite(QATRewrite):
+class Sparsity24Rewrite(QuantizeAwareTrainingRewrite):
"""Graph rewrite logic for fully-connected-sparsity24 rewrite."""
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
@@ -137,6 +119,74 @@ class Sparsity24Rewrite(QATRewrite):
"""Pruning-specific post-processing rewrite options."""
return self.strip_pruning_wrapper(model)
+ def preserved_quantize(
+ self,
+ model: keras.Model,
+ ) -> keras.Model:
+ """Apply pruning-preserved quantization training to a given model."""
+ model = tfmot.quantization.keras.quantize_annotate_model(model)
+ model = tfmot.quantization.keras.quantize_apply(
+ model,
+ tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
+ )
+
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Not needed here."""
+ return True
+
+
+class ClusteringRewrite(QuantizeAwareTrainingRewrite):
+ """Graph clustering rewrite logic to be used by RewritingOptimizer."""
+
+ _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
+
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Apply clustering-preserved quantization to a given model."""
+ quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model)
+ cqat_model = tfmot.quantization.keras.quantize_apply(
+ quant_aware_model,
+ tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(),
+ )
+ return cqat_model
+
+ def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool:
+ """Check if clustering has produced the correct result."""
+ number_of_clusters = kwargs.get("number_of_clusters")
+ if not number_of_clusters:
+ raise ValueError(
+ """
+ Expected check_preserved_quantize to have argument number_of_clusters.
+ """
+ )
+
+ for layer in model.layers:
+ for weight in layer.weights:
+ if "kernel" in weight.name:
+ if "kernel_min" in weight.name or "kernel_max" in weight.name:
+ continue
+ number_of_found_clusters = len(np.unique(weight))
+ if number_of_found_clusters != number_of_clusters:
+ logger.warning(
+ "\nWARNING: Expected %d cluster(s), found %d "
+ "cluster(s) in layer %s for weight %s \n",
+ number_of_clusters,
+ number_of_found_clusters,
+ layer.name,
+ weight.name,
+ )
+ return False
+ return True
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return the clustering stripped model."""
+ return self._strip_clustering_wrapper(model)
+
class RewriteRegistry(Registry[Rewrite]):
"""Registry rewrite functions."""
@@ -176,7 +226,7 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
- FullyConnectedRewrite("fully-connected", fc_rewrite),
+ GenericRewrite("fully-connected", fc_rewrite),
Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
]
@@ -200,7 +250,7 @@ class RewritingOptimizer(Optimizer):
rewrite = RewritingOptimizer.registry.items[
self.optimizer_configuration.optimization_target
]
- is_qat = isinstance(rewrite, QATRewrite)
+
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
@@ -218,9 +268,9 @@ class RewritingOptimizer(Optimizer):
output_model=str(tmp_output),
input_tfrec=str(tfrecord),
rewrite=rewrite,
+ is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite),
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
- is_qat=is_qat,
train_params=self.optimizer_configuration.train_params,
)