aboutsummaryrefslogtreecommitdiff
path: root/tests/test_common_optimization.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_common_optimization.py')
-rw-r--r--tests/test_common_optimization.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py
index 05a5b55..341e0d2 100644
--- a/tests/test_common_optimization.py
+++ b/tests/test_common_optimization.py
@@ -57,7 +57,7 @@ def test_optimizing_data_collector(
config_parameters={
"common_optimizations": {
"optimizations": optimizations,
- "training_parameters": [training_parameters],
+ "training_parameters": training_parameters,
}
}
)
@@ -94,7 +94,7 @@ def test_optimizing_data_collector(
collector.set_context(context)
collector.collect_data()
assert optimize_model_mock.call_args.args[0] == opt_settings[0]
- assert optimize_model_mock.call_args.args[1] == [training_parameters]
+ assert optimize_model_mock.call_args.args[1] == training_parameters
assert fake_optimizer.invocation_count == 1
@@ -158,10 +158,12 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) -
]
if not extra_args.get("optimization_profile"):
- assert advisor_parameters["common_optimizations"][
- "training_parameters"
- ] == [None]
+ assert (
+ advisor_parameters["common_optimizations"]["training_parameters"]
+ is None
+ )
else:
- assert advisor_parameters["common_optimizations"][
- "training_parameters"
- ] == list(extra_args["optimization_profile"].values())
+ assert (
+ advisor_parameters["common_optimizations"]["training_parameters"]
+ == extra_args["optimization_profile"]["training"]
+ )