From b863f8eaf0b2e2627fa9c5d2a51004fd8133cc68 Mon Sep 17 00:00:00 2001 From: Annie Tallund Date: Tue, 31 Jan 2023 16:20:12 +0100 Subject: MLIA-594 Save target profile configuration Save the target profile file in the output directory. Change-Id: I886e52cb922c5425e749b154bd67a5d294ce0201 --- src/mlia/cli/main.py | 17 ++++++++++++++++- src/mlia/target/config.py | 14 ++++++++++++++ tests/test_target_config.py | 12 ++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 76f199e..793e155 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -34,6 +34,7 @@ from mlia.core.context import ExecutionContext from mlia.core.errors import ConfigurationError from mlia.core.errors import InternalError from mlia.core.logging import setup_logging +from mlia.target.config import copy_profile_file_to_output_dir from mlia.target.registry import table as target_table @@ -174,7 +175,6 @@ def setup_context( if param_name not in skipped_params and (param_name not in expected_params or param_name in func_params) } - return (ctx, func_args) @@ -191,6 +191,11 @@ def run_command(args: argparse.Namespace) -> int: logger.info( "\nThis execution of MLIA uses output directory: %s", ctx.output_dir ) + if copy_profile_file(ctx, func_args): + logger.info( + "Target profile information copied to %s/target_profile.toml", + ctx.output_dir, + ) args.func(**func_args) return 0 except KeyboardInterrupt: @@ -261,6 +266,16 @@ def init_and_run(commands: list[CommandInfo], argv: list[str] | None = None) -> return run_command(args) +def copy_profile_file(ctx: ExecutionContext, func_args: dict) -> bool: + """If present, copy the target profile file to the output directory.""" + if func_args.get("target_profile"): + return copy_profile_file_to_output_dir( + func_args["target_profile"], ctx.output_dir + ) + + return False + + def main(argv: list[str] | None = None) -> int: """Entry point of the main application.""" commands = get_commands() diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py index ec3fb4c..bf603dd 100644 --- a/src/mlia/target/config.py +++ b/src/mlia/target/config.py @@ -7,6 +7,7 @@ from abc import ABC from abc import abstractmethod from dataclasses import dataclass from pathlib import Path +from shutil import copy from typing import Any from typing import cast from typing import TypeVar @@ -61,6 +62,19 @@ def get_target(target_profile: str | Path) -> str: return cast(str, profile["target"]) +def copy_profile_file_to_output_dir( + target_profile: str | Path, output_dir: str | Path +) -> bool: + """Copy the target profile file to output directory.""" + profile_file_path = get_profile_file(target_profile) + output_file_path = f"{output_dir}/{profile_file_path.stem}.toml" + try: + copy(profile_file_path, output_file_path) + return True + except OSError as err: + raise RuntimeError("Failed to copy profile file:", err.strerror) from err + + T = TypeVar("T", bound="TargetProfile") diff --git a/tests/test_target_config.py b/tests/test_target_config.py index 26f524e..c6235a5 100644 --- a/tests/test_target_config.py +++ b/tests/test_target_config.py @@ -3,12 +3,15 @@ """Tests for the backend config module.""" from __future__ import annotations +from pathlib import Path + import pytest from mlia.backend.config import BackendConfiguration from mlia.backend.config import BackendType from mlia.backend.config import System from mlia.core.common import AdviceCategory +from mlia.target.config import copy_profile_file_to_output_dir from mlia.target.config import get_builtin_supported_profile_names from mlia.target.config import get_profile_file from mlia.target.config import load_profile @@ -17,6 +20,15 @@ from mlia.target.config import TargetProfile from mlia.utils.registry import Registry +def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None: + """Test if the profile file is copied into the output directory.""" + test_target_profile_name = "ethos-u55-128" + test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml") + + copy_profile_file_to_output_dir(test_target_profile_name, tmp_path) + assert Path.is_file(test_file_path) + + def test_get_builtin_supported_profile_names() -> None: """Test profile names getter.""" assert get_builtin_supported_profile_names() == [ -- cgit v1.2.1