diff options
Diffstat (limited to 'tests/test_common_optimization.py')
-rw-r--r-- | tests/test_common_optimization.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 05a5b55..341e0d2 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -57,7 +57,7 @@ def test_optimizing_data_collector( config_parameters={ "common_optimizations": { "optimizations": optimizations, - "training_parameters": [training_parameters], + "training_parameters": training_parameters, } } ) @@ -94,7 +94,7 @@ def test_optimizing_data_collector( collector.set_context(context) collector.collect_data() assert optimize_model_mock.call_args.args[0] == opt_settings[0] - assert optimize_model_mock.call_args.args[1] == [training_parameters] + assert optimize_model_mock.call_args.args[1] == training_parameters assert fake_optimizer.invocation_count == 1 @@ -158,10 +158,12 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) - ] if not extra_args.get("optimization_profile"): - assert advisor_parameters["common_optimizations"][ - "training_parameters" - ] == [None] + assert ( + advisor_parameters["common_optimizations"]["training_parameters"] + is None + ) else: - assert advisor_parameters["common_optimizations"][ - "training_parameters" - ] == list(extra_args["optimization_profile"].values()) + assert ( + advisor_parameters["common_optimizations"]["training_parameters"] + == extra_args["optimization_profile"]["training"] + ) |