From a8ee1aee3e674c78a77801d1bf2256881ab6b4b9 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 21 Jul 2022 14:06:03 +0100 Subject: MLIA-549 Refactor API module to support several target profiles - Move target specific details out of API module - Move common logic for workflow event handler into a separate class Change-Id: Ic4a22657b722af1c1fead1d478f606ac57325788 --- src/mlia/utils/filesystem.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) (limited to 'src/mlia/utils') diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py index 7975905..0c28d35 100644 --- a/src/mlia/utils/filesystem.py +++ b/src/mlia/utils/filesystem.py @@ -11,6 +11,7 @@ from pathlib import Path from tempfile import mkstemp from tempfile import TemporaryDirectory from typing import Any +from typing import cast from typing import Dict from typing import Generator from typing import Iterable @@ -32,12 +33,12 @@ def get_vela_config() -> Path: def get_profiles_file() -> Path: - """Get the Ethos-U profiles file.""" + """Get the profiles file.""" return get_mlia_resources() / "profiles.json" def get_profiles_data() -> Dict[str, Dict[str, Any]]: - """Get the Ethos-U profile values as a dictionary.""" + """Get the profile values as a dictionary.""" with open(get_profiles_file(), encoding="utf-8") as json_file: profiles = json.load(json_file) @@ -47,14 +48,17 @@ def get_profiles_data() -> Dict[str, Dict[str, Any]]: return profiles -def get_profile(target: str) -> Dict[str, Any]: +def get_profile(target_profile: str) -> Dict[str, Any]: """Get settings for the provided target profile.""" - profiles = get_profiles_data() + if not target_profile: + raise Exception("Target profile is not provided") - if target not in profiles: - raise Exception(f"Unable to find target profile {target}") + profiles = get_profiles_data() - return profiles[target] + try: + return profiles[target_profile] + except KeyError as err: + raise Exception(f"Unable to find target profile {target_profile}") from err def get_supported_profile_names() -> List[str]: @@ -62,6 +66,12 @@ def get_supported_profile_names() -> List[str]: return list(get_profiles_data().keys()) +def get_target(target_profile: str) -> str: + """Return target for the provided target_profile.""" + profile_data = get_profile(target_profile) + return cast(str, profile_data["target"]) + + @contextmanager def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]: """Create temp file and remove it after.""" -- cgit v1.2.1