diff options
Diffstat (limited to 'tests/test_nn_select.py')
-rw-r--r-- | tests/test_nn_select.py | 69 |
1 files changed, 66 insertions, 3 deletions
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py index 31628d2..92b7a3d 100644 --- a/tests/test_nn_select.py +++ b/tests/test_nn_select.py @@ -1,16 +1,21 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module select.""" from __future__ import annotations from contextlib import ExitStack as does_not_raise +from dataclasses import asdict from pathlib import Path from typing import Any +from typing import cast import pytest import tensorflow as tf from mlia.core.errors import ConfigurationError +from mlia.nn.rewrite.core.rewrite import RewriteConfiguration +from mlia.nn.rewrite.core.rewrite import RewritingOptimizer +from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.select import get_optimizer from mlia.nn.select import MultiStageOptimizer from mlia.nn.select import OptimizationSettings @@ -135,6 +140,23 @@ from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration MultiStageOptimizer, "pruning: 0.5 - clustering: 32", ), + ( + OptimizationSettings( + optimization_type="rewrite", + optimization_target="fully_connected", # type: ignore + layers_to_optimize=None, + dataset=None, + ), + does_not_raise(), + RewritingOptimizer, + "rewrite: fully_connected", + ), + ( + RewriteConfiguration("fully_connected"), + does_not_raise(), + RewritingOptimizer, + "rewrite: fully_connected", + ), ], ) def test_get_optimizer( @@ -143,17 +165,58 @@ def test_get_optimizer( expected_type: type, expected_config: str, test_keras_model: Path, + test_tflite_model: Path, ) -> None: """Test function get_optimzer.""" - model = tf.keras.models.load_model(str(test_keras_model)) - with expected_error: + if ( + isinstance(config, OptimizationSettings) + and config.optimization_type == "rewrite" + ) or isinstance(config, RewriteConfiguration): + model = test_tflite_model + else: + model = tf.keras.models.load_model(str(test_keras_model)) optimizer = get_optimizer(model, config) assert isinstance(optimizer, expected_type) assert optimizer.optimization_config() == expected_config @pytest.mark.parametrize( + "rewrite_parameters", + [[None], [{"batch_size": 64, "learning_rate": 0.003}]], +) +@pytest.mark.skip_set_training_steps +def test_get_optimizer_training_parameters( + rewrite_parameters: list[dict], test_tflite_model: Path +) -> None: + """Test function get_optimzer with various combinations of parameters.""" + config = OptimizationSettings( + optimization_type="rewrite", + optimization_target="fully_connected", # type: ignore + layers_to_optimize=None, + dataset=None, + ) + optimizer = cast( + RewritingOptimizer, + get_optimizer(test_tflite_model, config, list(rewrite_parameters)), + ) + + assert len(rewrite_parameters) == 1 + + assert isinstance( + optimizer.optimizer_configuration.train_params, TrainingParameters + ) + if not rewrite_parameters[0]: + assert asdict(TrainingParameters()) == asdict( + optimizer.optimizer_configuration.train_params + ) + else: + assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict( + optimizer.optimizer_configuration.train_params + ) + + +@pytest.mark.parametrize( "params, expected_result", [ ( |