aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-04-15 15:34:03 +0100
committerNathan Bailey <nathan.bailey@arm.com>2024-05-08 12:15:48 +0000
commit198ba5eed95677ddb9d1e8c4119062dd3412510a (patch)
treec37154ab0a30a0c76c8ad275aae8ee4e40e2a1b2 /tests
parent0999ba0f4381ce1e2e8b06a932bfe693692223e2 (diff)
downloadmlia-198ba5eed95677ddb9d1e8c4119062dd3412510a.tar.gz
feat: Enables augmentations via --optimization-profile
- Enables user to specify augmentations via the --optimization-profile switch - Can specify from pre-given examples or can provide each parameter manually - Updates README Resolves: MLIA-1147 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I9cbe71d85def6a8db9dc974adc4bcc8d90625505
Diffstat (limited to 'tests')
-rw-r--r--tests/test_common_optimization.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 341e0d2..58ea8af 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the common optimization module."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raises
from pathlib import Path
from typing import Any
@@ -15,6 +17,7 @@ from mlia.nn.tensorflow.config import TFLiteModel
from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.common.optimization import OptimizingDataCollector
+from mlia.target.common.optimization import parse_augmentations
from mlia.target.config import load_profile
from mlia.target.config import TargetProfile
@@ -167,3 +170,58 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -
advisor_parameters["common_optimizations"]["training_parameters"]
== extra_args["optimization_profile"]["training"]
)
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_output",
+ [
+ (
+ {"gaussian_strength": 1.0, "mixup_strength": 1.0},
+ (1.0, 1.0),
+ ),
+ (
+ {"gaussian_strength": 1.0},
+ (None, 1.0),
+ ),
+ (
+ {"Wrong param": 1.0, "mixup_strength": 1.0},
+ (1.0, None),
+ ),
+ (
+ {"Wrong param1": 1.0, "Wrong param2": 1.0},
+ (None, None),
+ ),
+ (
+ "gaussian",
+ (None, 1.0),
+ ),
+ (
+ "mix_gaussian_large",
+ (2.0, 1.0),
+ ),
+ (
+ "not in presets",
+ (None, None),
+ ),
+ (
+ {"gaussian_strength": 1.0, "mixup_strength": 1.0, "mix2": 1.0},
+ (1.0, 1.0),
+ ),
+ (
+ {"gaussian_strength": "not a float", "mixup_strength": 1.0},
+ (1.0, None),
+ ),
+ (
+ None,
+ (None, None),
+ ),
+ ],
+)
+def test_parse_augmentations(
+ augmentations: dict | str | None, expected_output: tuple
+) -> None:
+ """Check that augmentation parameters in optimization_profiles are
+ correctly parsed."""
+
+ augmentation_output = parse_augmentations(augmentations)
+ assert augmentation_output == expected_output