diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_common_optimization.py | 18 | ||||
-rw-r--r-- | tests/test_nn_select.py | 12 | ||||
-rw-r--r-- | tests/test_target_cortex_a_advisor.py | 2 | ||||
-rw-r--r-- | tests/test_target_tosa_advisor.py | 2 |
4 files changed, 17 insertions, 17 deletions
diff --git a/tests/test_common_optimization.py b/tests/test_common_optimization.py index 05a5b55..341e0d2 100644 --- a/tests/test_common_optimization.py +++ b/tests/test_common_optimization.py @@ -57,7 +57,7 @@ def test_optimizing_data_collector( config_parameters={ "common_optimizations": { "optimizations": optimizations, - "training_parameters": [training_parameters], + "training_parameters": training_parameters, } } ) @@ -94,7 +94,7 @@ def test_optimizing_data_collector( collector.set_context(context) collector.collect_data() assert optimize_model_mock.call_args.args[0] == opt_settings[0] - assert optimize_model_mock.call_args.args[1] == [training_parameters] + assert optimize_model_mock.call_args.args[1] == training_parameters assert fake_optimizer.invocation_count == 1 @@ -158,10 +158,12 @@ def test_add_common_optimization_params(extra_args: dict, error_to_raise: Any) - ] if not extra_args.get("optimization_profile"): - assert advisor_parameters["common_optimizations"][ - "training_parameters" - ] == [None] + assert ( + advisor_parameters["common_optimizations"]["training_parameters"] + is None + ) else: - assert advisor_parameters["common_optimizations"][ - "training_parameters" - ] == list(extra_args["optimization_profile"].values()) + assert ( + advisor_parameters["common_optimizations"]["training_parameters"] + == extra_args["optimization_profile"]["training"] + ) diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py index aac07b4..4095076 100644 --- a/tests/test_nn_select.py +++ b/tests/test_nn_select.py @@ -183,11 +183,11 @@ def test_get_optimizer( @pytest.mark.parametrize( "rewrite_parameters", - [[None], [{"batch_size": 64, "learning_rate": 0.003}]], + [None, {"batch_size": 64, "learning_rate": 0.003}], ) @pytest.mark.skip_set_training_steps def test_get_optimizer_training_parameters( - rewrite_parameters: list[dict], test_tflite_model: Path + rewrite_parameters: dict | None, test_tflite_model: Path ) -> None: """Test function get_optimzer with various combinations of parameters.""" config = OptimizationSettings( @@ -198,20 +198,18 @@ def test_get_optimizer_training_parameters( ) optimizer = cast( RewritingOptimizer, - get_optimizer(test_tflite_model, config, list(rewrite_parameters)), + get_optimizer(test_tflite_model, config, rewrite_parameters), ) - assert len(rewrite_parameters) == 1 - assert isinstance( optimizer.optimizer_configuration.train_params, TrainingParameters ) - if not rewrite_parameters[0]: + if not rewrite_parameters: assert asdict(TrainingParameters()) == asdict( optimizer.optimizer_configuration.train_params ) else: - assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict( + assert asdict(TrainingParameters()) | rewrite_parameters == asdict( optimizer.optimizer_configuration.train_params ) diff --git a/tests/test_target_cortex_a_advisor.py b/tests/test_target_cortex_a_advisor.py index 59d54b5..7bb57c3 100644 --- a/tests/test_target_cortex_a_advisor.py +++ b/tests/test_target_cortex_a_advisor.py @@ -47,7 +47,7 @@ def test_configure_and_get_cortex_a_advisor(test_tflite_model: Path) -> None: }, ] ], - "training_parameters": [None], + "training_parameters": None, }, } diff --git a/tests/test_target_tosa_advisor.py b/tests/test_target_tosa_advisor.py index cc47321..020acc5 100644 --- a/tests/test_target_tosa_advisor.py +++ b/tests/test_target_tosa_advisor.py @@ -47,7 +47,7 @@ def test_configure_and_get_tosa_advisor( }, ] ], - "training_parameters": [None], + "training_parameters": None, }, "tosa_inference_advisor": { "model": str(test_tflite_model), |