aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/select.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r--src/mlia/nn/select.py33
1 files changed, 26 insertions, 7 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 6947206..20950cc 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for optimization selection."""
from __future__ import annotations
@@ -117,6 +117,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
model: tf.keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
+ training_parameters: list[dict | None] | None = None,
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -135,10 +136,14 @@ def get_optimizer(
return RewritingOptimizer(model, config)
if isinstance(config, OptimizationSettings):
- return _get_optimizer(model, cast(OptimizationSettings, config))
+ return _get_optimizer(
+ model, cast(OptimizationSettings, config), training_parameters
+ )
if is_list_of(config, OptimizationSettings):
- return _get_optimizer(model, cast(List[OptimizationSettings], config))
+ return _get_optimizer(
+ model, cast(List[OptimizationSettings], config), training_parameters
+ )
raise ConfigurationError(f"Unknown optimization configuration {config}")
@@ -146,16 +151,18 @@ def get_optimizer(
def _get_optimizer(
model: tf.keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
+ training_parameters: list[dict | None] | None = None,
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
optimizer_configs = []
+
for opt_type, opt_target, layers_to_optimize, dataset in optimization_settings:
_check_optimizer_params(opt_type, opt_target)
opt_config = _get_optimizer_configuration(
- opt_type, opt_target, layers_to_optimize, dataset
+ opt_type, opt_target, layers_to_optimize, dataset, training_parameters
)
optimizer_configs.append(opt_config)
@@ -165,13 +172,23 @@ def _get_optimizer(
return MultiStageOptimizer(model, optimizer_configs)
-def _get_rewrite_train_params() -> TrainingParameters:
+def _get_rewrite_params(
+ training_parameters: list[dict | None] | None = None,
+) -> list:
"""Get the rewrite TrainingParameters.
Return the default constructed TrainingParameters() per default, but can be
overwritten in the unit tests.
"""
- return TrainingParameters()
+ if training_parameters is None:
+ return [TrainingParameters()]
+
+ if training_parameters[0] is None:
+ train_params = TrainingParameters()
+ else:
+ train_params = TrainingParameters(**training_parameters[0])
+
+ return [train_params]
def _get_optimizer_configuration(
@@ -179,6 +196,7 @@ def _get_optimizer_configuration(
optimization_target: int | float | str,
layers_to_optimize: list[str] | None = None,
dataset: Path | None = None,
+ training_parameters: list[dict | None] | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -199,11 +217,12 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
+ rewrite_params = _get_rewrite_params(training_parameters)
return RewriteConfiguration(
optimization_target=str(optimization_target),
layers_to_optimize=layers_to_optimize,
dataset=dataset,
- train_params=_get_rewrite_train_params(),
+ train_params=rewrite_params[0],
)
raise ConfigurationError(