diff options
Diffstat (limited to 'tests/test_cli_commands.py')
-rw-r--r-- | tests/test_cli_commands.py | 144 |
1 files changed, 128 insertions, 16 deletions
diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index f3213c4..6765a53 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -3,6 +3,7 @@ """Tests for cli.commands module.""" from __future__ import annotations +from contextlib import ExitStack as does_not_raise from pathlib import Path from typing import Any from unittest.mock import call @@ -49,37 +50,148 @@ def test_performance_unknown_target( @pytest.mark.parametrize( - "target_profile, pruning, clustering, pruning_target, clustering_target", + "target_profile, pruning, clustering, pruning_target, clustering_target, " + "rewrite, rewrite_target, rewrite_start, rewrite_end, expected_error", [ - ["ethos-u55-256", True, False, 0.5, None], - ["ethos-u65-512", False, True, 0.5, 32], - ["ethos-u55-256", True, True, 0.5, None], - ["ethos-u55-256", False, False, 0.5, None], - ["ethos-u55-256", False, True, "invalid", 32], + [ + "ethos-u55-256", + True, + False, + 0.5, + None, + False, + None, + "node_a", + "node_b", + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + None, + None, + True, + "fully_connected", + "node_a", + "node_b", + does_not_raise(), + ], + [ + "ethos-u55-256", + True, + False, + 0.5, + None, + True, + "fully_connected", + "node_a", + "node_b", + pytest.raises( + Exception, + match=(r"Only 'rewrite' is supported for TensorFlow Lite files."), + ), + ], + [ + "ethos-u65-512", + False, + True, + 0.5, + 32, + False, + None, + None, + None, + does_not_raise(), + ], + [ + "ethos-u55-256", + False, + False, + 0.5, + None, + True, + "random", + "node_x", + "node_y", + pytest.raises( + Exception, + match=(r"Currently only remove and fully_connected are supported."), + ), + ], + [ + "ethos-u55-256", + False, + False, + 0.5, + None, + True, + None, + "node_m", + "node_n", + pytest.raises( + Exception, + match=( + r"To perform rewrite, rewrite-target, " + r"rewrite-start and rewrite-end must be set." + ), + ), + ], + [ + "ethos-u55-256", + False, + False, + "invalid", + None, + True, + "remove", + None, + "node_end", + pytest.raises( + Exception, + match=( + r"To perform rewrite, rewrite-target, " + r"rewrite-start and rewrite-end must be set." + ), + ), + ], ], ) -def test_opt_valid_optimization_target( +def test_opt_valid_optimization_target( # pylint: disable=too-many-arguments target_profile: str, sample_context: ExecutionContext, pruning: bool, clustering: bool, pruning_target: float | None, clustering_target: int | None, + rewrite: bool, + rewrite_target: str | None, + rewrite_start: str | None, + rewrite_end: str | None, + expected_error: Any, monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, + test_tflite_model: Path, ) -> None: """Test that command should not fail with valid optimization targets.""" mock_performance_estimation(monkeypatch) - optimize( - ctx=sample_context, - target_profile=target_profile, - model=str(test_keras_model), - pruning=pruning, - clustering=clustering, - pruning_target=pruning_target, - clustering_target=clustering_target, - ) + model_type = test_tflite_model if rewrite else test_keras_model + + with expected_error: + optimize( + ctx=sample_context, + target_profile=target_profile, + model=str(model_type), + pruning=pruning, + clustering=clustering, + pruning_target=pruning_target, + clustering_target=clustering_target, + rewrite=rewrite, + rewrite_target=rewrite_target, + rewrite_start=rewrite_start, + rewrite_end=rewrite_end, + ) def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None: |