diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-01-11 12:32:02 +0000 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2023-02-08 15:23:29 +0000 |
commit | a4fb8c72f15146c95df16c25e75f03344e9814fd (patch) | |
tree | ce6d9cf39951a0c85d2773d436cc5010ecf78a8f /src/mlia/target/config.py | |
parent | 09ecc5c8acb758e8def33155feb746a34dd7b560 (diff) | |
download | mlia-a4fb8c72f15146c95df16c25e75f03344e9814fd.tar.gz |
MLIA-591 Create interface for target profiles
New class 'TargetProfile' is used to load and verify target profiles.
Change-Id: I76373a923e2e5f55c4e95860635afe9fc5627a5d
Diffstat (limited to 'src/mlia/target/config.py')
-rw-r--r-- | src/mlia/target/config.py | 99 |
1 files changed, 94 insertions, 5 deletions
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index f257784..ec3fb4c 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -1,21 +1,110 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 -"""IP configuration module.""" +"""Target configuration module.""" from __future__ import annotations +from abc import ABC +from abc import abstractmethod from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import cast +from typing import TypeVar + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib # type: ignore from mlia.backend.registry import registry as backend_registry from mlia.core.common import AdviceCategory +from mlia.utils.filesystem import get_mlia_target_profiles_dir + + +def get_profile_file(target_profile: str | Path) -> Path: + """Get the target profile toml file.""" + if not target_profile: + raise Exception("Target profile is not provided.") + + profile_file = Path(get_mlia_target_profiles_dir() / f"{target_profile}.toml") + if not profile_file.is_file(): + profile_file = Path(target_profile) + + if not profile_file.exists(): + raise Exception(f"File not found: {profile_file}.") + return profile_file + + +def load_profile(path: str | Path) -> dict[str, Any]: + """Get settings for the provided target profile.""" + with open(path, "rb") as file: + profile = tomllib.load(file) + + return cast(dict, profile) + + +def get_builtin_supported_profile_names() -> list[str]: + """Return list of default profiles in the target profiles directory.""" + return sorted( + [ + item.stem + for item in get_mlia_target_profiles_dir().iterdir() + if item.is_file() and item.suffix == ".toml" + ] + ) + +def get_target(target_profile: str | Path) -> str: + """Return target for the provided target_profile.""" + profile_file = get_profile_file(target_profile) + profile = load_profile(profile_file) + return cast(str, profile["target"]) -class IPConfiguration: # pylint: disable=too-few-public-methods - """Base class for IP configuration.""" + +T = TypeVar("T", bound="TargetProfile") + + +class TargetProfile(ABC): + """Base class for target profiles.""" def __init__(self, target: str) -> None: - """Init IP configuration instance.""" + """Init TargetProfile instance with the target name.""" self.target = target + @classmethod + def load(cls: type[T], path: str | Path) -> T: + """Load and verify a target profile from file and return new instance.""" + profile = load_profile(path) + + try: + new_instance = cls(**profile) + except KeyError as ex: + raise KeyError(f"Missing key in file {path}.") from ex + + new_instance.verify() + + return new_instance + + @classmethod + def load_profile(cls: type[T], target_profile: str) -> T: + """Load a target profile by name.""" + profile_file = get_profile_file(target_profile) + return cls.load(profile_file) + + def save(self, path: str | Path) -> None: + """Save this target profile to a file.""" + raise NotImplementedError("Saving target profiles is currently not supported.") + + @abstractmethod + def verify(self) -> None: + """ + Check that all attributes contain valid values etc. + + Raises a ValueError, if an issue is detected. + """ + if not self.target: + raise ValueError(f"Invalid target name: {self.target}") + @dataclass class TargetInfo: |