aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/config.py
blob: 8ccdad897a474da616453be45a7eb2982f5fe3b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target configuration module."""
from __future__ import annotations

from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import Callable
from typing import cast
from typing import TypeVar

try:
    import tomllib
except ModuleNotFoundError:
    import tomli as tomllib

from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.advisor import InferenceAdvisor
from mlia.utils.filesystem import get_mlia_target_profiles_dir
from mlia.utils.filesystem import get_mlia_target_optimization_dir


def get_builtin_target_profile_path(target_profile: str) -> Path:
    """
    Construct the path to the built-in target profile file.

    No checks are performed.
    """
    return get_mlia_target_profiles_dir() / f"{target_profile}.toml"


def get_builtin_optimization_profile_path(optimization_profile: str) -> Path:
    """
    Construct the path to the built-in target profile file.

    No checks are performed.
    """
    return get_mlia_target_optimization_dir() / f"{optimization_profile}.toml"


@lru_cache
def load_profile(path: str | Path) -> dict[str, Any]:
    """Get settings for the provided target profile."""
    with open(path, "rb") as file:
        profile = tomllib.load(file)

    return cast(dict, profile)


def get_builtin_supported_profile_names() -> list[str]:
    """Return list of default profiles in the target profiles directory."""
    return sorted(
        [
            item.stem
            for item in get_mlia_target_profiles_dir().iterdir()
            if item.is_file() and item.suffix == ".toml"
        ]
    )


BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names()


def is_builtin_target_profile(profile_name: str | Path) -> bool:
    """Check if the given profile name belongs to a built-in profile."""
    return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES


BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = ["optimization"]


def is_builtin_optimization_profile(optimization_name: str | Path) -> bool:
    """Check if the given optimization name belongs to a built-in optimization."""
    return optimization_name in BUILTIN_SUPPORTED_OPTIMIZATION_NAMES


T = TypeVar("T", bound="TargetProfile")


class TargetProfile(ABC):
    """Base class for target profiles."""

    def __init__(self, target: str) -> None:
        """Init TargetProfile instance with the target name."""
        self.target = target

    @classmethod
    def load(cls: type[T], path: str | Path) -> T:
        """Load and verify a target profile from file and return new instance."""
        profile_data = load_profile(path)

        try:
            new_instance = cls.load_json_data(profile_data)
        except KeyError as ex:
            raise KeyError(f"Missing key in file {path}.") from ex

        return new_instance

    @classmethod
    def load_json_data(cls: type[T], profile_data: dict) -> T:
        """Load a target profile from the JSON data."""
        new_instance = cls(**profile_data)
        new_instance.verify()
        return new_instance

    @classmethod
    def load_profile(cls: type[T], target_profile: str | Path) -> T:
        """Load a target profile from built-in target profile name or file path."""
        if is_builtin_target_profile(target_profile):
            profile_file = get_builtin_target_profile_path(cast(str, target_profile))
        else:
            profile_file = Path(target_profile)
        return cls.load(profile_file)

    def save(self, path: str | Path) -> None:
        """Save this target profile to a file."""
        raise NotImplementedError("Saving target profiles is currently not supported.")

    @abstractmethod
    def verify(self) -> None:
        """
        Check that all attributes contain valid values etc.

        Raises a ValueError, if an issue is detected.
        """
        if not self.target:
            raise ValueError(f"Invalid target name: {self.target}")


@dataclass
class TargetInfo:
    """Collect information about supported targets."""

    supported_backends: list[str]
    default_backends: list[str]
    advisor_factory_func: Callable[..., InferenceAdvisor]
    target_profile_cls: type[TargetProfile]

    def __str__(self) -> str:
        """List supported backends."""
        return ", ".join(sorted(self.supported_backends))

    def is_supported(
        self, advice: AdviceCategory | None = None, check_system: bool = False
    ) -> bool:
        """Check if any of the supported backends support this kind of advice."""
        return any(
            name in backend_registry.items
            and backend_registry.items[name].is_supported(advice, check_system)
            for name in self.supported_backends
        )

    def filter_supported_backends(
        self, advice: AdviceCategory | None = None, check_system: bool = False
    ) -> list[str]:
        """Get the list of supported backends filtered by the given arguments."""
        return [
            name
            for name in self.supported_backends
            if name in backend_registry.items
            and backend_registry.items[name].is_supported(advice, check_system)
        ]