diff options
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r-- | src/mlia/nn/select.py | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index b61e713..d5470d1 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -17,7 +17,7 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import RewritingOptimizer -from mlia.nn.rewrite.core.rewrite import TrainingParameters +from mlia.nn.rewrite.core.train import TrainingParameters from mlia.nn.tensorflow.config import KerasModel from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.optimizations.clustering import Clusterer @@ -109,7 +109,7 @@ class MultiStageOptimizer(Optimizer): def apply_optimization(self) -> None: """Apply optimization to the model.""" for config in self.optimizations: - optimizer = get_optimizer(self.model, config) + optimizer = get_optimizer(self.model, config, {}) optimizer.apply_optimization() self.model = optimizer.get_model() @@ -117,7 +117,7 @@ class MultiStageOptimizer(Optimizer): def get_optimizer( model: keras.Model | KerasModel | TFLiteModel, config: OptimizerConfiguration | OptimizationSettings | list[OptimizationSettings], - training_parameters: dict | None = None, + rewrite_parameters: dict, ) -> Optimizer: """Get optimizer for provided configuration.""" if isinstance(model, KerasModel): @@ -137,12 +137,12 @@ def get_optimizer( if isinstance(config, OptimizationSettings): return _get_optimizer( - model, cast(OptimizationSettings, config), training_parameters + model, cast(OptimizationSettings, config), rewrite_parameters ) if is_list_of(config, OptimizationSettings): return _get_optimizer( - model, cast(List[OptimizationSettings], config), training_parameters + model, cast(List[OptimizationSettings], config), rewrite_parameters ) raise ConfigurationError(f"Unknown optimization configuration {config}") @@ -151,7 +151,7 @@ def get_optimizer( def _get_optimizer( model: keras.Model | Path, optimization_settings: OptimizationSettings | list[OptimizationSettings], - training_parameters: dict | None = None, + rewrite_parameters: dict, ) -> Optimizer: if isinstance(optimization_settings, OptimizationSettings): optimization_settings = [optimization_settings] @@ -162,12 +162,12 @@ def _get_optimizer( _check_optimizer_params(opt_type, opt_target) opt_config = _get_optimizer_configuration( - opt_type, opt_target, layers_to_optimize, dataset, training_parameters + opt_type, opt_target, rewrite_parameters, layers_to_optimize, dataset ) optimizer_configs.append(opt_config) if len(optimizer_configs) == 1: - return get_optimizer(model, optimizer_configs[0]) + return get_optimizer(model, optimizer_configs[0], {}) return MultiStageOptimizer(model, optimizer_configs) @@ -189,9 +189,9 @@ def _get_rewrite_params( def _get_optimizer_configuration( optimization_type: str, optimization_target: int | float | str, + rewrite_parameters: dict, layers_to_optimize: list[str] | None = None, dataset: Path | None = None, - training_parameters: dict | None = None, ) -> OptimizerConfiguration: """Get optimizer configuration for provided parameters.""" _check_optimizer_params(optimization_type, optimization_target) @@ -212,12 +212,14 @@ def _get_optimizer_configuration( if opt_type == "rewrite": if isinstance(optimization_target, str): - rewrite_params = _get_rewrite_params(training_parameters) return RewriteConfiguration( optimization_target=str(optimization_target), layers_to_optimize=layers_to_optimize, dataset=dataset, - train_params=rewrite_params, + train_params=_get_rewrite_params(rewrite_parameters["train_params"]), + rewrite_specific_params=rewrite_parameters.get( + "rewrite_specific_params" + ), ) raise ConfigurationError( |