diff options
Diffstat (limited to 'src/mlia/nn/rewrite/core/rewrite.py')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 108 |
1 files changed, 70 insertions, 38 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index e2c097c..78fa533 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -8,6 +8,7 @@ import tempfile from abc import ABC from abc import abstractmethod from dataclasses import dataclass +from inspect import getfullargspec from pathlib import Path from typing import Any from typing import Callable @@ -15,6 +16,9 @@ from typing import Callable import numpy as np import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import ( # pylint: disable=no-name-in-module + is_pruned_m_by_n, +) from mlia.core.errors import ConfigurationError from mlia.core.reporting import Column @@ -24,19 +28,16 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from mlia.nn.rewrite.library.fc_clustering_layer import ( - get_keras_model_clus as fc_clustering_rewrite, -) -from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite -from mlia.nn.rewrite.library.fc_sparsity24_layer import ( - get_keras_model as fc_rewrite_sparsity24, -) +from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite +from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite +from mlia.nn.rewrite.library.fc_layer import fc_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry - logger = logging.getLogger(__name__) -RewriteCallable = Callable[[Any, Any], keras.Model] +RewriteCallable = Callable[..., keras.Model] class Rewrite(ABC): @@ -47,10 +48,23 @@ class Rewrite(ABC): self.name = name self.function = rewrite_fn - def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model: + def __call__( + self, input_shape: Any, output_shape: Any, **kwargs: Any + ) -> keras.Model: """Perform the rewrite operation using the configured function.""" try: - return self.function(input_shape, output_shape) + return self.function(input_shape, output_shape, **kwargs) + except TypeError as ex: + expected_args = getfullargspec(self.function).args + if "input_shape" in expected_args: + expected_args.remove("input_shape") + if "output_shape" in expected_args: + expected_args.remove("output_shape") + raise KeyError( + f"Found unexpected parameters for rewrite. Expected (sub)set " + f"of {expected_args} found unexpected parameter(s) " + f"{list(set(list(kwargs.keys())) - set(expected_args))}" + ) from ex except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex @@ -60,19 +74,19 @@ class Rewrite(ABC): @abstractmethod def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" + """Return rewrite callbacks.""" @abstractmethod def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return post-processing rewrite option.""" @abstractmethod - def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Check if the optimization has produced the correct result.""" class GenericRewrite(Rewrite): - """Graph rewrite logic for fully-connected rewrite.""" + """Rewrite class for generic rewrites e.g. fully-connected.""" def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" @@ -83,10 +97,10 @@ class GenericRewrite(Rewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + """Return default post-processing rewrite option.""" return model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Not needed here.""" return True @@ -100,15 +114,15 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC): return model -class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): - """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" +class SparsityRewrite(QuantizeAwareTrainingRewrite): + """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) def quantize(self, model: keras.Model) -> keras.Model: - """Skip quantization when using pruning rewrite.""" + """Skip quantization when using sparsity rewrite.""" return model def training_callbacks(self) -> list: @@ -116,7 +130,7 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return [self.pruning_callback()] def post_process(self, model: keras.Model) -> keras.Model: - """Pruning-specific post-processing rewrite options.""" + """Pruning-specific post-processing rewrite option.""" return self.strip_pruning_wrapper(model) def preserved_quantize( @@ -132,13 +146,34 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: - """Not needed here.""" + def check_optimization( + self, + model: keras.Model, + sparsity_m: int = 2, + sparsity_n: int = 4, + **_: Any, + ) -> bool: + """Check if sparity has produced the correct result.""" + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)): + logger.warning( + "\nWARNING: Could not find (%d, %d) sparsity, " + "in layer %s for weight %s \n", + sparsity_m, + sparsity_n, + layer.name, + weight.name, + ) + return False return True class ClusteringRewrite(QuantizeAwareTrainingRewrite): - """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + """Rewrite class for clustering rewrite e.g. fully-connected-clustering.""" _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) @@ -151,27 +186,21 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): ) return cqat_model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization( + self, model: keras.Model, num_clusters: int = 2, **_: Any + ) -> bool: """Check if clustering has produced the correct result.""" - number_of_clusters = kwargs.get("number_of_clusters") - if not number_of_clusters: - raise ValueError( - """ - Expected check_preserved_quantize to have argument number_of_clusters. - """ - ) - for layer in model.layers: for weight in layer.weights: if "kernel" in weight.name: if "kernel_min" in weight.name or "kernel_max" in weight.name: continue number_of_found_clusters = len(np.unique(weight)) - if number_of_found_clusters != number_of_clusters: + if number_of_found_clusters != num_clusters: logger.warning( "\nWARNING: Expected %d cluster(s), found %d " "cluster(s) in layer %s for weight %s \n", - number_of_clusters, + num_clusters, number_of_found_clusters, layer.name, weight.name, @@ -184,7 +213,7 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): return [] def post_process(self, model: keras.Model) -> keras.Model: - """Return the clustering stripped model.""" + """Clustering-specific post-processing rewrite option.""" return self._strip_clustering_wrapper(model) @@ -215,6 +244,7 @@ class RewriteConfiguration(OptimizerConfiguration): layers_to_optimize: list[str] | None = None dataset: Path | None = None train_params: TrainingParameters = TrainingParameters() + rewrite_specific_params: dict | None = None def __str__(self) -> str: """Return string representation of the configuration.""" @@ -227,8 +257,10 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ GenericRewrite("fully-connected", fc_rewrite), - Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), + SparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), + ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite), + SparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite), ] ) @@ -250,7 +282,6 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] - use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -272,6 +303,7 @@ class RewritingOptimizer(Optimizer): input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, + rewrite_specific_params=self.optimizer_configuration.rewrite_specific_params, # pylint: disable=line-too-long ) if orig_vs_repl_stats: |