# SPDX-FileCopyrightText: Copyright 2023-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module mlia.nn.rewrite.core.rewrite.""" from __future__ import annotations from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import cast from unittest.mock import MagicMock import pytest import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module ClusterWeights, ) from mlia.nn.rewrite.core.rewrite import ClusteringRewrite from mlia.nn.rewrite.core.rewrite import GenericRewrite from mlia.nn.rewrite.core.rewrite import Rewrite from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.train import train_in_dir from mlia.nn.rewrite.library.fc_clustering_layer import ( get_keras_model_clus as fc_clustering_rewrite, ) from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters class TestRewrite(Rewrite): """Test rewrite class.""" def quantize(self, model: keras.Model) -> keras.Model: """Return a quantized model if required.""" return tfmot.quantization.keras.quantize_model(model) def preserved_quantize(self, model: keras.Model) -> keras.Model: """Not needed.""" return model def training_callbacks(self) -> list: """Return default rewrite callbacks.""" return [] def post_process(self, model: keras.Model) -> keras.Model: """Return default post-processing rewrite options.""" return model def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: """Not needed here.""" return True def mock_rewrite_function(*_: Any) -> Any: """Mock function to test autoloading of rewrite functions.""" def test_rewrite() -> None: """Test a derived Rewrite class.""" def bad_rewrite_func() -> Any: raise NotImplementedError() rewrite = TestRewrite( "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) ) with pytest.raises(RuntimeError): rewrite((1, 2), (1, 2)) @pytest.mark.parametrize( "rewrite_name, callbacks_length, instance", [ ("fully-connected", 0, GenericRewrite), ("fully-connected-clustering", 0, ClusteringRewrite), ("fully-connected-sparsity24", 1, Sparsity24Rewrite), ], ) def test_rewrite_selection( rewrite_name: str, callbacks_length: int, instance: Rewrite ) -> None: """Test that the correct rewrite class is instantiated.""" rewrite = RewritingOptimizer.registry.items[rewrite_name] assert rewrite.name == rewrite_name assert isinstance(rewrite, instance) # type: ignore assert len(rewrite.training_callbacks()) == callbacks_length @pytest.mark.parametrize( "rewrite_name, expected_error", [ ("fully-connected", does_not_raise()), ("fully-connected-sparsity24", does_not_raise()), ("fully-connected-clustering", does_not_raise()), ("random", does_not_raise()), ], ) def test_rewrite_configuration( test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any ) -> None: """Test get_rewrite function only supports rewrite type fully-connected, fully-connected-clustering and fully-connected-sparsity24.""" with expected_error: config_obj = RewriteConfiguration( rewrite_name, ["sample_node_start", "sample_node_end"], None, ) assert config_obj.optimization_target in str(config_obj) rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) assert rewriter_obj.optimizer_configuration.optimization_target == rewrite_name assert isinstance(rewriter_obj, RewritingOptimizer) def test_rewrite_fully_connected_clustering() -> None: """Check that model has the set number of clusters""" rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) model = rewrite(input_shape=(28, 28), output_shape=10) model = rewrite.post_process(model) assert rewrite.check_optimization(model, number_of_clusters=32) def test_rewrite_fully_connected_clustering_error_handling() -> None: """Check that model has the set number of clusters and that when quantized the number of clusters remain.""" rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) model = rewrite(input_shape=(28, 28), output_shape=10) with pytest.raises( ValueError, match=( r"Expected check_preserved_quantize to have argument number_of_clusters" ), ): rewrite.check_optimization(model, bad_arg_name=25) @pytest.mark.parametrize( "rewrite_type, expected_layers, quant", [ ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False], ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False], ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True], ], ) def test_rewriting_optimizer( # pylint: disable=too-many-locals test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, test_tflite_model: Path, test_tfrecord: Path, rewrite_type: str, expected_layers: list[object], quant: bool, ) -> None: """Test fc_layer rewrite process with rewrite type fully-connected.""" tfrecord = test_tfrecord if quant else test_tfrecord_fp32 tflite_model = test_tflite_model if quant else test_tflite_model_fp32 config_obj = RewriteConfiguration( rewrite_type, ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], tfrecord, train_params=MockTrainingParameters(), ) test_obj = RewritingOptimizer(tflite_model, config_obj) rewrite_function = RewritingOptimizer.registry.items[ test_obj.optimizer_configuration.optimization_target ] # Input, output shape does not matter, just need the test the layers are as expected rewrite_model = rewrite_function(input_shape=(28, 28, 1), output_shape=12) for idx, layer in enumerate(rewrite_model.layers): assert isinstance(layer, expected_layers[idx]) # type: ignore test_obj.apply_optimization() trained_model = test_obj.get_model() assert isinstance(trained_model, TFLiteModel) cfg = test_obj.optimization_config() assert isinstance(cfg, str) assert cfg def test_register_rewrite_function() -> None: """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1)) rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2)) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) assert registry.names() == ["r1", "r2"] def test_builtin_rewrite_names() -> None: """Test if all builtin rewrites are properly registered and returned.""" assert RewritingOptimizer.builtin_rewrite_names() == [ "fully-connected", "fully-connected-clustering", "fully-connected-sparsity24", ] def test_rewrite_configuration_train_params( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: """Test if we pass training parameters to the rewrite configuration function they are passed to train_in_dir.""" train_params = TrainingParameters( batch_size=64, steps=24000, learning_rate=1e-5, show_progress=True ) config_obj = RewriteConfiguration( "fully-connected", ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], test_tfrecord_fp32, train_params=train_params, ) rewriter_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) train_mock = MagicMock(side_effect=train_in_dir) monkeypatch.setattr("mlia.nn.rewrite.core.train.train_in_dir", train_mock) rewriter_obj.apply_optimization() train_mock.assert_called_once() assert train_mock.call_args.kwargs["train_params"] == train_params