aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/backend/repo.py
blob: 3dd2e57b775d1cf36cce5062b073069d53f1da21 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for backend repository.

Backend repository is responsible for managing backends
(apart from python package based) that have been installed
via command "mlia-backend".

Repository has associated directory (by default ~/.mlia) and
configuration file (by default ~/.mlia/mlia_config.json). In
configuration file repository keeps track of installed backends
and their settings. Backend settings could be used by MLIA for
correct instantiation of the backend.

If backend is removed then repository removes corresponding record
from configuration file along with backend files if needed.
"""
from __future__ import annotations

import json
import shutil
from pathlib import Path

from mlia.utils.filesystem import copy_all


class _ConfigFile:
    """Configuration file for backend repository."""

    def __init__(self, config_file: Path) -> None:
        """Init configuration file."""
        self.config_file = config_file
        self.config: dict = {"backends": []}

        if self.exists():
            content = self.config_file.read_text()
            self.config = json.loads(content)

    def exists(self) -> bool:
        """Check if configuration file exists."""
        return self.config_file.is_file()

    def save(self) -> None:
        """Save configuration."""
        content = json.dumps(self.config, indent=4)
        self.config_file.write_text(content)

    def add_backend(
        self,
        backend_name: str,
        settings: dict,
    ) -> None:
        """Add backend settings to configuration file."""
        item = {"name": backend_name, "settings": settings}
        self.config["backends"].append(item)

        self.save()

    def remove_backend(self, backend_name: str) -> None:
        """Remove backend settings."""
        backend = self._get_backend(backend_name)

        if backend:
            self.config["backends"].remove(backend)

        self.save()

    def backend_exists(self, backend_name: str) -> bool:
        """Check if backend exists in configuration file."""
        return self._get_backend(backend_name) is not None

    def _get_backend(self, backend_name: str) -> dict | None:
        """Find backend settings by backend name."""
        find_backend = (
            item for item in self.config["backends"] if item["name"] == backend_name
        )

        return next(find_backend, None)

    def get_backend_settings(self, backend_name: str) -> dict | None:
        """Get backend settings."""
        backend = self._get_backend(backend_name)

        return backend["settings"] if backend else None


class BackendRepository:
    """Repository for keeping track of the installed backends."""

    def __init__(
        self,
        repository: Path,
        config_filename: str = "mlia_config.json",
    ) -> None:
        """Init repository instance."""
        self.repository = repository
        self.config_file = _ConfigFile(repository / config_filename)

        self._init_repo()

    def copy_backend(
        self,
        backend_name: str,
        backend_path: Path,
        backend_dir_name: str,
        settings: dict | None = None,
    ) -> None:
        """Copy backend files into repository."""
        repo_backend_path = self._get_backend_path(backend_dir_name)

        if repo_backend_path.exists():
            raise Exception(f"Unable to copy backend files for {backend_name}.")

        copy_all(backend_path, dest=repo_backend_path)

        settings = settings or {}
        settings["backend_dir"] = backend_dir_name

        self.config_file.add_backend(backend_name, settings)

    def add_backend(
        self,
        backend_name: str,
        backend_path: Path,
        settings: dict | None = None,
    ) -> None:
        """Add backend to repository."""
        if self.is_backend_installed(backend_name):
            raise Exception(f"Backend {backend_name} already installed.")

        settings = settings or {}
        settings["backend_path"] = backend_path.absolute().as_posix()

        self.config_file.add_backend(backend_name, settings)

    def remove_backend(self, backend_name: str) -> None:
        """Remove backend from repository."""
        settings = self.config_file.get_backend_settings(backend_name)

        if not settings:
            raise Exception(f"Backend {backend_name} is not installed.")

        if "backend_dir" in settings:
            repo_backend_path = self._get_backend_path(settings["backend_dir"])
            shutil.rmtree(repo_backend_path)

        self.config_file.remove_backend(backend_name)

    def is_backend_installed(self, backend_name: str) -> bool:
        """Check if backend is installed."""
        return self.config_file.backend_exists(backend_name)

    def get_backend_settings(self, backend_name: str) -> tuple[Path, dict]:
        """Return backend settings."""
        settings = self.config_file.get_backend_settings(backend_name)

        if not settings:
            raise Exception(f"Backend {backend_name} is not installed.")

        if backend_dir := settings.get("backend_dir", None):
            return self._get_backend_path(backend_dir), settings

        if backend_path := settings.get("backend_path", None):
            return Path(backend_path), settings

        raise Exception(f"Unable to resolve path of the backend {backend_name}.")

    def _get_backend_path(self, backend_dir_name: str) -> Path:
        """Return path to backend."""
        return self.repository.joinpath("backends", backend_dir_name)

    def _init_repo(self) -> None:
        """Init repository."""
        if self.repository.exists():
            if not self.config_file.exists():
                raise Exception(
                    f"Directory {self.repository} could not be used as MLIA repository."
                )
        else:
            self.repository.mkdir()
            self.repository.joinpath("backends").mkdir()

            self.config_file.save()


def get_backend_repository(
    repository: Path = Path.home() / ".mlia",
) -> BackendRepository:
    """Return backend repository."""
    return BackendRepository(repository)