diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 91 |
1 files changed, 78 insertions, 13 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index ef4df6a..e502842 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -10,13 +10,14 @@ from typing import cast from unittest.mock import MagicMock import pytest +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module ClusterWeights, ) from mlia.nn.rewrite.core.rewrite import ClusteringRewrite -from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite +from mlia.nn.rewrite.core.rewrite import GenericRewrite from mlia.nn.rewrite.core.rewrite import Rewrite from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration @@ -25,17 +26,48 @@ from mlia.nn.rewrite.core.rewrite import RewritingOptimizer from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.library.fc_clustering_layer import ( + get_keras_model_clus as fc_clustering_rewrite, +) from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters +class TestRewrite(Rewrite): + """Test rewrite class.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Not needed.""" + return model + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + return model + + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Not needed here.""" + return True + + +def mock_rewrite_function(*_: Any) -> Any: + """Mock function to test autoloading of rewrite functions.""" + + def test_rewrite() -> None: """Test a derived Rewrite class.""" def bad_rewrite_func() -> Any: raise NotImplementedError() - rewrite = Sparsity24Rewrite( + rewrite = TestRewrite( "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) ) with pytest.raises(RuntimeError): @@ -45,7 +77,7 @@ def test_rewrite() -> None: @pytest.mark.parametrize( "rewrite_name, callbacks_length, instance", [ - ("fully-connected", 0, Rewrite), + ("fully-connected", 0, GenericRewrite), ("fully-connected-clustering", 0, ClusteringRewrite), ("fully-connected-sparsity24", 1, Sparsity24Rewrite), ], @@ -72,8 +104,8 @@ def test_rewrite_selection( def test_rewrite_configuration( test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any ) -> None: - """Test get_rewrite function only supports rewrite types - fully-connected, fully-connected-clustering and fully-connected-sparsity24.""" + """Test get_rewrite function only supports rewrite type fully-connected, + fully-connected-clustering and fully-connected-sparsity24.""" with expected_error: config_obj = RewriteConfiguration( rewrite_name, @@ -88,28 +120,61 @@ def test_rewrite_configuration( assert isinstance(rewriter_obj, RewritingOptimizer) +def test_rewrite_fully_connected_clustering() -> None: + """Check that model has the set number of clusters""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model, number_of_clusters=32) + + +def test_rewrite_fully_connected_clustering_error_handling() -> None: + """Check that model has the set number of clusters + and that when quantized the number of clusters + remain.""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + with pytest.raises( + ValueError, + match=( + r"Expected check_preserved_quantize to have argument number_of_clusters" + ), + ): + rewrite.check_optimization(model, bad_arg_name=25) + + @pytest.mark.parametrize( - "rewrite_type, expected_layers", + "rewrite_type, expected_layers, quant", [ - ["fully-connected", [keras.layers.Reshape, keras.layers.Dense]], - ["fully-connected-clustering", [ClusterWeights, ClusterWeights]], + ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True], ], ) -def test_rewriting_optimizer( +def test_rewriting_optimizer( # pylint: disable=too-many-locals test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, + test_tflite_model: Path, + test_tfrecord: Path, rewrite_type: str, expected_layers: list[object], + quant: bool, ) -> None: """Test fc_layer rewrite process with rewrite type fully-connected.""" + + tfrecord = test_tfrecord if quant else test_tfrecord_fp32 + tflite_model = test_tflite_model if quant else test_tflite_model_fp32 + config_obj = RewriteConfiguration( rewrite_type, ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], - test_tfrecord_fp32, + tfrecord, train_params=MockTrainingParameters(), ) - test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) + test_obj = RewritingOptimizer(tflite_model, config_obj) rewrite_function = RewritingOptimizer.registry.items[ test_obj.optimizer_configuration.optimization_target ] @@ -132,8 +197,8 @@ def test_register_rewrite_function() -> None: """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() - rewrite1 = FullyConnectedRewrite("r1", cast(RewriteCallable, lambda: 1)) - rewrite2 = Sparsity24Rewrite("r2", cast(RewriteCallable, lambda: 2)) + rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2)) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) |