From b90fd38588cfd7fc0b94786fde196ad5a27965d3 Mon Sep 17 00:00:00 2001 From: Nathan Bailey Date: Thu, 7 Mar 2024 14:27:54 +0000 Subject: feat: Enable rewrite parameterisation for specific rewrites Adds support for rewrite-specific parameters Resolves: MLIA-1114 Signed-off-by: Nathan Bailey Change-Id: I290c326af3356033a916a43b28027819c876c3dd --- src/mlia/nn/rewrite/core/rewrite.py | 67 +++++++++++++++++++------------ src/mlia/nn/rewrite/core/train.py | 34 +++++++++++----- src/mlia/nn/rewrite/library/clustering.py | 26 +++++++++--- src/mlia/nn/rewrite/library/sparsity.py | 18 +++++++-- src/mlia/nn/select.py | 24 ++++++----- 5 files changed, 112 insertions(+), 57 deletions(-) (limited to 'src/mlia/nn') diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 6d915c6..78fa533 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -8,6 +8,7 @@ import tempfile from abc import ABC from abc import abstractmethod from dataclasses import dataclass +from inspect import getfullargspec from pathlib import Path from typing import Any from typing import Callable @@ -36,7 +37,7 @@ from mlia.nn.tensorflow.config import TFLiteModel from mlia.utils.registry import Registry logger = logging.getLogger(__name__) -RewriteCallable = Callable[[Any, Any], keras.Model] +RewriteCallable = Callable[..., keras.Model] class Rewrite(ABC): @@ -47,10 +48,23 @@ class Rewrite(ABC): self.name = name self.function = rewrite_fn - def __call__(self, input_shape: Any, output_shape: Any) -> keras.Model: - """Return an instance of the rewrite model.""" + def __call__( + self, input_shape: Any, output_shape: Any, **kwargs: Any + ) -> keras.Model: + """Perform the rewrite operation using the configured function.""" try: - return self.function(input_shape, output_shape) + return self.function(input_shape, output_shape, **kwargs) + except TypeError as ex: + expected_args = getfullargspec(self.function).args + if "input_shape" in expected_args: + expected_args.remove("input_shape") + if "output_shape" in expected_args: + expected_args.remove("output_shape") + raise KeyError( + f"Found unexpected parameters for rewrite. Expected (sub)set " + f"of {expected_args} found unexpected parameter(s) " + f"{list(set(list(kwargs.keys())) - set(expected_args))}" + ) from ex except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex @@ -67,7 +81,7 @@ class Rewrite(ABC): """Return post-processing rewrite option.""" @abstractmethod - def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Check if the optimization has produced the correct result.""" @@ -86,7 +100,7 @@ class GenericRewrite(Rewrite): """Return default post-processing rewrite option.""" return model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization(self, model: keras.Model) -> bool: """Not needed here.""" return True @@ -100,8 +114,8 @@ class QuantizeAwareTrainingRewrite(Rewrite, ABC): return model -class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): - """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity24.""" +class SparsityRewrite(QuantizeAwareTrainingRewrite): + """Rewrite class for sparsity rewrite e.g. fully-connected-sparsity.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep @@ -132,17 +146,25 @@ class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): return model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization( + self, + model: keras.Model, + sparsity_m: int = 2, + sparsity_n: int = 4, + **_: Any, + ) -> bool: """Check if sparity has produced the correct result.""" for layer in model.layers: for weight in layer.weights: if "kernel" in weight.name: if "kernel_min" in weight.name or "kernel_max" in weight.name: continue - if not is_pruned_m_by_n(weight, m_by_n=(2, 4)): + if not is_pruned_m_by_n(weight, m_by_n=(sparsity_m, sparsity_n)): logger.warning( - "\nWARNING: Could not find (2,4) sparsity, " + "\nWARNING: Could not find (%d, %d) sparsity, " "in layer %s for weight %s \n", + sparsity_m, + sparsity_n, layer.name, weight.name, ) @@ -164,27 +186,21 @@ class ClusteringRewrite(QuantizeAwareTrainingRewrite): ) return cqat_model - def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + def check_optimization( + self, model: keras.Model, num_clusters: int = 2, **_: Any + ) -> bool: """Check if clustering has produced the correct result.""" - number_of_clusters = kwargs.get("number_of_clusters") - if not number_of_clusters: - raise ValueError( - """ - Expected check_optimization to have argument number_of_clusters. - """ - ) - for layer in model.layers: for weight in layer.weights: if "kernel" in weight.name: if "kernel_min" in weight.name or "kernel_max" in weight.name: continue number_of_found_clusters = len(np.unique(weight)) - if number_of_found_clusters != number_of_clusters: + if number_of_found_clusters != num_clusters: logger.warning( "\nWARNING: Expected %d cluster(s), found %d " "cluster(s) in layer %s for weight %s \n", - number_of_clusters, + num_clusters, number_of_found_clusters, layer.name, weight.name, @@ -228,6 +244,7 @@ class RewriteConfiguration(OptimizerConfiguration): layers_to_optimize: list[str] | None = None dataset: Path | None = None train_params: TrainingParameters = TrainingParameters() + rewrite_specific_params: dict | None = None def __str__(self) -> str: """Return string representation of the configuration.""" @@ -240,10 +257,10 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ GenericRewrite("fully-connected", fc_rewrite), - Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite), + SparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite), - Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite), + SparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite), ] ) @@ -265,7 +282,6 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] - use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -287,6 +303,7 @@ class RewritingOptimizer(Optimizer): input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], train_params=self.optimizer_configuration.train_params, + rewrite_specific_params=self.optimizer_configuration.rewrite_specific_params, # pylint: disable=line-too-long ) if orig_vs_repl_stats: diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 4204978..e99c7e9 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -83,6 +83,7 @@ def train( # pylint: disable=too-many-arguments input_tensors: list, output_tensors: list, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, ) -> Any: """Extract and train a model, and return the results.""" if unmodified_model: @@ -122,6 +123,7 @@ def train( # pylint: disable=too-many-arguments rewrite=rewrite, is_qat=is_qat, train_params=train_params, + rewrite_specific_params=rewrite_specific_params, ) for i, filename in enumerate(tflite_filenames): @@ -356,6 +358,7 @@ def train_in_dir( rewrite: Callable, is_qat: bool, train_params: TrainingParameters = TrainingParameters(), + rewrite_specific_params: dict | None = None, ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ and output.tfrec in train_dir. @@ -396,7 +399,13 @@ def train_in_dir( loss_fn = keras.losses.MeanSquaredError() model = create_model( - rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + rewrite_specific_params=rewrite_specific_params, ) logger.info(model.summary()) @@ -462,11 +471,9 @@ def train_in_dir( steps_per_epoch, post_process=True, ) - - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) if model_is_quantized and is_qat: model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] checkpoints = ( @@ -501,11 +508,10 @@ def train_in_dir( loss_fn, steps_per_epoch, ) - # Placeholder for now, will be parametrized later (MLIA-1114) - # rewrite.check_optimization( # type: ignore[attr-defined] - # model, number_of_clusters=32 - # ) + rewrite.check_optimization( # type: ignore[attr-defined] + model, **rewrite_specific_params if rewrite_specific_params else {} + ) teacher.close() return output_filenames @@ -528,9 +534,13 @@ def create_model( # pylint: disable=too-many-arguments loss_fn: Callable, model_is_quantized: bool, model_to_load_from: keras.model | None = None, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Create a model, optionally from another.""" - model = rewrite(input_shape, output_shape) + if rewrite_specific_params: + model = rewrite(input_shape, output_shape, **rewrite_specific_params) + else: + model = rewrite(input_shape, output_shape) if model_is_quantized: model = rewrite.quantize(model) # type: ignore[attr-defined] model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) @@ -558,6 +568,7 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn: Callable, steps_per_epoch: int, post_process: bool = False, + rewrite_specific_params: dict | None = None, ) -> keras.Model: """Train a tflite model.""" steps_so_far = 0 @@ -597,6 +608,7 @@ def model_fit( # pylint: disable=too-many-arguments loss_fn, model_is_quantized, model_to_load_from=model, + rewrite_specific_params=rewrite_specific_params, ) else: model_to_save = model diff --git a/src/mlia/nn/rewrite/library/clustering.py b/src/mlia/nn/rewrite/library/clustering.py index 81bfd90..a81d2d4 100644 --- a/src/mlia/nn/rewrite/library/clustering.py +++ b/src/mlia/nn/rewrite/library/clustering.py @@ -9,11 +9,18 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters -def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: +def fc_clustering_rewrite( + input_shape: Any, + output_shape: Any, + num_clusters: int = 2, + cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long + "CentroidInitialization.LINEAR" + ), +) -> keras.Model: """Fully connected TensorFlow Lite model ready for clustering.""" rewrite_params = { - "number_of_clusters": 4, - "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, + "number_of_clusters": num_clusters, + "cluster_centroids_init": cluster_centroids_init, } model = tfmot.clustering.keras.cluster_weights( to_cluster=keras.Sequential( @@ -28,11 +35,18 @@ def fc_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: return model -def conv2d_clustering_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: +def conv2d_clustering_rewrite( + input_shape: Any, + output_shape: Any, + num_clusters: int = 2, + cluster_centroids_init: tfmot.clustering.keras.CentroidInitialization = tfmot.clustering.keras.CentroidInitialization( # pylint: disable=line-too-long + "CentroidInitialization.LINEAR" + ), +) -> keras.Model: """Conv2d TensorFlow Lite model ready for clustering.""" rewrite_params = { - "number_of_clusters": 4, - "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, + "number_of_clusters": num_clusters, + "cluster_centroids_init": cluster_centroids_init, } conv2d_parameters = compute_conv2d_parameters( input_shape=input_shape, output_shape=output_shape diff --git a/src/mlia/nn/rewrite/library/sparsity.py b/src/mlia/nn/rewrite/library/sparsity.py index 745fa8b..2342e3d 100644 --- a/src/mlia/nn/rewrite/library/sparsity.py +++ b/src/mlia/nn/rewrite/library/sparsity.py @@ -9,7 +9,9 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters -def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: +def fc_sparsity_rewrite( + input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4 +) -> keras.Model: """Fully connected TensorFlow Lite model ready for sparse pruning.""" model = tfmot.sparsity.keras.prune_low_magnitude( to_prune=keras.Sequential( @@ -19,13 +21,18 @@ def fc_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: keras.layers.Dense(output_shape), ] ), - sparsity_m_by_n=(2, 4), + sparsity_m_by_n=( + sparsity_m, + sparsity_n, + ), ) return model -def conv2d_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: +def conv2d_sparsity_rewrite( + input_shape: Any, output_shape: Any, sparsity_m: int = 2, sparsity_n: int = 4 +) -> keras.Model: """Conv2d TensorFlow Lite model ready for sparse pruning.""" conv2d_parameters = compute_conv2d_parameters( input_shape=input_shape, output_shape=output_shape @@ -39,7 +46,10 @@ def conv2d_sparsity_rewrite(input_shape: Any, output_shape: Any) -> keras.Model: keras.layers.ReLU(), ] ), - sparsity_m_by_n=(2, 4), + sparsity_m_by_n=( + sparsity_m, + sparsity_n, + ), ) return model diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index b61e713..d5470d1 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -17,7 +17,7 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewritingOptimizer -from mlia.nn.rewrite.core.rewrite import TrainingParameters +from mlia.nn.rewrite.core.train import TrainingParameters from mlia.nn.tensorflow.config import KerasModel from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.optimizations.clustering import Clusterer @@ -109,7 +109,7 @@ class MultiStageOptimizer(Optimizer): def apply_optimization(self) -> None: """Apply optimization to the model.""" for config in self.optimizations: - optimizer = get_optimizer(self.model, config) + optimizer = get_optimizer(self.model, config, {}) optimizer.apply_optimization() self.model = optimizer.get_model() @@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], - training_parameters: dict | None = None, + rewrite_parameters: dict, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -137,12 +137,12 @@ def get_optimizer( if isinstance(config, OptimizationSettings): return _get_optimizer( - model, cast(OptimizationSettings, config), training_parameters + model, cast(OptimizationSettings, config), rewrite_parameters ) if is_list_of(config, OptimizationSettings): return _get_optimizer( - model, cast(List[OptimizationSettings], config), training_parameters + model, cast(List[OptimizationSettings], config), rewrite_parameters ) raise ConfigurationError(f"Unknown optimization configuration {config}") @@ -151,7 +151,7 @@ def get_optimizer( def _get_optimizer( model: keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], - training_parameters: dict | None = None, + rewrite_parameters: dict, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] @@ -162,12 +162,12 @@ def _get_optimizer( _check_optimizer_params(opt_type, opt_target) opt_config = _get_optimizer_configuration( - opt_type, opt_target, layers_to_optimize, dataset, training_parameters + opt_type, opt_target, rewrite_parameters, layers_to_optimize, dataset ) optimizer_configs.append(opt_config) if len(optimizer_configs) == 1: - return get_optimizer(model, optimizer_configs[0]) + return get_optimizer(model, optimizer_configs[0], {}) return MultiStageOptimizer(model, optimizer_configs) @@ -189,9 +189,9 @@ def _get_rewrite_params( def _get_optimizer_configuration( optimization_type: str, optimization_target: int | float | str, + rewrite_parameters: dict, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, - training_parameters: dict | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -212,12 +212,14 @@ def _get_optimizer_configuration( if opt_type == "rewrite": if isinstance(optimization_target, str): - rewrite_params = _get_rewrite_params(training_parameters) return RewriteConfiguration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=rewrite_params, + train_params=_get_rewrite_params(rewrite_parameters["train_params"]), + rewrite_specific_params=rewrite_parameters.get( + "rewrite_specific_params" + ), ) raise ConfigurationError( -- cgit v1.2.1