aboutsummaryrefslogtreecommitdiff
path: root/tests/test_common_optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_common_optimization.py')
-rw-r--r--tests/test_common_optimization.py106
1 files changed, 103 insertions, 3 deletions
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 599610d..05a5b55 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -1,15 +1,21 @@
-# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the common optimization module."""
+from contextlib import ExitStack as does_not_raises
from pathlib import Path
+from typing import Any
from unittest.mock import MagicMock
import pytest
from mlia.core.context import ExecutionContext
from mlia.nn.common import Optimizer
+from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
+from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.common.optimization import OptimizingDataCollector
+from mlia.target.config import load_profile
from mlia.target.config import TargetProfile
@@ -46,8 +52,14 @@ def test_optimizing_data_collector(
{"optimization_type": "fake", "optimization_target": 42},
]
]
+ training_parameters = {"batch_size": 32, "show_progress": False}
context = ExecutionContext(
- config_parameters={"common_optimizations": {"optimizations": optimizations}}
+ config_parameters={
+ "common_optimizations": {
+ "optimizations": optimizations,
+ "training_parameters": [training_parameters],
+ }
+ }
)
target_profile = MagicMock(spec=TargetProfile)
@@ -61,7 +73,95 @@ def test_optimizing_data_collector(
collector = OptimizingDataCollector(test_keras_model, target_profile)
+ optimize_model_mock = MagicMock(side_effect=collector.optimize_model)
+ monkeypatch.setattr(
+ "mlia.target.common.optimization.OptimizingDataCollector.optimize_model",
+ optimize_model_mock,
+ )
+ opt_settings = [
+ [
+ OptimizationSettings(
+ item.get("optimization_type"), # type: ignore
+ item.get("optimization_target"), # type: ignore
+ item.get("layers_to_optimize"), # type: ignore
+ item.get("dataset"), # type: ignore
+ )
+ for item in opt_configuration
+ ]
+ for opt_configuration in optimizations
+ ]
+
collector.set_context(context)
collector.collect_data()
-
+ assert optimize_model_mock.call_args.args[0] == opt_settings[0]
+ assert optimize_model_mock.call_args.args[1] == [training_parameters]
assert fake_optimizer.invocation_count == 1
+
+
+@pytest.mark.parametrize(
+ "extra_args, error_to_raise",
+ [
+ (
+ {
+ "optimization_targets": [
+ {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ }
+ ],
+ },
+ does_not_raises(),
+ ),
+ (
+ {
+ "optimization_profile": load_profile(
+ "src/mlia/resources/optimization_profiles/optimization.toml"
+ )
+ },
+ does_not_raises(),
+ ),
+ (
+ {
+ "optimization_targets": {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ },
+ },
+ pytest.raises(
+ TypeError, match="Optimization targets value has wrong format."
+ ),
+ ),
+ (
+ {"optimization_profile": [32, 1e-3, True, 48000, "cosine", 1, 0]},
+ pytest.raises(
+ TypeError, match="Training Parameter values has wrong format."
+ ),
+ ),
+ ],
+)
+def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -> None:
+ """Test to check that optimization_targets and optimization_profiles are
+ correctly parsed."""
+ advisor_parameters: dict = {}
+
+ with error_to_raise:
+ add_common_optimization_params(advisor_parameters, extra_args)
+ if not extra_args.get("optimization_targets"):
+ assert advisor_parameters["common_optimizations"]["optimizations"] == [
+ _DEFAULT_OPTIMIZATION_TARGETS
+ ]
+ else:
+ assert advisor_parameters["common_optimizations"]["optimizations"] == [
+ extra_args["optimization_targets"]
+ ]
+
+ if not extra_args.get("optimization_profile"):
+ assert advisor_parameters["common_optimizations"][
+ "training_parameters"
+ ] == [None]
+ else:
+ assert advisor_parameters["common_optimizations"][
+ "training_parameters"
+ ] == list(extra_args["optimization_profile"].values())