aboutsummaryrefslogtreecommitdiff
path: root/src/mlia
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia')
-rw-r--r--src/mlia/resources/optimization_profiles/optimization.toml1
-rw-r--r--src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml13
-rw-r--r--src/mlia/target/common/optimization.py55
3 files changed, 68 insertions, 1 deletions
diff --git a/src/mlia/resources/optimization_profiles/optimization.toml b/src/mlia/resources/optimization_profiles/optimization.toml
index 623a763..42b64f0 100644
--- a/src/mlia/resources/optimization_profiles/optimization.toml
+++ b/src/mlia/resources/optimization_profiles/optimization.toml
@@ -7,5 +7,6 @@ learning_rate = 1e-3
show_progress = true
steps = 48000
learning_rate_schedule = "cosine"
+augmentations = "gaussian"
num_procs = 1
num_threads = 0
diff --git a/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml b/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml
new file mode 100644
index 0000000..5d1f917
--- /dev/null
+++ b/src/mlia/resources/optimization_profiles/optimization_custom_augmentation.toml
@@ -0,0 +1,13 @@
+# SPDX-FileCopyrightText: Copyright 2024, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+[training]
+batch_size = 32
+learning_rate = 1e-3
+show_progress = true
+steps = 48000
+learning_rate_schedule = "cosine"
+num_procs = 1
+num_threads = 0
+augmentations.gaussian_strength = 0.1
+augmentations.mixup_strength = 0.1
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 1423189..a139a7d 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -17,6 +17,7 @@ from mlia.core.errors import FunctionalityNotSupportedError
from mlia.core.performance import estimate_performance
from mlia.core.performance import P
from mlia.core.performance import PerformanceEstimator
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.select import get_optimizer
from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.config import get_keras_model
@@ -218,7 +219,54 @@ _DEFAULT_OPTIMIZATION_TARGETS = [
]
-def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -> None:
+def parse_augmentations(
+ augmentations: dict | str | None,
+) -> tuple[float | None, float | None]:
+ """Parse augmentations from optimization-profile and return a valid tuple."""
+ if isinstance(augmentations, str):
+ match_augmentation = AUGMENTATION_PRESETS.get(augmentations)
+ if not match_augmentation:
+ match_augmentation = AUGMENTATION_PRESETS["none"]
+ return match_augmentation
+ if isinstance(augmentations, dict):
+ augmentation_keys_test_for_valid = list(augmentations.keys())
+ augmentation_keys_test_for_float = list(augmentations.keys())
+ valid_keys = ["mixup_strength", "gaussian_strength"]
+ tuple_to_return = []
+ for valid_key in valid_keys.copy():
+ if augmentations.get(valid_key):
+ del augmentation_keys_test_for_valid[
+ augmentation_keys_test_for_valid.index(valid_key)
+ ]
+ if isinstance(augmentations.get(valid_key), float):
+ tuple_to_return.append(augmentations[valid_key])
+ del augmentation_keys_test_for_float[
+ augmentation_keys_test_for_float.index(valid_key)
+ ]
+ else:
+ tuple_to_return.append(None)
+ else:
+ tuple_to_return.append(None)
+
+ if len(augmentation_keys_test_for_valid) > 0:
+ logger.warning(
+ "Warning! Expected augmentation parameters to be 'gaussian_strength' "
+ "and/or 'mixup_strength' got %s. "
+ "Removing invalid augmentations",
+ str(list(augmentations.keys())),
+ )
+ elif len(augmentation_keys_test_for_float) > 0:
+ logger.warning(
+ "Warning! Not all augmentation parameters were floats, "
+ "removing non-float augmentations"
+ )
+ return (tuple_to_return[0], tuple_to_return[1])
+ return AUGMENTATION_PRESETS["none"]
+
+
+def add_common_optimization_params( # pylint: disable=too-many-branches
+ advisor_parameters: dict, extra_args: dict
+) -> None:
"""Add common optimization parameters."""
optimization_targets = extra_args.get("optimization_targets")
if not optimization_targets:
@@ -234,6 +282,11 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
raise TypeError("Training Parameter values has wrong format.")
training_parameters = extra_args["optimization_profile"].get("training")
+ if training_parameters:
+ training_parameters["augmentations"] = parse_augmentations(
+ training_parameters.get("augmentations")
+ )
+
advisor_parameters.update(
{
"common_optimizations": {