From 1ebb335cba516bcf973b041efa6a9878d1022b93 Mon Sep 17 00:00:00 2001 From: Madeleine Dunn Date: Wed, 21 Feb 2024 17:10:07 +0000 Subject: feat: Implement int8 sparsity 2:4 rewrite - Implement pruning-preserving quantisation aware training - Rework the training logic to avoid duplication - Remove the DynamicallyLoadedRewrite class as it is now unused Resolves: MLIA-1003 Signed-off-by: Madeleine Dunn Change-Id: Ia7a4acf5f477a27963cffa88180cca085b32ffe4 --- src/mlia/nn/rewrite/core/rewrite.py | 94 +++++++++++++++++----------------- src/mlia/nn/rewrite/core/train.py | 95 +++++++++++++++++++++++++++++++---- tests/test_nn_rewrite_core_rewrite.py | 71 +++++++++----------------- tests/test_nn_rewrite_core_train.py | 12 +++-- 4 files changed, 163 insertions(+), 109 deletions(-) diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py index 2a7b432..4fe1c26 100644 --- a/src/mlia/nn/rewrite/core/rewrite.py +++ b/src/mlia/nn/rewrite/core/rewrite.py @@ -3,14 +3,14 @@ """Contains class RewritingOptimizer to replace a subgraph/layer of a model.""" from __future__ import annotations -import importlib import logging import tempfile +from abc import ABC +from abc import abstractmethod from dataclasses import dataclass from pathlib import Path from typing import Any from typing import Callable -from typing import cast import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 @@ -35,8 +35,8 @@ logger = logging.getLogger(__name__) RewriteCallable = Callable[[Any, Any], keras.Model] -class Rewrite: - """Graph rewrite logic to be used by RewritingOptimizer.""" +class Rewrite(ABC): + """Abstract class for rewrite logic to be used by RewritingOptimizer.""" def __init__(self, name: str, rewrite_fn: RewriteCallable): """Initialize a Rewrite instance with a given name and an optional function.""" @@ -50,10 +50,42 @@ class Rewrite: except Exception as ex: raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex - def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model: + @abstractmethod + def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" - if model_is_quantized: - model = tfmot.quantization.keras.quantize_model(model) + + @abstractmethod + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + + @abstractmethod + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + + +class QATRewrite(Rewrite): + """Logic for rewrites requiring quantization-aware training.""" + + def pruning_preserved_quantization( + self, + model: keras.Model, + ) -> keras.Model: + """Apply pruning-preserved quantization training to a given model.""" + model = tfmot.quantization.keras.quantize_annotate_model(model) + model = tfmot.quantization.keras.quantize_apply( + model, + tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(), + ) + + return model + + +class FullyConnectedRewrite(Rewrite): + """Graph rewrite logic for fully-connected rewrite.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + model = tfmot.quantization.keras.quantize_model(model) return model def training_callbacks(self) -> list: @@ -65,18 +97,15 @@ class Rewrite: return model -class PruningRewrite(Rewrite): - """Derived Rewrite class with pruning-specific logic.""" +class Sparsity24Rewrite(QATRewrite): + """Graph rewrite logic for fully-connected-sparsity24 rewrite.""" pruning_callback = tfmot.sparsity.keras.UpdatePruningStep strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning) - def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model: - """Return a quantized model if required.""" - if model_is_quantized: - # placeholder for PQAT - pass + def quantize(self, model: keras.Model) -> keras.Model: + """Skip quantization when using pruning rewrite.""" return model def training_callbacks(self) -> list: @@ -88,35 +117,6 @@ class PruningRewrite(Rewrite): return self.strip_pruning_wrapper(model) -@dataclass -class DynamicallyLoadedRewrite(Rewrite): - """A rewrite which can load logic from a function loaded dynamically.""" - - def __init__(self, name: str, function_name: str): - """Initialize.""" - - def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model: - """Load the function from a file dynamically.""" - self.load_function(function_name) - return self.function(input_shape, output_shape) - - super().__init__(name, load_and_run) - - def load_function(self, function_name: str) -> RewriteCallable: - """Return the rewrite function. Import using the auto_load attr if necessary.""" - try: - name_parts = function_name.split(".") - module_name = ".".join(name_parts[:-1]) - fn_name = name_parts[-1] - module = importlib.import_module(module_name) - self.function = cast(RewriteCallable, getattr(module, fn_name)) - return self.function - except Exception as ex: - raise RuntimeError( - f"Unable to load rewrite function '{function_name}' for '{self.name}'." - ) from ex - - class RewriteRegistry(Registry[Rewrite]): """Registry rewrite functions.""" @@ -155,8 +155,8 @@ class RewritingOptimizer(Optimizer): registry = RewriteRegistry( [ - Rewrite("fully-connected", fc_rewrite), - PruningRewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), + FullyConnectedRewrite("fully-connected", fc_rewrite), + Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24), ] ) @@ -178,6 +178,7 @@ class RewritingOptimizer(Optimizer): rewrite = RewritingOptimizer.registry.items[ self.optimizer_configuration.optimization_target ] + is_qat = isinstance(rewrite, QATRewrite) use_unmodified_model = True tflite_model = self.model.model_path tfrecord = str(self.optimizer_configuration.dataset) @@ -190,7 +191,6 @@ class RewritingOptimizer(Optimizer): "Input and output tensor names need to be set for rewrite." ) - self.optimizer_configuration.train_params.checkpoint_at = [5000, 10000] orig_vs_repl_stats, total_stats = train( source_model=tflite_model, unmodified_model=tflite_model if use_unmodified_model else None, @@ -199,6 +199,7 @@ class RewritingOptimizer(Optimizer): rewrite=rewrite, input_tensors=[self.optimizer_configuration.layers_to_optimize[0]], output_tensors=[self.optimizer_configuration.layers_to_optimize[1]], + is_qat=is_qat, train_params=self.optimizer_configuration.train_params, ) @@ -245,6 +246,7 @@ class RewritingOptimizer(Optimizer): notes=notes, ) logger.info(table.to_plain_text(show_title=True)) + self.model = TFLiteModel(tmp_output) def get_model(self) -> TFLiteModel: """Return optimized model.""" diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py index 89de880..4b9821c 100644 --- a/src/mlia/nn/rewrite/core/train.py +++ b/src/mlia/nn/rewrite/core/train.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Sequential trainer.""" +# pylint: disable=too-many-arguments # pylint: disable=too-many-locals # pylint: disable=too-many-statements from __future__ import annotations @@ -80,6 +81,7 @@ def train( rewrite: Callable, input_tensors: list, output_tensors: list, + is_qat: bool, train_params: TrainingParameters = TrainingParameters(), ) -> Any: """Extract and train a model, and return the results.""" @@ -118,6 +120,7 @@ def train( baseline_dir=unmodified_model_dir_path, output_filename=Path(train_dir, "new.tflite"), rewrite=rewrite, + is_qat=is_qat, train_params=train_params, ) @@ -345,6 +348,7 @@ def train_in_dir( baseline_dir: Any, output_filename: Path, rewrite: Callable, + is_qat: bool, train_params: TrainingParameters = TrainingParameters(), ) -> list[str]: """Train a replacement for replace.tflite using the input.tfrec \ @@ -385,8 +389,9 @@ def train_in_dir( optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) loss_fn = keras.losses.MeanSquaredError() - model = rewrite.quantize(model, model_is_quantized) # type: ignore[attr-defined] - model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) + if model_is_quantized: + model = rewrite.quantize(model) # type: ignore[attr-defined] + model = model_compile(model, optimizer, loss_fn) logger.info(model.summary()) @@ -428,11 +433,82 @@ def train_in_dir( callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined] - output_filenames = [] + output_filenames = [] # type: list[str] checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [ train_params.steps ] + model, output_filenames = model_fit( + model, + train_params, + checkpoints.copy(), + optimizer, + dataset, + callbacks, + output_filename, + rewrite, + replace, + input_name, + output_name, + model_is_quantized, + output_filenames, + ) + + if model_is_quantized and is_qat: + model = rewrite.pruning_preserved_quantization( # type: ignore[attr-defined] + model, + ) + optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate) + model = model_compile(model, optimizer, loss_fn) + + callbacks.pop(-1) + output_filenames = [] + + model, output_filenames = model_fit( + model, + train_params, + checkpoints.copy(), + optimizer, + dataset, + callbacks, + output_filename, + rewrite, + replace, + input_name, + output_name, + model_is_quantized, + output_filenames, + ) + + teacher.close() + return output_filenames + + +def model_compile( + model: tf.keras.Model, optimizer: tf.keras.optimizers, loss_fn: tf.keras.losses +) -> tf.keras.Model: + """Compiles a tflite model.""" + model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"]) + return model + + +def model_fit( + model: tf.keras.Model, + train_params: TrainingParameters, + checkpoints: list, + optimizer: tf.optimizers.Nadam, + dataset: tf.data.Dataset, + callbacks: list, + output_filename: Path, + rewrite: Callable, + replace: TFLiteModel, + input_name: str, + output_name: str, + model_is_quantized: bool, + output_filenames: list, +) -> tuple[tf.keras.Model, list]: + """Train the model.""" + steps_so_far = 0 while steps_so_far < train_params.steps: steps_to_train = checkpoints.pop(0) - steps_so_far lr_start = optimizer.learning_rate.numpy() @@ -460,15 +536,16 @@ def train_in_dir( ) else: checkpoint_filename = str(output_filename) + + if steps_so_far == train_params.steps: + model = rewrite.post_process(model) # type: ignore[attr-defined] + with log_action( f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}" ): - if steps_so_far == train_params.steps: - model = rewrite.post_process(model) # type: ignore[attr-defined] - save_as_tflite( model, - checkpoint_filename, + str(checkpoint_filename), input_name, replace.shape_from_name[input_name], output_name, @@ -476,9 +553,7 @@ def train_in_dir( model_is_quantized, ) output_filenames.append(checkpoint_filename) - - teacher.close() - return output_filenames + return model, output_filenames def save_as_tflite( diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index e614cad..8ef5bd2 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -11,45 +11,51 @@ from unittest.mock import MagicMock import pytest -from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite -from mlia.nn.rewrite.core.rewrite import Rewrite +from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.train import train_in_dir from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters -def mock_rewrite_function(*_: Any) -> Any: - """Mock function to test autoloading of rewrite functions.""" - - def test_rewrite() -> None: - """Test the Rewrite class.""" + """Test a derived Rewrite class.""" def bad_rewrite_func() -> Any: raise NotImplementedError() - rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)) + rewrite = Sparsity24Rewrite( + "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) + ) with pytest.raises(RuntimeError): rewrite((1, 2), (1, 2)) @pytest.mark.parametrize( - "rewrite_name, callbacks_length", + "rewrite_name, rewrite_class", [ - ("fully-connected", 0), - ("fully-connected-sparsity24", 1), + ("fully-connected", FullyConnectedRewrite), + ("fully-connected-sparsity24", Sparsity24Rewrite), ], ) -def test_rewrite_selection(rewrite_name: str, callbacks_length: int) -> None: - """Test that the correct rewrite class is instantiated.""" - rewrite = RewritingOptimizer.registry.items[rewrite_name] +def test_rewrite_selection( + rewrite_name: str, + rewrite_class: Any, +) -> None: + """Check that the correct rewrite class is instantiated through the registry""" + config_obj = RewriteConfiguration( + rewrite_name, + ["sample_node_start", "sample_node_end"], + ) + + rewrite = RewritingOptimizer.registry.items[config_obj.optimization_target] assert rewrite.name == rewrite_name - assert len(rewrite.training_callbacks()) == callbacks_length + assert isinstance(rewrite, rewrite_class) @pytest.mark.parametrize( @@ -106,8 +112,8 @@ def test_register_rewrite_function() -> None: """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() - rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1)) - rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2)) + rewrite1 = FullyConnectedRewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = Sparsity24Rewrite("r2", cast(RewriteCallable, lambda: 2)) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) @@ -122,37 +128,6 @@ def test_builtin_rewrite_names() -> None: ] -def test_rewrite_function_autoload() -> None: - """Test rewrite function loading.""" - function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function" - rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name) - assert rewrite.name == "mock_rewrite" - - assert rewrite.function is not mock_rewrite_function - assert rewrite.load_function(function_name) is mock_rewrite_function - assert rewrite.function is mock_rewrite_function - - -def test_rewrite_function_autoload_fail() -> None: - """Test rewrite function loading failure.""" - function_name = "invalid_module.invalid_function" - rewrite = DynamicallyLoadedRewrite( - name="mock_rewrite", - function_name="invalid_module.invalid_function", - ) - assert rewrite.name == "mock_rewrite" - - with pytest.raises(Exception) as exc_info: - rewrite.load_function(function_name) - - message = exc_info.value.args[0] - - assert message == ( - "Unable to load rewrite function 'invalid_module.invalid_function'" - " for 'mock_rewrite'." - ) - - def test_rewrite_configuration_train_params( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 34b9543..371c79f 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -14,13 +14,15 @@ import pytest import tensorflow as tf from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 -from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite +from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite +from mlia.nn.rewrite.core.rewrite import QATRewrite from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters +from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite from tests.utils.rewrite import MockTrainingParameters @@ -54,12 +56,11 @@ def check_train( """Test the train() function.""" with TemporaryDirectory() as tmp_dir: output_file = Path(tmp_dir, "out.tflite") - mock_rewrite = DynamicallyLoadedRewrite( + mock_rewrite = FullyConnectedRewrite( name="replace", - function_name=( - "tests.test_nn_rewrite_core_train.replace_fully_connected_with_conv" - ), + rewrite_fn=fc_rewrite, ) + is_qat = isinstance(mock_rewrite, QATRewrite) result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, @@ -68,6 +69,7 @@ def check_train( rewrite=mock_rewrite, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], + is_qat=is_qat, train_params=train_params, ) -- cgit v1.2.1