aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-03-08 14:08:06 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-04-16 13:11:31 +0100
commit32405c279d2f98c2d40bdbbb7f7306ff12c86cd6 (patch)
tree42781ca219b822a9ec9f212a9ee516f65b184a27 /tests
parent427e02696f1ede596ef6dce82787a37e122efa78 (diff)
downloadmlia-32405c279d2f98c2d40bdbbb7f7306ff12c86cd6.tar.gz
feat: Implement the clustering rewrite for int8
Implements a clustering rewrite for fully connected layers for int8 models Resolves: MLIA-1080 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: If48efb22764187a382e5b84bbb5c3b75a6e71b75
Diffstat (limited to 'tests')
-rw-r--r--tests/test_nn_rewrite_core_rewrite.py91
-rw-r--r--tests/test_nn_rewrite_core_train.py12
2 files changed, 81 insertions, 22 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py
index ef4df6a..e502842 100644
--- a/tests/test_nn_rewrite_core_rewrite.py
+++ b/tests/test_nn_rewrite_core_rewrite.py
@@ -10,13 +10,14 @@ from typing import cast
from unittest.mock import MagicMock
import pytest
+import tensorflow_model_optimization as tfmot
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module
ClusterWeights,
)
from mlia.nn.rewrite.core.rewrite import ClusteringRewrite
-from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite
+from mlia.nn.rewrite.core.rewrite import GenericRewrite
from mlia.nn.rewrite.core.rewrite import Rewrite
from mlia.nn.rewrite.core.rewrite import RewriteCallable
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
@@ -25,17 +26,48 @@ from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite
from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.rewrite.core.train import train_in_dir
+from mlia.nn.rewrite.library.fc_clustering_layer import (
+ get_keras_model_clus as fc_clustering_rewrite,
+)
from mlia.nn.tensorflow.config import TFLiteModel
from tests.utils.rewrite import MockTrainingParameters
+class TestRewrite(Rewrite):
+ """Test rewrite class."""
+
+ def quantize(self, model: keras.Model) -> keras.Model:
+ """Return a quantized model if required."""
+ return tfmot.quantization.keras.quantize_model(model)
+
+ def preserved_quantize(self, model: keras.Model) -> keras.Model:
+ """Not needed."""
+ return model
+
+ def training_callbacks(self) -> list:
+ """Return default rewrite callbacks."""
+ return []
+
+ def post_process(self, model: keras.Model) -> keras.Model:
+ """Return default post-processing rewrite options."""
+ return model
+
+ def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool:
+ """Not needed here."""
+ return True
+
+
+def mock_rewrite_function(*_: Any) -> Any:
+ """Mock function to test autoloading of rewrite functions."""
+
+
def test_rewrite() -> None:
"""Test a derived Rewrite class."""
def bad_rewrite_func() -> Any:
raise NotImplementedError()
- rewrite = Sparsity24Rewrite(
+ rewrite = TestRewrite(
"BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)
)
with pytest.raises(RuntimeError):
@@ -45,7 +77,7 @@ def test_rewrite() -> None:
@pytest.mark.parametrize(
"rewrite_name, callbacks_length, instance",
[
- ("fully-connected", 0, Rewrite),
+ ("fully-connected", 0, GenericRewrite),
("fully-connected-clustering", 0, ClusteringRewrite),
("fully-connected-sparsity24", 1, Sparsity24Rewrite),
],
@@ -72,8 +104,8 @@ def test_rewrite_selection(
def test_rewrite_configuration(
test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any
) -> None:
- """Test get_rewrite function only supports rewrite types
- fully-connected, fully-connected-clustering and fully-connected-sparsity24."""
+ """Test get_rewrite function only supports rewrite type fully-connected,
+ fully-connected-clustering and fully-connected-sparsity24."""
with expected_error:
config_obj = RewriteConfiguration(
rewrite_name,
@@ -88,28 +120,61 @@ def test_rewrite_configuration(
assert isinstance(rewriter_obj, RewritingOptimizer)
+def test_rewrite_fully_connected_clustering() -> None:
+ """Check that model has the set number of clusters"""
+
+ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
+ model = rewrite(input_shape=(28, 28), output_shape=10)
+ model = rewrite.post_process(model)
+ assert rewrite.check_optimization(model, number_of_clusters=32)
+
+
+def test_rewrite_fully_connected_clustering_error_handling() -> None:
+ """Check that model has the set number of clusters
+ and that when quantized the number of clusters
+ remain."""
+
+ rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite)
+ model = rewrite(input_shape=(28, 28), output_shape=10)
+ with pytest.raises(
+ ValueError,
+ match=(
+ r"Expected check_preserved_quantize to have argument number_of_clusters"
+ ),
+ ):
+ rewrite.check_optimization(model, bad_arg_name=25)
+
+
@pytest.mark.parametrize(
- "rewrite_type, expected_layers",
+ "rewrite_type, expected_layers, quant",
[
- ["fully-connected", [keras.layers.Reshape, keras.layers.Dense]],
- ["fully-connected-clustering", [ClusterWeights, ClusterWeights]],
+ ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False],
+ ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True],
],
)
-def test_rewriting_optimizer(
+def test_rewriting_optimizer( # pylint: disable=too-many-locals
test_tflite_model_fp32: Path,
test_tfrecord_fp32: Path,
+ test_tflite_model: Path,
+ test_tfrecord: Path,
rewrite_type: str,
expected_layers: list[object],
+ quant: bool,
) -> None:
"""Test fc_layer rewrite process with rewrite type fully-connected."""
+
+ tfrecord = test_tfrecord if quant else test_tfrecord_fp32
+ tflite_model = test_tflite_model if quant else test_tflite_model_fp32
+
config_obj = RewriteConfiguration(
rewrite_type,
["sequential/flatten/Reshape", "StatefulPartitionedCall:0"],
- test_tfrecord_fp32,
+ tfrecord,
train_params=MockTrainingParameters(),
)
- test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj)
+ test_obj = RewritingOptimizer(tflite_model, config_obj)
rewrite_function = RewritingOptimizer.registry.items[
test_obj.optimizer_configuration.optimization_target
]
@@ -132,8 +197,8 @@ def test_register_rewrite_function() -> None:
"""Test adding rewrite functions and verify they are reported via the registry."""
registry = RewriteRegistry()
- rewrite1 = FullyConnectedRewrite("r1", cast(RewriteCallable, lambda: 1))
- rewrite2 = Sparsity24Rewrite("r2", cast(RewriteCallable, lambda: 2))
+ rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1))
+ rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2))
registry.register_rewrite(rewrite1)
registry.register_rewrite(rewrite2)
diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py
index 371c79f..94c99ff 100644
--- a/tests/test_nn_rewrite_core_train.py
+++ b/tests/test_nn_rewrite_core_train.py
@@ -14,15 +14,13 @@ import pytest
import tensorflow as tf
from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107
-from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite
-from mlia.nn.rewrite.core.rewrite import QATRewrite
from mlia.nn.rewrite.core.train import augment_fn_twins
from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.rewrite.core.train import LearningRateSchedule
from mlia.nn.rewrite.core.train import mixup
from mlia.nn.rewrite.core.train import train
from mlia.nn.rewrite.core.train import TrainingParameters
-from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite
+from tests.test_nn_rewrite_core_rewrite import TestRewrite
from tests.utils.rewrite import MockTrainingParameters
@@ -56,20 +54,16 @@ def check_train(
"""Test the train() function."""
with TemporaryDirectory() as tmp_dir:
output_file = Path(tmp_dir, "out.tflite")
- mock_rewrite = FullyConnectedRewrite(
- name="replace",
- rewrite_fn=fc_rewrite,
- )
- is_qat = isinstance(mock_rewrite, QATRewrite)
+ mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv)
result = train(
source_model=str(tflite_model),
unmodified_model=str(tflite_model) if use_unmodified_model else None,
output_model=str(output_file),
input_tfrec=str(tfrecord),
rewrite=mock_rewrite,
+ is_qat=False,
input_tensors=["sequential/flatten/Reshape"],
output_tensors=["StatefulPartitionedCall:0"],
- is_qat=is_qat,
train_params=train_params,
)