aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/select.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r--src/mlia/nn/select.py24
1 files changed, 13 insertions, 11 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index b61e713..d5470d1 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -17,7 +17,7 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import RewritingOptimizer
-from mlia.nn.rewrite.core.rewrite import TrainingParameters
+from mlia.nn.rewrite.core.train import TrainingParameters
from mlia.nn.tensorflow.config import KerasModel
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.optimizations.clustering import Clusterer
@@ -109,7 +109,7 @@ class MultiStageOptimizer(Optimizer):
def apply_optimization(self) -> None:
"""Apply optimization to the model."""
for config in self.optimizations:
- optimizer = get_optimizer(self.model, config)
+ optimizer = get_optimizer(self.model, config, {})
optimizer.apply_optimization()
self.model = optimizer.get_model()
@@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer):
def get_optimizer(
model: keras.Model | KerasModel | TFLiteModel,
config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings],
- training_parameters: dict | None = None,
+ rewrite_parameters: dict,
) -> Optimizer:
"""Get optimizer for provided configuration."""
if isinstance(model, KerasModel):
@@ -137,12 +137,12 @@ def get_optimizer(
if isinstance(config, OptimizationSettings):
return _get_optimizer(
- model, cast(OptimizationSettings, config), training_parameters
+ model, cast(OptimizationSettings, config), rewrite_parameters
)
if is_list_of(config, OptimizationSettings):
return _get_optimizer(
- model, cast(List[OptimizationSettings], config), training_parameters
+ model, cast(List[OptimizationSettings], config), rewrite_parameters
)
raise ConfigurationError(f"Unknown optimization configuration {config}")
@@ -151,7 +151,7 @@ def get_optimizer(
def _get_optimizer(
model: keras.Model | Path,
optimization_settings: OptimizationSettings | list[OptimizationSettings],
- training_parameters: dict | None = None,
+ rewrite_parameters: dict,
) -> Optimizer:
if isinstance(optimization_settings, OptimizationSettings):
optimization_settings = [optimization_settings]
@@ -162,12 +162,12 @@ def _get_optimizer(
_check_optimizer_params(opt_type, opt_target)
opt_config = _get_optimizer_configuration(
- opt_type, opt_target, layers_to_optimize, dataset, training_parameters
+ opt_type, opt_target, rewrite_parameters, layers_to_optimize, dataset
)
optimizer_configs.append(opt_config)
if len(optimizer_configs) == 1:
- return get_optimizer(model, optimizer_configs[0])
+ return get_optimizer(model, optimizer_configs[0], {})
return MultiStageOptimizer(model, optimizer_configs)
@@ -189,9 +189,9 @@ def _get_rewrite_params(
def _get_optimizer_configuration(
optimization_type: str,
optimization_target: int | float | str,
+ rewrite_parameters: dict,
layers_to_optimize: list[str] | None = None,
dataset: Path | None = None,
- training_parameters: dict | None = None,
) -> OptimizerConfiguration:
"""Get optimizer configuration for provided parameters."""
_check_optimizer_params(optimization_type, optimization_target)
@@ -212,12 +212,14 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
- rewrite_params = _get_rewrite_params(training_parameters)
return RewriteConfiguration(
optimization_target=str(optimization_target),
layers_to_optimize=layers_to_optimize,
dataset=dataset,
- train_params=rewrite_params,
+ train_params=_get_rewrite_params(rewrite_parameters["train_params"]),
+ rewrite_specific_params=rewrite_parameters.get(
+ "rewrite_specific_params"
+ ),
)
raise ConfigurationError(