aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-04-15 15:34:03 +0100
committerNathan Bailey <nathan.bailey@arm.com>2024-05-08 12:15:48 +0000
commit198ba5eed95677ddb9d1e8c4119062dd3412510a (patch)
treec37154ab0a30a0c76c8ad275aae8ee4e40e2a1b2
parent0999ba0f4381ce1e2e8b06a932bfe693692223e2 (diff)
downloadmlia-198ba5eed95677ddb9d1e8c4119062dd3412510a.tar.gz
feat: Enables augmentations via --optimization-profile
- Enables user to specify augmentations via the --optimization-profile switch - Can specify from pre-given examples or can provide each parameter manually - Updates README Resolves: MLIA-1147 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I9cbe71d85def6a8db9dc974adc4bcc8d90625505
-rw-r--r--README.md23
-rw-r--r--src/mlia/resources/optimization_profiles/optimization.toml1
-rw-r--r--src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml13
-rw-r--r--src/mlia/target/common/optimization.py55
-rw-r--r--tests/test_common_optimization.py58
5 files changed, 146 insertions, 4 deletions
diff --git a/README.md b/README.md
index 7d08a16..6c145d1 100644
--- a/README.md
+++ b/README.md
@@ -215,9 +215,26 @@ Training parameters for rewrites can be specified.
There are a number of predefined profiles:
-| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints |
-| :----------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: |
-| optimization | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None |
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Augmentations |
+| :----------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :-------------: |
+| optimization | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | "gaussian" |
+
+| Name | Batch Size | LR | Show Progress | Steps | LR Schedule | Num Procs | Num Threads | Checkpoints | Augmentations - gaussian_strength | Augmentations - mixup_strength |
+| :------------------------------: | :--------: | :--: | :-----------: | :---: | :---------: | :-------: | :---------: | :---------: | :-------------------------------: | :----------------------------: |
+| optimization_custom_augmentation | 32 | 1e-3 | True | 48000 | "cosine" | 1 | 0 | None | 0.1 | 0.1 |
+
+The augmentations consist of 2 parameters: mixup strength and gaussian strength.
+
+Augmenations can be selected from a number of pre-defined profiles (see the table below) or each individual parameter can be chosen (see optimization_custom_augmentation above for an example):
+
+| Name | MixUp Strength | Gaussian Strength |
+| :------------------: | :------------: | :---------------: |
+| "none" | None | None |
+| "gaussian" | None | 1.0 |
+| "mixup" | 1.0 | None |
+| "mixout" | 1.6 | None |
+| "mix_gaussian_large" | 2.0 | 1.0 |
+| "mix_gaussian_small" | 1.6 | 0.3 |
```bash
##### An example for using optimization Profiles
diff --git a/src/mlia/resources/optimization_profiles/optimization.toml b/src/mlia/resources/optimization_profiles/optimization.toml
index 623a763..42b64f0 100644
--- a/src/mlia/resources/optimization_profiles/optimization.toml
+++ b/src/mlia/resources/optimization_profiles/optimization.toml
@@ -7,5 +7,6 @@ learning_rate = 1e-3
show_progress = true
steps = 48000
learning_rate_schedule = "cosine"
+augmentations = "gaussian"
num_procs = 1
num_threads = 0
diff --git a/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml b/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml
new file mode 100644
index 0000000..5d1f917
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml
@@ -0,0 +1,13 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[training]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.1
+augmentations.mixup_strength = 0.1
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 1423189..a139a7d 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -17,6 +17,7 @@ from mlia.core.errors import FunctionalityNotSupportedError
from mlia.core.performance import estimate_performance
from mlia.core.performance import P
from mlia.core.performance import PerformanceEstimator
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.select import get_optimizer
from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.config import get_keras_model
@@ -218,7 +219,54 @@ _DEFAULT_OPTIMIZATION_TARGETS = [
]
-def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -> None:
+def parse_augmentations(
+ augmentations: dict | str | None,
+) -> tuple[float | None, float | None]:
+ """Parse augmentations from optimization-profile and return a valid tuple."""
+ if isinstance(augmentations, str):
+ match_augmentation = AUGMENTATION_PRESETS.get(augmentations)
+ if not match_augmentation:
+ match_augmentation = AUGMENTATION_PRESETS["none"]
+ return match_augmentation
+ if isinstance(augmentations, dict):
+ augmentation_keys_test_for_valid = list(augmentations.keys())
+ augmentation_keys_test_for_float = list(augmentations.keys())
+ valid_keys = ["mixup_strength", "gaussian_strength"]
+ tuple_to_return = []
+ for valid_key in valid_keys.copy():
+ if augmentations.get(valid_key):
+ del augmentation_keys_test_for_valid[
+ augmentation_keys_test_for_valid.index(valid_key)
+ ]
+ if isinstance(augmentations.get(valid_key), float):
+ tuple_to_return.append(augmentations[valid_key])
+ del augmentation_keys_test_for_float[
+ augmentation_keys_test_for_float.index(valid_key)
+ ]
+ else:
+ tuple_to_return.append(None)
+ else:
+ tuple_to_return.append(None)
+
+ if len(augmentation_keys_test_for_valid) > 0:
+ logger.warning(
+ "Warning! Expected augmentation parameters to be 'gaussian_strength' "
+ "and/or 'mixup_strength' got %s. "
+ "Removing invalid augmentations",
+ str(list(augmentations.keys())),
+ )
+ elif len(augmentation_keys_test_for_float) > 0:
+ logger.warning(
+ "Warning! Not all augmentation parameters were floats, "
+ "removing non-float augmentations"
+ )
+ return (tuple_to_return[0], tuple_to_return[1])
+ return AUGMENTATION_PRESETS["none"]
+
+
+def add_common_optimization_params( # pylint: disable=too-many-branches
+ advisor_parameters: dict, extra_args: dict
+) -> None:
"""Add common optimization parameters."""
optimization_targets = extra_args.get("optimization_targets")
if not optimization_targets:
@@ -234,6 +282,11 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
raise TypeError("Training Parameter values has wrong format.")
training_parameters = extra_args["optimization_profile"].get("training")
+ if training_parameters:
+ training_parameters["augmentations"] = parse_augmentations(
+ training_parameters.get("augmentations")
+ )
+
advisor_parameters.update(
{
"common_optimizations": {
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 341e0d2..58ea8af 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the common optimization module."""
+from __future__ import annotations
+
from contextlib import ExitStack as does_not_raises
from pathlib import Path
from typing import Any
@@ -15,6 +17,7 @@ from mlia.nn.tensorflow.config import TFLiteModel
from mlia.target.common.optimization import _DEFAULT_OPTIMIZATION_TARGETS
from mlia.target.common.optimization import add_common_optimization_params
from mlia.target.common.optimization import OptimizingDataCollector
+from mlia.target.common.optimization import parse_augmentations
from mlia.target.config import load_profile
from mlia.target.config import TargetProfile
@@ -167,3 +170,58 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -
advisor_parameters["common_optimizations"]["training_parameters"]
== extra_args["optimization_profile"]["training"]
)
+
+
+@pytest.mark.parametrize(
+ "augmentations, expected_output",
+ [
+ (
+ {"gaussian_strength": 1.0, "mixup_strength": 1.0},
+ (1.0, 1.0),
+ ),
+ (
+ {"gaussian_strength": 1.0},
+ (None, 1.0),
+ ),
+ (
+ {"Wrong param": 1.0, "mixup_strength": 1.0},
+ (1.0, None),
+ ),
+ (
+ {"Wrong param1": 1.0, "Wrong param2": 1.0},
+ (None, None),
+ ),
+ (
+ "gaussian",
+ (None, 1.0),
+ ),
+ (
+ "mix_gaussian_large",
+ (2.0, 1.0),
+ ),
+ (
+ "not in presets",
+ (None, None),
+ ),
+ (
+ {"gaussian_strength": 1.0, "mixup_strength": 1.0, "mix2": 1.0},
+ (1.0, 1.0),
+ ),
+ (
+ {"gaussian_strength": "not a float", "mixup_strength": 1.0},
+ (1.0, None),
+ ),
+ (
+ None,
+ (None, None),
+ ),
+ ],
+)
+def test_parse_augmentations(
+ augmentations: dict | str | None, expected_output: tuple
+) -> None:
+ """Check that augmentation parameters in optimization_profiles are
+ correctly parsed."""
+
+ augmentation_output = parse_augmentations(augmentations)
+ assert augmentation_output == expected_output