aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/config.py
diff options
context:
space:
mode:
authorNathan Bailey <nathan.bailey@arm.com>2024-02-15 14:50:58 +0000
committerNathan Bailey <nathan.bailey@arm.com>2024-03-14 15:45:40 +0000
commit0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch)
tree09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /src/mlia/target/config.py
parent09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff)
downloadmlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization Resolves: MLIA-1004 Signed-off-by: Nathan Bailey <nathan.bailey@arm.com> Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'src/mlia/target/config.py')
-rw-r--r--src/mlia/target/config.py28
1 files changed, 23 insertions, 5 deletions
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index 3bc74fa..8ccdad8 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target configuration module."""
from __future__ import annotations
@@ -22,9 +22,10 @@ from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.advisor import InferenceAdvisor
from mlia.utils.filesystem import get_mlia_target_profiles_dir
+from mlia.utils.filesystem import get_mlia_target_optimization_dir
-def get_builtin_profile_path(target_profile: str) -> Path:
+def get_builtin_target_profile_path(target_profile: str) -> Path:
"""
Construct the path to the built-in target profile file.
@@ -33,6 +34,15 @@ def get_builtin_profile_path(target_profile: str) -> Path:
return get_mlia_target_profiles_dir() / f"{target_profile}.toml"
+def get_builtin_optimization_profile_path(optimization_profile: str) -> Path:
+ """
+ Construct the path to the built-in target profile file.
+
+ No checks are performed.
+ """
+ return get_mlia_target_optimization_dir() / f"{optimization_profile}.toml"
+
+
@lru_cache
def load_profile(path: str | Path) -> dict[str, Any]:
"""Get settings for the provided target profile."""
@@ -56,11 +66,19 @@ def get_builtin_supported_profile_names() -> list[str]:
BUILTIN_SUPPORTED_PROFILE_NAMES = get_builtin_supported_profile_names()
-def is_builtin_profile(profile_name: str | Path) -> bool:
+def is_builtin_target_profile(profile_name: str | Path) -> bool:
"""Check if the given profile name belongs to a built-in profile."""
return profile_name in BUILTIN_SUPPORTED_PROFILE_NAMES
+BUILTIN_SUPPORTED_OPTIMIZATION_NAMES = ["optimization"]
+
+
+def is_builtin_optimization_profile(optimization_name: str | Path) -> bool:
+ """Check if the given optimization name belongs to a built-in optimization."""
+ return optimization_name in BUILTIN_SUPPORTED_OPTIMIZATION_NAMES
+
+
T = TypeVar("T", bound="TargetProfile")
@@ -93,8 +111,8 @@ class TargetProfile(ABC):
@classmethod
def load_profile(cls: type[T], target_profile: str | Path) -> T:
"""Load a target profile from built-in target profile name or file path."""
- if is_builtin_profile(target_profile):
- profile_file = get_builtin_profile_path(cast(str, target_profile))
+ if is_builtin_target_profile(target_profile):
+ profile_file = get_builtin_target_profile_path(cast(str, target_profile))
else:
profile_file = Path(target_profile)
return cls.load(profile_file)