aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/config.py')
-rw-r--r--src/mlia/target/config.py71
1 files changed, 35 insertions, 36 deletions
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index bf603dd..eb7ecff 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -6,9 +6,10 @@ from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
+from functools import lru_cache
from pathlib import Path
-from shutil import copy
from typing import Any
+from typing import Callable
from typing import cast
from typing import TypeVar
@@ -19,23 +20,20 @@ except ModuleNotFoundError:
from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
+from mlia.core.advisor import InferenceAdvisor
from mlia.utils.filesystem import get_mlia_target_profiles_dir
-def get_profile_file(target_profile: str | Path) -> Path:
- """Get the target profile toml file."""
- if not target_profile:
- raise Exception("Target profile is not provided.")
+def get_builtin_profile_path(target_profile: str) -> Path:
+ """
+ Construct the path to the built-in target profile file.
- profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml")
- if not profile_file.is_file():
- profile_file = Path(target_profile)
-
- if not profile_file.exists():
- raise Exception(f"File not found: {profile_file}.")
- return profile_file
+ No checks are performed.
+ """
+ return get_mlia_target_profiles_dir() / f"{target_profile}.toml"
+@lru_cache
def load_profile(path: str | Path) -> dict[str, Any]:
"""Get settings for the provided target profile."""
with open(path, "rb") as file:
@@ -55,24 +53,12 @@ def get_builtin_supported_profile_names() -> list[str]:
)
-def get_target(target_profile: str | Path) -> str:
- """Return target for the provided target_profile."""
- profile_file = get_profile_file(target_profile)
- profile = load_profile(profile_file)
- return cast(str, profile["target"])
+BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names()
-def copy_profile_file_to_output_dir(
- target_profile: str | Path, output_dir: str | Path
-) -> bool:
- """Copy the target profile file to output directory."""
- profile_file_path = get_profile_file(target_profile)
- output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
- try:
- copy(profile_file_path, output_file_path)
- return True
- except OSError as err:
- raise RuntimeError("Failed to copy profile file:", err.strerror) from err
+def is_builtin_profile(profile_name: str | Path) -> bool:
+ """Check if the given profile name belongs to a built-in profile."""
+ return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES
T = TypeVar("T", bound="TargetProfile")
@@ -88,21 +74,29 @@ class TargetProfile(ABC):
@classmethod
def load(cls: type[T], path: str | Path) -> T:
"""Load and verify a target profile from file and return new instance."""
- profile = load_profile(path)
+ profile_data = load_profile(path)
try:
- new_instance = cls(**profile)
+ new_instance = cls.load_json_data(profile_data)
except KeyError as ex:
raise KeyError(f"Missing key in file {path}.") from ex
- new_instance.verify()
+ return new_instance
+ @classmethod
+ def load_json_data(cls: type[T], profile_data: dict) -> T:
+ """Load a target profile from the JSON data."""
+ new_instance = cls(**profile_data)
+ new_instance.verify()
return new_instance
@classmethod
- def load_profile(cls: type[T], target_profile: str) -> T:
- """Load a target profile by name."""
- profile_file = get_profile_file(target_profile)
+ def load_profile(cls: type[T], target_profile: str | Path) -> T:
+ """Load a target profile from built-in target profile name or file path."""
+ if is_builtin_profile(target_profile):
+ profile_file = get_builtin_profile_path(cast(str, target_profile))
+ else:
+ profile_file = Path(target_profile)
return cls.load(profile_file)
def save(self, path: str | Path) -> None:
@@ -125,6 +119,9 @@ class TargetInfo:
"""Collect information about supported targets."""
supported_backends: list[str]
+ default_backends: list[str]
+ advisor_factory_func: Callable[..., InferenceAdvisor]
+ target_profile_cls: type[TargetProfile]
def __str__(self) -> str:
"""List supported backends."""
@@ -135,7 +132,8 @@ class TargetInfo:
) -> bool:
"""Check if any of the supported backends support this kind of advice."""
return any(
- backend_registry.items[name].is_supported(advice, check_system)
+ name in backend_registry.items
+ and backend_registry.items[name].is_supported(advice, check_system)
for name in self.supported_backends
)
@@ -146,5 +144,6 @@ class TargetInfo:
return [
name
for name in self.supported_backends
- if backend_registry.items[name].is_supported(advice, check_system)
+ if name in backend_registry.items
+ and backend_registry.items[name].is_supported(advice, check_system)
]