diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-08 14:08:06 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-04-16 13:11:31 +0100 |
commit | 32405c279d2f98c2d40bdbbb7f7306ff12c86cd6 (patch) | |
tree | 42781ca219b822a9ec9f212a9ee516f65b184a27 /src | |
parent | 427e02696f1ede596ef6dce82787a37e122efa78 (diff) | |
download | mlia-32405c279d2f98c2d40bdbbb7f7306ff12c86cd6.tar.gz |
feat: Implement the clustering rewrite for int8
Implements a clustering rewrite for fully connected layers for int8 models
Resolves: MLIA-1080
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: If48efb22764187a382e5b84bbb5c3b75a6e71b75
Diffstat (limited to 'src')
-rw-r--r-- | src/mlia/nn/rewrite/core/rewrite.py | 132 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/core/train.py | 118 | ||||
-rw-r--r-- | src/mlia/nn/rewrite/library/fc_clustering_layer.py | 4 |
3 files changed, 182 insertions, 72 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 6a3695a..e2c097c 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any from typing import Callable +import numpy as np import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 @@ -53,9 +54,9 @@ class Rewrite(ABC): except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex - @abstractmethod def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" + return model @abstractmethod def training_callbacks(self) -> list: @@ -65,60 +66,41 @@ class Rewrite(ABC): def post_process(self, model: keras.Model) -> keras.Model: """Return default post-processing rewrite options.""" + @abstractmethod + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Check if the optimization has produced the correct result.""" -class ClusteringRewrite(Rewrite): - """Graph clustering rewrite logic to be used by RewritingOptimizer.""" - strip_pruning_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) +class GenericRewrite(Rewrite): + """Graph rewrite logic for fully-connected rewrite.""" def quantize(self, model: keras.Model) -> keras.Model: - """Return a quantized model.""" - return model - - def post_process(self, model: keras.Model) -> keras.Model: - """Return the clustering stripped model.""" - return self.strip_pruning_wrapper(model) + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) def training_callbacks(self) -> list: """Return default rewrite callbacks.""" return [] - -class QATRewrite(Rewrite): - """Logic for rewrites requiring quantization-aware training.""" - - def pruning_preserved_quantization( - self, - model: keras.Model, - ) -> keras.Model: - """Apply pruning-preserved quantization training to a given model.""" - model = tfmot.quantization.keras.quantize_annotate_model(model) - model = tfmot.quantization.keras.quantize_apply( - model, - tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), - ) - + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" return model + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True -class FullyConnectedRewrite(Rewrite): - """Graph rewrite logic for fully-connected rewrite.""" - - def quantize(self, model: keras.Model) -> keras.Model: - """Return a quantized model if required.""" - model = tfmot.quantization.keras.quantize_model(model) - return model - def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" - return [] +class QuantizeAwareTrainingRewrite(Rewrite, ABC): + """Abstract class for rewrites that perform QAT.""" - def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + @abstractmethod + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply optimization-aware quantization to a given model.""" return model -class Sparsity24Rewrite(QATRewrite): +class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep @@ -137,6 +119,74 @@ class Sparsity24Rewrite(QATRewrite): """Pruning-specific post-processing rewrite options.""" return self.strip_pruning_wrapper(model) + def preserved_quantize( + self, + model: keras.Model, + ) -> keras.Model: + """Apply pruning-preserved quantization training to a given model.""" + model = tfmot.quantization.keras.quantize_annotate_model(model) + model = tfmot.quantization.keras.quantize_apply( + model, + tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), + ) + + return model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True + + +class ClusteringRewrite(QuantizeAwareTrainingRewrite): + """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + + _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply clustering-preserved quantization to a given model.""" + quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model) + cqat_model = tfmot.quantization.keras.quantize_apply( + quant_aware_model, + tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(), + ) + return cqat_model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Check if clustering has produced the correct result.""" + number_of_clusters = kwargs.get("number_of_clusters") + if not number_of_clusters: + raise ValueError( + """ + Expected check_preserved_quantize to have argument number_of_clusters. + """ + ) + + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + number_of_found_clusters = len(np.unique(weight)) + if number_of_found_clusters != number_of_clusters: + logger.warning( + "\nWARNING: Expected %d cluster(s), found %d " + "cluster(s) in layer %s for weight %s \n", + number_of_clusters, + number_of_found_clusters, + layer.name, + weight.name, + ) + return False + return True + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return the clustering stripped model.""" + return self._strip_clustering_wrapper(model) + class RewriteRegistry(Registry[Rewrite]): """Registry rewrite functions.""" @@ -176,7 +226,7 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ - FullyConnectedRewrite("fully-connected", fc_rewrite), + GenericRewrite("fully-connected", fc_rewrite), Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), ] @@ -200,7 +250,7 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] - is_qat = isinstance(rewrite, QATRewrite) + use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -218,9 +268,9 @@ class RewritingOptimizer(Optimizer): output_model=str(tmp_output), input_tfrec=str(tfrecord), rewrite=rewrite, + is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite), input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], - is_qat=is_qat, train_params=self.optimizer_configuration.train_params, ) diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 4b9821c..88efa23 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -73,15 +73,15 @@ class TrainingParameters: checkpoint_at: list | None = None -def train( +def train( # pylint: disable=too-many-arguments source_model: str, unmodified_model: Any, output_model: str, input_tfrec: str, rewrite: Callable, + is_qat: bool, input_tensors: list, output_tensors: list, - is_qat: bool, train_params: TrainingParameters = TrainingParameters(), ) -> Any: """Extract and train a model, and return the results.""" @@ -383,15 +383,15 @@ def train_in_dir( ) input_shape = teacher.shape_from_name[input_name][1:] + output_shape = teacher.shape_from_name[output_name][1:] - model = rewrite(input_shape, output_shape) optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - if model_is_quantized: - model = rewrite.quantize(model) # type: ignore[attr-defined] - model = model_compile(model, optimizer, loss_fn) + model = create_model( + rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + ) logger.info(model.summary()) @@ -432,16 +432,14 @@ def train_in_dir( callbacks = [] callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] - - output_filenames = [] # type: list[str] + output_filenames: list = [] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps ] - model, output_filenames = model_fit( model, train_params, - checkpoints.copy(), + checkpoints, optimizer, dataset, callbacks, @@ -452,22 +450,35 @@ def train_in_dir( output_name, model_is_quantized, output_filenames, + input_shape, + output_shape, + loss_fn, + post_process=True, ) + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) if model_is_quantized and is_qat: - model = rewrite.pruning_preserved_quantization( # type: ignore[attr-defined] - model, - ) + model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] + checkpoints = ( + train_params.checkpoint_at if train_params.checkpoint_at else [] + ) + [train_params.steps] + output_filenames = [] + + if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined] + rewrite.training_callbacks() # type: ignore[attr-defined] + ).issubset(callbacks): + callbacks.pop(-1) + optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) model = model_compile(model, optimizer, loss_fn) - callbacks.pop(-1) - output_filenames = [] - model, output_filenames = model_fit( model, train_params, - checkpoints.copy(), + checkpoints, optimizer, dataset, callbacks, @@ -478,22 +489,50 @@ def train_in_dir( output_name, model_is_quantized, output_filenames, + input_shape, + output_shape, + loss_fn, ) + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) teacher.close() return output_filenames def model_compile( - model: tf.keras.Model, optimizer: tf.keras.optimizers, loss_fn: tf.keras.losses -) -> tf.keras.Model: + model: keras.Model, + optimizer: keras.optimizers.Nadam, + loss_fn: keras.losses.Loss, +) -> keras.Model: """Compiles a tflite model.""" model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) return model -def model_fit( - model: tf.keras.Model, +def create_model( # pylint: disable=too-many-arguments + rewrite: Callable, + input_shape: int, + output_shape: int, + optimizer: Callable, + loss_fn: Callable, + model_is_quantized: bool, + model_to_load_from: keras.model | None = None, +) -> keras.Model: + """Create a model, optionally from another.""" + model = rewrite(input_shape, output_shape) + if model_is_quantized: + model = rewrite.quantize(model) # type: ignore[attr-defined] + model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) + if model_to_load_from: + model.set_weights(model_to_load_from.get_weights()) + return model + + +def model_fit( # pylint: disable=too-many-arguments + model: keras.Model, train_params: TrainingParameters, checkpoints: list, optimizer: tf.optimizers.Nadam, @@ -506,8 +545,12 @@ def model_fit( output_name: str, model_is_quantized: bool, output_filenames: list, -) -> tuple[tf.keras.Model, list]: - """Train the model.""" + input_shape: int, + output_shape: int, + loss_fn: Callable, + post_process: bool = False, +) -> keras.Model: + """Train a tflite model.""" steps_so_far = 0 while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far @@ -534,18 +577,34 @@ def model_fit( checkpoint_filename = ( filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext ) + # If post processing we are stripping the clustering/pruning layers below + # Thus copy the model before saving, so training can continue + if post_process: + model_to_save = create_model( + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + model_to_load_from=model, + ) + else: + model_to_save = model else: checkpoint_filename = str(output_filename) - - if steps_so_far == train_params.steps: - model = rewrite.post_process(model) # type: ignore[attr-defined] + model_to_save = model with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): + if post_process: + model_to_save = rewrite.post_process( # type: ignore[attr-defined] + model_to_save + ) save_as_tflite( - model, - str(checkpoint_filename), + model_to_save, + checkpoint_filename, input_name, replace.shape_from_name[input_name], output_name, @@ -553,7 +612,8 @@ def model_fit( model_is_quantized, ) output_filenames.append(checkpoint_filename) - return model, output_filenames + + return model_to_save, output_filenames def save_as_tflite( diff --git a/src/mlia/nn/rewrite/library/fc_clustering_layer.py b/src/mlia/nn/rewrite/library/fc_clustering_layer.py index 72931c0..7cc383e 100644 --- a/src/mlia/nn/rewrite/library/fc_clustering_layer.py +++ b/src/mlia/nn/rewrite/library/fc_clustering_layer.py @@ -9,7 +9,7 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model: """Generate TensorFlow Lite model for clustering rewrite.""" - clustering_params = { + rewrite_params = { "number_of_clusters": 32, "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, } @@ -21,6 +21,6 @@ def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model: keras.layers.Dense(units=output_shape), ] ), - **clustering_params + **rewrite_params ) return model |