diff options
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r-- | src/mlia/nn/select.py | 23 |
1 files changed, 9 insertions, 14 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 81a614f..b61e713 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -151,7 +151,7 @@ def get_optimizer( def _get_optimizer( model: keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] @@ -173,22 +173,17 @@ def _get_optimizer( def _get_rewrite_params( - training_parameters: list[dict | None] | None = None, -) -> list: + training_parameters: dict | None = None, +) -> TrainingParameters: """Get the rewrite TrainingParameters. Return the default constructed TrainingParameters() per default, but can be overwritten in the unit tests. """ - if training_parameters is None: - return [TrainingParameters()] + if not training_parameters: + return TrainingParameters() - if training_parameters[0] is None: - train_params = TrainingParameters() - else: - train_params = TrainingParameters(**training_parameters[0]) - - return [train_params] + return TrainingParameters(**training_parameters) def _get_optimizer_configuration( @@ -196,7 +191,7 @@ def _get_optimizer_configuration( optimization_target: int | float | str, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, - training_parameters: list[dict | None] | None = None, + training_parameters: dict | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -222,7 +217,7 @@ def _get_optimizer_configuration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=rewrite_params[0], + train_params=rewrite_params, ) raise ConfigurationError( |