aboutsummaryrefslogtreecommitdiff
path: root/src/aiet/backend/source.py
blob: dec175ac6645ef3e8a8c91334983a22f7a1c6b68 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Contain source related classes and functions."""
import os
import shutil
import tarfile
from abc import ABC
from abc import abstractmethod
from pathlib import Path
from tarfile import TarFile
from typing import Optional
from typing import Union

from aiet.backend.common import AIET_CONFIG_FILE
from aiet.backend.common import ConfigurationException
from aiet.backend.common import get_backend_config
from aiet.backend.common import is_backend_directory
from aiet.backend.common import load_config
from aiet.backend.config import BackendConfig
from aiet.utils.fs import copy_directory_content


class Source(ABC):
    """Source class."""

    @abstractmethod
    def name(self) -> Optional[str]:
        """Get source name."""

    @abstractmethod
    def config(self) -> Optional[BackendConfig]:
        """Get configuration file content."""

    @abstractmethod
    def install_into(self, destination: Path) -> None:
        """Install source into destination directory."""

    @abstractmethod
    def create_destination(self) -> bool:
        """Return True if destination folder should be created before installation."""


class DirectorySource(Source):
    """DirectorySource class."""

    def __init__(self, directory_path: Path) -> None:
        """Create the DirectorySource instance."""
        assert isinstance(directory_path, Path)
        self.directory_path = directory_path

    def name(self) -> str:
        """Return name of source."""
        return self.directory_path.name

    def config(self) -> Optional[BackendConfig]:
        """Return configuration file content."""
        if not is_backend_directory(self.directory_path):
            raise ConfigurationException("No configuration file found")

        config_file = get_backend_config(self.directory_path)
        return load_config(config_file)

    def install_into(self, destination: Path) -> None:
        """Install source into destination directory."""
        if not destination.is_dir():
            raise ConfigurationException("Wrong destination {}".format(destination))

        if not self.directory_path.is_dir():
            raise ConfigurationException(
                "Directory {} does not exist".format(self.directory_path)
            )

        copy_directory_content(self.directory_path, destination)

    def create_destination(self) -> bool:
        """Return True if destination folder should be created before installation."""
        return True


class TarArchiveSource(Source):
    """TarArchiveSource class."""

    def __init__(self, archive_path: Path) -> None:
        """Create the TarArchiveSource class."""
        assert isinstance(archive_path, Path)
        self.archive_path = archive_path
        self._config: Optional[BackendConfig] = None
        self._has_top_level_folder: Optional[bool] = None
        self._name: Optional[str] = None

    def _read_archive_content(self) -> None:
        """Read various information about archive."""
        # get source name from archive name (everything without extensions)
        extensions = "".join(self.archive_path.suffixes)
        self._name = self.archive_path.name.rstrip(extensions)

        if not self.archive_path.exists():
            return

        with self._open(self.archive_path) as archive:
            try:
                config_entry = archive.getmember(AIET_CONFIG_FILE)
                self._has_top_level_folder = False
            except KeyError as error_no_config:
                try:
                    archive_entries = archive.getnames()
                    entries_common_prefix = os.path.commonprefix(archive_entries)
                    top_level_dir = entries_common_prefix.rstrip("/")

                    if not top_level_dir:
                        raise RuntimeError(
                            "Archive has no top level directory"
                        ) from error_no_config

                    config_path = "{}/{}".format(top_level_dir, AIET_CONFIG_FILE)

                    config_entry = archive.getmember(config_path)
                    self._has_top_level_folder = True
                    self._name = top_level_dir
                except (KeyError, RuntimeError) as error_no_root_dir_or_config:
                    raise ConfigurationException(
                        "No configuration file found"
                    ) from error_no_root_dir_or_config

            content = archive.extractfile(config_entry)
            self._config = load_config(content)

    def config(self) -> Optional[BackendConfig]:
        """Return configuration file content."""
        if self._config is None:
            self._read_archive_content()

        return self._config

    def name(self) -> Optional[str]:
        """Return name of the source."""
        if self._name is None:
            self._read_archive_content()

        return self._name

    def create_destination(self) -> bool:
        """Return True if destination folder must be created before installation."""
        if self._has_top_level_folder is None:
            self._read_archive_content()

        return not self._has_top_level_folder

    def install_into(self, destination: Path) -> None:
        """Install source into destination directory."""
        if not destination.is_dir():
            raise ConfigurationException("Wrong destination {}".format(destination))

        with self._open(self.archive_path) as archive:
            archive.extractall(destination)

    def _open(self, archive_path: Path) -> TarFile:
        """Open archive file."""
        if not archive_path.is_file():
            raise ConfigurationException("File {} does not exist".format(archive_path))

        if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"):
            mode = "r:gz"
        else:
            raise ConfigurationException(
                "Unsupported archive type {}".format(archive_path)
            )

        # The returned TarFile object can be used as a context manager (using
        # 'with') by the calling instance.
        return tarfile.open(  # pylint: disable=consider-using-with
            self.archive_path, mode=mode
        )


def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]:
    """Return appropriate source instance based on provided source path."""
    if source_path.is_file():
        return TarArchiveSource(source_path)

    if source_path.is_dir():
        return DirectorySource(source_path)

    raise ConfigurationException("Unable to read {}".format(source_path))


def create_destination_and_install(source: Source, resource_path: Path) -> None:
    """Create destination directory and install source.

    This function is used for actual installation of system/backend New
    directory will be created inside :resource_path: if needed If for example
    archive contains top level folder then no need to create new directory
    """
    destination = resource_path
    create_destination = source.create_destination()

    if create_destination:
        name = source.name()
        if not name:
            raise ConfigurationException("Unable to get source name")

        destination = resource_path / name
        destination.mkdir()
    try:
        source.install_into(destination)
    except Exception as error:
        if create_destination:
            shutil.rmtree(destination)
        raise error