diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_common_optimization.py | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 341e0d2..58ea8af 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the common optimization module.""" +from __future__ import annotations + from contextlib import ExitStack as does_not_raises from pathlib import Path from typing import Any @@ -15,6 +17,7 @@ from mlia.nn.tensorflow.config import TFLiteModel from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS from mlia.target.common.optimization import add_common_optimization_params from mlia.target.common.optimization import OptimizingDataCollector +from mlia.target.common.optimization import parse_augmentations from mlia.target.config import load_profile from mlia.target.config import TargetProfile @@ -167,3 +170,58 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) - advisor_parameters["common_optimizations"]["training_parameters"] == extra_args["optimization_profile"]["training"] ) + + +@pytest.mark.parametrize( + "augmentations, expected_output", + [ + ( + {"gaussian_strength": 1.0, "mixup_strength": 1.0}, + (1.0, 1.0), + ), + ( + {"gaussian_strength": 1.0}, + (None, 1.0), + ), + ( + {"Wrong param": 1.0, "mixup_strength": 1.0}, + (1.0, None), + ), + ( + {"Wrong param1": 1.0, "Wrong param2": 1.0}, + (None, None), + ), + ( + "gaussian", + (None, 1.0), + ), + ( + "mix_gaussian_large", + (2.0, 1.0), + ), + ( + "not in presets", + (None, None), + ), + ( + {"gaussian_strength": 1.0, "mixup_strength": 1.0, "mix2": 1.0}, + (1.0, 1.0), + ), + ( + {"gaussian_strength": "not a float", "mixup_strength": 1.0}, + (1.0, None), + ), + ( + None, + (None, None), + ), + ], +) +def test_parse_augmentations( + augmentations: dict | str | None, expected_output: tuple +) -> None: + """Check that augmentation parameters in optimization_profiles are + correctly parsed.""" + + augmentation_output = parse_augmentations(augmentations) + assert augmentation_output == expected_output |