aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/utils/filesystem.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/utils/filesystem.py')
-rw-r--r--src/mlia/utils/filesystem.py59
1 files changed, 36 insertions, 23 deletions
diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py
index fcd09b5..4734a84 100644
--- a/src/mlia/utils/filesystem.py
+++ b/src/mlia/utils/filesystem.py
@@ -1,11 +1,10 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Utils related to file management."""
from __future__ import annotations
import hashlib
import importlib.resources as pkg_resources
-import json
import os
import shutil
from contextlib import contextmanager
@@ -17,6 +16,11 @@ from typing import cast
from typing import Generator
from typing import Iterable
+try:
+ import tomllib
+except ModuleNotFoundError:
+ import tomli as tomllib # type: ignore
+
def get_mlia_resources() -> Path:
"""Get the path to the resources directory."""
@@ -30,44 +34,53 @@ def get_vela_config() -> Path:
return get_mlia_resources() / "vela/vela.ini"
-def get_profiles_file() -> Path:
+def get_mlia_target_profiles_dir() -> Path:
"""Get the profiles file."""
- return get_mlia_resources() / "profiles.json"
+ return get_mlia_resources() / "target_profiles"
-def get_profiles_data() -> dict[str, dict[str, Any]]:
- """Get the profile values as a dictionary."""
- with open(get_profiles_file(), encoding="utf-8") as json_file:
- profiles = json.load(json_file)
+def get_profile_toml_file(target_profile: str | Path) -> str | Path:
+ """Get the target profile toml file."""
+ if not target_profile:
+ raise Exception("Target profile is not provided")
- if not isinstance(profiles, dict):
- raise Exception("Profiles data format is not valid")
+ profile_toml_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml")
+ if not profile_toml_file.is_file():
+ profile_toml_file = Path(target_profile)
- return profiles
+ if not profile_toml_file.exists():
+ raise Exception(f"File not found: {profile_toml_file}.")
+ return profile_toml_file
-def get_profile(target_profile: str) -> dict[str, Any]:
+def get_profile(target_profile: str | Path) -> dict[str, Any]:
"""Get settings for the provided target profile."""
if not target_profile:
raise Exception("Target profile is not provided")
- profiles = get_profiles_data()
+ toml_file = get_profile_toml_file(target_profile)
- try:
- return profiles[target_profile]
- except KeyError as err:
- raise Exception(f"Unable to find target profile {target_profile}") from err
+ with open(toml_file, "rb") as file:
+ profile = tomllib.load(file)
+
+ return cast(dict, profile)
-def get_supported_profile_names() -> list[str]:
- """Get the supported Ethos-U profile names."""
- return list(get_profiles_data().keys())
+def get_builtin_supported_profile_names() -> list[str]:
+ """Return list of default profiles in the target profiles directory."""
+ return sorted(
+ [
+ item.stem
+ for item in get_mlia_target_profiles_dir().iterdir()
+ if item.is_file() and item.suffix == ".toml"
+ ]
+ )
-def get_target(target_profile: str) -> str:
+def get_target(target_profile: str | Path) -> str:
"""Return target for the provided target_profile."""
- profile_data = get_profile(target_profile)
- return cast(str, profile_data["target"])
+ profile = get_profile(target_profile)
+ return cast(str, profile["target"])
@contextmanager