aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMadeleine Dunn <madeleine.dunn@arm.com>2024-02-21 17:10:07 +0000
committerMadeleine Dunn <madeleine.dunn@arm.com>2024-04-04 15:26:36 +0100
commit1ebb335cba516bcf973b041efa6a9878d1022b93 (patch)
tree9038cc30c9f32403b715506abbd76f59cbf3d6a6
parent17813ba5be09f0e11fc0748afa4ccf2da02881b6 (diff)
downloadmlia-1ebb335cba516bcf973b041efa6a9878d1022b93.tar.gz
feat: Implement int8 sparsity 2:4 rewrite
- Implement pruning-preserving quantisation aware training - Rework the training logic to avoid duplication - Remove the DynamicallyLoadedRewrite class as it is now unused Resolves: MLIA-1003 Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com> Change-Id: Ia7a4acf5f477a27963cffa88180cca085b32ffe4
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py94
-rw-r--r--src/mlia/nn/rewrite/core/train.py95
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py71
-rw-r--r--tests/test_nn_rewrite_core_train.py12
4 files changed, 163 insertions, 109 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index 2a7b432..4fe1c26 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -3,14 +3,14 @@
"""Contains class RewritingOptimizer to replace a subgraph/layer of a model."""
from __future__ import annotations
-import importlib
import logging
import tempfile
+from abc import ABC
+from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
-from typing import cast
import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
@@ -35,8 +35,8 @@ logger = logging.getLogger(__name__)
RewriteCallable = Callable[[Any, Any], keras.Model]
-class Rewrite:
- """Graph rewrite logic to be used by RewritingOptimizer."""
+class Rewrite(ABC):
+ """Abstract class for rewrite logic to be used by RewritingOptimizer."""
def __init__(self, name: str, rewrite_fn: RewriteCallable):
"""Initialize a Rewrite instance with a given name and an optional function."""
@@ -50,10 +50,42 @@ class Rewrite:
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
- def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model:
+ @abstractmethod
+ def quantize(self, model: keras.Model) -> keras.Model:
"""Return a quantized model if required."""
- if model_is_quantized:
- model = tfmot.quantization.keras.quantize_model(model)
+
+ @abstractmethod
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+
+ @abstractmethod
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+
+
+class QATRewrite(Rewrite):
+ """Logic for rewrites requiring quantization-aware training."""
+
+ def pruning_preserved_quantization(
+ self,
+ model: keras.Model,
+ ) -> keras.Model:
+ """Apply pruning-preserved quantization training to a given model."""
+ model = tfmot.quantization.keras.quantize_annotate_model(model)
+ model = tfmot.quantization.keras.quantize_apply(
+ model,
+ tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme(),
+ )
+
+ return model
+
+
+class FullyConnectedRewrite(Rewrite):
+ """Graph rewrite logic for fully-connected rewrite."""
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ model = tfmot.quantization.keras.quantize_model(model)
return model
def training_callbacks(self) -> list:
@@ -65,18 +97,15 @@ class Rewrite:
return model
-class PruningRewrite(Rewrite):
- """Derived Rewrite class with pruning-specific logic."""
+class Sparsity24Rewrite(QATRewrite):
+ """Graph rewrite logic for fully-connected-sparsity24 rewrite."""
pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
- def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model:
- """Return a quantized model if required."""
- if model_is_quantized:
- # placeholder for PQAT
- pass
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Skip quantization when using pruning rewrite."""
return model
def training_callbacks(self) -> list:
@@ -88,35 +117,6 @@ class PruningRewrite(Rewrite):
return self.strip_pruning_wrapper(model)
-@dataclass
-class DynamicallyLoadedRewrite(Rewrite):
- """A rewrite which can load logic from a function loaded dynamically."""
-
- def __init__(self, name: str, function_name: str):
- """Initialize."""
-
- def load_and_run(input_shape: Any, output_shape: Any) -> keras.Model:
- """Load the function from a file dynamically."""
- self.load_function(function_name)
- return self.function(input_shape, output_shape)
-
- super().__init__(name, load_and_run)
-
- def load_function(self, function_name: str) -> RewriteCallable:
- """Return the rewrite function. Import using the auto_load attr if necessary."""
- try:
- name_parts = function_name.split(".")
- module_name = ".".join(name_parts[:-1])
- fn_name = name_parts[-1]
- module = importlib.import_module(module_name)
- self.function = cast(RewriteCallable, getattr(module, fn_name))
- return self.function
- except Exception as ex:
- raise RuntimeError(
- f"Unable to load rewrite function '{function_name}' for '{self.name}'."
- ) from ex
-
-
class RewriteRegistry(Registry[Rewrite]):
"""Registry rewrite functions."""
@@ -155,8 +155,8 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
- Rewrite("fully-connected", fc_rewrite),
- PruningRewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
+ FullyConnectedRewrite("fully-connected", fc_rewrite),
+ Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
]
)
@@ -178,6 +178,7 @@ class RewritingOptimizer(Optimizer):
rewrite = RewritingOptimizer.registry.items[
self.optimizer_configuration.optimization_target
]
+ is_qat = isinstance(rewrite, QATRewrite)
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
@@ -190,7 +191,6 @@ class RewritingOptimizer(Optimizer):
"Input and output tensor names need to be set for rewrite."
)
- self.optimizer_configuration.train_params.checkpoint_at = [5000, 10000]
orig_vs_repl_stats, total_stats = train(
source_model=tflite_model,
unmodified_model=tflite_model if use_unmodified_model else None,
@@ -199,6 +199,7 @@ class RewritingOptimizer(Optimizer):
rewrite=rewrite,
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
+ is_qat=is_qat,
train_params=self.optimizer_configuration.train_params,
)
@@ -245,6 +246,7 @@ class RewritingOptimizer(Optimizer):
notes=notes,
)
logger.info(table.to_plain_text(show_title=True))
+ self.model = TFLiteModel(tmp_output)
def get_model(self) -> TFLiteModel:
"""Return optimized model."""
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index 89de880..4b9821c 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Sequential trainer."""
+# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
from __future__ import annotations
@@ -80,6 +81,7 @@ def train(
rewrite: Callable,
input_tensors: list,
output_tensors: list,
+ is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
) -> Any:
"""Extract and train a model, and return the results."""
@@ -118,6 +120,7 @@ def train(
baseline_dir=unmodified_model_dir_path,
output_filename=Path(train_dir, "new.tflite"),
rewrite=rewrite,
+ is_qat=is_qat,
train_params=train_params,
)
@@ -345,6 +348,7 @@ def train_in_dir(
baseline_dir: Any,
output_filename: Path,
rewrite: Callable,
+ is_qat: bool,
train_params: TrainingParameters = TrainingParameters(),
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
@@ -385,8 +389,9 @@ def train_in_dir(
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = keras.losses.MeanSquaredError()
- model = rewrite.quantize(model, model_is_quantized) # type: ignore[attr-defined]
- model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
+ if model_is_quantized:
+ model = rewrite.quantize(model) # type: ignore[attr-defined]
+ model = model_compile(model, optimizer, loss_fn)
logger.info(model.summary())
@@ -428,11 +433,82 @@ def train_in_dir(
callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined]
- output_filenames = []
+ output_filenames = [] # type: list[str]
checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
train_params.steps
]
+ model, output_filenames = model_fit(
+ model,
+ train_params,
+ checkpoints.copy(),
+ optimizer,
+ dataset,
+ callbacks,
+ output_filename,
+ rewrite,
+ replace,
+ input_name,
+ output_name,
+ model_is_quantized,
+ output_filenames,
+ )
+
+ if model_is_quantized and is_qat:
+ model = rewrite.pruning_preserved_quantization( # type: ignore[attr-defined]
+ model,
+ )
+ optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
+ model = model_compile(model, optimizer, loss_fn)
+
+ callbacks.pop(-1)
+ output_filenames = []
+
+ model, output_filenames = model_fit(
+ model,
+ train_params,
+ checkpoints.copy(),
+ optimizer,
+ dataset,
+ callbacks,
+ output_filename,
+ rewrite,
+ replace,
+ input_name,
+ output_name,
+ model_is_quantized,
+ output_filenames,
+ )
+
+ teacher.close()
+ return output_filenames
+
+
+def model_compile(
+ model: tf.keras.Model, optimizer: tf.keras.optimizers, loss_fn: tf.keras.losses
+) -> tf.keras.Model:
+ """Compiles a tflite model."""
+ model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
+ return model
+
+
+def model_fit(
+ model: tf.keras.Model,
+ train_params: TrainingParameters,
+ checkpoints: list,
+ optimizer: tf.optimizers.Nadam,
+ dataset: tf.data.Dataset,
+ callbacks: list,
+ output_filename: Path,
+ rewrite: Callable,
+ replace: TFLiteModel,
+ input_name: str,
+ output_name: str,
+ model_is_quantized: bool,
+ output_filenames: list,
+) -> tuple[tf.keras.Model, list]:
+ """Train the model."""
+ steps_so_far = 0
while steps_so_far < train_params.steps:
steps_to_train = checkpoints.pop(0) - steps_so_far
lr_start = optimizer.learning_rate.numpy()
@@ -460,15 +536,16 @@ def train_in_dir(
)
else:
checkpoint_filename = str(output_filename)
+
+ if steps_so_far == train_params.steps:
+ model = rewrite.post_process(model) # type: ignore[attr-defined]
+
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):
- if steps_so_far == train_params.steps:
- model = rewrite.post_process(model) # type: ignore[attr-defined]
-
save_as_tflite(
model,
- checkpoint_filename,
+ str(checkpoint_filename),
input_name,
replace.shape_from_name[input_name],
output_name,
@@ -476,9 +553,7 @@ def train_in_dir(
model_is_quantized,
)
output_filenames.append(checkpoint_filename)
-
- teacher.close()
- return output_filenames
+ return model, output_filenames
def save_as_tflite(
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index e614cad..8ef5bd2 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -11,45 +11,51 @@ from unittest.mock import MagicMock
import pytest
-from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
-from mlia.nn.rewrite.core.rewrite import Rewrite
+from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite
from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
+from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.rewrite.core.train import train_in_dir
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
-def mock_rewrite_function(*_: Any) -> Any:
- """Mock function to test autoloading of rewrite functions."""
-
-
def test_rewrite() -> None:
- """Test the Rewrite class."""
+ """Test a derived Rewrite class."""
def bad_rewrite_func() -> Any:
raise NotImplementedError()
- rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func))
+ rewrite = Sparsity24Rewrite(
+ "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)
+ )
with pytest.raises(RuntimeError):
rewrite((1, 2), (1, 2))
@pytest.mark.parametrize(
- "rewrite_name, callbacks_length",
+ "rewrite_name, rewrite_class",
[
- ("fully-connected", 0),
- ("fully-connected-sparsity24", 1),
+ ("fully-connected", FullyConnectedRewrite),
+ ("fully-connected-sparsity24", Sparsity24Rewrite),
],
)
-def test_rewrite_selection(rewrite_name: str, callbacks_length: int) -> None:
- """Test that the correct rewrite class is instantiated."""
- rewrite = RewritingOptimizer.registry.items[rewrite_name]
+def test_rewrite_selection(
+ rewrite_name: str,
+ rewrite_class: Any,
+) -> None:
+ """Check that the correct rewrite class is instantiated through the registry"""
+ config_obj = RewriteConfiguration(
+ rewrite_name,
+ ["sample_node_start", "sample_node_end"],
+ )
+
+ rewrite = RewritingOptimizer.registry.items[config_obj.optimization_target]
assert rewrite.name == rewrite_name
- assert len(rewrite.training_callbacks()) == callbacks_length
+ assert isinstance(rewrite, rewrite_class)
@pytest.mark.parametrize(
@@ -106,8 +112,8 @@ def test_register_rewrite_function() -> None:
"""Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
- rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1))
- rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2))
+ rewrite1 = FullyConnectedRewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = Sparsity24Rewrite("r2", cast(RewriteCallable, lambda: 2))
registry.register_rewrite(rewrite1)
registry.register_rewrite(rewrite2)
@@ -122,37 +128,6 @@ def test_builtin_rewrite_names() -> None:
]
-def test_rewrite_function_autoload() -> None:
- """Test rewrite function loading."""
- function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function"
- rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name)
- assert rewrite.name == "mock_rewrite"
-
- assert rewrite.function is not mock_rewrite_function
- assert rewrite.load_function(function_name) is mock_rewrite_function
- assert rewrite.function is mock_rewrite_function
-
-
-def test_rewrite_function_autoload_fail() -> None:
- """Test rewrite function loading failure."""
- function_name = "invalid_module.invalid_function"
- rewrite = DynamicallyLoadedRewrite(
- name="mock_rewrite",
- function_name="invalid_module.invalid_function",
- )
- assert rewrite.name == "mock_rewrite"
-
- with pytest.raises(Exception) as exc_info:
- rewrite.load_function(function_name)
-
- message = exc_info.value.args[0]
-
- assert message == (
- "Unable to load rewrite function 'invalid_module.invalid_function'"
- " for 'mock_rewrite'."
- )
-
-
def test_rewrite_configuration_train_params(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 34b9543..371c79f 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -14,13 +14,15 @@ import pytest
import tensorflow as tf
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
+from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite
+from mlia.nn.rewrite.core.rewrite import QATRewrite
from mlia.nn.rewrite.core.train import augment_fn_twins
from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
+from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
from tests.utils.rewrite import MockTrainingParameters
@@ -54,12 +56,11 @@ def check_train(
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
output_file = Path(tmp_dir, "out.tflite")
- mock_rewrite = DynamicallyLoadedRewrite(
+ mock_rewrite = FullyConnectedRewrite(
name="replace",
- function_name=(
- "tests.test_nn_rewrite_core_train.replace_fully_connected_with_conv"
- ),
+ rewrite_fn=fc_rewrite,
)
+ is_qat = isinstance(mock_rewrite, QATRewrite)
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
@@ -68,6 +69,7 @@ def check_train(
rewrite=mock_rewrite,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
+ is_qat=is_qat,
train_params=train_params,
)