diff options
Diffstat (limited to 'tests/test_nn_rewrite_core_rewrite.py')
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 166 |
1 files changed, 123 insertions, 43 deletions
diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index b32fafd..e502842 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -10,45 +10,102 @@ from typing import cast from unittest.mock import MagicMock import pytest +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module + ClusterWeights, +) -from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite +from mlia.nn.rewrite.core.rewrite import ClusteringRewrite +from mlia.nn.rewrite.core.rewrite import GenericRewrite from mlia.nn.rewrite.core.rewrite import Rewrite from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.library.fc_clustering_layer import ( + get_keras_model_clus as fc_clustering_rewrite, +) from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters +class TestRewrite(Rewrite): + """Test rewrite class.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Not needed.""" + return model + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + return model + + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Not needed here.""" + return True + + def mock_rewrite_function(*_: Any) -> Any: """Mock function to test autoloading of rewrite functions.""" def test_rewrite() -> None: - """Test the Rewrite class.""" + """Test a derived Rewrite class.""" def bad_rewrite_func() -> Any: raise NotImplementedError() - rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)) + rewrite = TestRewrite( + "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) + ) with pytest.raises(RuntimeError): rewrite((1, 2), (1, 2)) @pytest.mark.parametrize( + "rewrite_name, callbacks_length, instance", + [ + ("fully-connected", 0, GenericRewrite), + ("fully-connected-clustering", 0, ClusteringRewrite), + ("fully-connected-sparsity24", 1, Sparsity24Rewrite), + ], +) +def test_rewrite_selection( + rewrite_name: str, callbacks_length: int, instance: Rewrite +) -> None: + """Test that the correct rewrite class is instantiated.""" + rewrite = RewritingOptimizer.registry.items[rewrite_name] + assert rewrite.name == rewrite_name + assert isinstance(rewrite, instance) # type: ignore + assert len(rewrite.training_callbacks()) == callbacks_length + + +@pytest.mark.parametrize( "rewrite_name, expected_error", [ ("fully-connected", does_not_raise()), + ("fully-connected-sparsity24", does_not_raise()), + ("fully-connected-clustering", does_not_raise()), ("random", does_not_raise()), ], ) def test_rewrite_configuration( test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any ) -> None: - """Test get_rewrite function only supports rewrite type fully-connected.""" + """Test get_rewrite function only supports rewrite type fully-connected, + fully-connected-clustering and fully-connected-sparsity24.""" with expected_error: config_obj = RewriteConfiguration( rewrite_name, @@ -63,19 +120,69 @@ def test_rewrite_configuration( assert isinstance(rewriter_obj, RewritingOptimizer) -def test_rewriting_optimizer( +def test_rewrite_fully_connected_clustering() -> None: + """Check that model has the set number of clusters""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model, number_of_clusters=32) + + +def test_rewrite_fully_connected_clustering_error_handling() -> None: + """Check that model has the set number of clusters + and that when quantized the number of clusters + remain.""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + with pytest.raises( + ValueError, + match=( + r"Expected check_preserved_quantize to have argument number_of_clusters" + ), + ): + rewrite.check_optimization(model, bad_arg_name=25) + + +@pytest.mark.parametrize( + "rewrite_type, expected_layers, quant", + [ + ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True], + ], +) +def test_rewriting_optimizer( # pylint: disable=too-many-locals test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, + test_tflite_model: Path, + test_tfrecord: Path, + rewrite_type: str, + expected_layers: list[object], + quant: bool, ) -> None: """Test fc_layer rewrite process with rewrite type fully-connected.""" + + tfrecord = test_tfrecord if quant else test_tfrecord_fp32 + tflite_model = test_tflite_model if quant else test_tflite_model_fp32 + config_obj = RewriteConfiguration( - "fully-connected", + rewrite_type, ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], - test_tfrecord_fp32, + tfrecord, train_params=MockTrainingParameters(), ) - test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) + test_obj = RewritingOptimizer(tflite_model, config_obj) + rewrite_function = RewritingOptimizer.registry.items[ + test_obj.optimizer_configuration.optimization_target + ] + # Input, output shape does not matter, just need the test the layers are as expected + rewrite_model = rewrite_function(input_shape=(28, 28, 1), output_shape=12) + for idx, layer in enumerate(rewrite_model.layers): + assert isinstance(layer, expected_layers[idx]) # type: ignore + test_obj.apply_optimization() trained_model = test_obj.get_model() @@ -87,11 +194,11 @@ def test_rewriting_optimizer( def test_register_rewrite_function() -> None: - """Test adding rewrite functions and verify the are reported via the registry.""" + """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() - rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1)) - rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2)) + rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2)) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) @@ -100,38 +207,11 @@ def test_register_rewrite_function() -> None: def test_builtin_rewrite_names() -> None: """Test if all builtin rewrites are properly registered and returned.""" - assert RewritingOptimizer.builtin_rewrite_names() == ["fully-connected"] - - -def test_rewrite_function_autoload() -> None: - """Test rewrite function loading.""" - function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function" - rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name) - assert rewrite.name == "mock_rewrite" - - assert rewrite.function is not mock_rewrite_function - assert rewrite.load_function(function_name) is mock_rewrite_function - assert rewrite.function is mock_rewrite_function - - -def test_rewrite_function_autoload_fail() -> None: - """Test rewrite function loading failure.""" - function_name = "invalid_module.invalid_function" - rewrite = DynamicallyLoadedRewrite( - name="mock_rewrite", - function_name="invalid_module.invalid_function", - ) - assert rewrite.name == "mock_rewrite" - - with pytest.raises(Exception) as exc_info: - rewrite.load_function(function_name) - - message = exc_info.value.args[0] - - assert message == ( - "Unable to load rewrite function 'invalid_module.invalid_function'" - " for 'mock_rewrite'." - ) + assert RewritingOptimizer.builtin_rewrite_names() == [ + "fully-connected", + "fully-connected-clustering", + "fully-connected-sparsity24", + ] def test_rewrite_configuration_train_params( |