diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/conftest.py | 10 | ||||
-rw-r--r-- | tests/test_cli_commands.py | 56 | ||||
-rw-r--r-- | tests/test_common_optimization.py | 76 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_rewrite.py | 323 | ||||
-rw-r--r-- | tests/test_nn_rewrite_core_train.py | 26 | ||||
-rw-r--r-- | tests/test_nn_rewrite_library_helper_functions.py | 51 | ||||
-rw-r--r-- | tests/test_nn_select.py | 12 | ||||
-rw-r--r-- | tests/test_target_cortex_a_advisor.py | 2 | ||||
-rw-r--r-- | tests/test_target_tosa_advisor.py | 2 |
9 files changed, 486 insertions, 72 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 3d0b832..981bf3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -257,17 +257,15 @@ def fixture_test_tfrecord_fp32( yield from create_tfrecord(tmp_path_factory, random_data) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="function", autouse=True) def set_training_steps( request: _pytest.fixtures.SubRequest, ) -> Generator[None, None, None]: """Speed up tests by using MockTrainingParameters.""" - if "set_training_steps" == request.fixturename: - yield - else: + if "skip_set_training_steps" not in request.keywords: with pytest.MonkeyPatch.context() as monkeypatch: monkeypatch.setattr( "mlia.nn.select._get_rewrite_params", - MagicMock(return_value=[MockTrainingParameters(), None, None]), + MagicMock(return_value=MockTrainingParameters()), ) - yield + yield diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index 9cda27c..ce9e144 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -84,6 +84,19 @@ def test_performance_unknown_target( ], [ "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "fully-connected-sparsity24", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", + does_not_raise(), + ], + [ + "ethos-u55-256", True, False, None, @@ -126,7 +139,9 @@ def test_performance_unknown_target( Exception, match=re.escape( "Invalid rewrite target: 'random'. " - "Supported rewrites: ['fully-connected']" + "Supported rewrites: ['conv2d-clustering', 'conv2d-sparsity24', " + "'fully-connected', 'fully-connected-clustering', " + "'fully-connected-sparsity24']" ), ), ], @@ -168,6 +183,45 @@ def test_performance_unknown_target( ), ), ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "fully-connected-clustering", + "sequential/flatten/Reshape", + "StatefulPartitionedCall:0", + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "conv2d-sparsity24", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + None, + None, + None, + True, + "conv2d-clustering", + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + does_not_raise(), + ], ], ) def test_opt_valid_optimization_target( # pylint: disable=too-many-locals,too-many-arguments diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 05a5b55..58ea8af 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the common optimization module.""" +from __future__ import annotations + from contextlib import ExitStack as does_not_raises from pathlib import Path from typing import Any @@ -15,6 +17,7 @@ from mlia.nn.tensorflow.config import TFLiteModel from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS from mlia.target.common.optimization import add_common_optimization_params from mlia.target.common.optimization import OptimizingDataCollector +from mlia.target.common.optimization import parse_augmentations from mlia.target.config import load_profile from mlia.target.config import TargetProfile @@ -57,7 +60,7 @@ def test_optimizing_data_collector( config_parameters={ "common_optimizations": { "optimizations": optimizations, - "training_parameters": [training_parameters], + "training_parameters": training_parameters, } } ) @@ -94,7 +97,7 @@ def test_optimizing_data_collector( collector.set_context(context) collector.collect_data() assert optimize_model_mock.call_args.args[0] == opt_settings[0] - assert optimize_model_mock.call_args.args[1] == [training_parameters] + assert optimize_model_mock.call_args.args[1] == training_parameters assert fake_optimizer.invocation_count == 1 @@ -158,10 +161,67 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) - ] if not extra_args.get("optimization_profile"): - assert advisor_parameters["common_optimizations"][ - "training_parameters" - ] == [None] + assert ( + advisor_parameters["common_optimizations"]["training_parameters"] + is None + ) else: - assert advisor_parameters["common_optimizations"][ - "training_parameters" - ] == list(extra_args["optimization_profile"].values()) + assert ( + advisor_parameters["common_optimizations"]["training_parameters"] + == extra_args["optimization_profile"]["training"] + ) + + +@pytest.mark.parametrize( + "augmentations, expected_output", + [ + ( + {"gaussian_strength": 1.0, "mixup_strength": 1.0}, + (1.0, 1.0), + ), + ( + {"gaussian_strength": 1.0}, + (None, 1.0), + ), + ( + {"Wrong param": 1.0, "mixup_strength": 1.0}, + (1.0, None), + ), + ( + {"Wrong param1": 1.0, "Wrong param2": 1.0}, + (None, None), + ), + ( + "gaussian", + (None, 1.0), + ), + ( + "mix_gaussian_large", + (2.0, 1.0), + ), + ( + "not in presets", + (None, None), + ), + ( + {"gaussian_strength": 1.0, "mixup_strength": 1.0, "mix2": 1.0}, + (1.0, 1.0), + ), + ( + {"gaussian_strength": "not a float", "mixup_strength": 1.0}, + (1.0, None), + ), + ( + None, + (None, None), + ), + ], +) +def test_parse_augmentations( + augmentations: dict | str | None, expected_output: tuple +) -> None: + """Check that augmentation parameters in optimization_profiles are + correctly parsed.""" + + augmentation_output = parse_augmentations(augmentations) + assert augmentation_output == expected_output diff --git a/tests/test_nn_rewrite_core_rewrite.py b/tests/test_nn_rewrite_core_rewrite.py index b32fafd..97b0b96 100644 --- a/tests/test_nn_rewrite_core_rewrite.py +++ b/tests/test_nn_rewrite_core_rewrite.py @@ -3,52 +3,120 @@ """Tests for module mlia.nn.rewrite.core.rewrite.""" from __future__ import annotations +import re from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from typing import cast from unittest.mock import MagicMock +import numpy as np import pytest +import tensorflow_model_optimization as tfmot +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 +from tensorflow_model_optimization.python.core.clustering.keras.cluster_wrapper import ( # pylint: disable=no-name-in-module + ClusterWeights, +) +from tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper import ( # pylint: disable=no-name-in-module + PruneLowMagnitude, +) -from mlia.nn.rewrite.core.rewrite import DynamicallyLoadedRewrite +from mlia.nn.rewrite.core.rewrite import ClusteringRewrite +from mlia.nn.rewrite.core.rewrite import GenericRewrite from mlia.nn.rewrite.core.rewrite import Rewrite from mlia.nn.rewrite.core.rewrite import RewriteCallable from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewriteRegistry from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import Sparsity24Rewrite from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.rewrite.core.train import train_in_dir +from mlia.nn.rewrite.library.clustering import conv2d_clustering_rewrite +from mlia.nn.rewrite.library.clustering import fc_clustering_rewrite +from mlia.nn.rewrite.library.sparsity import conv2d_sparsity_rewrite +from mlia.nn.rewrite.library.sparsity import fc_sparsity_rewrite from mlia.nn.tensorflow.config import TFLiteModel from tests.utils.rewrite import MockTrainingParameters +class TestRewrite(Rewrite): + """Test rewrite class.""" + + def quantize(self, model: keras.Model) -> keras.Model: + """Return a quantized model if required.""" + return tfmot.quantization.keras.quantize_model(model) + + def preserved_quantize(self, model: keras.Model) -> keras.Model: + """Not needed.""" + return model + + def training_callbacks(self) -> list: + """Return default rewrite callbacks.""" + return [] + + def post_process(self, model: keras.Model) -> keras.Model: + """Return default post-processing rewrite options.""" + return model + + def check_optimization(self, model: keras.Model, **kwargs: dict) -> bool: + """Not needed here.""" + return True + + def mock_rewrite_function(*_: Any) -> Any: """Mock function to test autoloading of rewrite functions.""" def test_rewrite() -> None: - """Test the Rewrite class.""" + """Test a derived Rewrite class.""" def bad_rewrite_func() -> Any: raise NotImplementedError() - rewrite = Rewrite("BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func)) + rewrite = TestRewrite( + "BAD_REWRITE", rewrite_fn=cast(RewriteCallable, bad_rewrite_func) + ) with pytest.raises(RuntimeError): rewrite((1, 2), (1, 2)) @pytest.mark.parametrize( + "rewrite_name, callbacks_length, instance", + [ + ("fully-connected", 0, GenericRewrite), + ("fully-connected-clustering", 0, ClusteringRewrite), + ("fully-connected-sparsity24", 1, Sparsity24Rewrite), + ("conv2d-clustering", 0, ClusteringRewrite), + ("conv2d-sparsity24", 1, Sparsity24Rewrite), + ], +) +def test_rewrite_selection( + rewrite_name: str, callbacks_length: int, instance: Rewrite +) -> None: + """Test that the correct rewrite class is instantiated.""" + rewrite = RewritingOptimizer.registry.items[rewrite_name] + assert rewrite.name == rewrite_name + assert isinstance(rewrite, instance) # type: ignore + assert len(rewrite.training_callbacks()) == callbacks_length + + +@pytest.mark.parametrize( "rewrite_name, expected_error", [ ("fully-connected", does_not_raise()), + ("fully-connected-sparsity24", does_not_raise()), + ("fully-connected-clustering", does_not_raise()), + ("conv2d-clustering", does_not_raise()), + ("conv2d-sparsity24", does_not_raise()), ("random", does_not_raise()), ], ) def test_rewrite_configuration( test_tflite_model_fp32: Path, rewrite_name: str, expected_error: Any ) -> None: - """Test get_rewrite function only supports rewrite type fully-connected.""" + """Test get_rewrite function only supports rewrite type fully-connected, + fully-connected-clustering, fully-connected-sparsity24, conv2d-clustering + and conv2d-sparsity24.""" with expected_error: config_obj = RewriteConfiguration( rewrite_name, @@ -63,19 +131,209 @@ def test_rewrite_configuration( assert isinstance(rewriter_obj, RewritingOptimizer) -def test_rewriting_optimizer( +def train_rewrite_model( + input_shape: tuple | np.ndarray, + output_shape: int | np.ndarray, + rewrite_model: keras.Model, +) -> keras.Model: + """Helper function to quickly train a rewrite model.""" + rewrite_model.compile( + optimizer=keras.optimizers.Nadam(learning_rate=0.01), + loss=keras.losses.MeanSquaredError(), + metrics=["mae"], + ) + if isinstance(output_shape, int): + output_shape_list = [output_shape] + else: + output_shape_list = output_shape.tolist() + rewrite_model.fit( + x=np.random.rand(16, *input_shape), + y=np.random.rand(16, *output_shape_list), + batch_size=1, + epochs=1, + callbacks=[tfmot.sparsity.keras.UpdatePruningStep()], + ) + return rewrite_model + + +def test_rewrite_fully_connected_clustering(caplog: pytest.LogCaptureFixture) -> None: + """Check that fully connected clustering rewrite model + has the set number of clusters.""" + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model, number_of_clusters=4) + rewrite.check_optimization(model, number_of_clusters=2) + log_records = caplog.records + warning_messages = [x.message for x in log_records if x.levelno == 30] + assert ( + re.search( + r"\nWARNING: Expected 2 cluster\(s\), found \d+ cluster\(s\) " + r"in layer dense_?\d? for weight kernel:0 \n", + warning_messages[0], + ) + is not None + ) + + +def test_rewrite_conv2d_clustering(caplog: pytest.LogCaptureFixture) -> None: + """Check that conv2d clustering rewrite model has the set number of clusters.""" + + rewrite = ClusteringRewrite("conv2d-clustering", conv2d_clustering_rewrite) + model = rewrite( + input_shape=np.array([28, 28, 3]), output_shape=np.array([14, 14, 3]) + ) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model, number_of_clusters=4) + rewrite.check_optimization(model, number_of_clusters=2) + log_records = caplog.records + warning_messages = [x.message for x in log_records if x.levelno == 30] + assert ( + re.search( + r"\nWARNING: Expected 2 cluster\(s\), found \d+ cluster\(s\) " + r"in layer conv2d_?\d? for weight kernel:0 \n", + warning_messages[0], + ) + is not None + ) + + +def test_rewrite_clustering_error_handling() -> None: + """ + Check that the clustering rewrite check_optimization + function returns the current error. + """ + + rewrite = ClusteringRewrite("fully-connected-clustering", fc_clustering_rewrite) + model = rewrite(input_shape=(28, 28), output_shape=10) + with pytest.raises( + ValueError, + match=(r"Expected check_optimization to have argument number_of_clusters"), + ): + rewrite.check_optimization(model, bad_arg_name=25) + + +def test_rewrite_fully_connected_sparsity(caplog: pytest.LogCaptureFixture) -> None: + """ + Check that sparse fully connected + rewrite model is correctly sparse. + """ + + rewrite = Sparsity24Rewrite("fully-connected-sparsity24", fc_sparsity_rewrite) + input_shape = (28, 28) + output_shape = 10 + model = rewrite(input_shape=tuple(input_shape), output_shape=output_shape) + model = rewrite.post_process(model) + assert not rewrite.check_optimization(model) + log_records = caplog.records + warning_messages = [x.message for x in log_records if x.levelno == 30] + assert ( + re.search( + r"\nWARNING: Could not find \(2,4\) sparsity, in " + r"layer dense_?\d? for weight dense_?\d?\/kernel:0 \n", + warning_messages[0], + ) + is not None + ) + model = rewrite(input_shape=input_shape, output_shape=output_shape) + train_rewrite_model( + input_shape=input_shape, output_shape=output_shape, rewrite_model=model + ) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model) + + +def test_rewrite_conv2d_sparsity(caplog: pytest.LogCaptureFixture) -> None: + """Check that sparse conv2d rewrite model is correctly sparse.""" + + rewrite = Sparsity24Rewrite("conv2d-sparsity24", conv2d_sparsity_rewrite) + input_shape = np.array([28, 28, 3]) + output_shape = np.array([14, 14, 3]) + model = rewrite(input_shape=input_shape, output_shape=output_shape) + model = rewrite.post_process(model) + assert not rewrite.check_optimization(model) + log_records = caplog.records + warning_messages = [x.message for x in log_records if x.levelno == 30] + assert ( + re.search( + r"\nWARNING: Could not find \(2,4\) sparsity, in " + r"layer conv2d_?\d? for weight conv2d_?\d?\/kernel:0 \n", + warning_messages[0], + ) + is not None + ) + model = rewrite(input_shape=input_shape, output_shape=output_shape) + train_rewrite_model( + input_shape=input_shape, output_shape=output_shape, rewrite_model=model + ) + model = rewrite.post_process(model) + assert rewrite.check_optimization(model) + + +@pytest.mark.parametrize( + "rewrite_type, expected_layers, quant", + [ + ["fully-connected", [keras.layers.Reshape, keras.layers.Dense], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], False], + ["fully-connected-clustering", [ClusterWeights, ClusterWeights], True], + ["fully-connected-sparsity24", [PruneLowMagnitude, PruneLowMagnitude], False], + ["fully-connected-sparsity24", [PruneLowMagnitude, PruneLowMagnitude], True], + ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], False], + ["conv2d-clustering", [ClusterWeights, ClusterWeights, ClusterWeights], True], + [ + "conv2d-sparsity24", + [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude], + False, + ], + [ + "conv2d-sparsity24", + [PruneLowMagnitude, PruneLowMagnitude, PruneLowMagnitude], + True, + ], + ], +) +def test_rewriting_optimizer( # pylint: disable=too-many-locals test_tflite_model_fp32: Path, test_tfrecord_fp32: Path, + test_tflite_model: Path, + test_tfrecord: Path, + rewrite_type: str, + expected_layers: list[object], + quant: bool, ) -> None: - """Test fc_layer rewrite process with rewrite type fully-connected.""" + """Test the rewrite process with all rewrite types.""" + + tfrecord = test_tfrecord if quant else test_tfrecord_fp32 + tflite_model = test_tflite_model if quant else test_tflite_model_fp32 + config_obj = RewriteConfiguration( - "fully-connected", - ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"], - test_tfrecord_fp32, + rewrite_type, + ["sequential/flatten/Reshape", "StatefulPartitionedCall:0"] + if "fully-connected" in rewrite_type + else [ + "sequential/conv1/Relu;sequential/conv1/Conv2D", + "sequential/conv2/Relu;sequential/conv2/Conv2D", + ], + tfrecord, train_params=MockTrainingParameters(), ) - test_obj = RewritingOptimizer(test_tflite_model_fp32, config_obj) + test_obj = RewritingOptimizer(tflite_model, config_obj) + rewrite_function = RewritingOptimizer.registry.items[ + test_obj.optimizer_configuration.optimization_target + ] + # Input, output shape does not matter, just need the test the layers are as expected + rewrite_model = ( + rewrite_function(input_shape=(28, 28, 1), output_shape=12) + if "fully-connected" in rewrite_type + else rewrite_function( + input_shape=np.array([28, 28, 3]), output_shape=np.array([14, 14, 3]) + ) + ) + for idx, layer in enumerate(rewrite_model.layers): + assert isinstance(layer, expected_layers[idx]) # type: ignore + test_obj.apply_optimization() trained_model = test_obj.get_model() @@ -87,11 +345,11 @@ def test_rewriting_optimizer( def test_register_rewrite_function() -> None: - """Test adding rewrite functions and verify the are reported via the registry.""" + """Test adding rewrite functions and verify they are reported via the registry.""" registry = RewriteRegistry() - rewrite1 = Rewrite("r1", cast(RewriteCallable, lambda: 1)) - rewrite2 = Rewrite("r2", cast(RewriteCallable, lambda: 2)) + rewrite1 = TestRewrite("r1", cast(RewriteCallable, lambda: 1)) + rewrite2 = TestRewrite("r2", cast(RewriteCallable, lambda: 2)) registry.register_rewrite(rewrite1) registry.register_rewrite(rewrite2) @@ -100,38 +358,13 @@ def test_register_rewrite_function() -> None: def test_builtin_rewrite_names() -> None: """Test if all builtin rewrites are properly registered and returned.""" - assert RewritingOptimizer.builtin_rewrite_names() == ["fully-connected"] - - -def test_rewrite_function_autoload() -> None: - """Test rewrite function loading.""" - function_name = "tests.test_nn_rewrite_core_rewrite.mock_rewrite_function" - rewrite = DynamicallyLoadedRewrite(name="mock_rewrite", function_name=function_name) - assert rewrite.name == "mock_rewrite" - - assert rewrite.function is not mock_rewrite_function - assert rewrite.load_function(function_name) is mock_rewrite_function - assert rewrite.function is mock_rewrite_function - - -def test_rewrite_function_autoload_fail() -> None: - """Test rewrite function loading failure.""" - function_name = "invalid_module.invalid_function" - rewrite = DynamicallyLoadedRewrite( - name="mock_rewrite", - function_name="invalid_module.invalid_function", - ) - assert rewrite.name == "mock_rewrite" - - with pytest.raises(Exception) as exc_info: - rewrite.load_function(function_name) - - message = exc_info.value.args[0] - - assert message == ( - "Unable to load rewrite function 'invalid_module.invalid_function'" - " for 'mock_rewrite'." - ) + assert RewritingOptimizer.builtin_rewrite_names() == [ + "conv2d-clustering", + "conv2d-sparsity24", + "fully-connected", + "fully-connected-clustering", + "fully-connected-sparsity24", + ] def test_rewrite_configuration_train_params( diff --git a/tests/test_nn_rewrite_core_train.py b/tests/test_nn_rewrite_core_train.py index 6d24133..94c99ff 100644 --- a/tests/test_nn_rewrite_core_train.py +++ b/tests/test_nn_rewrite_core_train.py @@ -20,6 +20,7 @@ from mlia.nn.rewrite.core.train import LearningRateSchedule from mlia.nn.rewrite.core.train import mixup from mlia.nn.rewrite.core.train import train from mlia.nn.rewrite.core.train import TrainingParameters +from tests.test_nn_rewrite_core_rewrite import TestRewrite from tests.utils.rewrite import MockTrainingParameters @@ -53,18 +54,23 @@ def check_train( """Test the train() function.""" with TemporaryDirectory() as tmp_dir: output_file = Path(tmp_dir, "out.tflite") + mock_rewrite = TestRewrite("replace", replace_fully_connected_with_conv) result = train( source_model=str(tflite_model), unmodified_model=str(tflite_model) if use_unmodified_model else None, output_model=str(output_file), input_tfrec=str(tfrecord), - replace_fn=replace_fully_connected_with_conv, + rewrite=mock_rewrite, + is_qat=False, input_tensors=["sequential/flatten/Reshape"], output_tensors=["StatefulPartitionedCall:0"], train_params=train_params, ) - assert len(result) == 2 - assert all(res >= 0.0 for res in result[0]), f"Results out of bound: {result}" + + assert len(result[0][0]) == 2 + assert all( + res >= 0.0 for res in result[0][0] + ), f"Results out of bound: {result}" assert output_file.is_file() if quantized: @@ -229,3 +235,17 @@ def test_augment_fn_twins(augmentations: tuple, expected_error: Any) -> None: with expected_error: fn_twins = augment_fn_twins(dataset, augmentations) # type: ignore assert len(fn_twins) == 2 + + +def test_train_checkpoint( + test_tflite_model: Path, + test_tfrecord: Path, +) -> None: + """Test the train() function with valid checkpoint parameters.""" + check_train( + tflite_model=test_tflite_model, + tfrecord=test_tfrecord, + train_params=MockTrainingParameters(steps=64, checkpoint_at=[24, 32]), + use_unmodified_model=False, + quantized=True, + ) diff --git a/tests/test_nn_rewrite_library_helper_functions.py b/tests/test_nn_rewrite_library_helper_functions.py new file mode 100644 index 0000000..c3117f0 --- /dev/null +++ b/tests/test_nn_rewrite_library_helper_functions.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module mlia.nn.rewrite.library.helper_functions.""" +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest +from keras.api._v2 import keras # Temporary workaround for now: MLIA-1107 + +from mlia.nn.rewrite.library.helper_functions import compute_conv2d_parameters + + +def compute_conv_output( + input_data: np.ndarray, input_shape: np.ndarray, conv_parameters: dict[str, Any] +) -> np.ndarray: + """Compute the output of a conv layer for testing.""" + test_model = keras.Sequential( + [ + keras.layers.InputLayer(input_shape=input_shape), + keras.layers.Conv2D(**conv_parameters), + ] + ) + output = test_model(input_data) + return np.array(output.shape[1:]) + + +@pytest.mark.parametrize( + "input_shape, output_shape", + [ + (np.array([32, 32, 3]), np.array([16, 16, 3])), + (np.array([32, 32, 3]), np.array([8, 8, 3])), + (np.array([32, 32, 3]), np.array([8, 16, 3])), + (np.array([25, 10, 3]), np.array([13, 5, 3])), + (np.array([25, 10, 3]), np.array([7, 5, 3])), + (np.array([25, 10, 3]), np.array([6, 4, 3])), + (np.array([25, 10, 3]), np.array([5, 5, 3])), + ], +) +def test_compute_conv2d_parameters( + input_shape: np.ndarray, output_shape: np.ndarray +) -> None: + """Test to check compute_conv2d_parameters works as expected.""" + conv_parameters = compute_conv2d_parameters( + input_shape=input_shape, output_shape=output_shape + ) + computed_output_shape = compute_conv_output( + np.random.rand(1, *input_shape), input_shape, conv_parameters + ) + assert np.equal(computed_output_shape, output_shape).all() diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py index aac07b4..4095076 100644 --- a/tests/test_nn_select.py +++ b/tests/test_nn_select.py @@ -183,11 +183,11 @@ def test_get_optimizer( @pytest.mark.parametrize( "rewrite_parameters", - [[None], [{"batch_size": 64, "learning_rate": 0.003}]], + [None, {"batch_size": 64, "learning_rate": 0.003}], ) @pytest.mark.skip_set_training_steps def test_get_optimizer_training_parameters( - rewrite_parameters: list[dict], test_tflite_model: Path + rewrite_parameters: dict | None, test_tflite_model: Path ) -> None: """Test function get_optimzer with various combinations of parameters.""" config = OptimizationSettings( @@ -198,20 +198,18 @@ def test_get_optimizer_training_parameters( ) optimizer = cast( RewritingOptimizer, - get_optimizer(test_tflite_model, config, list(rewrite_parameters)), + get_optimizer(test_tflite_model, config, rewrite_parameters), ) - assert len(rewrite_parameters) == 1 - assert isinstance( optimizer.optimizer_configuration.train_params, TrainingParameters ) - if not rewrite_parameters[0]: + if not rewrite_parameters: assert asdict(TrainingParameters()) == asdict( optimizer.optimizer_configuration.train_params ) else: - assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict( + assert asdict(TrainingParameters()) | rewrite_parameters == asdict( optimizer.optimizer_configuration.train_params ) diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py index 59d54b5..7bb57c3 100644 --- a/tests/test_target_cortex_a_advisor.py +++ b/tests/test_target_cortex_a_advisor.py @@ -47,7 +47,7 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: }, ] ], - "training_parameters": [None], + "training_parameters": None, }, } diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py index cc47321..020acc5 100644 --- a/tests/test_target_tosa_advisor.py +++ b/tests/test_target_tosa_advisor.py @@ -47,7 +47,7 @@ def test_configure_and_get_tosa_advisor( }, ] ], - "training_parameters": [None], + "training_parameters": None, }, "tosa_inference_advisor": { "model": str(test_tflite_model), |