From 1ebb335cba516bcf973b041efa6a9878d1022b93 Mon Sep 17 00:00:00 2001 From: Madeleine Dunn Date: Wed, 21 Feb 2024 17:10:07 +0000 Subject: feat: Implement int8 sparsity 2:4 rewrite - Implement pruning-preserving quantisation aware training - Rework the training logic to avoid duplication - Remove the DynamicallyLoadedRewrite class as it is now unused Resolves: MLIA-1003 Signed-off-by: Madeleine Dunn Change-Id: Ia7a4acf5f477a27963cffa88180cca085b32ffe4 --- tests/test_nn_rewrite_core_rewrite.py | 71 ++++++++++++----------------------- tests/test_nn_rewrite_core_train.py | 12 +++--- 2 files changed, 30 insertions(+), 53 deletions(-) (limited to 'tests') diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index e614cad..8ef5bd2 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -11,45 +11,51 @@ from unittest.mock import MagicMock import pytest -from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite -from mlia.nn.rewrite.core.rewrite import Rewrite +from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.train import train_in_dir from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters -def mock_rewrite_function(*_: Any) -> Any: - """Mock function to test autoloading of rewrite functions.""" - - def test_rewrite() -> None: - """Test the Rewrite class.""" + """Test a derived Rewrite class.""" def bad_rewrite_func() -> Any: raise NotImplementedError() - rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)) + rewrite = Sparsity24Rewrite( + "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) + ) with pytest.raises(RuntimeError): rewrite((1, 2), (1, 2)) @pytest.mark.parametrize( - "rewrite_name, callbacks_length", + "rewrite_name, rewrite_class", [ - ("fully-connected", 0), - ("fully-connected-sparsity24", 1), + ("fully-connected", FullyConnectedRewrite), + ("fully-connected-sparsity24", Sparsity24Rewrite), ], ) -def test_rewrite_selection(rewrite_name: str, callbacks_length: int) -> None: - """Test that the correct rewrite class is instantiated.""" - rewrite = RewritingOptimizer.registry.items[rewrite_name] +def test_rewrite_selection( + rewrite_name: str, + rewrite_class: Any, +) -> None: + """Check that the correct rewrite class is instantiated through the registry""" + config_obj = RewriteConfiguration( + rewrite_name, + ["sample_node_start", "sample_node_end"], + ) + + rewrite = RewritingOptimizer.registry.items[config_obj.optimization_target] assert rewrite.name == rewrite_name - assert len(rewrite.training_callbacks()) == callbacks_length + assert isinstance(rewrite, rewrite_class) @pytest.mark.parametrize( @@ -106,8 +112,8 @@ def test_register_rewrite_function() -> None: """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() - rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1)) - rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2)) + rewrite1 = FullyConnectedRewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = Sparsity24Rewrite("r2", cast(RewriteCallable, lambda: 2)) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) @@ -122,37 +128,6 @@ def test_builtin_rewrite_names() -> None: ] -def test_rewrite_function_autoload() -> None: - """Test rewrite function loading.""" - function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function" - rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name) - assert rewrite.name == "mock_rewrite" - - assert rewrite.function is not mock_rewrite_function - assert rewrite.load_function(function_name) is mock_rewrite_function - assert rewrite.function is mock_rewrite_function - - -def test_rewrite_function_autoload_fail() -> None: - """Test rewrite function loading failure.""" - function_name = "invalid_module.invalid_function" - rewrite = DynamicallyLoadedRewrite( - name="mock_rewrite", - function_name="invalid_module.invalid_function", - ) - assert rewrite.name == "mock_rewrite" - - with pytest.raises(Exception) as exc_info: - rewrite.load_function(function_name) - - message = exc_info.value.args[0] - - assert message == ( - "Unable to load rewrite function 'invalid_module.invalid_function'" - " for 'mock_rewrite'." - ) - - def test_rewrite_configuration_train_params( test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 34b9543..371c79f 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -14,13 +14,15 @@ import pytest import tensorflow as tf from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 -from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite +from mlia.nn.rewrite.core.rewrite import FullyConnectedRewrite +from mlia.nn.rewrite.core.rewrite import QATRewrite from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters +from mlia.nn.rewrite.library.fc_layer import get_keras_model as fc_rewrite from tests.utils.rewrite import MockTrainingParameters @@ -54,12 +56,11 @@ def check_train( """Test the train() function.""" with TemporaryDirectory() as tmp_dir: output_file = Path(tmp_dir, "out.tflite") - mock_rewrite = DynamicallyLoadedRewrite( + mock_rewrite = FullyConnectedRewrite( name="replace", - function_name=( - "tests.test_nn_rewrite_core_train.replace_fully_connected_with_conv" - ), + rewrite_fn=fc_rewrite, ) + is_qat = isinstance(mock_rewrite, QATRewrite) result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, @@ -68,6 +69,7 @@ def check_train( rewrite=mock_rewrite, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], + is_qat=is_qat, train_params=train_params, ) -- cgit v1.2.1