From 32405c279d2f98c2d40bdbbb7f7306ff12c86cd6 Mon Sep 17 00:00:00 2001 From: Nathan Bailey Date: Fri, 8 Mar 2024 14:08:06 +0000 Subject: feat: Implement the clustering rewrite for int8 Implements a clustering rewrite for fully connected layers for int8 models Resolves: MLIA-1080 Signed-off-by: Nathan Bailey Change-Id: If48efb22764187a382e5b84bbb5c3b75a6e71b75 --- setup.cfg | 2 +- src/mlia/nn/rewrite/core/rewrite.py | 132 ++++++++++++++------- src/mlia/nn/rewrite/core/train.py | 118 +++++++++++++----- src/mlia/nn/rewrite/library/fc_clustering_layer.py | 4 +- tests/test_nn_rewrite_core_rewrite.py | 91 ++++++++++++-- tests/test_nn_rewrite_core_train.py | 12 +- 6 files changed, 264 insertions(+), 95 deletions(-) diff --git a/setup.cfg b/setup.cfg index 6ddb576..0714caf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-FileCopyrightText: Copyright (c) 2020 Troy Comi # SPDX-License-Identifier: Apache-2.0 AND MIT diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 6a3695a..e2c097c 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any from typing import Callable +import numpy as np import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 @@ -53,9 +54,9 @@ class Rewrite(ABC): except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex - @abstractmethod def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" + return model @abstractmethod def training_callbacks(self) -> list: @@ -65,60 +66,41 @@ class Rewrite(ABC): def post_process(self, model: keras.Model) -> keras.Model: """Return default post-processing rewrite options.""" + @abstractmethod + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Check if the optimization has produced the correct result.""" -class ClusteringRewrite(Rewrite): - """Graph clustering rewrite logic to be used by RewritingOptimizer.""" - strip_pruning_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) +class GenericRewrite(Rewrite): + """Graph rewrite logic for fully-connected rewrite.""" def quantize(self, model: keras.Model) -> keras.Model: - """Return a quantized model.""" - return model - - def post_process(self, model: keras.Model) -> keras.Model: - """Return the clustering stripped model.""" - return self.strip_pruning_wrapper(model) + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) def training_callbacks(self) -> list: """Return default rewrite callbacks.""" return [] - -class QATRewrite(Rewrite): - """Logic for rewrites requiring quantization-aware training.""" - - def pruning_preserved_quantization( - self, - model: keras.Model, - ) -> keras.Model: - """Apply pruning-preserved quantization training to a given model.""" - model = tfmot.quantization.keras.quantize_annotate_model(model) - model = tfmot.quantization.keras.quantize_apply( - model, - tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), - ) - + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" return model + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True -class FullyConnectedRewrite(Rewrite): - """Graph rewrite logic for fully-connected rewrite.""" - - def quantize(self, model: keras.Model) -> keras.Model: - """Return a quantized model if required.""" - model = tfmot.quantization.keras.quantize_model(model) - return model - def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" - return [] +class QuantizeAwareTrainingRewrite(Rewrite, ABC): + """Abstract class for rewrites that perform QAT.""" - def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" + @abstractmethod + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply optimization-aware quantization to a given model.""" return model -class Sparsity24Rewrite(QATRewrite): +class Sparsity24Rewrite(QuantizeAwareTrainingRewrite): """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep @@ -137,6 +119,74 @@ class Sparsity24Rewrite(QATRewrite): """Pruning-specific post-processing rewrite options.""" return self.strip_pruning_wrapper(model) + def preserved_quantize( + self, + model: keras.Model, + ) -> keras.Model: + """Apply pruning-preserved quantization training to a given model.""" + model = tfmot.quantization.keras.quantize_annotate_model(model) + model = tfmot.quantization.keras.quantize_apply( + model, + tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), + ) + + return model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Not needed here.""" + return True + + +class ClusteringRewrite(QuantizeAwareTrainingRewrite): + """Graph clustering rewrite logic to be used by RewritingOptimizer.""" + + _strip_clustering_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Apply clustering-preserved quantization to a given model.""" + quant_aware_model = tfmot.quantization.keras.quantize_annotate_model(model) + cqat_model = tfmot.quantization.keras.quantize_apply( + quant_aware_model, + tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(), + ) + return cqat_model + + def check_optimization(self, model: keras.Model, **kwargs: Any) -> bool: + """Check if clustering has produced the correct result.""" + number_of_clusters = kwargs.get("number_of_clusters") + if not number_of_clusters: + raise ValueError( + """ + Expected check_preserved_quantize to have argument number_of_clusters. + """ + ) + + for layer in model.layers: + for weight in layer.weights: + if "kernel" in weight.name: + if "kernel_min" in weight.name or "kernel_max" in weight.name: + continue + number_of_found_clusters = len(np.unique(weight)) + if number_of_found_clusters != number_of_clusters: + logger.warning( + "\nWARNING: Expected %d cluster(s), found %d " + "cluster(s) in layer %s for weight %s \n", + number_of_clusters, + number_of_found_clusters, + layer.name, + weight.name, + ) + return False + return True + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return the clustering stripped model.""" + return self._strip_clustering_wrapper(model) + class RewriteRegistry(Registry[Rewrite]): """Registry rewrite functions.""" @@ -176,7 +226,7 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ - FullyConnectedRewrite("fully-connected", fc_rewrite), + GenericRewrite("fully-connected", fc_rewrite), Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite), ] @@ -200,7 +250,7 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] - is_qat = isinstance(rewrite, QATRewrite) + use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -218,9 +268,9 @@ class RewritingOptimizer(Optimizer): output_model=str(tmp_output), input_tfrec=str(tfrecord), rewrite=rewrite, + is_qat=isinstance(rewrite, QuantizeAwareTrainingRewrite), input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], - is_qat=is_qat, train_params=self.optimizer_configuration.train_params, ) diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 4b9821c..88efa23 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -73,15 +73,15 @@ class TrainingParameters: checkpoint_at: list | None = None -def train( +def train( # pylint: disable=too-many-arguments source_model: str, unmodified_model: Any, output_model: str, input_tfrec: str, rewrite: Callable, + is_qat: bool, input_tensors: list, output_tensors: list, - is_qat: bool, train_params: TrainingParameters = TrainingParameters(), ) -> Any: """Extract and train a model, and return the results.""" @@ -383,15 +383,15 @@ def train_in_dir( ) input_shape = teacher.shape_from_name[input_name][1:] + output_shape = teacher.shape_from_name[output_name][1:] - model = rewrite(input_shape, output_shape) optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - if model_is_quantized: - model = rewrite.quantize(model) # type: ignore[attr-defined] - model = model_compile(model, optimizer, loss_fn) + model = create_model( + rewrite, input_shape, output_shape, optimizer, loss_fn, model_is_quantized + ) logger.info(model.summary()) @@ -432,16 +432,14 @@ def train_in_dir( callbacks = [] callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] - - output_filenames = [] # type: list[str] + output_filenames: list = [] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps ] - model, output_filenames = model_fit( model, train_params, - checkpoints.copy(), + checkpoints, optimizer, dataset, callbacks, @@ -452,22 +450,35 @@ def train_in_dir( output_name, model_is_quantized, output_filenames, + input_shape, + output_shape, + loss_fn, + post_process=True, ) + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) if model_is_quantized and is_qat: - model = rewrite.pruning_preserved_quantization( # type: ignore[attr-defined] - model, - ) + model = rewrite.preserved_quantize(model) # type: ignore[attr-defined] + checkpoints = ( + train_params.checkpoint_at if train_params.checkpoint_at else [] + ) + [train_params.steps] + output_filenames = [] + + if len(rewrite.training_callbacks()) > 0 and set( # type: ignore[attr-defined] + rewrite.training_callbacks() # type: ignore[attr-defined] + ).issubset(callbacks): + callbacks.pop(-1) + optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) model = model_compile(model, optimizer, loss_fn) - callbacks.pop(-1) - output_filenames = [] - model, output_filenames = model_fit( model, train_params, - checkpoints.copy(), + checkpoints, optimizer, dataset, callbacks, @@ -478,22 +489,50 @@ def train_in_dir( output_name, model_is_quantized, output_filenames, + input_shape, + output_shape, + loss_fn, ) + # Placeholder for now, will be parametrized later (MLIA-1114) + # rewrite.check_optimization( # type: ignore[attr-defined] + # model, number_of_clusters=32 + # ) teacher.close() return output_filenames def model_compile( - model: tf.keras.Model, optimizer: tf.keras.optimizers, loss_fn: tf.keras.losses -) -> tf.keras.Model: + model: keras.Model, + optimizer: keras.optimizers.Nadam, + loss_fn: keras.losses.Loss, +) -> keras.Model: """Compiles a tflite model.""" model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) return model -def model_fit( - model: tf.keras.Model, +def create_model( # pylint: disable=too-many-arguments + rewrite: Callable, + input_shape: int, + output_shape: int, + optimizer: Callable, + loss_fn: Callable, + model_is_quantized: bool, + model_to_load_from: keras.model | None = None, +) -> keras.Model: + """Create a model, optionally from another.""" + model = rewrite(input_shape, output_shape) + if model_is_quantized: + model = rewrite.quantize(model) # type: ignore[attr-defined] + model = model_compile(model, optimizer=optimizer, loss_fn=loss_fn) + if model_to_load_from: + model.set_weights(model_to_load_from.get_weights()) + return model + + +def model_fit( # pylint: disable=too-many-arguments + model: keras.Model, train_params: TrainingParameters, checkpoints: list, optimizer: tf.optimizers.Nadam, @@ -506,8 +545,12 @@ def model_fit( output_name: str, model_is_quantized: bool, output_filenames: list, -) -> tuple[tf.keras.Model, list]: - """Train the model.""" + input_shape: int, + output_shape: int, + loss_fn: Callable, + post_process: bool = False, +) -> keras.Model: + """Train a tflite model.""" steps_so_far = 0 while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far @@ -534,18 +577,34 @@ def model_fit( checkpoint_filename = ( filename_dir + "/" + filename + (f"_@{steps_so_far}") + ext ) + # If post processing we are stripping the clustering/pruning layers below + # Thus copy the model before saving, so training can continue + if post_process: + model_to_save = create_model( + rewrite, + input_shape, + output_shape, + optimizer, + loss_fn, + model_is_quantized, + model_to_load_from=model, + ) + else: + model_to_save = model else: checkpoint_filename = str(output_filename) - - if steps_so_far == train_params.steps: - model = rewrite.post_process(model) # type: ignore[attr-defined] + model_to_save = model with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): + if post_process: + model_to_save = rewrite.post_process( # type: ignore[attr-defined] + model_to_save + ) save_as_tflite( - model, - str(checkpoint_filename), + model_to_save, + checkpoint_filename, input_name, replace.shape_from_name[input_name], output_name, @@ -553,7 +612,8 @@ def model_fit( model_is_quantized, ) output_filenames.append(checkpoint_filename) - return model, output_filenames + + return model_to_save, output_filenames def save_as_tflite( diff --git a/src/mlia/nn/rewrite/library/fc_clustering_layer.py b/src/mlia/nn/rewrite/library/fc_clustering_layer.py index 72931c0..7cc383e 100644 --- a/src/mlia/nn/rewrite/library/fc_clustering_layer.py +++ b/src/mlia/nn/rewrite/library/fc_clustering_layer.py @@ -9,7 +9,7 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model: """Generate TensorFlow Lite model for clustering rewrite.""" - clustering_params = { + rewrite_params = { "number_of_clusters": 32, "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR, } @@ -21,6 +21,6 @@ def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model: keras.layers.Dense(units=output_shape), ] ), - **clustering_params + **rewrite_params ) return model diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index ef4df6a..e502842 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -10,13 +10,14 @@ from typing import cast from unittest.mock import MagicMock import pytest +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module ClusterWeights, ) from mlia.nn.rewrite.core.rewrite import ClusteringRewrite -from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite +from mlia.nn.rewrite.core.rewrite import GenericRewrite from mlia.nn.rewrite.core.rewrite import Rewrite from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration @@ -25,17 +26,48 @@ from mlia.nn.rewrite.core.rewrite import RewritingOptimizer from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.library.fc_clustering_layer import ( + get_keras_model_clus as fc_clustering_rewrite, +) from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters +class TestRewrite(Rewrite): + """Test rewrite class.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Not needed.""" + return model + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + return model + + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Not needed here.""" + return True + + +def mock_rewrite_function(*_: Any) -> Any: + """Mock function to test autoloading of rewrite functions.""" + + def test_rewrite() -> None: """Test a derived Rewrite class.""" def bad_rewrite_func() -> Any: raise NotImplementedError() - rewrite = Sparsity24Rewrite( + rewrite = TestRewrite( "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) ) with pytest.raises(RuntimeError): @@ -45,7 +77,7 @@ def test_rewrite() -> None: @pytest.mark.parametrize( "rewrite_name, callbacks_length, instance", [ - ("fully-connected", 0, Rewrite), + ("fully-connected", 0, GenericRewrite), ("fully-connected-clustering", 0, ClusteringRewrite), ("fully-connected-sparsity24", 1, Sparsity24Rewrite), ], @@ -72,8 +104,8 @@ def test_rewrite_selection( def test_rewrite_configuration( test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any ) -> None: - """Test get_rewrite function only supports rewrite types - fully-connected, fully-connected-clustering and fully-connected-sparsity24.""" + """Test get_rewrite function only supports rewrite type fully-connected, + fully-connected-clustering and fully-connected-sparsity24.""" with expected_error: config_obj = RewriteConfiguration( rewrite_name, @@ -88,28 +120,61 @@ def test_rewrite_configuration( assert isinstance(rewriter_obj, RewritingOptimizer) +def test_rewrite_fully_connected_clustering() -> None: + """Check that model has the set number of clusters""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model, number_of_clusters=32) + + +def test_rewrite_fully_connected_clustering_error_handling() -> None: + """Check that model has the set number of clusters + and that when quantized the number of clusters + remain.""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + with pytest.raises( + ValueError, + match=( + r"Expected check_preserved_quantize to have argument number_of_clusters" + ), + ): + rewrite.check_optimization(model, bad_arg_name=25) + + @pytest.mark.parametrize( - "rewrite_type, expected_layers", + "rewrite_type, expected_layers, quant", [ - ["fully-connected", [keras.layers.Reshape, keras.layers.Dense]], - ["fully-connected-clustering", [ClusterWeights, ClusterWeights]], + ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True], ], ) -def test_rewriting_optimizer( +def test_rewriting_optimizer( # pylint: disable=too-many-locals test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, + test_tflite_model: Path, + test_tfrecord: Path, rewrite_type: str, expected_layers: list[object], + quant: bool, ) -> None: """Test fc_layer rewrite process with rewrite type fully-connected.""" + + tfrecord = test_tfrecord if quant else test_tfrecord_fp32 + tflite_model = test_tflite_model if quant else test_tflite_model_fp32 + config_obj = RewriteConfiguration( rewrite_type, ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], - test_tfrecord_fp32, + tfrecord, train_params=MockTrainingParameters(), ) - test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) + test_obj = RewritingOptimizer(tflite_model, config_obj) rewrite_function = RewritingOptimizer.registry.items[ test_obj.optimizer_configuration.optimization_target ] @@ -132,8 +197,8 @@ def test_register_rewrite_function() -> None: """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() - rewrite1 = FullyConnectedRewrite("r1", cast(RewriteCallable, lambda: 1)) - rewrite2 = Sparsity24Rewrite("r2", cast(RewriteCallable, lambda: 2)) + rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2)) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 371c79f..94c99ff 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -14,15 +14,13 @@ import pytest import tensorflow as tf from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 -from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite -from mlia.nn.rewrite.core.rewrite import QATRewrite from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite +from tests.test_nn_rewrite_core_rewrite import TestRewrite from tests.utils.rewrite import MockTrainingParameters @@ -56,20 +54,16 @@ def check_train( """Test the train() function.""" with TemporaryDirectory() as tmp_dir: output_file = Path(tmp_dir, "out.tflite") - mock_rewrite = FullyConnectedRewrite( - name="replace", - rewrite_fn=fc_rewrite, - ) - is_qat = isinstance(mock_rewrite, QATRewrite) + mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv) result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, output_model=str(output_file), input_tfrec=str(tfrecord), rewrite=mock_rewrite, + is_qat=False, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], - is_qat=is_qat, train_params=train_params, ) -- cgit v1.2.1