diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-25 13:05:32 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-04-12 14:08:07 +0000 |
commit | ec59b3c95106daebe2ce0e57592b2bf9e6562f54 (patch) | |
tree | 3a3861c73d5963cb8ef1d21dd6929e24123fc898 /src/mlia/nn/select.py | |
parent | 4de782fde8e38ec92bb5bc60e156de027f13bfba (diff) | |
download | mlia-ec59b3c95106daebe2ce0e57592b2bf9e6562f54.tar.gz |
fix: Change training_parameters to return empty list instead of list of None if needed.
Extension to MLIA-1004
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: Ib40c2e5932c1210a1d141200815a76e33f5ab078
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r-- | src/mlia/nn/select.py | 23 |
1 files changed, 9 insertions, 14 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 81a614f..b61e713 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -151,7 +151,7 @@ def get_optimizer( def _get_optimizer( model: keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] @@ -173,22 +173,17 @@ def _get_optimizer( def _get_rewrite_params( - training_parameters: list[dict | None] | None = None, -) -> list: + training_parameters: dict | None = None, +) -> TrainingParameters: """Get the rewrite TrainingParameters. Return the default constructed TrainingParameters() per default, but can be overwritten in the unit tests. """ - if training_parameters is None: - return [TrainingParameters()] + if not training_parameters: + return TrainingParameters() - if training_parameters[0] is None: - train_params = TrainingParameters() - else: - train_params = TrainingParameters(**training_parameters[0]) - - return [train_params] + return TrainingParameters(**training_parameters) def _get_optimizer_configuration( @@ -196,7 +191,7 @@ def _get_optimizer_configuration( optimization_target: int | float | str, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -222,7 +217,7 @@ def _get_optimizer_configuration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=rewrite_params[0], + train_params=rewrite_params, ) raise ConfigurationError( |