diff options
Diffstat (limited to 'src/mlia/target/registry.py')
-rw-r--r-- | src/mlia/target/registry.py | 39 |
1 files changed, 34 insertions, 5 deletions
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py index b7b6193..b850284 100644 --- a/src/mlia/target/registry.py +++ b/src/mlia/target/registry.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Target module.""" from __future__ import annotations @@ -13,9 +13,12 @@ from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory from mlia.core.reporting import Column from mlia.core.reporting import Table +from mlia.target.config import BUILTIN_SUPPORTED_OPTIMIZATION_NAMES from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES -from mlia.target.config import get_builtin_profile_path -from mlia.target.config import is_builtin_profile +from mlia.target.config import get_builtin_optimization_profile_path +from mlia.target.config import get_builtin_target_profile_path +from mlia.target.config import is_builtin_optimization_profile +from mlia.target.config import is_builtin_target_profile from mlia.target.config import load_profile from mlia.target.config import TargetInfo from mlia.target.config import TargetProfile @@ -44,13 +47,18 @@ def builtin_profile_names() -> list[str]: return BUILTIN_SUPPORTED_PROFILE_NAMES +def builtin_optimization_names() -> list[str]: + """Return a list of built-in profile names (not file paths).""" + return BUILTIN_SUPPORTED_OPTIMIZATION_NAMES + + @lru_cache def profile(target_profile: str | Path) -> TargetProfile: """Get the target profile data (built-in or custom file).""" if not target_profile: raise ValueError("No valid target profile was provided.") - if is_builtin_profile(target_profile): - profile_file = get_builtin_profile_path(cast(str, target_profile)) + if is_builtin_target_profile(target_profile): + profile_file = get_builtin_target_profile_path(cast(str, target_profile)) profile_ = create_target_profile(profile_file) else: profile_file = Path(target_profile) @@ -65,6 +73,27 @@ def profile(target_profile: str | Path) -> TargetProfile: return profile_ +def get_optimization_profile(optimization_profile: str | Path) -> dict: + """Get the optimization profile data (built-in or custom file).""" + if not optimization_profile: + raise ValueError("No valid optimization profile was provided.") + if is_builtin_optimization_profile(optimization_profile): + profile_file = get_builtin_optimization_profile_path( + cast(str, optimization_profile) + ) + profile_dict = load_profile(profile_file) + else: + profile_file = Path(optimization_profile) + if profile_file.is_file(): + profile_dict = load_profile(profile_file) + else: + raise ValueError( + f"optimization Profile '{optimization_profile}' is neither a valid " + "built-in optimization profile name or a valid file path." + ) + return profile_dict + + def get_target(target_profile: str | Path) -> str: """Return target for the provided target_profile.""" return profile(target_profile).target |