diff options
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r-- | src/mlia/nn/select.py | 33 |
1 files changed, 26 insertions, 7 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 6947206..20950cc 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for optimization selection.""" from __future__ import annotations @@ -117,6 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: tf.keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], + training_parameters: list[dict | None] | None = None, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -135,10 +136,14 @@ def get_optimizer( return RewritingOptimizer(model, config) if isinstance(config, OptimizationSettings): - return _get_optimizer(model, cast(OptimizationSettings, config)) + return _get_optimizer( + model, cast(OptimizationSettings, config), training_parameters + ) if is_list_of(config, OptimizationSettings): - return _get_optimizer(model, cast(List[OptimizationSettings], config)) + return _get_optimizer( + model, cast(List[OptimizationSettings], config), training_parameters + ) raise ConfigurationError(f"Unknown optimization configuration {config}") @@ -146,16 +151,18 @@ def get_optimizer( def _get_optimizer( model: tf.keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], + training_parameters: list[dict | None] | None = None, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] optimizer_configs = [] + for opt_type, opt_target, layers_to_optimize, dataset in optimization_settings: _check_optimizer_params(opt_type, opt_target) opt_config = _get_optimizer_configuration( - opt_type, opt_target, layers_to_optimize, dataset + opt_type, opt_target, layers_to_optimize, dataset, training_parameters ) optimizer_configs.append(opt_config) @@ -165,13 +172,23 @@ def _get_optimizer( return MultiStageOptimizer(model, optimizer_configs) -def _get_rewrite_train_params() -> TrainingParameters: +def _get_rewrite_params( + training_parameters: list[dict | None] | None = None, +) -> list: """Get the rewrite TrainingParameters. Return the default constructed TrainingParameters() per default, but can be overwritten in the unit tests. """ - return TrainingParameters() + if training_parameters is None: + return [TrainingParameters()] + + if training_parameters[0] is None: + train_params = TrainingParameters() + else: + train_params = TrainingParameters(**training_parameters[0]) + + return [train_params] def _get_optimizer_configuration( @@ -179,6 +196,7 @@ def _get_optimizer_configuration( optimization_target: int | float | str, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, + training_parameters: list[dict | None] | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -199,11 +217,12 @@ def _get_optimizer_configuration( if opt_type == "rewrite": if isinstance(optimization_target, str): + rewrite_params = _get_rewrite_params(training_parameters) return RewriteConfiguration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=_get_rewrite_train_params(), + train_params=rewrite_params[0], ) raise ConfigurationError( |