aboutsummaryrefslogtreecommitdiff
path: root/tests_e2e/test_e2e.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests_e2e/test_e2e.py')
-rw-r--r--tests_e2e/test_e2e.py93
1 files changed, 25 insertions, 68 deletions
diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py
index 439723b..fb40735 100644
--- a/tests_e2e/test_e2e.py
+++ b/tests_e2e/test_e2e.py
@@ -14,12 +14,13 @@ from contextlib import ExitStack
from dataclasses import dataclass
from pathlib import Path
from typing import Any
-from typing import cast
from typing import Generator
from typing import Iterable
import pytest
+from mlia.cli.config import get_available_backends
+from mlia.cli.config import get_default_backends
from mlia.cli.main import get_commands
from mlia.cli.main import get_possible_command_names
from mlia.cli.main import init_commands
@@ -35,52 +36,6 @@ VALID_COMMANDS = get_possible_command_names(get_commands())
@dataclass
-class CommandExecution:
- """Command execution."""
-
- parsed_args: argparse.Namespace
- parameters: list[str]
-
- def __str__(self) -> str:
- """Return string representation."""
- command = self._get_param("command")
- target_profile = self._get_param("target_profile")
-
- model_path = Path(self._get_param("model"))
- model = model_path.name
-
- evaluate_on = self._get_param("evaluate_on", None)
- evalute_on_opts = f" evaluate_on={','.join(evaluate_on)}" if evaluate_on else ""
-
- opt_type = self._get_param("optimization_type", None)
- opt_target = self._get_param("optimization_target", None)
-
- opts = (
- f" optimization={opts}"
- if (opts := self._merge(opt_type, opt_target))
- else ""
- )
-
- return f"command {command}: {target_profile=} {model=}{evalute_on_opts}{opts}"
-
- def _get_param(self, param: str, default: str | None = "unknown") -> Any:
- return getattr(self.parsed_args, param, default)
-
- @staticmethod
- def _merge(value1: str, value2: str, sep: str = ",") -> str:
- """Split and merge values into a string."""
- if not value1 or not value2:
- return ""
-
- values = [
- f"{v1} {v2}"
- for v1, v2 in zip(str(value1).split(sep), str(value2).split(sep))
- ]
-
- return ",".join(values)
-
-
-@dataclass
class ExecutionConfiguration:
"""Execution configuration."""
@@ -271,38 +226,40 @@ def get_all_commands_combinations(executions: Any) -> Generator[list[str], None,
)
-def try_to_parse_args(combination: list[str]) -> argparse.Namespace:
- """Try to parse command."""
- try:
- # parser contains some static data and could not be reused
- # this is why it is being created for each combination
- args_parser = get_args_parser()
- return cast(argparse.Namespace, args_parser.parse_args(combination))
- except SystemExit as err:
- raise Exception(
- f"Configuration contains invalid parameters: {combination}"
- ) from err
+def check_args(args: list[str], no_skip: bool) -> None:
+ """Check the arguments and skip/fail test cases based on that."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--evaluate-on",
+ help="Backends to use for evaluation (default: %(default)s)",
+ nargs="*",
+ default=get_default_backends(),
+ )
+
+ parsed_args, _ = parser.parse_known_args(args)
+ required_backends = set(parsed_args.evaluate_on)
+ available_backends = set(get_available_backends())
+ missing_backends = required_backends.difference(available_backends)
+
+ if missing_backends and not no_skip:
+ pytest.skip(f"Missing backend(s): {','.join(missing_backends)}")
-def get_execution_definitions() -> Generator[CommandExecution, None, None]:
+def get_execution_definitions() -> Generator[list[str], None, None]:
"""Collect all execution definitions from configuration file."""
config_file = get_config_file()
executions = get_config_content(config_file)
executions = resolve_parameters(executions)
- for combination in get_all_commands_combinations(executions):
- # parse parameters to generate meaningful test description
- args = try_to_parse_args(combination)
-
- yield CommandExecution(args, combination)
+ return get_all_commands_combinations(executions)
class TestEndToEnd:
"""End to end command tests."""
- @pytest.mark.parametrize("command_execution", get_execution_definitions(), ids=str)
- def test_command(self, command_execution: CommandExecution) -> None:
+ @pytest.mark.parametrize("command", get_execution_definitions(), ids=str)
+ def test_e2e(self, command: list[str], no_skip: bool) -> None:
"""Test MLIA command with the provided parameters."""
- mlia_command = ["mlia", *command_execution.parameters]
-
+ check_args(command, no_skip)
+ mlia_command = ["mlia", *command]
run_command(mlia_command)