aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-02-27 12:46:52 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-04-16 12:57:10 +0100
commit427e02696f1ede596ef6dce82787a37e122efa78 (patch)
tree1fae7f7c8cb10af4f7c5119b73371b709c2c7caa
parent2973b6d52914023f9b8797aec8309957457d4189 (diff)
downloadmlia-427e02696f1ede596ef6dce82787a37e122efa78.tar.gz
feat: Implement the clustering rewrite for fp32
Implements a clustering rewrite for fully connected layers for fp32 models Resolves: MLIA-1079 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I4c12f0bf911219b4066f0760976e424ebe900a0b
-rw-r--r--src/mlia/nn/rewrite/core/rewrite.py24
-rw-r--r--src/mlia/nn/rewrite/library/fc_clustering_layer.py19
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py49
3 files changed, 72 insertions, 20 deletions
diff --git a/src/mlia/nn/rewrite/core/rewrite.py b/src/mlia/nn/rewrite/core/rewrite.py
index a8084e8..6a3695a 100644
--- a/src/mlia/nn/rewrite/core/rewrite.py
+++ b/src/mlia/nn/rewrite/core/rewrite.py
@@ -23,6 +23,9 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
+from mlia.nn.rewrite.library.fc_clustering_layer import (
+ get_keras_model_clus as fc_clustering_rewrite,
+)
from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
from mlia.nn.rewrite.library.fc_sparsity24_layer import (
get_keras_model as fc_rewrite_sparsity24,
@@ -63,6 +66,24 @@ class Rewrite(ABC):
"""Return default post-processing rewrite options."""
+class ClusteringRewrite(Rewrite):
+ """Graph clustering rewrite logic to be used by RewritingOptimizer."""
+
+ strip_pruning_wrapper = staticmethod(tfmot.clustering.keras.strip_clustering)
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model."""
+ return model
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return the clustering stripped model."""
+ return self.strip_pruning_wrapper(model)
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+
class QATRewrite(Rewrite):
"""Logic for rewrites requiring quantization-aware training."""
@@ -157,7 +178,7 @@ class RewritingOptimizer(Optimizer):
[
FullyConnectedRewrite("fully-connected", fc_rewrite),
Sparsity24Rewrite("fully-connected-sparsity24", fc_rewrite_sparsity24),
- FullyConnectedRewrite("fully-connected-clustering", fc_rewrite),
+ ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite),
]
)
@@ -191,7 +212,6 @@ class RewritingOptimizer(Optimizer):
raise ConfigurationError(
"Input and output tensor names need to be set for rewrite."
)
-
orig_vs_repl_stats, total_stats = train(
source_model=tflite_model,
unmodified_model=tflite_model if use_unmodified_model else None,
diff --git a/src/mlia/nn/rewrite/library/fc_clustering_layer.py b/src/mlia/nn/rewrite/library/fc_clustering_layer.py
index 07c07ac..72931c0 100644
--- a/src/mlia/nn/rewrite/library/fc_clustering_layer.py
+++ b/src/mlia/nn/rewrite/library/fc_clustering_layer.py
@@ -3,11 +3,24 @@
"""Example rewrite with one fully connected clustered layer."""
from typing import Any
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-from mlia.nn.rewrite.library.fc_layer import get_keras_model
-
def get_keras_model_clus(input_shape: Any, output_shape: Any) -> keras.Model:
"""Generate TensorFlow Lite model for clustering rewrite."""
- return get_keras_model(input_shape, output_shape)
+ clustering_params = {
+ "number_of_clusters": 32,
+ "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization.LINEAR,
+ }
+ model = tfmot.clustering.keras.cluster_weights(
+ to_cluster=keras.Sequential(
+ [
+ keras.layers.InputLayer(input_shape=input_shape),
+ keras.layers.Flatten(),
+ keras.layers.Dense(units=output_shape),
+ ]
+ ),
+ **clustering_params
+ )
+ return model
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index 96e4160..ef4df6a 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -10,8 +10,14 @@ from typing import cast
from unittest.mock import MagicMock
import pytest
+from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
+from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module
+ ClusterWeights,
+)
+from mlia.nn.rewrite.core.rewrite import ClusteringRewrite
from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite
+from mlia.nn.rewrite.core.rewrite import Rewrite
from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteRegistry
@@ -37,25 +43,21 @@ def test_rewrite() -> None:
@pytest.mark.parametrize(
- "rewrite_name, rewrite_class",
+ "rewrite_name, callbacks_length, instance",
[
- ("fully-connected", FullyConnectedRewrite),
- ("fully-connected-sparsity24", Sparsity24Rewrite),
+ ("fully-connected", 0, Rewrite),
+ ("fully-connected-clustering", 0, ClusteringRewrite),
+ ("fully-connected-sparsity24", 1, Sparsity24Rewrite),
],
)
def test_rewrite_selection(
- rewrite_name: str,
- rewrite_class: Any,
+ rewrite_name: str, callbacks_length: int, instance: Rewrite
) -> None:
- """Check that the correct rewrite class is instantiated through the registry"""
- config_obj = RewriteConfiguration(
- rewrite_name,
- ["sample_node_start", "sample_node_end"],
- )
-
- rewrite = RewritingOptimizer.registry.items[config_obj.optimization_target]
+ """Test that the correct rewrite class is instantiated."""
+ rewrite = RewritingOptimizer.registry.items[rewrite_name]
assert rewrite.name == rewrite_name
- assert isinstance(rewrite, rewrite_class)
+ assert isinstance(rewrite, instance) # type: ignore
+ assert len(rewrite.training_callbacks()) == callbacks_length
@pytest.mark.parametrize(
@@ -71,7 +73,7 @@ def test_rewrite_configuration(
test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
) -> None:
"""Test get_rewrite function only supports rewrite types
- fully-connected and fully-connected-sparsity24."""
+ fully-connected, fully-connected-clustering and fully-connected-sparsity24."""
with expected_error:
config_obj = RewriteConfiguration(
rewrite_name,
@@ -86,19 +88,36 @@ def test_rewrite_configuration(
assert isinstance(rewriter_obj, RewritingOptimizer)
+@pytest.mark.parametrize(
+ "rewrite_type, expected_layers",
+ [
+ ["fully-connected", [keras.layers.Reshape, keras.layers.Dense]],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights]],
+ ],
+)
def test_rewriting_optimizer(
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
+ rewrite_type: str,
+ expected_layers: list[object],
) -> None:
"""Test fc_layer rewrite process with rewrite type fully-connected."""
config_obj = RewriteConfiguration(
- "fully-connected",
+ rewrite_type,
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
test_tfrecord_fp32,
train_params=MockTrainingParameters(),
)
test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
+ rewrite_function = RewritingOptimizer.registry.items[
+ test_obj.optimizer_configuration.optimization_target
+ ]
+ # Input, output shape does not matter, just need the test the layers are as expected
+ rewrite_model = rewrite_function(input_shape=(28, 28, 1), output_shape=12)
+ for idx, layer in enumerate(rewrite_model.layers):
+ assert isinstance(layer, expected_layers[idx]) # type: ignore
+
test_obj.apply_optimization()
trained_model = test_obj.get_model()