aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_common_optimization.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 341e0d2..58ea8af 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the common optimization module."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raises
from pathlib import Path
from typing import Any
@@ -15,6 +17,7 @@ from mlia.nn.tensorflow.config import TFLiteModel
from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.common.optimization import OptimizingDataCollector
+from mlia.target.common.optimization import parse_augmentations
from mlia.target.config import load_profile
from mlia.target.config import TargetProfile
@@ -167,3 +170,58 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -
advisor_parameters["common_optimizations"]["training_parameters"]
== extra_args["optimization_profile"]["training"]
)
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_output",
+ [
+ (
+ {"gaussian_strength": 1.0, "mixup_strength": 1.0},
+ (1.0, 1.0),
+ ),
+ (
+ {"gaussian_strength": 1.0},
+ (None, 1.0),
+ ),
+ (
+ {"Wrong param": 1.0, "mixup_strength": 1.0},
+ (1.0, None),
+ ),
+ (
+ {"Wrong param1": 1.0, "Wrong param2": 1.0},
+ (None, None),
+ ),
+ (
+ "gaussian",
+ (None, 1.0),
+ ),
+ (
+ "mix_gaussian_large",
+ (2.0, 1.0),
+ ),
+ (
+ "not in presets",
+ (None, None),
+ ),
+ (
+ {"gaussian_strength": 1.0, "mixup_strength": 1.0, "mix2": 1.0},
+ (1.0, 1.0),
+ ),
+ (
+ {"gaussian_strength": "not a float", "mixup_strength": 1.0},
+ (1.0, None),
+ ),
+ (
+ None,
+ (None, None),
+ ),
+ ],
+)
+def test_parse_augmentations(
+ augmentations: dict | str | None, expected_output: tuple
+) -> None:
+ """Check that augmentation parameters in optimization_profiles are
+ correctly parsed."""
+
+ augmentation_output = parse_augmentations(augmentations)
+ assert augmentation_output == expected_output