diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-02-02 14:02:05 +0000 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-02-10 13:45:18 +0000 |
commit | 7a661257b6adad0c8f53e32b42ced56a1e7d952f (patch) | |
tree | 938ad8578c5b9edc0573e810ce64ce0a5bda3d8c /src/mlia/target/config.py | |
parent | 50271dee0a84bfc481ce798184f07b5b0b4bc64d (diff) | |
download | mlia-7a661257b6adad0c8f53e32b42ced56a1e7d952f.tar.gz |
MLIA-769 Expand use of target/backend registries
- Use the target/backend registries to avoid hard-coded names.
- Cache target profiles to avoid re-loading them
Change-Id: I474b7c9ef23894e1d8a3ea06d13a37652054c62e
Diffstat (limited to 'src/mlia/target/config.py')
-rw-r--r-- | src/mlia/target/config.py | 71 |
1 files changed, 35 insertions, 36 deletions
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index bf603dd..eb7ecff 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -6,9 +6,10 @@ from __future__ import annotations from abc import ABC from abc import abstractmethod from dataclasses import dataclass +from functools import lru_cache from pathlib import Path -from shutil import copy from typing import Any +from typing import Callable from typing import cast from typing import TypeVar @@ -19,23 +20,20 @@ except ModuleNotFoundError: from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory +from mlia.core.advisor import InferenceAdvisor from mlia.utils.filesystem import get_mlia_target_profiles_dir -def get_profile_file(target_profile: str | Path) -> Path: - """Get the target profile toml file.""" - if not target_profile: - raise Exception("Target profile is not provided.") +def get_builtin_profile_path(target_profile: str) -> Path: + """ + Construct the path to the built-in target profile file. - profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml") - if not profile_file.is_file(): - profile_file = Path(target_profile) - - if not profile_file.exists(): - raise Exception(f"File not found: {profile_file}.") - return profile_file + No checks are performed. + """ + return get_mlia_target_profiles_dir() / f"{target_profile}.toml" +@lru_cache def load_profile(path: str | Path) -> dict[str, Any]: """Get settings for the provided target profile.""" with open(path, "rb") as file: @@ -55,24 +53,12 @@ def get_builtin_supported_profile_names() -> list[str]: ) -def get_target(target_profile: str | Path) -> str: - """Return target for the provided target_profile.""" - profile_file = get_profile_file(target_profile) - profile = load_profile(profile_file) - return cast(str, profile["target"]) +BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names() -def copy_profile_file_to_output_dir( - target_profile: str | Path, output_dir: str | Path -) -> bool: - """Copy the target profile file to output directory.""" - profile_file_path = get_profile_file(target_profile) - output_file_path = f"{output_dir}/{profile_file_path.stem}.toml" - try: - copy(profile_file_path, output_file_path) - return True - except OSError as err: - raise RuntimeError("Failed to copy profile file:", err.strerror) from err +def is_builtin_profile(profile_name: str | Path) -> bool: + """Check if the given profile name belongs to a built-in profile.""" + return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES T = TypeVar("T", bound="TargetProfile") @@ -88,21 +74,29 @@ class TargetProfile(ABC): @classmethod def load(cls: type[T], path: str | Path) -> T: """Load and verify a target profile from file and return new instance.""" - profile = load_profile(path) + profile_data = load_profile(path) try: - new_instance = cls(**profile) + new_instance = cls.load_json_data(profile_data) except KeyError as ex: raise KeyError(f"Missing key in file {path}.") from ex - new_instance.verify() + return new_instance + @classmethod + def load_json_data(cls: type[T], profile_data: dict) -> T: + """Load a target profile from the JSON data.""" + new_instance = cls(**profile_data) + new_instance.verify() return new_instance @classmethod - def load_profile(cls: type[T], target_profile: str) -> T: - """Load a target profile by name.""" - profile_file = get_profile_file(target_profile) + def load_profile(cls: type[T], target_profile: str | Path) -> T: + """Load a target profile from built-in target profile name or file path.""" + if is_builtin_profile(target_profile): + profile_file = get_builtin_profile_path(cast(str, target_profile)) + else: + profile_file = Path(target_profile) return cls.load(profile_file) def save(self, path: str | Path) -> None: @@ -125,6 +119,9 @@ class TargetInfo: """Collect information about supported targets.""" supported_backends: list[str] + default_backends: list[str] + advisor_factory_func: Callable[..., InferenceAdvisor] + target_profile_cls: type[TargetProfile] def __str__(self) -> str: """List supported backends.""" @@ -135,7 +132,8 @@ class TargetInfo: ) -> bool: """Check if any of the supported backends support this kind of advice.""" return any( - backend_registry.items[name].is_supported(advice, check_system) + name in backend_registry.items + and backend_registry.items[name].is_supported(advice, check_system) for name in self.supported_backends ) @@ -146,5 +144,6 @@ class TargetInfo: return [ name for name in self.supported_backends - if backend_registry.items[name].is_supported(advice, check_system) + if name in backend_registry.items + and backend_registry.items[name].is_supported(advice, check_system) ] |