aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/registry.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/registry.py')
-rw-r--r--src/mlia/target/registry.py68
1 files changed, 67 insertions, 1 deletions
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
index 4870fc8..9fccecb 100644
--- a/src/mlia/target/registry.py
+++ b/src/mlia/target/registry.py
@@ -3,17 +3,78 @@
"""Target module."""
from __future__ import annotations
+from functools import lru_cache
+from pathlib import Path
+from typing import cast
+
from mlia.backend.config import BackendType
from mlia.backend.manager import get_installation_manager
from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.reporting import Column
from mlia.core.reporting import Table
+from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES
+from mlia.target.config import get_builtin_profile_path
+from mlia.target.config import is_builtin_profile
+from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
+from mlia.target.config import TargetProfile
from mlia.utils.registry import Registry
+
+class TargetRegistry(Registry[TargetInfo]):
+ """Registry for targets."""
+
+ def register(self, name: str, item: TargetInfo) -> bool:
+ """Register an item: returns `False` if already registered."""
+ assert all(
+ backend in backend_registry.items for backend in item.supported_backends
+ )
+ return super().register(name, item)
+
+
# All supported targets are required to be registered here.
-registry = Registry[TargetInfo]()
+registry = TargetRegistry()
+
+
+def builtin_profile_names() -> list[str]:
+ """Return a list of built-in profile names (not file paths)."""
+ return BUILTIN_SUPPORTED_PROFILE_NAMES
+
+
+@lru_cache
+def profile(target_profile: str | Path) -> TargetProfile:
+ """Get the target profile data (built-in or custom file)."""
+ if not target_profile:
+ raise ValueError("No valid target profile was provided.")
+ if is_builtin_profile(target_profile):
+ profile_file = get_builtin_profile_path(cast(str, target_profile))
+ profile_ = create_target_profile(profile_file)
+ else:
+ profile_file = Path(target_profile)
+ if profile_file.is_file():
+ profile_ = create_target_profile(profile_file)
+ else:
+ raise ValueError(
+ f"Profile '{target_profile}' is neither a valid built-in "
+ "target profile name or a valid file path."
+ )
+
+ return profile_
+
+
+def get_target(target_profile: str | Path) -> str:
+ """Return target for the provided target_profile."""
+ return profile(target_profile).target
+
+
+@lru_cache
+def create_target_profile(path: Path) -> TargetProfile:
+ """Create a new instance of a TargetProfile from the file."""
+ profile_data = load_profile(path)
+ target = profile_data["target"]
+ target_info = registry.items[target]
+ return target_info.target_profile_cls.load_json_data(profile_data)
def supported_advice(target: str) -> list[AdviceCategory]:
@@ -29,6 +90,11 @@ def supported_backends(target: str) -> list[str]:
return registry.items[target].filter_supported_backends(check_system=False)
+def default_backends(target: str) -> list[str]:
+ """Get a list of default backends for the given target."""
+ return registry.items[target].default_backends
+
+
def get_backend_to_supported_targets() -> dict[str, list]:
"""Get a dict that maps a list of supported targets given backend."""
targets = dict(registry.items)