diff options
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r-- | src/mlia/nn/select.py | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py index 5a7f289..983426b 100644 --- a/src/mlia/nn/select.py +++ b/src/mlia/nn/select.py @@ -17,6 +17,7 @@ from mlia.nn.common import Optimizer from mlia.nn.common import OptimizerConfiguration from mlia.nn.rewrite.core.rewrite import RewriteConfiguration from mlia.nn.rewrite.core.rewrite import Rewriter +from mlia.nn.rewrite.core.rewrite import TrainingParameters from mlia.nn.tensorflow.config import KerasModel from mlia.nn.tensorflow.config import TFLiteModel from mlia.nn.tensorflow.optimizations.clustering import Clusterer @@ -164,6 +165,15 @@ def _get_optimizer( return MultiStageOptimizer(model, optimizer_configs) +def _get_rewrite_train_params() -> TrainingParameters: + """Get the rewrite TrainingParameters. + + Return the default constructed TrainingParameters() per default, but can be + overwritten in the unit tests. + """ + return TrainingParameters() + + def _get_optimizer_configuration( optimization_type: str, optimization_target: int | float | str, @@ -190,7 +200,10 @@ def _get_optimizer_configuration( if opt_type == "rewrite": if isinstance(optimization_target, str): return RewriteConfiguration( - str(optimization_target), layers_to_optimize, dataset + optimization_target=str(optimization_target), + layers_to_optimize=layers_to_optimize, + dataset=dataset, + train_params=_get_rewrite_train_params(), ) raise ConfigurationError( |