aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/config.py')
-rw-r--r--src/mlia/target/config.py28
1 files changed, 23 insertions, 5 deletions
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 3bc74fa..8ccdad8 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target configuration module."""
from __future__ import annotations
@@ -22,9 +22,10 @@ from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.advisor import InferenceAdvisor
from mlia.utils.filesystem import get_mlia_target_profiles_dir
+from mlia.utils.filesystem import get_mlia_target_optimization_dir
-def get_builtin_profile_path(target_profile: str) -> Path:
+def get_builtin_target_profile_path(target_profile: str) -> Path:
"""
Construct the path to the built-in target profile file.
@@ -33,6 +34,15 @@ def get_builtin_profile_path(target_profile: str) -> Path:
return get_mlia_target_profiles_dir() / f"{target_profile}.toml"
+def get_builtin_optimization_profile_path(optimization_profile: str) -> Path:
+ """
+ Construct the path to the built-in target profile file.
+
+ No checks are performed.
+ """
+ return get_mlia_target_optimization_dir() / f"{optimization_profile}.toml"
+
+
@lru_cache
def load_profile(path: str | Path) -> dict[str, Any]:
"""Get settings for the provided target profile."""
@@ -56,11 +66,19 @@ def get_builtin_supported_profile_names() -> list[str]:
BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names()
-def is_builtin_profile(profile_name: str | Path) -> bool:
+def is_builtin_target_profile(profile_name: str | Path) -> bool:
"""Check if the given profile name belongs to a built-in profile."""
return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES
+BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = ["optimization"]
+
+
+def is_builtin_optimization_profile(optimization_name: str | Path) -> bool:
+ """Check if the given optimization name belongs to a built-in optimization."""
+ return optimization_name in BUILTIN_SUPPORTED_OPTIMIZATION_NAMES
+
+
T = TypeVar("T", bound="TargetProfile")
@@ -93,8 +111,8 @@ class TargetProfile(ABC):
@classmethod
def load_profile(cls: type[T], target_profile: str | Path) -> T:
"""Load a target profile from built-in target profile name or file path."""
- if is_builtin_profile(target_profile):
- profile_file = get_builtin_profile_path(cast(str, target_profile))
+ if is_builtin_target_profile(target_profile):
+ profile_file = get_builtin_target_profile_path(cast(str, target_profile))
else:
profile_file = Path(target_profile)
return cls.load(profile_file)