diff options
Diffstat (limited to 'src/mlia/target/common/optimization.py')
-rw-r--r-- | src/mlia/target/common/optimization.py | 36 |
1 files changed, 31 insertions, 5 deletions
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py index 5f359c5..8c5d184 100644 --- a/src/mlia/target/common/optimization.py +++ b/src/mlia/target/common/optimization.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Data collector support for performance optimizations.""" from __future__ import annotations @@ -50,6 +50,8 @@ class OptimizingDataCollector(ContextAwareDataCollector): optimizations = self._get_optimization_settings(self.context) + training_parameters = self._get_training_settings(self.context) + if not optimizations or optimizations == [[]]: raise FunctionalityNotSupportedError( reason="No optimization targets provided", @@ -75,17 +77,22 @@ class OptimizingDataCollector(ContextAwareDataCollector): model = self.model # type: ignore optimizers: list[Callable] = [ - partial(self.optimize_model, opts) for opts in opt_settings + partial(self.optimize_model, opts, training_parameters) + for opts in opt_settings ] return self.optimize_and_estimate_performance(model, optimizers, opt_settings) def optimize_model( - self, opt_settings: list[OptimizationSettings], model: KerasModel | TFLiteModel + self, + opt_settings: list[OptimizationSettings], + training_parameters: list[dict | None], + model: KerasModel | TFLiteModel, ) -> Any: """Run optimization.""" - optimizer = get_optimizer(model, opt_settings) - + optimizer = get_optimizer( + model, opt_settings, training_parameters=training_parameters + ) opts_as_str = ", ".join(str(opt) for opt in opt_settings) logger.info("Applying model optimizations - [%s]", opts_as_str) optimizer.apply_optimization() @@ -116,6 +123,16 @@ class OptimizingDataCollector(ContextAwareDataCollector): context=context, ) + def _get_training_settings(self, context: Context) -> list[dict]: + """Get optimization settings.""" + return self.get_parameter( # type: ignore + OptimizingDataCollector.name(), + "training_parameters", + expected_type=list, + expected=False, + context=context, + ) + @staticmethod def _parse_optimization_params( optimizations: list[list[dict]], @@ -210,10 +227,19 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) - if not is_list_of(optimization_targets, dict): raise TypeError("Optimization targets value has wrong format.") + rewrite_parameters = extra_args.get("optimization_profile") + if not rewrite_parameters: + training_parameters = None + else: + if not isinstance(rewrite_parameters, dict): + raise TypeError("Training Parameter values has wrong format.") + training_parameters = extra_args["optimization_profile"].get("training") + advisor_parameters.update( { "common_optimizations": { "optimizations": [optimization_targets], + "training_parameters": [training_parameters], }, } ) |