aboutsummaryrefslogtreecommitdiff
path: root/tests_e2e/test_e2e.py
blob: 74ff51c41106a66912bcec2eb8257bd9a6fbfee2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""End to end tests for MLIA CLI."""
from __future__ import annotations

import argparse
import glob
import itertools
import json
import os
import subprocess  # nosec
import tempfile
from contextlib import ExitStack
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Generator
from typing import Iterable
from typing import Sequence

import pytest

from mlia.backend.config import System
from mlia.backend.manager import get_available_backends
from mlia.cli.main import get_commands
from mlia.cli.main import get_possible_command_names
from mlia.cli.main import init_parser
from mlia.target.config import get_builtin_supported_profile_names
from mlia.utils.types import is_list_of


pytestmark = pytest.mark.e2e

VALID_COMMANDS = get_possible_command_names(get_commands())


@dataclass
class ExecutionConfiguration:
    """Execution configuration."""

    command: str
    parameters: dict[str, list[list[str]]]

    @classmethod
    def from_dict(cls, exec_info: dict) -> ExecutionConfiguration:
        """Create instance from the dictionary."""
        if not (command := exec_info.get("command")):
            raise ValueError("Command is not defined.")

        if command not in VALID_COMMANDS:
            raise ValueError(f"Command {command} is unknown.")

        if not (params := exec_info.get("parameters")):
            raise ValueError(f"Command {command} should have parameters.")

        assert isinstance(params, dict), "Parameters should be a dictionary"
        assert all(
            isinstance(param_group_name, str)
            and is_list_of(param_group_values, list)
            and all(is_list_of(param_list, str) for param_list in param_group_values)
            for param_group_name, param_group_values in params.items()
        ), "Execution configuration should be a dictionary of list of list of strings"

        return cls(command, params)

    @property
    def all_combinations(self) -> Iterable[list[str]]:
        """Generate all command combinations."""
        parameter_groups = self.parameters.values()
        parameter_combinations = itertools.product(*parameter_groups)

        return (
            [self.command, *itertools.chain.from_iterable(param_combination)]
            for param_combination in parameter_combinations
        )


def launch_and_wait(
    cmd: list[str],
    output_file: Path | None = None,
    print_output: bool = True,
    stdin: Any | None = None,
) -> None:
    """Launch command and wait for the completion."""
    with subprocess.Popen(  # nosec
        cmd,
        stdin=stdin,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        universal_newlines=True,
        bufsize=1,
    ) as process:
        # Process stdout
        if process.stdout:
            # Store the output in a variable
            output = process.stdout.read()
            # Save the output into a file
            if output_file:
                output_file.write_text(output)
                print(f"Output saved to {output_file}")
            # Show the output to stdout
            if print_output:
                print(output)
        else:
            raise RuntimeError("Unable to get process output. stdout is unavailable.")

        # Wait for the process to terminate
        process.wait()

        if (exit_code := process.poll()) != 0:
            raise RuntimeError(f"Command failed with exit_code {exit_code}.")


def run_command(
    cmd: list[str],
    output_file: Path | None = None,
    print_output: bool = True,
    cmd_input: str | None = None,
) -> None:
    """Run command."""
    print(f"Run command: {' '.join(cmd)}")

    with ExitStack() as exit_stack:
        cmd_input_file = None

        if cmd_input is not None:
            print(f"Command will receive next input: {repr(cmd_input)}")

            cmd_input_file = (
                tempfile.NamedTemporaryFile(  # pylint: disable=consider-using-with
                    mode="w", prefix="mlia_", suffix="_test"
                )
            )
            exit_stack.enter_context(cmd_input_file)

            cmd_input_file.write(cmd_input)
            cmd_input_file.seek(0)

        launch_and_wait(cmd, output_file, print_output, cmd_input_file)


def get_config_file() -> Path:
    """Get path to the configuration file."""
    env_var_name = "MLIA_E2E_CONFIG_FILE"
    if not (config_file_env_var := os.environ.get(env_var_name)):
        raise ValueError(f"Config file env variable ({env_var_name}) is not set.")

    config_file = Path(config_file_env_var)
    if not config_file.is_file():
        raise FileNotFoundError(f"Invalid config file {config_file_env_var}.")

    return config_file


def get_args_parser() -> Any:
    """Return MLIA argument parser."""
    commands = get_commands()
    return init_parser(commands)


def replace_element(params: list[str], idx: int, value: str) -> list[str]:
    """Replace element in the list at the index."""
    # fmt: off
    return [*params[:idx], value, *params[idx + 1:]]
    # fmt: on


def resolve(params: list[str]) -> Generator[list[str], None, None]:
    """Replace wildcard with actual param."""
    for idx, param in enumerate(params):
        if "*" not in param:
            continue

        prev = None if idx == 0 else params[idx - 1]

        if prev == "--target-profile" and param == "*":
            resolved = (
                replace_element(params, idx, profile)
                for profile in get_builtin_supported_profile_names()
            )
        elif param.startswith("e2e_config") and (
            filenames := glob.glob(f"{Path.cwd()}/{param}", recursive=True)
        ):
            resolved = (
                replace_element(params, idx, filename) for filename in filenames
            )
        else:
            raise ValueError(f"Unable to resolve parameter {param}")

        for item in resolved:
            yield from resolve(item)

        break
    else:
        yield params


def resolve_parameters(executions: dict) -> dict:
    """Resolve command parameters."""
    for execution in executions:
        parameters = execution.get("parameters", {})

        for param_group, param_group_values in parameters.items():
            resolved_params: list[list[str]] = []

            for group in param_group_values:
                if any("*" in item for item in group):
                    resolved_params.extend(resolve(group))
                else:
                    resolved_params.append(group)

            parameters[param_group] = resolved_params

    return executions


def get_config_content(config_file: Path) -> Any:
    """Get executions configuration."""
    with open(config_file, encoding="utf-8") as file:
        json_data = json.load(file)

    assert isinstance(json_data, dict), "JSON configuration expected to be a dictionary"

    executions = json_data.get("executions", [])
    assert is_list_of(executions, dict), "List of the dictionaries expected"

    settings = json_data.get("settings", {})
    assert isinstance(settings, dict)

    return settings, executions


def get_all_commands_combinations(
    executions: Any,
) -> Generator[dict[str, Sequence[str]], None, None]:
    """Return all commands combinations."""
    exec_configs = (
        ExecutionConfiguration.from_dict(exec_info) for exec_info in executions
    )

    for exec_config in exec_configs:
        for command_combination in exec_config.all_combinations:
            parser = get_args_parser()
            args = parser.parse_args(command_combination)

            yield {
                "model_name": Path(args.model).stem,
                "command_combination": command_combination,
            }


def check_args(args: list[str], no_skip: bool) -> None:
    """Check the arguments and skip/fail test cases based on that."""
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--backend",
        help="Backends to use for evaluation.",
        action="append",
    )
    parser.add_argument(
        "--target-profile",
        help="Target profiles to use for evaluation.",
    )

    parsed_args, _ = parser.parse_known_args(args)
    if parsed_args.backend:
        required_backends = set(parsed_args.backend)
        available_backends = set(get_available_backends())
        missing_backends = required_backends.difference(available_backends)

        if missing_backends and not no_skip:
            pytest.skip(f"Missing backend(s): {','.join(missing_backends)}")

    if parsed_args.target_profile == "tosa":
        if System.CURRENT == System.LINUX_AARCH64:
            pytest.skip("TOSA is not yet available for AArch64, skipping this test.")


def get_execution_definitions(
    executions: dict,
) -> Generator[dict[str, Sequence[str]], None, None]:
    """Collect all execution definitions from configuration file."""
    resolved_executions = resolve_parameters(executions)
    return get_all_commands_combinations(resolved_executions)


class TestEndToEnd:
    """End to end command tests."""

    configuration_file = get_config_file()
    settings, executions = get_config_content(configuration_file)

    @pytest.mark.parametrize(
        "command_data", get_execution_definitions(executions), ids=str
    )
    def test_e2e(self, command_data: dict[str, list[str]], no_skip: bool) -> None:
        """Test MLIA command with the provided parameters."""
        command = command_data["command_combination"]
        model_name = command_data["model_name"]
        check_args(command, no_skip)
        mlia_command = ["mlia", *command]
        print_output = self.settings.get("print_output", True)
        output_file = self.settings.get("output_file", None)
        if output_file:
            output_file = Path(output_file.replace("{model_name}", model_name))
        run_command(mlia_command, output_file, print_output)