aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/select.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r--src/mlia/nn/select.py23
1 files changed, 9 insertions, 14 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 81a614f..b61e713 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
model: keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -151,7 +151,7 @@ def get_optimizer(
def _get_optimizer(
model: keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
@@ -173,22 +173,17 @@ def _get_optimizer(
def _get_rewrite_params(
- training_parameters: list[dict | None] | None = None,
-) -> list:
+ training_parameters: dict | None = None,
+) -> TrainingParameters:
"""Get the rewrite TrainingParameters.
Return the default constructed TrainingParameters() per default, but can be
overwritten in the unit tests.
"""
- if training_parameters is None:
- return [TrainingParameters()]
+ if not training_parameters:
+ return TrainingParameters()
- if training_parameters[0] is None:
- train_params = TrainingParameters()
- else:
- train_params = TrainingParameters(**training_parameters[0])
-
- return [train_params]
+ return TrainingParameters(**training_parameters)
def _get_optimizer_configuration(
@@ -196,7 +191,7 @@ def _get_optimizer_configuration(
optimization_target: int | float | str,
layers_to_optimize: list[str] | None = None,
dataset: Path | None = None,
- training_parameters: list[dict | None] | None = None,
+ training_parameters: dict | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -222,7 +217,7 @@ def _get_optimizer_configuration(
optimization_target=str(optimization_target),
layers_to_optimize=layers_to_optimize,
dataset=dataset,
- train_params=rewrite_params[0],
+ train_params=rewrite_params,
)
raise ConfigurationError(