diff options
Diffstat (limited to 'src/mlia/target/config.py')
-rw-r--r-- | src/mlia/target/config.py | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index 3bc74fa..8ccdad8 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Target configuration module.""" from __future__ import annotations @@ -22,9 +22,10 @@ from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory from mlia.core.advisor import InferenceAdvisor from mlia.utils.filesystem import get_mlia_target_profiles_dir +from mlia.utils.filesystem import get_mlia_target_optimization_dir -def get_builtin_profile_path(target_profile: str) -> Path: +def get_builtin_target_profile_path(target_profile: str) -> Path: """ Construct the path to the built-in target profile file. @@ -33,6 +34,15 @@ def get_builtin_profile_path(target_profile: str) -> Path: return get_mlia_target_profiles_dir() / f"{target_profile}.toml" +def get_builtin_optimization_profile_path(optimization_profile: str) -> Path: + """ + Construct the path to the built-in target profile file. + + No checks are performed. + """ + return get_mlia_target_optimization_dir() / f"{optimization_profile}.toml" + + @lru_cache def load_profile(path: str | Path) -> dict[str, Any]: """Get settings for the provided target profile.""" @@ -56,11 +66,19 @@ def get_builtin_supported_profile_names() -> list[str]: BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names() -def is_builtin_profile(profile_name: str | Path) -> bool: +def is_builtin_target_profile(profile_name: str | Path) -> bool: """Check if the given profile name belongs to a built-in profile.""" return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES +BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = ["optimization"] + + +def is_builtin_optimization_profile(optimization_name: str | Path) -> bool: + """Check if the given optimization name belongs to a built-in optimization.""" + return optimization_name in BUILTIN_SUPPORTED_OPTIMIZATION_NAMES + + T = TypeVar("T", bound="TargetProfile") @@ -93,8 +111,8 @@ class TargetProfile(ABC): @classmethod def load_profile(cls: type[T], target_profile: str | Path) -> T: """Load a target profile from built-in target profile name or file path.""" - if is_builtin_profile(target_profile): - profile_file = get_builtin_profile_path(cast(str, target_profile)) + if is_builtin_target_profile(target_profile): + profile_file = get_builtin_target_profile_path(cast(str, target_profile)) else: profile_file = Path(target_profile) return cls.load(profile_file) |