aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnnie Tallund <annie.tallund@arm.com>2023-01-31 16:20:12 +0100
committerAnnie Tallund <annie.tallund@arm.com>2023-02-10 12:54:45 +0100
commitb863f8eaf0b2e2627fa9c5d2a51004fd8133cc68 (patch)
tree6ac932f81de1d326477f0716b1668e9bebae332e
parent3e3dcb9bd5abb88adcd85b4f89e8a81e7f6fa293 (diff)
downloadmlia-b863f8eaf0b2e2627fa9c5d2a51004fd8133cc68.tar.gz
MLIA-594 Save target profile configuration
Save the target profile file in the output directory. Change-Id: I886e52cb922c5425e749b154bd67a5d294ce0201
-rw-r--r--src/mlia/cli/main.py17
-rw-r--r--src/mlia/target/config.py14
-rw-r--r--tests/test_target_config.py12
3 files changed, 42 insertions, 1 deletions
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 76f199e..793e155 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -34,6 +34,7 @@ from mlia.core.context import ExecutionContext
from mlia.core.errors import ConfigurationError
from mlia.core.errors import InternalError
from mlia.core.logging import setup_logging
+from mlia.target.config import copy_profile_file_to_output_dir
from mlia.target.registry import table as target_table
@@ -174,7 +175,6 @@ def setup_context(
if param_name not in skipped_params
and (param_name not in expected_params or param_name in func_params)
}
-
return (ctx, func_args)
@@ -191,6 +191,11 @@ def run_command(args: argparse.Namespace) -> int:
logger.info(
"\nThis execution of MLIA uses output directory: %s", ctx.output_dir
)
+ if copy_profile_file(ctx, func_args):
+ logger.info(
+ "Target profile information copied to %s/target_profile.toml",
+ ctx.output_dir,
+ )
args.func(**func_args)
return 0
except KeyboardInterrupt:
@@ -261,6 +266,16 @@ def init_and_run(commands: list[CommandInfo], argv: list[str] | None = None) ->
return run_command(args)
+def copy_profile_file(ctx: ExecutionContext, func_args: dict) -> bool:
+ """If present, copy the target profile file to the output directory."""
+ if func_args.get("target_profile"):
+ return copy_profile_file_to_output_dir(
+ func_args["target_profile"], ctx.output_dir
+ )
+
+ return False
+
+
def main(argv: list[str] | None = None) -> int:
"""Entry point of the main application."""
commands = get_commands()
diff --git a/src/mlia/target/config.py b/src/mlia/target/config.py
index ec3fb4c..bf603dd 100644
--- a/src/mlia/target/config.py
+++ b/src/mlia/target/config.py
@@ -7,6 +7,7 @@ from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
+from shutil import copy
from typing import Any
from typing import cast
from typing import TypeVar
@@ -61,6 +62,19 @@ def get_target(target_profile: str | Path) -> str:
return cast(str, profile["target"])
+def copy_profile_file_to_output_dir(
+ target_profile: str | Path, output_dir: str | Path
+) -> bool:
+ """Copy the target profile file to output directory."""
+ profile_file_path = get_profile_file(target_profile)
+ output_file_path = f"{output_dir}/{profile_file_path.stem}.toml"
+ try:
+ copy(profile_file_path, output_file_path)
+ return True
+ except OSError as err:
+ raise RuntimeError("Failed to copy profile file:", err.strerror) from err
+
+
T = TypeVar("T", bound="TargetProfile")
diff --git a/tests/test_target_config.py b/tests/test_target_config.py
index 26f524e..c6235a5 100644
--- a/tests/test_target_config.py
+++ b/tests/test_target_config.py
@@ -3,12 +3,15 @@
"""Tests for the backend config module."""
from __future__ import annotations
+from pathlib import Path
+
import pytest
from mlia.backend.config import BackendConfiguration
from mlia.backend.config import BackendType
from mlia.backend.config import System
from mlia.core.common import AdviceCategory
+from mlia.target.config import copy_profile_file_to_output_dir
from mlia.target.config import get_builtin_supported_profile_names
from mlia.target.config import get_profile_file
from mlia.target.config import load_profile
@@ -17,6 +20,15 @@ from mlia.target.config import TargetProfile
from mlia.utils.registry import Registry
+def test_copy_profile_file_to_output_dir(tmp_path: Path) -> None:
+ """Test if the profile file is copied into the output directory."""
+ test_target_profile_name = "ethos-u55-128"
+ test_file_path = Path(f"{tmp_path}/{test_target_profile_name}.toml")
+
+ copy_profile_file_to_output_dir(test_target_profile_name, tmp_path)
+ assert Path.is_file(test_file_path)
+
+
def test_get_builtin_supported_profile_names() -> None:
"""Test profile names getter."""
assert get_builtin_supported_profile_names() == [