aboutsummaryrefslogtreecommitdiff
path: root/tests/aiet/conftest.py
blob: cab3dc2f25447cce94657d971a5bb571182c603f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=redefined-outer-name
"""conftest for pytest."""
import shutil
import tarfile
from pathlib import Path
from typing import Any

import pytest
from click.testing import CliRunner

from aiet.backend.common import get_backend_configs


@pytest.fixture(scope="session")
def test_systems_path(test_resources_path: Path) -> Path:
    """Return test systems path in a pytest fixture."""
    return test_resources_path / "systems"


@pytest.fixture(scope="session")
def test_applications_path(test_resources_path: Path) -> Path:
    """Return test applications path in a pytest fixture."""
    return test_resources_path / "applications"


@pytest.fixture(scope="session")
def test_tools_path(test_resources_path: Path) -> Path:
    """Return test tools path in a pytest fixture."""
    return test_resources_path / "tools"


@pytest.fixture(scope="session")
def test_resources_path() -> Path:
    """Return test resources path in a pytest fixture."""
    current_path = Path(__file__).parent.absolute()
    return current_path / "test_resources"


@pytest.fixture(scope="session")
def non_optimised_input_model_file(test_tflite_model: Path) -> Path:
    """Provide the path to a quantized dummy model file."""
    return test_tflite_model


@pytest.fixture(scope="session")
def optimised_input_model_file(test_tflite_vela_model: Path) -> Path:
    """Provide path to Vela-optimised dummy model file."""
    return test_tflite_vela_model


@pytest.fixture(scope="session")
def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path:
    """Provide the path to an invalid dummy model file."""
    return test_tflite_invalid_model


@pytest.fixture(autouse=True)
def test_resources(monkeypatch: pytest.MonkeyPatch, test_resources_path: Path) -> Any:
    """Force using test resources as middleware's repository."""

    def get_test_resources() -> Path:
        """Return path to the test resources."""
        return test_resources_path

    monkeypatch.setattr("aiet.utils.fs.get_aiet_resources", get_test_resources)
    yield


@pytest.fixture(scope="session", autouse=True)
def add_tools(test_resources_path: Path) -> Any:
    """Symlink the tools from the original resources path to the test resources path."""
    # tool_dirs = get_available_tool_directory_names()
    tool_dirs = [cfg.parent for cfg in get_backend_configs("tools")]

    links = {
        src_dir: (test_resources_path / "tools" / src_dir.name) for src_dir in tool_dirs
    }
    for src_dir, dst_dir in links.items():
        if not dst_dir.exists():
            dst_dir.symlink_to(src_dir, target_is_directory=True)
    yield
    # Remove symlinks
    for dst_dir in links.values():
        if dst_dir.is_symlink():
            dst_dir.unlink()


def create_archive(
    archive_name: str, source: Path, destination: Path, with_root_folder: bool = False
) -> None:
    """Create archive from directory source."""
    with tarfile.open(destination / archive_name, mode="w:gz") as tar:
        for item in source.iterdir():
            item_name = item.name
            if with_root_folder:
                item_name = f"{source.name}/{item_name}"
            tar.add(item, item_name)


def process_directory(source: Path, destination: Path) -> None:
    """Process resource directory."""
    destination.mkdir()

    for item in source.iterdir():
        if item.is_dir():
            create_archive(f"{item.name}.tar.gz", item, destination)
            create_archive(f"{item.name}_dir.tar.gz", item, destination, True)


@pytest.fixture(scope="session", autouse=True)
def add_archives(
    test_resources_path: Path, tmp_path_factory: pytest.TempPathFactory
) -> Any:
    """Generate archives of the test resources."""
    tmp_path = tmp_path_factory.mktemp("archives")

    archives_path = tmp_path / "archives"
    archives_path.mkdir()

    if (archives_path_link := test_resources_path / "archives").is_symlink():
        archives_path.unlink()

    archives_path_link.symlink_to(archives_path, target_is_directory=True)

    for item in ["applications", "systems"]:
        process_directory(test_resources_path / item, archives_path / item)

    yield

    archives_path_link.unlink()
    shutil.rmtree(tmp_path)


@pytest.fixture(scope="module")
def cli_runner() -> CliRunner:
    """Return CliRunner instance in a pytest fixture."""
    return CliRunner()