aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_select.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_nn_select.py')
-rw-r--r--tests/test_nn_select.py12
1 files changed, 5 insertions, 7 deletions
diff --git a/tests/test_nn_select.py b/tests/test_nn_select.py
index aac07b4..4095076 100644
--- a/tests/test_nn_select.py
+++ b/tests/test_nn_select.py
@@ -183,11 +183,11 @@ def test_get_optimizer(
@pytest.mark.parametrize(
"rewrite_parameters",
- [[None], [{"batch_size": 64, "learning_rate": 0.003}]],
+ [None, {"batch_size": 64, "learning_rate": 0.003}],
)
@pytest.mark.skip_set_training_steps
def test_get_optimizer_training_parameters(
- rewrite_parameters: list[dict], test_tflite_model: Path
+ rewrite_parameters: dict | None, test_tflite_model: Path
) -> None:
"""Test function get_optimzer with various combinations of parameters."""
config = OptimizationSettings(
@@ -198,20 +198,18 @@ def test_get_optimizer_training_parameters(
)
optimizer = cast(
RewritingOptimizer,
- get_optimizer(test_tflite_model, config, list(rewrite_parameters)),
+ get_optimizer(test_tflite_model, config, rewrite_parameters),
)
- assert len(rewrite_parameters) == 1
-
assert isinstance(
optimizer.optimizer_configuration.train_params, TrainingParameters
)
- if not rewrite_parameters[0]:
+ if not rewrite_parameters:
assert asdict(TrainingParameters()) == asdict(
optimizer.optimizer_configuration.train_params
)
else:
- assert asdict(TrainingParameters()) | rewrite_parameters[0] == asdict(
+ assert asdict(TrainingParameters()) | rewrite_parameters == asdict(
optimizer.optimizer_configuration.train_params
)