diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 132 |
1 files changed, 91 insertions, 41 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 6a3695a..e2c097c 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any from typing import Callable +import numpy as np import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 @@ -53,9 +54,9 @@ class Rewrite(ABC): except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex - @abstractmethod def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" + return model @abstractmethod def training_callbacks(self) -> list: @@ -65,60 +66,41 @@ class Rewrite(ABC): def post_process(self, model: keras.Model) -> keras.Model: """Return default post-processing rewrite options.""" + @abstractmethod + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Check if the optimization has produced the correct result.""" -class ClusteringRewrite(Rewrite): - """Graph clustering rewrite logic to be used by RewritingOptimizer.""" - strip_pruning_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) +class GenericRewrite(Rewrite): + """Graph rewrite logic for fully-connected rewrite.""" def quantize(self, model: keras.Model) -> keras.Model: - """Return a quantized model.""" - return model - - def post_process(self, model: keras.Model) -> keras.Model: - """Return the clustering stripped model.""" - return self.strip_pruning_wrapper(model) + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) def training_callbacks(self) -> list: """Return default rewrite callbacks.""" return [] - -class QATRewrite(Rewrite): - """Logic for rewrites requiring quantization-aware training.""" - - def pruning_preserved_quantization( - self, - model: keras.Model, - ) -> keras.Model: - """Apply pruning-preserved quantization training to a given model.""" - model = tfmot.quantization.keras.quantize_annotate_model(model) - model = tfmot.quantization.keras.quantize_apply( - model, - tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), - ) - + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" return model + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True -class FullyConnectedRewrite(Rewrite): - """Graph rewrite logic for fully-connected rewrite.""" - - def quantize(self, model: keras.Model) -> keras.Model: - """Return a quantized model if required.""" - model = tfmot.quantization.keras.quantize_model(model) - return model - def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" - return [] +class QuantizeAwareTrainingRewrite(Rewrite, ABC): + """Abstract class for rewrites that perform QAT.""" - def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + @abstractmethod + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply optimization-aware quantization to a given model.""" return model -class Sparsity24Rewrite(QATRewrite): +class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep @@ -137,6 +119,74 @@ class Sparsity24Rewrite(QATRewrite): """Pruning-specific post-processing rewrite options.""" return self.strip_pruning_wrapper(model) + def preserved_quantize( + self, + model: keras.Model, + ) -> keras.Model: + """Apply pruning-preserved quantization training to a given model.""" + model = tfmot.quantization.keras.quantize_annotate_model(model) + model = tfmot.quantization.keras.quantize_apply( + model, + tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), + ) + + return model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True + + +class ClusteringRewrite(QuantizeAwareTrainingRewrite): + """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + + _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply clustering-preserved quantization to a given model.""" + quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model) + cqat_model = tfmot.quantization.keras.quantize_apply( + quant_aware_model, + tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(), + ) + return cqat_model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Check if clustering has produced the correct result.""" + number_of_clusters = kwargs.get("number_of_clusters") + if not number_of_clusters: + raise ValueError( + """ + Expected check_preserved_quantize to have argument number_of_clusters. + """ + ) + + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + number_of_found_clusters = len(np.unique(weight)) + if number_of_found_clusters != number_of_clusters: + logger.warning( + "\nWARNING: Expected %d cluster(s), found %d " + "cluster(s) in layer %s for weight %s \n", + number_of_clusters, + number_of_found_clusters, + layer.name, + weight.name, + ) + return False + return True + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return the clustering stripped model.""" + return self._strip_clustering_wrapper(model) + class RewriteRegistry(Registry[Rewrite]): """Registry rewrite functions.""" @@ -176,7 +226,7 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ - FullyConnectedRewrite("fully-connected", fc_rewrite), + GenericRewrite("fully-connected", fc_rewrite), Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), ] @@ -200,7 +250,7 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] - is_qat = isinstance(rewrite, QATRewrite) + use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -218,9 +268,9 @@ class RewritingOptimizer(Optimizer): output_model=str(tmp_output), input_tfrec=str(tfrecord), rewrite=rewrite, + is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite), input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], - is_qat=is_qat, train_params=self.optimizer_configuration.train_params, ) |