aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/common/optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/common/optimization.py')
-rw-r--r--src/mlia/target/common/optimization.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/src/mlia/target/common/optimization.py b/src/mlia/target/common/optimization.py
index 8c5d184..1423189 100644
--- a/src/mlia/target/common/optimization.py
+++ b/src/mlia/target/common/optimization.py
@@ -86,7 +86,7 @@ class OptimizingDataCollector(ContextAwareDataCollector):
def optimize_model(
self,
opt_settings: list[OptimizationSettings],
- training_parameters: list[dict | None],
+ training_parameters: dict | None,
model: KerasModel | TFLiteModel,
) -> Any:
"""Run optimization."""
@@ -123,12 +123,12 @@ class OptimizingDataCollector(ContextAwareDataCollector):
context=context,
)
- def _get_training_settings(self, context: Context) -> list[dict]:
+ def _get_training_settings(self, context: Context) -> dict:
"""Get optimization settings."""
return self.get_parameter( # type: ignore
OptimizingDataCollector.name(),
"training_parameters",
- expected_type=list,
+ expected_type=dict,
expected=False,
context=context,
)
@@ -228,9 +228,8 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
raise TypeError("Optimization targets value has wrong format.")
rewrite_parameters = extra_args.get("optimization_profile")
- if not rewrite_parameters:
- training_parameters = None
- else:
+ training_parameters = None
+ if rewrite_parameters:
if not isinstance(rewrite_parameters, dict):
raise TypeError("Training Parameter values has wrong format.")
training_parameters = extra_args["optimization_profile"].get("training")
@@ -239,7 +238,7 @@ def add_common_optimization_params(advisor_parameters: dict, extra_args: dict) -
{
"common_optimizations": {
"optimizations": [optimization_targets],
- "training_parameters": [training_parameters],
+ "training_parameters": training_parameters,
},
}
)