diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-04-15 15:34:03 +0100 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-05-08 12:15:48 +0000 |
commit | 198ba5eed95677ddb9d1e8c4119062dd3412510a (patch) | |
tree | c37154ab0a30a0c76c8ad275aae8ee4e40e2a1b2 /tests | |
parent | 0999ba0f4381ce1e2e8b06a932bfe693692223e2 (diff) | |
download | mlia-198ba5eed95677ddb9d1e8c4119062dd3412510a.tar.gz |
feat: Enables augmentations via --optimization-profile
- Enables user to specify augmentations via the --optimization-profile switch
- Can specify from pre-given examples or can provide each parameter manually
- Updates README
Resolves: MLIA-1147
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I9cbe71d85def6a8db9dc974adc4bcc8d90625505
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_common_optimization.py | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 341e0d2..58ea8af 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for the common optimization module.""" +from __future__ import annotations + from contextlib import ExitStack as does_not_raises from pathlib import Path from typing import Any @@ -15,6 +17,7 @@ from mlia.nn.tensorflow.config import TFLiteModel from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS from mlia.target.common.optimization import add_common_optimization_params from mlia.target.common.optimization import OptimizingDataCollector +from mlia.target.common.optimization import parse_augmentations from mlia.target.config import load_profile from mlia.target.config import TargetProfile @@ -167,3 +170,58 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) - advisor_parameters["common_optimizations"]["training_parameters"] == extra_args["optimization_profile"]["training"] ) + + +@pytest.mark.parametrize( + "augmentations, expected_output", + [ + ( + {"gaussian_strength": 1.0, "mixup_strength": 1.0}, + (1.0, 1.0), + ), + ( + {"gaussian_strength": 1.0}, + (None, 1.0), + ), + ( + {"Wrong param": 1.0, "mixup_strength": 1.0}, + (1.0, None), + ), + ( + {"Wrong param1": 1.0, "Wrong param2": 1.0}, + (None, None), + ), + ( + "gaussian", + (None, 1.0), + ), + ( + "mix_gaussian_large", + (2.0, 1.0), + ), + ( + "not in presets", + (None, None), + ), + ( + {"gaussian_strength": 1.0, "mixup_strength": 1.0, "mix2": 1.0}, + (1.0, 1.0), + ), + ( + {"gaussian_strength": "not a float", "mixup_strength": 1.0}, + (1.0, None), + ), + ( + None, + (None, None), + ), + ], +) +def test_parse_augmentations( + augmentations: dict | str | None, expected_output: tuple +) -> None: + """Check that augmentation parameters in optimization_profiles are + correctly parsed.""" + + augmentation_output = parse_augmentations(augmentations) + assert augmentation_output == expected_output |