aboutsummaryrefslogtreecommitdiff
path: root/tests/test_cli_main.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_cli_main.py')
-rw-r--r--tests/test_cli_main.py51
1 files changed, 40 insertions, 11 deletions
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 4b16ac5..d0f7152 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -15,6 +15,7 @@ from unittest.mock import MagicMock
import pytest
import mlia
+from mlia.cli.main import backend_main
from mlia.cli.main import CommandInfo
from mlia.cli.main import main
from mlia.core.context import ExecutionContext
@@ -122,6 +123,17 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non
non_default_command.assert_called_once_with(param="test")
+def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
+ """Wrap the command with the mock."""
+
+ @wraps(command)
+ def mock_command(*args: Any, **kwargs: Any) -> Any:
+ """Mock the command."""
+ mock(*args, **kwargs)
+
+ return mock_command
+
+
@pytest.mark.parametrize(
"params, expected_call",
[
@@ -273,16 +285,6 @@ def test_commands_execution(
"""Test calling commands from the main function."""
mock = MagicMock()
- def wrap_mock_command(command: Callable) -> Callable:
- """Wrap the command with the mock."""
-
- @wraps(command)
- def mock_command(*args: Any, **kwargs: Any) -> Any:
- """Mock the command."""
- mock(*args, **kwargs)
-
- return mock_command
-
monkeypatch.setattr(
"mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"])
)
@@ -295,7 +297,7 @@ def test_commands_execution(
for command in ["all_tests", "operators", "performance", "optimization"]:
monkeypatch.setattr(
f"mlia.cli.main.{command}",
- wrap_mock_command(getattr(mlia.cli.main, command)),
+ wrap_mock_command(mock, getattr(mlia.cli.main, command)),
)
main(params)
@@ -304,6 +306,33 @@ def test_commands_execution(
@pytest.mark.parametrize(
+ "params, expected_call",
+ [
+ [
+ ["list"],
+ call(),
+ ],
+ ],
+)
+def test_commands_execution_backend_main(
+ monkeypatch: pytest.MonkeyPatch,
+ params: list[str],
+ expected_call: Any,
+) -> None:
+ """Test calling commands from the backend_main function."""
+ mock = MagicMock()
+
+ monkeypatch.setattr(
+ "mlia.cli.main.backend_list",
+ wrap_mock_command(mock, getattr(mlia.cli.main, "backend_list")),
+ )
+
+ backend_main(params)
+
+ mock.assert_called_once_with(*expected_call.args, **expected_call.kwargs)
+
+
+@pytest.mark.parametrize(
"verbose, exc_mock, expected_output",
[
[