diff options
Diffstat (limited to 'tests/test_cli_main.py')
-rw-r--r-- | tests/test_cli_main.py | 51 |
1 files changed, 40 insertions, 11 deletions
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index 4b16ac5..d0f7152 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock import pytest import mlia +from mlia.cli.main import backend_main from mlia.cli.main import CommandInfo from mlia.cli.main import main from mlia.core.context import ExecutionContext @@ -122,6 +123,17 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non non_default_command.assert_called_once_with(param="test") +def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: + """Wrap the command with the mock.""" + + @wraps(command) + def mock_command(*args: Any, **kwargs: Any) -> Any: + """Mock the command.""" + mock(*args, **kwargs) + + return mock_command + + @pytest.mark.parametrize( "params, expected_call", [ @@ -273,16 +285,6 @@ def test_commands_execution( """Test calling commands from the main function.""" mock = MagicMock() - def wrap_mock_command(command: Callable) -> Callable: - """Wrap the command with the mock.""" - - @wraps(command) - def mock_command(*args: Any, **kwargs: Any) -> Any: - """Mock the command.""" - mock(*args, **kwargs) - - return mock_command - monkeypatch.setattr( "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"]) ) @@ -295,7 +297,7 @@ def test_commands_execution( for command in ["all_tests", "operators", "performance", "optimization"]: monkeypatch.setattr( f"mlia.cli.main.{command}", - wrap_mock_command(getattr(mlia.cli.main, command)), + wrap_mock_command(mock, getattr(mlia.cli.main, command)), ) main(params) @@ -304,6 +306,33 @@ def test_commands_execution( @pytest.mark.parametrize( + "params, expected_call", + [ + [ + ["list"], + call(), + ], + ], +) +def test_commands_execution_backend_main( + monkeypatch: pytest.MonkeyPatch, + params: list[str], + expected_call: Any, +) -> None: + """Test calling commands from the backend_main function.""" + mock = MagicMock() + + monkeypatch.setattr( + "mlia.cli.main.backend_list", + wrap_mock_command(mock, getattr(mlia.cli.main, "backend_list")), + ) + + backend_main(params) + + mock.assert_called_once_with(*expected_call.args, **expected_call.kwargs) + + +@pytest.mark.parametrize( "verbose, exc_mock, expected_output", [ [ |