diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-02-15 14:50:58 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-14 15:45:40 +0000 |
commit | 0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch) | |
tree | 09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /src/mlia/nn/select.py | |
parent | 09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff) | |
download | mlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz |
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization
Resolves: MLIA-1004
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r-- | src/mlia/nn/select.py | 33 |
1 files changed, 26 insertions, 7 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 6947206..20950cc 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Module for optimization selection.""" from __future__ import annotations @@ -117,6 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: tf.keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], + training_parameters: list[dict | None] | None = None, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -135,10 +136,14 @@ def get_optimizer( return RewritingOptimizer(model, config) if isinstance(config, OptimizationSettings): - return _get_optimizer(model, cast(OptimizationSettings, config)) + return _get_optimizer( + model, cast(OptimizationSettings, config), training_parameters + ) if is_list_of(config, OptimizationSettings): - return _get_optimizer(model, cast(List[OptimizationSettings], config)) + return _get_optimizer( + model, cast(List[OptimizationSettings], config), training_parameters + ) raise ConfigurationError(f"Unknown optimization configuration {config}") @@ -146,16 +151,18 @@ def get_optimizer( def _get_optimizer( model: tf.keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], + training_parameters: list[dict | None] | None = None, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] optimizer_configs = [] + for opt_type, opt_target, layers_to_optimize, dataset in optimization_settings: _check_optimizer_params(opt_type, opt_target) opt_config = _get_optimizer_configuration( - opt_type, opt_target, layers_to_optimize, dataset + opt_type, opt_target, layers_to_optimize, dataset, training_parameters ) optimizer_configs.append(opt_config) @@ -165,13 +172,23 @@ def _get_optimizer( return MultiStageOptimizer(model, optimizer_configs) -def _get_rewrite_train_params() -> TrainingParameters: +def _get_rewrite_params( + training_parameters: list[dict | None] | None = None, +) -> list: """Get the rewrite TrainingParameters. Return the default constructed TrainingParameters() per default, but can be overwritten in the unit tests. """ - return TrainingParameters() + if training_parameters is None: + return [TrainingParameters()] + + if training_parameters[0] is None: + train_params = TrainingParameters() + else: + train_params = TrainingParameters(**training_parameters[0]) + + return [train_params] def _get_optimizer_configuration( @@ -179,6 +196,7 @@ def _get_optimizer_configuration( optimization_target: int | float | str, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, + training_parameters: list[dict | None] | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -199,11 +217,12 @@ def _get_optimizer_configuration( if opt_type == "rewrite": if isinstance(optimization_target, str): + rewrite_params = _get_rewrite_params(training_parameters) return RewriteConfiguration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=_get_rewrite_train_params(), + train_params=rewrite_params[0], ) raise ConfigurationError( |