diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-08 14:08:06 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-04-16 13:11:31 +0100 |
commit | 32405c279d2f98c2d40bdbbb7f7306ff12c86cd6 (patch) | |
tree | 42781ca219b822a9ec9f212a9ee516f65b184a27 /tests | |
parent | 427e02696f1ede596ef6dce82787a37e122efa78 (diff) | |
download | mlia-32405c279d2f98c2d40bdbbb7f7306ff12c86cd6.tar.gz |
feat: Implement the clustering rewrite for int8
Implements a clustering rewrite for fully connected layers for int8 models
Resolves: MLIA-1080
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: If48efb22764187a382e5b84bbb5c3b75a6e71b75
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 91 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 12 |
2 files changed, 81 insertions, 22 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index ef4df6a..e502842 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -10,13 +10,14 @@ from typing import cast from unittest.mock import MagicMock import pytest +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module ClusterWeights, ) from mlia.nn.rewrite.core.rewrite import ClusteringRewrite -from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite +from mlia.nn.rewrite.core.rewrite import GenericRewrite from mlia.nn.rewrite.core.rewrite import Rewrite from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration @@ -25,17 +26,48 @@ from mlia.nn.rewrite.core.rewrite import RewritingOptimizer from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.library.fc_clustering_layer import ( + get_keras_model_clus as fc_clustering_rewrite, +) from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters +class TestRewrite(Rewrite): + """Test rewrite class.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Not needed.""" + return model + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + return model + + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Not needed here.""" + return True + + +def mock_rewrite_function(*_: Any) -> Any: + """Mock function to test autoloading of rewrite functions.""" + + def test_rewrite() -> None: """Test a derived Rewrite class.""" def bad_rewrite_func() -> Any: raise NotImplementedError() - rewrite = Sparsity24Rewrite( + rewrite = TestRewrite( "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) ) with pytest.raises(RuntimeError): @@ -45,7 +77,7 @@ def test_rewrite() -> None: @pytest.mark.parametrize( "rewrite_name, callbacks_length, instance", [ - ("fully-connected", 0, Rewrite), + ("fully-connected", 0, GenericRewrite), ("fully-connected-clustering", 0, ClusteringRewrite), ("fully-connected-sparsity24", 1, Sparsity24Rewrite), ], @@ -72,8 +104,8 @@ def test_rewrite_selection( def test_rewrite_configuration( test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any ) -> None: - """Test get_rewrite function only supports rewrite types - fully-connected, fully-connected-clustering and fully-connected-sparsity24.""" + """Test get_rewrite function only supports rewrite type fully-connected, + fully-connected-clustering and fully-connected-sparsity24.""" with expected_error: config_obj = RewriteConfiguration( rewrite_name, @@ -88,28 +120,61 @@ def test_rewrite_configuration( assert isinstance(rewriter_obj, RewritingOptimizer) +def test_rewrite_fully_connected_clustering() -> None: + """Check that model has the set number of clusters""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model, number_of_clusters=32) + + +def test_rewrite_fully_connected_clustering_error_handling() -> None: + """Check that model has the set number of clusters + and that when quantized the number of clusters + remain.""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + with pytest.raises( + ValueError, + match=( + r"Expected check_preserved_quantize to have argument number_of_clusters" + ), + ): + rewrite.check_optimization(model, bad_arg_name=25) + + @pytest.mark.parametrize( - "rewrite_type, expected_layers", + "rewrite_type, expected_layers, quant", [ - ["fully-connected", [keras.layers.Reshape, keras.layers.Dense]], - ["fully-connected-clustering", [ClusterWeights, ClusterWeights]], + ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True], ], ) -def test_rewriting_optimizer( +def test_rewriting_optimizer( # pylint: disable=too-many-locals test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, + test_tflite_model: Path, + test_tfrecord: Path, rewrite_type: str, expected_layers: list[object], + quant: bool, ) -> None: """Test fc_layer rewrite process with rewrite type fully-connected.""" + + tfrecord = test_tfrecord if quant else test_tfrecord_fp32 + tflite_model = test_tflite_model if quant else test_tflite_model_fp32 + config_obj = RewriteConfiguration( rewrite_type, ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], - test_tfrecord_fp32, + tfrecord, train_params=MockTrainingParameters(), ) - test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) + test_obj = RewritingOptimizer(tflite_model, config_obj) rewrite_function = RewritingOptimizer.registry.items[ test_obj.optimizer_configuration.optimization_target ] @@ -132,8 +197,8 @@ def test_register_rewrite_function() -> None: """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() - rewrite1 = FullyConnectedRewrite("r1", cast(RewriteCallable, lambda: 1)) - rewrite2 = Sparsity24Rewrite("r2", cast(RewriteCallable, lambda: 2)) + rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2)) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 371c79f..94c99ff 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -14,15 +14,13 @@ import pytest import tensorflow as tf from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 -from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite -from mlia.nn.rewrite.core.rewrite import QATRewrite from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite +from tests.test_nn_rewrite_core_rewrite import TestRewrite from tests.utils.rewrite import MockTrainingParameters @@ -56,20 +54,16 @@ def check_train( """Test the train() function.""" with TemporaryDirectory() as tmp_dir: output_file = Path(tmp_dir, "out.tflite") - mock_rewrite = FullyConnectedRewrite( - name="replace", - rewrite_fn=fc_rewrite, - ) - is_qat = isinstance(mock_rewrite, QATRewrite) + mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv) result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, output_model=str(output_file), input_tfrec=str(tfrecord), rewrite=mock_rewrite, + is_qat=False, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], - is_qat=is_qat, train_params=train_params, ) |