aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMadeleine Dunn <madeleine.dunn@arm.com>2023-11-13 15:40:21 +0000
committerMadeleine Dunn <madeleine.dunn@arm.com>2024-04-03 16:33:39 +0100
commit17813ba5be09f0e11fc0748afa4ccf2da02881b6 (patch)
tree8ec5f3ce3501b86e9398cf5af6f7bd9876685512
parent2a2a910d6d7cc3e7555b0a3c1ba458a4065c41ae (diff)
downloadmlia-17813ba5be09f0e11fc0748afa4ccf2da02881b6.tar.gz
feat: Implement fp32 sparsity 2:4 rewrite
- Update the existing placeholder with code to prune the given model Resolves: MLIA-1002 Signed-off-by: Madeleine Dunn <madeleine.dunn@arm.com> Change-Id: I76b0e0bfe81be5e57d518cd7bb588eef76a11641
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py55
-rw-r--r--src/mlia/nn/rewrite/core/train.py19
-rw-r--r--src/mlia/nn/rewrite/library/fc_sparsity24_layer.py21
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py16
-rw-r--r--tests/test_nn_rewrite_core_train.py9
5 files changed, 94 insertions, 26 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index a4d47c4..2a7b432 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -12,6 +12,7 @@ from typing import Any
from typing import Callable
from typing import cast
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from mlia.core.errors import ConfigurationError
@@ -22,6 +23,10 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
+from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
+from mlia.nn.rewrite.library.fc_sparsity24_layer import (
+ get_keras_model as fc_rewrite_sparsity24,
+)
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.utils.registry import Registry
@@ -45,6 +50,43 @@ class Rewrite:
except Exception as ex:
raise RuntimeError(f"Rewrite '{self.name}' failed.") from ex
+ def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model:
+ """Return a quantized model if required."""
+ if model_is_quantized:
+ model = tfmot.quantization.keras.quantize_model(model)
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+ return model
+
+
+class PruningRewrite(Rewrite):
+ """Derived Rewrite class with pruning-specific logic."""
+
+ pruning_callback = tfmot.sparsity.keras.UpdatePruningStep
+
+ strip_pruning_wrapper = staticmethod(tfmot.sparsity.keras.strip_pruning)
+
+ def quantize(self, model: keras.Model, model_is_quantized: bool) -> keras.Model:
+ """Return a quantized model if required."""
+ if model_is_quantized:
+ # placeholder for PQAT
+ pass
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return pruning-specific rewrite callback."""
+ return [self.pruning_callback()]
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Pruning-specific post-processing rewrite options."""
+ return self.strip_pruning_wrapper(model)
+
@dataclass
class DynamicallyLoadedRewrite(Rewrite):
@@ -113,14 +155,8 @@ class RewritingOptimizer(Optimizer):
registry = RewriteRegistry(
[
- DynamicallyLoadedRewrite(
- "fully-connected",
- "mlia.nn.rewrite.library.fc_layer.get_keras_model",
- ),
- DynamicallyLoadedRewrite(
- "fully-connected-sparsity24",
- "mlia.nn.rewrite.library.fc_sparsity24_layer.get_keras_model24",
- ),
+ Rewrite("fully-connected", fc_rewrite),
+ PruningRewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
]
)
@@ -142,7 +178,6 @@ class RewritingOptimizer(Optimizer):
rewrite = RewritingOptimizer.registry.items[
self.optimizer_configuration.optimization_target
]
-
use_unmodified_model = True
tflite_model = self.model.model_path
tfrecord = str(self.optimizer_configuration.dataset)
@@ -161,7 +196,7 @@ class RewritingOptimizer(Optimizer):
unmodified_model=tflite_model if use_unmodified_model else None,
output_model=str(tmp_output),
input_tfrec=str(tfrecord),
- replace_fn=rewrite,
+ rewrite=rewrite,
input_tensors=[self.optimizer_configuration.layers_to_optimize[0]],
output_tensors=[self.optimizer_configuration.layers_to_optimize[1]],
train_params=self.optimizer_configuration.train_params,
diff --git a/src/mlia/nn/rewrite/core/train.py b/src/mlia/nn/rewrite/core/train.py
index e0b3c75..89de880 100644
--- a/src/mlia/nn/rewrite/core/train.py
+++ b/src/mlia/nn/rewrite/core/train.py
@@ -22,7 +22,6 @@ from typing import Literal
import numpy as np
import tensorflow as tf
-import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from numpy.random import Generator
@@ -78,7 +77,7 @@ def train(
unmodified_model: Any,
output_model: str,
input_tfrec: str,
- replace_fn: Callable,
+ rewrite: Callable,
input_tensors: list,
output_tensors: list,
train_params: TrainingParameters = TrainingParameters(),
@@ -118,7 +117,7 @@ def train(
train_dir=train_dir,
baseline_dir=unmodified_model_dir_path,
output_filename=Path(train_dir, "new.tflite"),
- replace_fn=replace_fn,
+ rewrite=rewrite,
train_params=train_params,
)
@@ -345,7 +344,7 @@ def train_in_dir(
train_dir: str,
baseline_dir: Any,
output_filename: Path,
- replace_fn: Callable,
+ rewrite: Callable,
train_params: TrainingParameters = TrainingParameters(),
) -> list[str]:
"""Train a replacement for replace.tflite using the input.tfrec \
@@ -381,13 +380,12 @@ def train_in_dir(
input_shape = teacher.shape_from_name[input_name][1:]
output_shape = teacher.shape_from_name[output_name][1:]
-
- model = replace_fn(input_shape, output_shape)
+ model = rewrite(input_shape, output_shape)
optimizer = keras.optimizers.Nadam(learning_rate=train_params.learning_rate)
loss_fn = keras.losses.MeanSquaredError()
- if model_is_quantized:
- model = tfmot.quantization.keras.quantize_model(model)
+
+ model = rewrite.quantize(model, model_is_quantized) # type: ignore[attr-defined]
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["mae"])
logger.info(model.summary())
@@ -428,6 +426,8 @@ def train_in_dir(
elif train_params.learning_rate_schedule == "constant":
callbacks = []
+ callbacks.extend(rewrite.training_callbacks()) # type: ignore[attr-defined]
+
output_filenames = []
checkpoints = (train_params.checkpoint_at if train_params.checkpoint_at else []) + [
train_params.steps
@@ -463,6 +463,9 @@ def train_in_dir(
with log_action(
f"{steps_so_far}/{train_params.steps}: Saved as {checkpoint_filename}"
):
+ if steps_so_far == train_params.steps:
+ model = rewrite.post_process(model) # type: ignore[attr-defined]
+
save_as_tflite(
model,
checkpoint_filename,
diff --git a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py
index 1b17522..531b34a 100644
--- a/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_sparsity24_layer.py
@@ -1,14 +1,23 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
-"""PLACEHOLDER for example rewrite with one fully connected 2:4 sparsity layer."""
+"""Example rewrite with one fully connected 2:4 sparsity layer."""
from typing import Any
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-from .fc_layer import get_keras_model
-
-def get_keras_model24(input_shape: Any, output_shape: Any) -> keras.Model:
+def get_keras_model(input_shape: Any, output_shape: Any) -> keras.Model:
"""Generate TensorFlow Lite model for rewrite."""
- model = get_keras_model(input_shape, output_shape)
+ model = tfmot.sparsity.keras.prune_low_magnitude(
+ to_prune=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Reshape([-1]),
+ keras.layers.Dense(output_shape),
+ ]
+ ),
+ sparsity_m_by_n=(2, 4),
+ )
+
return model
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index 1d0100a..e614cad 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -39,6 +39,20 @@ def test_rewrite() -> None:
@pytest.mark.parametrize(
+ "rewrite_name, callbacks_length",
+ [
+ ("fully-connected", 0),
+ ("fully-connected-sparsity24", 1),
+ ],
+)
+def test_rewrite_selection(rewrite_name: str, callbacks_length: int) -> None:
+ """Test that the correct rewrite class is instantiated."""
+ rewrite = RewritingOptimizer.registry.items[rewrite_name]
+ assert rewrite.name == rewrite_name
+ assert len(rewrite.training_callbacks()) == callbacks_length
+
+
+@pytest.mark.parametrize(
"rewrite_name, expected_error",
[
("fully-connected", does_not_raise()),
@@ -89,7 +103,7 @@ def test_rewriting_optimizer(
def test_register_rewrite_function() -> None:
- """Test adding rewrite functions and verify the are reported via the registry."""
+ """Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1))
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 624c5ed..34b9543 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -14,6 +14,7 @@ import pytest
import tensorflow as tf
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite
from mlia.nn.rewrite.core.train import augment_fn_twins
from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
@@ -53,12 +54,18 @@ def check_train(
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
output_file = Path(tmp_dir, "out.tflite")
+ mock_rewrite = DynamicallyLoadedRewrite(
+ name="replace",
+ function_name=(
+ "tests.test_nn_rewrite_core_train.replace_fully_connected_with_conv"
+ ),
+ )
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
output_model=str(output_file),
input_tfrec=str(tfrecord),
- replace_fn=replace_fully_connected_with_conv,
+ rewrite=mock_rewrite,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
train_params=train_params,