aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/config.py
blob: bf603dde8b5eab337ea6a7f4e027fd34e4f164e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target configuration module."""
from __future__ import annotations

from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from shutil import copy
from typing import Any
from typing import cast
from typing import TypeVar

try:
    import tomllib
except ModuleNotFoundError:
    import tomli as tomllib  # type: ignore

from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.utils.filesystem import get_mlia_target_profiles_dir


def get_profile_file(target_profile: str | Path) -> Path:
    """Get the target profile toml file."""
    if not target_profile:
        raise Exception("Target profile is not provided.")

    profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml")
    if not profile_file.is_file():
        profile_file = Path(target_profile)

    if not profile_file.exists():
        raise Exception(f"File not found: {profile_file}.")
    return profile_file


def load_profile(path: str | Path) -> dict[str, Any]:
    """Get settings for the provided target profile."""
    with open(path, "rb") as file:
        profile = tomllib.load(file)

    return cast(dict, profile)


def get_builtin_supported_profile_names() -> list[str]:
    """Return list of default profiles in the target profiles directory."""
    return sorted(
        [
            item.stem
            for item in get_mlia_target_profiles_dir().iterdir()
            if item.is_file() and item.suffix == ".toml"
        ]
    )


def get_target(target_profile: str | Path) -> str:
    """Return target for the provided target_profile."""
    profile_file = get_profile_file(target_profile)
    profile = load_profile(profile_file)
    return cast(str, profile["target"])


def copy_profile_file_to_output_dir(
    target_profile: str | Path, output_dir: str | Path
) -> bool:
    """Copy the target profile file to output directory."""
    profile_file_path = get_profile_file(target_profile)
    output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
    try:
        copy(profile_file_path, output_file_path)
        return True
    except OSError as err:
        raise RuntimeError("Failed to copy profile file:", err.strerror) from err


T = TypeVar("T", bound="TargetProfile")


class TargetProfile(ABC):
    """Base class for target profiles."""

    def __init__(self, target: str) -> None:
        """Init TargetProfile instance with the target name."""
        self.target = target

    @classmethod
    def load(cls: type[T], path: str | Path) -> T:
        """Load and verify a target profile from file and return new instance."""
        profile = load_profile(path)

        try:
            new_instance = cls(**profile)
        except KeyError as ex:
            raise KeyError(f"Missing key in file {path}.") from ex

        new_instance.verify()

        return new_instance

    @classmethod
    def load_profile(cls: type[T], target_profile: str) -> T:
        """Load a target profile by name."""
        profile_file = get_profile_file(target_profile)
        return cls.load(profile_file)

    def save(self, path: str | Path) -> None:
        """Save this target profile to a file."""
        raise NotImplementedError("Saving target profiles is currently not supported.")

    @abstractmethod
    def verify(self) -> None:
        """
        Check that all attributes contain valid values etc.

        Raises a ValueError, if an issue is detected.
        """
        if not self.target:
            raise ValueError(f"Invalid target name: {self.target}")


@dataclass
class TargetInfo:
    """Collect information about supported targets."""

    supported_backends: list[str]

    def __str__(self) -> str:
        """List supported backends."""
        return ", ".join(sorted(self.supported_backends))

    def is_supported(
        self, advice: AdviceCategory | None = None, check_system: bool = False
    ) -> bool:
        """Check if any of the supported backends support this kind of advice."""
        return any(
            backend_registry.items[name].is_supported(advice, check_system)
            for name in self.supported_backends
        )

    def filter_supported_backends(
        self, advice: AdviceCategory | None = None, check_system: bool = False
    ) -> list[str]:
        """Get the list of supported backends filtered by the given arguments."""
        return [
            name
            for name in self.supported_backends
            if backend_registry.items[name].is_supported(advice, check_system)
        ]