diff options
Diffstat (limited to 'src/mlia/target/common/optimization.py')
-rw-r--r-- | src/mlia/target/common/optimization.py | 68 |
1 files changed, 60 insertions, 8 deletions
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py index 8c5d184..a139a7d 100644 --- a/src/mlia/target/common/optimization.py +++ b/src/mlia/target/common/optimization.py @@ -17,6 +17,7 @@ from mlia.core.errors import FunctionalityNotSupportedError from mlia.core.performance import estimate_performance from mlia.core.performance import P from mlia.core.performance import PerformanceEstimator +from mlia.nn.rewrite.core.train import AUGMENTATION_PRESETS from mlia.nn.select import get_optimizer from mlia.nn.select import OptimizationSettings from mlia.nn.tensorflow.config import get_keras_model @@ -86,7 +87,7 @@ class OptimizingDataCollector(ContextAwareDataCollector): def optimize_model( self, opt_settings: list[OptimizationSettings], - training_parameters: list[dict | None], + training_parameters: dict | None, model: KerasModel | TFLiteModel, ) -> Any: """Run optimization.""" @@ -123,12 +124,12 @@ class OptimizingDataCollector(ContextAwareDataCollector): context=context, ) - def _get_training_settings(self, context: Context) -> list[dict]: + def _get_training_settings(self, context: Context) -> dict: """Get optimization settings.""" return self.get_parameter( # type: ignore OptimizingDataCollector.name(), "training_parameters", - expected_type=list, + expected_type=dict, expected=False, context=context, ) @@ -218,7 +219,54 @@ _DEFAULT_OPTIMIZATION_TARGETS = [ ] -def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -> None: +def parse_augmentations( + augmentations: dict | str | None, +) -> tuple[float | None, float | None]: + """Parse augmentations from optimization-profile and return a valid tuple.""" + if isinstance(augmentations, str): + match_augmentation = AUGMENTATION_PRESETS.get(augmentations) + if not match_augmentation: + match_augmentation = AUGMENTATION_PRESETS["none"] + return match_augmentation + if isinstance(augmentations, dict): + augmentation_keys_test_for_valid = list(augmentations.keys()) + augmentation_keys_test_for_float = list(augmentations.keys()) + valid_keys = ["mixup_strength", "gaussian_strength"] + tuple_to_return = [] + for valid_key in valid_keys.copy(): + if augmentations.get(valid_key): + del augmentation_keys_test_for_valid[ + augmentation_keys_test_for_valid.index(valid_key) + ] + if isinstance(augmentations.get(valid_key), float): + tuple_to_return.append(augmentations[valid_key]) + del augmentation_keys_test_for_float[ + augmentation_keys_test_for_float.index(valid_key) + ] + else: + tuple_to_return.append(None) + else: + tuple_to_return.append(None) + + if len(augmentation_keys_test_for_valid) > 0: + logger.warning( + "Warning! Expected augmentation parameters to be 'gaussian_strength' " + "and/or 'mixup_strength' got %s. " + "Removing invalid augmentations", + str(list(augmentations.keys())), + ) + elif len(augmentation_keys_test_for_float) > 0: + logger.warning( + "Warning! Not all augmentation parameters were floats, " + "removing non-float augmentations" + ) + return (tuple_to_return[0], tuple_to_return[1]) + return AUGMENTATION_PRESETS["none"] + + +def add_common_optimization_params( # pylint: disable=too-many-branches + advisor_parameters: dict, extra_args: dict +) -> None: """Add common optimization parameters.""" optimization_targets = extra_args.get("optimization_targets") if not optimization_targets: @@ -228,18 +276,22 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) - raise TypeError("Optimization targets value has wrong format.") rewrite_parameters = extra_args.get("optimization_profile") - if not rewrite_parameters: - training_parameters = None - else: + training_parameters = None + if rewrite_parameters: if not isinstance(rewrite_parameters, dict): raise TypeError("Training Parameter values has wrong format.") training_parameters = extra_args["optimization_profile"].get("training") + if training_parameters: + training_parameters["augmentations"] = parse_augmentations( + training_parameters.get("augmentations") + ) + advisor_parameters.update( { "common_optimizations": { "optimizations": [optimization_targets], - "training_parameters": [training_parameters], + "training_parameters": training_parameters, }, } ) |