aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/common/optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/common/optimization.py')
-rw-r--r--src/mlia/target/common/optimization.py68
1 files changed, 60 insertions, 8 deletions
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 8c5d184..a139a7d 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -17,6 +17,7 @@ from mlia.core.errors import FunctionalityNotSupportedError
from mlia.core.performance import estimate_performance
from mlia.core.performance import P
from mlia.core.performance import PerformanceEstimator
+from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS
from mlia.nn.select import get_optimizer
from mlia.nn.select import OptimizationSettings
from mlia.nn.tensorflow.config import get_keras_model
@@ -86,7 +87,7 @@ class OptimizingDataCollector(ContextAwareDataCollector):
def optimize_model(
self,
opt_settings: list[OptimizationSettings],
- training_parameters: list[dict | None],
+ training_parameters: dict | None,
model: KerasModel | TFLiteModel,
) -> Any:
"""Run optimization."""
@@ -123,12 +124,12 @@ class OptimizingDataCollector(ContextAwareDataCollector):
context=context,
)
- def _get_training_settings(self, context: Context) -> list[dict]:
+ def _get_training_settings(self, context: Context) -> dict:
"""Get optimization settings."""
return self.get_parameter( # type: ignore
OptimizingDataCollector.name(),
"training_parameters",
- expected_type=list,
+ expected_type=dict,
expected=False,
context=context,
)
@@ -218,7 +219,54 @@ _DEFAULT_OPTIMIZATION_TARGETS = [
]
-def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -> None:
+def parse_augmentations(
+ augmentations: dict | str | None,
+) -> tuple[float | None, float | None]:
+ """Parse augmentations from optimization-profile and return a valid tuple."""
+ if isinstance(augmentations, str):
+ match_augmentation = AUGMENTATION_PRESETS.get(augmentations)
+ if not match_augmentation:
+ match_augmentation = AUGMENTATION_PRESETS["none"]
+ return match_augmentation
+ if isinstance(augmentations, dict):
+ augmentation_keys_test_for_valid = list(augmentations.keys())
+ augmentation_keys_test_for_float = list(augmentations.keys())
+ valid_keys = ["mixup_strength", "gaussian_strength"]
+ tuple_to_return = []
+ for valid_key in valid_keys.copy():
+ if augmentations.get(valid_key):
+ del augmentation_keys_test_for_valid[
+ augmentation_keys_test_for_valid.index(valid_key)
+ ]
+ if isinstance(augmentations.get(valid_key), float):
+ tuple_to_return.append(augmentations[valid_key])
+ del augmentation_keys_test_for_float[
+ augmentation_keys_test_for_float.index(valid_key)
+ ]
+ else:
+ tuple_to_return.append(None)
+ else:
+ tuple_to_return.append(None)
+
+ if len(augmentation_keys_test_for_valid) > 0:
+ logger.warning(
+ "Warning! Expected augmentation parameters to be 'gaussian_strength' "
+ "and/or 'mixup_strength' got %s. "
+ "Removing invalid augmentations",
+ str(list(augmentations.keys())),
+ )
+ elif len(augmentation_keys_test_for_float) > 0:
+ logger.warning(
+ "Warning! Not all augmentation parameters were floats, "
+ "removing non-float augmentations"
+ )
+ return (tuple_to_return[0], tuple_to_return[1])
+ return AUGMENTATION_PRESETS["none"]
+
+
+def add_common_optimization_params( # pylint: disable=too-many-branches
+ advisor_parameters: dict, extra_args: dict
+) -> None:
"""Add common optimization parameters."""
optimization_targets = extra_args.get("optimization_targets")
if not optimization_targets:
@@ -228,18 +276,22 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
raise TypeError("Optimization targets value has wrong format.")
rewrite_parameters = extra_args.get("optimization_profile")
- if not rewrite_parameters:
- training_parameters = None
- else:
+ training_parameters = None
+ if rewrite_parameters:
if not isinstance(rewrite_parameters, dict):
raise TypeError("Training Parameter values has wrong format.")
training_parameters = extra_args["optimization_profile"].get("training")
+ if training_parameters:
+ training_parameters["augmentations"] = parse_augmentations(
+ training_parameters.get("augmentations")
+ )
+
advisor_parameters.update(
{
"common_optimizations": {
"optimizations": [optimization_targets],
- "training_parameters": [training_parameters],
+ "training_parameters": training_parameters,
},
}
)