From 7a661257b6adad0c8f53e32b42ced56a1e7d952f Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Thu, 2 Feb 2023 14:02:05 +0000 Subject: MLIA-769 Expand use of target/backend registries - Use the target/backend registries to avoid hard-coded names. - Cache target profiles to avoid re-loading them Change-Id: I474b7c9ef23894e1d8a3ea06d13a37652054c62e --- src/mlia/target/registry.py | 68 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) (limited to 'src/mlia/target/registry.py') diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py index 4870fc8..9fccecb 100644 --- a/src/mlia/target/registry.py +++ b/src/mlia/target/registry.py @@ -3,17 +3,78 @@ """Target module.""" from __future__ import annotations +from functools import lru_cache +from pathlib import Path +from typing import cast + from mlia.backend.config import BackendType from mlia.backend.manager import get_installation_manager from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory from mlia.core.reporting import Column from mlia.core.reporting import Table +from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES +from mlia.target.config import get_builtin_profile_path +from mlia.target.config import is_builtin_profile +from mlia.target.config import load_profile from mlia.target.config import TargetInfo +from mlia.target.config import TargetProfile from mlia.utils.registry import Registry + +class TargetRegistry(Registry[TargetInfo]): + """Registry for targets.""" + + def register(self, name: str, item: TargetInfo) -> bool: + """Register an item: returns `False` if already registered.""" + assert all( + backend in backend_registry.items for backend in item.supported_backends + ) + return super().register(name, item) + + # All supported targets are required to be registered here. -registry = Registry[TargetInfo]() +registry = TargetRegistry() + + +def builtin_profile_names() -> list[str]: + """Return a list of built-in profile names (not file paths).""" + return BUILTIN_SUPPORTED_PROFILE_NAMES + + +@lru_cache +def profile(target_profile: str | Path) -> TargetProfile: + """Get the target profile data (built-in or custom file).""" + if not target_profile: + raise ValueError("No valid target profile was provided.") + if is_builtin_profile(target_profile): + profile_file = get_builtin_profile_path(cast(str, target_profile)) + profile_ = create_target_profile(profile_file) + else: + profile_file = Path(target_profile) + if profile_file.is_file(): + profile_ = create_target_profile(profile_file) + else: + raise ValueError( + f"Profile '{target_profile}' is neither a valid built-in " + "target profile name or a valid file path." + ) + + return profile_ + + +def get_target(target_profile: str | Path) -> str: + """Return target for the provided target_profile.""" + return profile(target_profile).target + + +@lru_cache +def create_target_profile(path: Path) -> TargetProfile: + """Create a new instance of a TargetProfile from the file.""" + profile_data = load_profile(path) + target = profile_data["target"] + target_info = registry.items[target] + return target_info.target_profile_cls.load_json_data(profile_data) def supported_advice(target: str) -> list[AdviceCategory]: @@ -29,6 +90,11 @@ def supported_backends(target: str) -> list[str]: return registry.items[target].filter_supported_backends(check_system=False) +def default_backends(target: str) -> list[str]: + """Get a list of default backends for the given target.""" + return registry.items[target].default_backends + + def get_backend_to_supported_targets() -> dict[str, list]: """Get a dict that maps a list of supported targets given backend.""" targets = dict(registry.items) -- cgit v1.2.1