diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/conftest.py | 42 | ||||
-rw-r--r-- | tests/test_cli_commands.py | 33 | ||||
-rw-r--r-- | tests/test_cli_helpers.py | 2 | ||||
-rw-r--r-- | tests/test_common_optimization.py | 91 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 221 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 44 | ||||
-rw-r--r-- | tests/test_nn_rewrite_library_helper_functions.py | 95 | ||||
-rw-r--r-- | tests/test_nn_select.py | 57 | ||||
-rw-r--r-- | tests/test_target_cortex_a_advisor.py | 21 | ||||
-rw-r--r-- | tests/test_target_tosa_advisor.py | 21 |
10 files changed, 498 insertions, 129 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 3d0b832..a64f320 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,9 +126,28 @@ def get_test_keras_model() -> keras.Model: return model +def get_test_keras_model_no_activation() -> keras.Model: + """Return test Keras model.""" + model = keras.Sequential( + [ + keras.Input(shape=(28, 28, 1), batch_size=1, name="input"), + keras.layers.Reshape((28, 28, 1)), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv1"), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), name="conv2"), + keras.layers.MaxPool2D(2, 2), + keras.layers.Flatten(), + keras.layers.Dense(10, name="output"), + ] + ) + + model.compile(optimizer="sgd", loss="mean_squared_error") + return model + + TEST_MODEL_KERAS_FILE = "test_model.h5" TEST_MODEL_TFLITE_FP32_FILE = "test_model_fp32.tflite" TEST_MODEL_TFLITE_INT8_FILE = "test_model_int8.tflite" +TEST_MODEL_TFLITE_NO_ACT_FILE = "test_model_no_act.tflite" TEST_MODEL_TFLITE_VELA_FILE = "test_model_vela.tflite" TEST_MODEL_TF_SAVED_MODEL_FILE = "tf_model_test_model" TEST_MODEL_INVALID_FILE = "invalid.tflite" @@ -153,6 +172,13 @@ def fixture_test_models_path( keras_model, quantized=False, output_path=tmp_path / TEST_MODEL_TFLITE_FP32_FILE ) + # Un-quantized TensorFlow Lite model with ReLU activation (fp32) + convert_to_tflite( + get_test_keras_model_no_activation(), + quantized=False, + output_path=tmp_path / TEST_MODEL_TFLITE_NO_ACT_FILE, + ) + # Quantized TensorFlow Lite model (int8) tflite_model_path = tmp_path / TEST_MODEL_TFLITE_INT8_FILE convert_to_tflite(keras_model, quantized=True, output_path=tflite_model_path) @@ -195,6 +221,12 @@ def fixture_test_tflite_vela_model(test_models_path: Path) -> Path: return test_models_path / TEST_MODEL_TFLITE_VELA_FILE +@pytest.fixture(scope="session", name="test_tflite_no_act_model") +def fixture_test_tflite_no_act_model(test_models_path: Path) -> Path: + """Return test TensorFlow Lite model with relu activation.""" + return test_models_path / TEST_MODEL_TFLITE_NO_ACT_FILE + + @pytest.fixture(scope="session", name="test_tf_model") def fixture_test_tf_model(test_models_path: Path) -> Path: """Return test TensorFlow Lite model.""" @@ -257,17 +289,17 @@ def fixture_test_tfrecord_fp32( yield from create_tfrecord(tmp_path_factory, random_data) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="function", autouse=True) def set_training_steps( request: _pytest.fixtures.SubRequest, ) -> Generator[None, None, None]: """Speed up tests by using MockTrainingParameters.""" - if "set_training_steps" == request.fixturename: - yield - else: + if "skip_set_training_steps" not in request.keywords: with pytest.MonkeyPatch.context() as monkeypatch: monkeypatch.setattr( "mlia.nn.select._get_rewrite_params", - MagicMock(return_value=[MockTrainingParameters(), None, None]), + MagicMock(return_value=MockTrainingParameters()), ) yield + else: + yield diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 93a05bd..6c54a73 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -90,7 +90,7 @@ def test_performance_unknown_target( None, None, True, - "fully-connected-sparsity24", + "fully-connected-sparsity", "sequential/flatten/Reshape", "StatefulPartitionedCall:0", does_not_raise(), @@ -139,8 +139,9 @@ def test_performance_unknown_target( Exception, match=re.escape( "Invalid rewrite target: 'random'. " - "Supported rewrites: ['fully-connected'," - " 'fully-connected-clustering', 'fully-connected-sparsity24']" + "Supported rewrites: ['conv2d-clustering', 'conv2d-sparsity', " + "'fully-connected', 'fully-connected-clustering', " + "'fully-connected-sparsity']" ), ), ], @@ -195,6 +196,32 @@ def test_performance_unknown_target( "StatefulPartitionedCall:0", does_not_raise(), ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "conv2d-sparsity", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "conv2d-clustering", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], ], ) def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py index 0e9f0d6..69e6ffe 100644 --- a/tests/test_cli_helpers.py +++ b/tests/test_cli_helpers.py @@ -156,7 +156,7 @@ def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None: def test_copy_optimization_file_to_output_dir(tmp_path: Path) -> None: - """Test if the optimization profile file is copied into the output directory.""" + """Test if the profile file is copied into the output directory.""" test_target_profile_name = "optimization" test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml") diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 58ea8af..bdcf034 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -60,7 +60,10 @@ def test_optimizing_data_collector( config_parameters={ "common_optimizations": { "optimizations": optimizations, - "training_parameters": training_parameters, + "rewrite_parameters": { + "train_params": training_parameters, + "rewrite_specific_params": None, + }, } } ) @@ -97,12 +100,15 @@ def test_optimizing_data_collector( collector.set_context(context) collector.collect_data() assert optimize_model_mock.call_args.args[0] == opt_settings[0] - assert optimize_model_mock.call_args.args[1] == training_parameters + assert optimize_model_mock.call_args.args[1] == { + "train_params": training_parameters, + "rewrite_specific_params": None, + } assert fake_optimizer.invocation_count == 1 @pytest.mark.parametrize( - "extra_args, error_to_raise", + "extra_args, error_to_raise, rewrite_parameter_type", [ ( { @@ -115,14 +121,39 @@ def test_optimizing_data_collector( ], }, does_not_raises(), + type(None), ), ( { + "optimization_targets": [ + { + "optimization_type": "rewrite", + "optimization_target": "fully-connected-clustering", + } + ], "optimization_profile": load_profile( - "src/mlia/resources/optimization_profiles/optimization.toml" - ) + "src/mlia/resources/optimization_profiles/" + "optimization-fully-connected-clustering.toml" + ), + }, + does_not_raises(), + dict, + ), + ( + { + "optimization_targets": [ + { + "optimization_type": "rewrite", + "optimization_target": "fully-connected-sparsity", + } + ], + "optimization_profile": load_profile( + "src/mlia/resources/optimization_profiles/" + "optimization-fully-connected-pruning.toml" + ), }, does_not_raises(), + dict, ), ( { @@ -135,16 +166,22 @@ def test_optimizing_data_collector( pytest.raises( TypeError, match="Optimization targets value has wrong format." ), + type(None), ), ( {"optimization_profile": [32, 1e-3, True, 48000, "cosine", 1, 0]}, pytest.raises( - TypeError, match="Training Parameter values has wrong format." + TypeError, match="Optimization Parameter values has wrong format." ), + type(None), ), ], ) -def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -> None: +def test_add_common_optimization_params( + extra_args: dict, + error_to_raise: Any, + rewrite_parameter_type: dict | None, +) -> None: """Test to check that optimization_targets and optimization_profiles are correctly parsed.""" advisor_parameters: dict = {} @@ -161,14 +198,40 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) - ] if not extra_args.get("optimization_profile"): - assert ( - advisor_parameters["common_optimizations"]["training_parameters"] - is None - ) + assert advisor_parameters["common_optimizations"]["rewrite_parameters"] == { + "train_params": None, + "rewrite_specific_params": None, + } else: - assert ( - advisor_parameters["common_optimizations"]["training_parameters"] - == extra_args["optimization_profile"]["training"] + if not extra_args["optimization_profile"].get("rewrite"): + assert isinstance( + advisor_parameters["common_optimizations"]["rewrite_parameters"][ + "train_params" + ], + type(None), + ) + elif not extra_args["optimization_profile"]["rewrite"].get( + "training_parameters" + ): + assert isinstance( + advisor_parameters["common_optimizations"]["rewrite_parameters"][ + "train_params" + ], + type(None), + ) + else: + assert isinstance( + advisor_parameters["common_optimizations"]["rewrite_parameters"][ + "train_params" + ], + dict, + ) + + assert isinstance( + advisor_parameters["common_optimizations"]["rewrite_parameters"][ + "rewrite_specific_params" + ], + rewrite_parameter_type, # type: ignore ) diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index e502842..c874017 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -3,18 +3,23 @@ """Tests for module mlia.nn.rewrite.core.rewrite.""" from __future__ import annotations +import re from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import cast from unittest.mock import MagicMock +import numpy as np import pytest import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module ClusterWeights, ) +from tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper import ( # pylint: disable=no-name-in-module + PruneLowMagnitude, +) from mlia.nn.rewrite.core.rewrite import ClusteringRewrite from mlia.nn.rewrite.core.rewrite import GenericRewrite @@ -23,40 +28,16 @@ from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer -from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite +from mlia.nn.rewrite.core.rewrite import SparsityRewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.train import train_in_dir -from mlia.nn.rewrite.library.fc_clustering_layer import ( - get_keras_model_clus as fc_clustering_rewrite, -) +from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters -class TestRewrite(Rewrite): - """Test rewrite class.""" - - def quantize(self, model: keras.Model) -> keras.Model: - """Return a quantized model if required.""" - return tfmot.quantization.keras.quantize_model(model) - - def preserved_quantize(self, model: keras.Model) -> keras.Model: - """Not needed.""" - return model - - def training_callbacks(self) -> list: - """Return default rewrite callbacks.""" - return [] - - def post_process(self, model: keras.Model) -> keras.Model: - """Return default post-processing rewrite options.""" - return model - - def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: - """Not needed here.""" - return True - - def mock_rewrite_function(*_: Any) -> Any: """Mock function to test autoloading of rewrite functions.""" @@ -67,10 +48,10 @@ def test_rewrite() -> None: def bad_rewrite_func() -> Any: raise NotImplementedError() - rewrite = TestRewrite( + rewrite = GenericRewrite( "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) ) - with pytest.raises(RuntimeError): + with pytest.raises(KeyError): rewrite((1, 2), (1, 2)) @@ -79,7 +60,9 @@ def test_rewrite() -> None: [ ("fully-connected", 0, GenericRewrite), ("fully-connected-clustering", 0, ClusteringRewrite), - ("fully-connected-sparsity24", 1, Sparsity24Rewrite), + ("fully-connected-sparsity", 1, SparsityRewrite), + ("conv2d-clustering", 0, ClusteringRewrite), + ("conv2d-sparsity", 1, SparsityRewrite), ], ) def test_rewrite_selection( @@ -96,8 +79,10 @@ def test_rewrite_selection( "rewrite_name, expected_error", [ ("fully-connected", does_not_raise()), - ("fully-connected-sparsity24", does_not_raise()), + ("fully-connected-sparsity", does_not_raise()), ("fully-connected-clustering", does_not_raise()), + ("conv2d-clustering", does_not_raise()), + ("conv2d-sparsity", does_not_raise()), ("random", does_not_raise()), ], ) @@ -105,7 +90,8 @@ def test_rewrite_configuration( test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any ) -> None: """Test get_rewrite function only supports rewrite type fully-connected, - fully-connected-clustering and fully-connected-sparsity24.""" + fully-connected-clustering, fully-connected-sparsity, conv2d-clustering + and conv2d-sparsity.""" with expected_error: config_obj = RewriteConfiguration( rewrite_name, @@ -120,29 +106,114 @@ def test_rewrite_configuration( assert isinstance(rewriter_obj, RewritingOptimizer) +def train_rewrite_model( + input_shape: tuple | np.ndarray, + output_shape: int | np.ndarray, + rewrite_model: keras.Model, +) -> keras.Model: + """Helper function to quickly train a rewrite model.""" + rewrite_model.compile( + optimizer=keras.optimizers.Nadam(learning_rate=0.01), + loss=keras.losses.MeanSquaredError(), + metrics=["mae"], + ) + if isinstance(output_shape, int): + output_shape_list = [output_shape] + else: + output_shape_list = output_shape.tolist() + rewrite_model.fit( + x=np.random.rand(16, *input_shape), + y=np.random.rand(16, *output_shape_list), + batch_size=1, + epochs=1, + callbacks=[tfmot.sparsity.keras.UpdatePruningStep()], + ) + return rewrite_model + + def test_rewrite_fully_connected_clustering() -> None: - """Check that model has the set number of clusters""" + """Check that fully connected clustering rewrite model + has the set number of clusters.""" + + rewrite = ClusteringRewrite( + "fully-connected-clustering", + fc_clustering_rewrite, + ) + + model = rewrite(input_shape=(28, 28), output_shape=10, num_clusters=2) + model = rewrite.post_process(model) + assert rewrite.check_optimization( + model, + num_clusters=2, + ) + - rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) - model = rewrite(input_shape=(28, 28), output_shape=10) +def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> None: + """ + Check that sparse fully connected + rewrite model is correctly sparse. + """ + + rewrite = SparsityRewrite("fully-connected-sparsity", fc_sparsity_rewrite) + input_shape = (28, 28) + output_shape = 10 + model = rewrite( + input_shape=tuple(input_shape), + output_shape=output_shape, + sparsity_m=2, + sparsity_n=4, + ) + model = rewrite.post_process(model) + assert not rewrite.check_optimization(model) + log_records = caplog.records + warning_messages = [x.message for x in log_records if x.levelno == 30] + assert ( + re.search( + r"\nWARNING: Could not find \(2, 4\) sparsity, in " + r"layer dense_?\d? for weight dense_?\d?\/kernel:0 \n", + warning_messages[0], + ) + is not None + ) + model = rewrite( + input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4 + ) + train_rewrite_model( + input_shape=input_shape, output_shape=output_shape, rewrite_model=model + ) model = rewrite.post_process(model) - assert rewrite.check_optimization(model, number_of_clusters=32) + assert rewrite.check_optimization(model) -def test_rewrite_fully_connected_clustering_error_handling() -> None: - """Check that model has the set number of clusters - and that when quantized the number of clusters - remain.""" +def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None: + """Check that sparse conv2d rewrite model is correctly sparse.""" - rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) - model = rewrite(input_shape=(28, 28), output_shape=10) - with pytest.raises( - ValueError, - match=( - r"Expected check_preserved_quantize to have argument number_of_clusters" - ), - ): - rewrite.check_optimization(model, bad_arg_name=25) + rewrite = SparsityRewrite("conv2d-sparsity", conv2d_sparsity_rewrite) + input_shape = np.array([28, 28, 3]) + output_shape = np.array([14, 14, 3]) + model = rewrite( + input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4 + ) + model = rewrite.post_process(model) + assert not rewrite.check_optimization(model) + log_records = caplog.records + warning_messages = [x.message for x in log_records if x.levelno == 30] + assert ( + re.search( + r"\nWARNING: Could not find \(2, 4\) sparsity, in " + r"layer conv2d_?\d? for weight conv2d_?\d?\/kernel:0 \n", + warning_messages[0], + ) + is not None + ) + model = rewrite( + input_shape=input_shape, output_shape=output_shape, sparsity_m=2, sparsity_n=4 + ) + train_rewrite_model( + input_shape=input_shape, output_shape=output_shape, rewrite_model=model + ) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model) @pytest.mark.parametrize( @@ -151,6 +222,20 @@ def test_rewrite_fully_connected_clustering_error_handling() -> None: ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False], ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False], ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True], + ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], False], + ["fully-connected-sparsity", [PruneLowMagnitude, PruneLowMagnitude], True], + ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], False], + ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], True], + [ + "conv2d-sparsity", + [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude], + False, + ], + [ + "conv2d-sparsity", + [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude], + True, + ], ], ) def test_rewriting_optimizer( # pylint: disable=too-many-locals @@ -162,24 +247,32 @@ def test_rewriting_optimizer( # pylint: disable=too-many-locals expected_layers: list[object], quant: bool, ) -> None: - """Test fc_layer rewrite process with rewrite type fully-connected.""" + """Test the rewrite process with all rewrite types.""" tfrecord = test_tfrecord if quant else test_tfrecord_fp32 tflite_model = test_tflite_model if quant else test_tflite_model_fp32 + rewrite_function = RewritingOptimizer.registry.items[rewrite_type] config_obj = RewriteConfiguration( rewrite_type, - ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], + ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"] + if "fully-connected" in rewrite_type + else [ + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + ], tfrecord, train_params=MockTrainingParameters(), ) - test_obj = RewritingOptimizer(tflite_model, config_obj) - rewrite_function = RewritingOptimizer.registry.items[ - test_obj.optimizer_configuration.optimization_target - ] # Input, output shape does not matter, just need the test the layers are as expected - rewrite_model = rewrite_function(input_shape=(28, 28, 1), output_shape=12) + rewrite_model = ( + rewrite_function(input_shape=(28, 28, 1), output_shape=12) + if "fully-connected" in rewrite_type + else rewrite_function( + input_shape=np.array([28, 28, 3]), output_shape=np.array([14, 14, 3]) + ) + ) for idx, layer in enumerate(rewrite_model.layers): assert isinstance(layer, expected_layers[idx]) # type: ignore @@ -197,8 +290,14 @@ def test_register_rewrite_function() -> None: """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() - rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1)) - rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2)) + rewrite1 = GenericRewrite( + "r1", + cast(RewriteCallable, lambda: 1), + ) + rewrite2 = GenericRewrite( + "r2", + cast(RewriteCallable, lambda: 2), + ) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) @@ -208,9 +307,11 @@ def test_register_rewrite_function() -> None: def test_builtin_rewrite_names() -> None: """Test if all builtin rewrites are properly registered and returned.""" assert RewritingOptimizer.builtin_rewrite_names() == [ + "conv2d-clustering", + "conv2d-sparsity", "fully-connected", "fully-connected-clustering", - "fully-connected-sparsity24", + "fully-connected-sparsity", ] diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 94c99ff..03b230f 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -16,11 +16,12 @@ from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.nn.rewrite.core.train import augment_fn_twins from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS +from mlia.nn.rewrite.core.train import detect_activation_from_rewrite_function from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters -from tests.test_nn_rewrite_core_rewrite import TestRewrite +from tests.test_nn_rewrite_core_rewrite import GenericRewrite from tests.utils.rewrite import MockTrainingParameters @@ -54,7 +55,7 @@ def check_train( """Test the train() function.""" with TemporaryDirectory() as tmp_dir: output_file = Path(tmp_dir, "out.tflite") - mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv) + mock_rewrite = GenericRewrite("replace", replace_fully_connected_with_conv) result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, @@ -65,6 +66,7 @@ def check_train( input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], train_params=train_params, + rewrite_specific_params={}, ) assert len(result[0][0]) == 2 @@ -249,3 +251,41 @@ def test_train_checkpoint( use_unmodified_model=False, quantized=True, ) + + +def test_detect_activation_from_rewrite_function_no_activation( + caplog: pytest.LogCaptureFixture, test_tflite_no_act_model: Path +) -> None: + """ + Test function detect_activation_from_rewrite_function() + with a model with no activation functions. + """ + caplog.set_level(level=20) + activation = detect_activation_from_rewrite_function( + test_tflite_no_act_model.as_posix() + ) + log_records = caplog.get_records(when="call") + logging_messages = [x.message for x in log_records if x.levelno == 20] + assert activation == "relu" + assert ( + "No activation function specified, setting activation function to ReLU" + in logging_messages + ) + + +def test_detect_activation_from_rewrite_function_relu_activation( + caplog: pytest.LogCaptureFixture, test_tflite_model: Path +) -> None: + """ + Test function detect_activation_from_rewrite_function() + with a model with ReLU activation functions. + """ + caplog.set_level(level=20) + activation = detect_activation_from_rewrite_function(test_tflite_model.as_posix()) + log_records = caplog.get_records(when="call") + logging_messages = [x.message for x in log_records if x.levelno == 20] + assert activation == "relu" + assert ( + "No activation function specified, setting activation function " + "to most common activation detected in rewrite graph: relu" in logging_messages + ) diff --git a/tests/test_nn_rewrite_library_helper_functions.py b/tests/test_nn_rewrite_library_helper_functions.py new file mode 100644 index 0000000..9e880aa --- /dev/null +++ b/tests/test_nn_rewrite_library_helper_functions.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.library.helper_functions.""" +from __future__ import annotations + +from contextlib import ExitStack as does_not_raise +from typing import Any + +import numpy as np +import pytest +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import ACTIVATION_FUNCTION_LIST +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters +from mlia.nn.rewrite.library.helper_functions import get_activation_function + + +def compute_conv_output( + input_data: np.ndarray, input_shape: np.ndarray, conv_parameters: dict[str, Any] +) -> np.ndarray: + """Compute the output of a conv layer for testing.""" + test_model = keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv_parameters), + ] + ) + output = test_model(input_data) + return np.array(output.shape[1:]) + + +@pytest.mark.parametrize( + "input_shape, output_shape", + [ + (np.array([32, 32, 3]), np.array([16, 16, 3])), + (np.array([32, 32, 3]), np.array([8, 8, 3])), + (np.array([32, 32, 3]), np.array([8, 16, 3])), + (np.array([25, 10, 3]), np.array([13, 5, 3])), + (np.array([25, 10, 3]), np.array([7, 5, 3])), + (np.array([25, 10, 3]), np.array([6, 4, 3])), + (np.array([25, 10, 3]), np.array([5, 5, 3])), + ], +) +def test_compute_conv2d_parameters( + input_shape: np.ndarray, output_shape: np.ndarray +) -> None: + """Test to check compute_conv2d_parameters works as expected.""" + conv_parameters = compute_conv2d_parameters( + input_shape=input_shape, output_shape=output_shape + ) + computed_output_shape = compute_conv_output( + np.random.rand(1, *input_shape), input_shape, conv_parameters + ) + assert np.equal(computed_output_shape, output_shape).all() + + +@pytest.mark.parametrize( + "activation, expected_function_type, expected_extra_args, expected_error", + [ + ("relu", keras.layers.ReLU, {}, does_not_raise()), + ("relu6", keras.layers.ReLU, {"max_value": 6}, does_not_raise()), + ("none", keras.layers.Identity, {}, does_not_raise()), + ( + "wrong_key", + keras.layers.Identity, + {}, + pytest.raises( + KeyError, + match=( + "Expected activation function to be " + rf"in \{ACTIVATION_FUNCTION_LIST}\, found wrong_key" + ), + ), + ), + ], +) +def test_get_activation_functions( + activation: str, + expected_function_type: type[keras.layers.Layer], + expected_extra_args: dict, + expected_error: Any, +) -> None: + """ + Check the get_activation_function returns + the expected layer and extra arguments. + """ + with expected_error: + activation_function, activation_function_extra_args = get_activation_function( + activation + ) + assert isinstance( + activation_function(**activation_function_extra_args), + expected_function_type, + ) + assert expected_extra_args == activation_function_extra_args diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py index 4095076..08752bd 100644 --- a/tests/test_nn_select.py +++ b/tests/test_nn_select.py @@ -4,12 +4,12 @@ from __future__ import annotations from contextlib import ExitStack as does_not_raise -from dataclasses import asdict from pathlib import Path from typing import Any from typing import cast import pytest +import tensorflow_model_optimization as tfmot from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 from mlia.core.errors import ConfigurationError @@ -176,23 +176,50 @@ def test_get_optimizer( model = test_tflite_model else: model = keras.models.load_model(str(test_keras_model)) - optimizer = get_optimizer(model, config) + optimizer = get_optimizer( + model, config, {"train_params": None, "rewrite_specific_params": None} + ) assert isinstance(optimizer, expected_type) assert optimizer.optimization_config() == expected_config +# pylint: disable=line-too-long @pytest.mark.parametrize( - "rewrite_parameters", - [None, {"batch_size": 64, "learning_rate": 0.003}], + "rewrite_parameters, optimization_target", + [ + [ + {"train_params": None, "rewrite_specific_params": None}, + "fully-connected-clustering", + ], + [ + { + "train_params": None, + "rewrite_specific_params": { + "num_clusters": 5, + "cluster_centroids_init": tfmot.clustering.keras.CentroidInitialization( + "CentroidInitialization.LINEAR" + ), + }, + }, + "fully-connected-clustering", + ], + [ + {"train_params": None, "rewrite_specific_params": None}, + "fully-connected", + ], + ], ) +# pylint: enable=line-too-long @pytest.mark.skip_set_training_steps def test_get_optimizer_training_parameters( - rewrite_parameters: dict | None, test_tflite_model: Path + rewrite_parameters: dict, + optimization_target: str, + test_tflite_model: Path, ) -> None: """Test function get_optimzer with various combinations of parameters.""" config = OptimizationSettings( optimization_type="rewrite", - optimization_target="fully-connected", # type: ignore + optimization_target=optimization_target, # type: ignore layers_to_optimize=None, dataset=None, ) @@ -200,18 +227,20 @@ def test_get_optimizer_training_parameters( RewritingOptimizer, get_optimizer(test_tflite_model, config, rewrite_parameters), ) + assert len(list(rewrite_parameters.items())) == 2 + if rewrite_parameters.get("rewrite_specific_params"): + assert isinstance( + rewrite_parameters["rewrite_specific_params"], + type(optimizer.optimizer_configuration.rewrite_specific_params), + ) + assert ( + optimizer.optimizer_configuration.rewrite_specific_params + == rewrite_parameters["rewrite_specific_params"] + ) assert isinstance( optimizer.optimizer_configuration.train_params, TrainingParameters ) - if not rewrite_parameters: - assert asdict(TrainingParameters()) == asdict( - optimizer.optimizer_configuration.train_params - ) - else: - assert asdict(TrainingParameters()) | rewrite_parameters == asdict( - optimizer.optimizer_configuration.train_params - ) @pytest.mark.parametrize( diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py index 7bb57c3..2f06f54 100644 --- a/tests/test_target_cortex_a_advisor.py +++ b/tests/test_target_cortex_a_advisor.py @@ -8,6 +8,7 @@ import pytest from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.core.workflow import DefaultWorkflowExecutor +from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS from mlia.target.cortex_a.advisor import configure_and_get_cortexa_advisor from mlia.target.cortex_a.advisor import CortexAInferenceAdvisor @@ -33,21 +34,11 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: "target_profile": "cortex-a", }, "common_optimizations": { - "optimizations": [ - [ - { - "layers_to_optimize": None, - "optimization_target": 0.5, - "optimization_type": "pruning", - }, - { - "layers_to_optimize": None, - "optimization_target": 32, - "optimization_type": "clustering", - }, - ] - ], - "training_parameters": None, + "optimizations": [_DEFAULT_OPTIMIZATION_TARGETS], + "rewrite_parameters": { + "train_params": None, + "rewrite_specific_params": None, + }, }, } diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py index 020acc5..d0b42b9 100644 --- a/tests/test_target_tosa_advisor.py +++ b/tests/test_target_tosa_advisor.py @@ -9,6 +9,7 @@ import pytest from mlia.core.common import AdviceCategory from mlia.core.context import ExecutionContext from mlia.core.workflow import DefaultWorkflowExecutor +from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS from mlia.target.tosa.advisor import configure_and_get_tosa_advisor from mlia.target.tosa.advisor import TOSAInferenceAdvisor @@ -33,21 +34,11 @@ def test_configure_and_get_tosa_advisor( assert ctx.event_handlers is not None assert ctx.config_parameters == { "common_optimizations": { - "optimizations": [ - [ - { - "layers_to_optimize": None, - "optimization_target": 0.5, - "optimization_type": "pruning", - }, - { - "layers_to_optimize": None, - "optimization_target": 32, - "optimization_type": "clustering", - }, - ] - ], - "training_parameters": None, + "optimizations": [_DEFAULT_OPTIMIZATION_TARGETS], + "rewrite_parameters": { + "train_params": None, + "rewrite_specific_params": None, + }, }, "tosa_inference_advisor": { "model": str(test_tflite_model), |