From 09ecc5c8acb758e8def33155feb746a34dd7b560 Mon Sep 17 00:00:00 2001 From: Annie Tallund Date: Wed, 14 Dec 2022 15:55:19 +0100 Subject: MLIA-590 Support path to custom target profiles - Start using TOML format for target profile - Add support for loading custom target profile files Change-Id: I6be019d4341e93115440ccdbdb6dafdc1c85b966 --- src/mlia/utils/filesystem.py | 59 +++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 23 deletions(-) (limited to 'src/mlia/utils/filesystem.py') diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py index fcd09b5..4734a84 100644 --- a/src/mlia/utils/filesystem.py +++ b/src/mlia/utils/filesystem.py @@ -1,11 +1,10 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Utils related to file management.""" from __future__ import annotations import hashlib import importlib.resources as pkg_resources -import json import os import shutil from contextlib import contextmanager @@ -17,6 +16,11 @@ from typing import cast from typing import Generator from typing import Iterable +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib # type: ignore + def get_mlia_resources() -> Path: """Get the path to the resources directory.""" @@ -30,44 +34,53 @@ def get_vela_config() -> Path: return get_mlia_resources() / "vela/vela.ini" -def get_profiles_file() -> Path: +def get_mlia_target_profiles_dir() -> Path: """Get the profiles file.""" - return get_mlia_resources() / "profiles.json" + return get_mlia_resources() / "target_profiles" -def get_profiles_data() -> dict[str, dict[str, Any]]: - """Get the profile values as a dictionary.""" - with open(get_profiles_file(), encoding="utf-8") as json_file: - profiles = json.load(json_file) +def get_profile_toml_file(target_profile: str | Path) -> str | Path: + """Get the target profile toml file.""" + if not target_profile: + raise Exception("Target profile is not provided") - if not isinstance(profiles, dict): - raise Exception("Profiles data format is not valid") + profile_toml_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml") + if not profile_toml_file.is_file(): + profile_toml_file = Path(target_profile) - return profiles + if not profile_toml_file.exists(): + raise Exception(f"File not found: {profile_toml_file}.") + return profile_toml_file -def get_profile(target_profile: str) -> dict[str, Any]: +def get_profile(target_profile: str | Path) -> dict[str, Any]: """Get settings for the provided target profile.""" if not target_profile: raise Exception("Target profile is not provided") - profiles = get_profiles_data() + toml_file = get_profile_toml_file(target_profile) - try: - return profiles[target_profile] - except KeyError as err: - raise Exception(f"Unable to find target profile {target_profile}") from err + with open(toml_file, "rb") as file: + profile = tomllib.load(file) + + return cast(dict, profile) -def get_supported_profile_names() -> list[str]: - """Get the supported Ethos-U profile names.""" - return list(get_profiles_data().keys()) +def get_builtin_supported_profile_names() -> list[str]: + """Return list of default profiles in the target profiles directory.""" + return sorted( + [ + item.stem + for item in get_mlia_target_profiles_dir().iterdir() + if item.is_file() and item.suffix == ".toml" + ] + ) -def get_target(target_profile: str) -> str: +def get_target(target_profile: str | Path) -> str: """Return target for the provided target_profile.""" - profile_data = get_profile(target_profile) - return cast(str, profile_data["target"]) + profile = get_profile(target_profile) + return cast(str, profile["target"]) @contextmanager -- cgit v1.2.1