diff options
Diffstat (limited to 'tests_e2e/test_e2e.py')
-rw-r--r-- | tests_e2e/test_e2e.py | 93 |
1 files changed, 25 insertions, 68 deletions
diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py index 439723b..fb40735 100644 --- a/tests_e2e/test_e2e.py +++ b/tests_e2e/test_e2e.py @@ -14,12 +14,13 @@ from contextlib import ExitStack from dataclasses import dataclass from pathlib import Path from typing import Any -from typing import cast from typing import Generator from typing import Iterable import pytest +from mlia.cli.config import get_available_backends +from mlia.cli.config import get_default_backends from mlia.cli.main import get_commands from mlia.cli.main import get_possible_command_names from mlia.cli.main import init_commands @@ -35,52 +36,6 @@ VALID_COMMANDS = get_possible_command_names(get_commands()) @dataclass -class CommandExecution: - """Command execution.""" - - parsed_args: argparse.Namespace - parameters: list[str] - - def __str__(self) -> str: - """Return string representation.""" - command = self._get_param("command") - target_profile = self._get_param("target_profile") - - model_path = Path(self._get_param("model")) - model = model_path.name - - evaluate_on = self._get_param("evaluate_on", None) - evalute_on_opts = f" evaluate_on={','.join(evaluate_on)}" if evaluate_on else "" - - opt_type = self._get_param("optimization_type", None) - opt_target = self._get_param("optimization_target", None) - - opts = ( - f" optimization={opts}" - if (opts := self._merge(opt_type, opt_target)) - else "" - ) - - return f"command {command}: {target_profile=} {model=}{evalute_on_opts}{opts}" - - def _get_param(self, param: str, default: str | None = "unknown") -> Any: - return getattr(self.parsed_args, param, default) - - @staticmethod - def _merge(value1: str, value2: str, sep: str = ",") -> str: - """Split and merge values into a string.""" - if not value1 or not value2: - return "" - - values = [ - f"{v1} {v2}" - for v1, v2 in zip(str(value1).split(sep), str(value2).split(sep)) - ] - - return ",".join(values) - - -@dataclass class ExecutionConfiguration: """Execution configuration.""" @@ -271,38 +226,40 @@ def get_all_commands_combinations(executions: Any) -> Generator[list[str], None, ) -def try_to_parse_args(combination: list[str]) -> argparse.Namespace: - """Try to parse command.""" - try: - # parser contains some static data and could not be reused - # this is why it is being created for each combination - args_parser = get_args_parser() - return cast(argparse.Namespace, args_parser.parse_args(combination)) - except SystemExit as err: - raise Exception( - f"Configuration contains invalid parameters: {combination}" - ) from err +def check_args(args: list[str], no_skip: bool) -> None: + """Check the arguments and skip/fail test cases based on that.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--evaluate-on", + help="Backends to use for evaluation (default: %(default)s)", + nargs="*", + default=get_default_backends(), + ) + + parsed_args, _ = parser.parse_known_args(args) + required_backends = set(parsed_args.evaluate_on) + available_backends = set(get_available_backends()) + missing_backends = required_backends.difference(available_backends) + + if missing_backends and not no_skip: + pytest.skip(f"Missing backend(s): {','.join(missing_backends)}") -def get_execution_definitions() -> Generator[CommandExecution, None, None]: +def get_execution_definitions() -> Generator[list[str], None, None]: """Collect all execution definitions from configuration file.""" config_file = get_config_file() executions = get_config_content(config_file) executions = resolve_parameters(executions) - for combination in get_all_commands_combinations(executions): - # parse parameters to generate meaningful test description - args = try_to_parse_args(combination) - - yield CommandExecution(args, combination) + return get_all_commands_combinations(executions) class TestEndToEnd: """End to end command tests.""" - @pytest.mark.parametrize("command_execution", get_execution_definitions(), ids=str) - def test_command(self, command_execution: CommandExecution) -> None: + @pytest.mark.parametrize("command", get_execution_definitions(), ids=str) + def test_e2e(self, command: list[str], no_skip: bool) -> None: """Test MLIA command with the provided parameters.""" - mlia_command = ["mlia", *command_execution.parameters] - + check_args(command, no_skip) + mlia_command = ["mlia", *command] run_command(mlia_command) |