aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/target/registry.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/target/registry.py')
-rw-r--r--src/mlia/target/registry.py39
1 files changed, 34 insertions, 5 deletions
diff --git a/src/mlia/target/registry.py b/src/mlia/target/registry.py
index b7b6193..b850284 100644
--- a/src/mlia/target/registry.py
+++ b/src/mlia/target/registry.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Target module."""
from __future__ import annotations
@@ -13,9 +13,12 @@ from mlia.backend.registry import registry as backend_registry
from mlia.core.common import AdviceCategory
from mlia.core.reporting import Column
from mlia.core.reporting import Table
+from mlia.target.config import BUILTIN_SUPPORTED_OPTIMIZATION_NAMES
from mlia.target.config import BUILTIN_SUPPORTED_PROFILE_NAMES
-from mlia.target.config import get_builtin_profile_path
-from mlia.target.config import is_builtin_profile
+from mlia.target.config import get_builtin_optimization_profile_path
+from mlia.target.config import get_builtin_target_profile_path
+from mlia.target.config import is_builtin_optimization_profile
+from mlia.target.config import is_builtin_target_profile
from mlia.target.config import load_profile
from mlia.target.config import TargetInfo
from mlia.target.config import TargetProfile
@@ -44,13 +47,18 @@ def builtin_profile_names() -> list[str]:
return BUILTIN_SUPPORTED_PROFILE_NAMES
+def builtin_optimization_names() -> list[str]:
+ """Return a list of built-in profile names (not file paths)."""
+ return BUILTIN_SUPPORTED_OPTIMIZATION_NAMES
+
+
@lru_cache
def profile(target_profile: str | Path) -> TargetProfile:
"""Get the target profile data (built-in or custom file)."""
if not target_profile:
raise ValueError("No valid target profile was provided.")
- if is_builtin_profile(target_profile):
- profile_file = get_builtin_profile_path(cast(str, target_profile))
+ if is_builtin_target_profile(target_profile):
+ profile_file = get_builtin_target_profile_path(cast(str, target_profile))
profile_ = create_target_profile(profile_file)
else:
profile_file = Path(target_profile)
@@ -65,6 +73,27 @@ def profile(target_profile: str | Path) -> TargetProfile:
return profile_
+def get_optimization_profile(optimization_profile: str | Path) -> dict:
+ """Get the optimization profile data (built-in or custom file)."""
+ if not optimization_profile:
+ raise ValueError("No valid optimization profile was provided.")
+ if is_builtin_optimization_profile(optimization_profile):
+ profile_file = get_builtin_optimization_profile_path(
+ cast(str, optimization_profile)
+ )
+ profile_dict = load_profile(profile_file)
+ else:
+ profile_file = Path(optimization_profile)
+ if profile_file.is_file():
+ profile_dict = load_profile(profile_file)
+ else:
+ raise ValueError(
+ f"optimization Profile '{optimization_profile}' is neither a valid "
+ "built-in optimization profile name or a valid file path."
+ )
+ return profile_dict
+
+
def get_target(target_profile: str | Path) -> str:
"""Return target for the provided target_profile."""
return profile(target_profile).target