diff options
Diffstat (limited to 'tests/test_common_optimization.py')
-rw-r--r-- | tests/test_common_optimization.py | 106 |
1 files changed, 103 insertions, 3 deletions
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 599610d..05a5b55 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -1,15 +1,21 @@ -# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the common optimization module.""" +from contextlib import ExitStack as does_not_raises from pathlib import Path +from typing import Any from unittest.mock import MagicMock import pytest from mlia.core.context import ExecutionContext from mlia.nn.common import Optimizer +from mlia.nn.select import OptimizationSettings from mlia.nn.tensorflow.config import TFLiteModel +from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS +from mlia.target.common.optimization import add_common_optimization_params from mlia.target.common.optimization import OptimizingDataCollector +from mlia.target.config import load_profile from mlia.target.config import TargetProfile @@ -46,8 +52,14 @@ def test_optimizing_data_collector( {"optimization_type": "fake", "optimization_target": 42}, ] ] + training_parameters = {"batch_size": 32, "show_progress": False} context = ExecutionContext( - config_parameters={"common_optimizations": {"optimizations": optimizations}} + config_parameters={ + "common_optimizations": { + "optimizations": optimizations, + "training_parameters": [training_parameters], + } + } ) target_profile = MagicMock(spec=TargetProfile) @@ -61,7 +73,95 @@ def test_optimizing_data_collector( collector = OptimizingDataCollector(test_keras_model, target_profile) + optimize_model_mock = MagicMock(side_effect=collector.optimize_model) + monkeypatch.setattr( + "mlia.target.common.optimization.OptimizingDataCollector.optimize_model", + optimize_model_mock, + ) + opt_settings = [ + [ + OptimizationSettings( + item.get("optimization_type"), # type: ignore + item.get("optimization_target"), # type: ignore + item.get("layers_to_optimize"), # type: ignore + item.get("dataset"), # type: ignore + ) + for item in opt_configuration + ] + for opt_configuration in optimizations + ] + collector.set_context(context) collector.collect_data() - + assert optimize_model_mock.call_args.args[0] == opt_settings[0] + assert optimize_model_mock.call_args.args[1] == [training_parameters] assert fake_optimizer.invocation_count == 1 + + +@pytest.mark.parametrize( + "extra_args, error_to_raise", + [ + ( + { + "optimization_targets": [ + { + "optimization_type": "pruning", + "optimization_target": 0.5, + "layers_to_optimize": None, + } + ], + }, + does_not_raises(), + ), + ( + { + "optimization_profile": load_profile( + "src/mlia/resources/optimization_profiles/optimization.toml" + ) + }, + does_not_raises(), + ), + ( + { + "optimization_targets": { + "optimization_type": "pruning", + "optimization_target": 0.5, + "layers_to_optimize": None, + }, + }, + pytest.raises( + TypeError, match="Optimization targets value has wrong format." + ), + ), + ( + {"optimization_profile": [32, 1e-3, True, 48000, "cosine", 1, 0]}, + pytest.raises( + TypeError, match="Training Parameter values has wrong format." + ), + ), + ], +) +def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -> None: + """Test to check that optimization_targets and optimization_profiles are + correctly parsed.""" + advisor_parameters: dict = {} + + with error_to_raise: + add_common_optimization_params(advisor_parameters, extra_args) + if not extra_args.get("optimization_targets"): + assert advisor_parameters["common_optimizations"]["optimizations"] == [ + _DEFAULT_OPTIMIZATION_TARGETS + ] + else: + assert advisor_parameters["common_optimizations"]["optimizations"] == [ + extra_args["optimization_targets"] + ] + + if not extra_args.get("optimization_profile"): + assert advisor_parameters["common_optimizations"][ + "training_parameters" + ] == [None] + else: + assert advisor_parameters["common_optimizations"][ + "training_parameters" + ] == list(extra_args["optimization_profile"].values()) |