aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/select.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/select.py')
-rw-r--r--src/mlia/nn/select.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/src/mlia/nn/select.py b/src/mlia/nn/select.py
index 5a7f289..983426b 100644
--- a/src/mlia/nn/select.py
+++ b/src/mlia/nn/select.py
@@ -17,6 +17,7 @@ from mlia.nn.common import Optimizer
from mlia.nn.common import OptimizerConfiguration
from mlia.nn.rewrite.core.rewrite import RewriteConfiguration
from mlia.nn.rewrite.core.rewrite import Rewriter
+from mlia.nn.rewrite.core.rewrite import TrainingParameters
from mlia.nn.tensorflow.config import KerasModel
from mlia.nn.tensorflow.config import TFLiteModel
from mlia.nn.tensorflow.optimizations.clustering import Clusterer
@@ -164,6 +165,15 @@ def _get_optimizer(
return MultiStageOptimizer(model, optimizer_configs)
+def _get_rewrite_train_params() -> TrainingParameters:
+ """Get the rewrite TrainingParameters.
+
+ Return the default constructed TrainingParameters() per default, but can be
+ overwritten in the unit tests.
+ """
+ return TrainingParameters()
+
+
def _get_optimizer_configuration(
optimization_type: str,
optimization_target: int | float | str,
@@ -190,7 +200,10 @@ def _get_optimizer_configuration(
if opt_type == "rewrite":
if isinstance(optimization_target, str):
return RewriteConfiguration(
- str(optimization_target), layers_to_optimize, dataset
+ optimization_target=str(optimization_target),
+ layers_to_optimize=layers_to_optimize,
+ dataset=dataset,
+ train_params=_get_rewrite_train_params(),
)
raise ConfigurationError(