diff options
author | Nathan Bailey <nathan.bailey@arm.com> | 2024-02-15 14:50:58 +0000 |
---|---|---|
committer | Nathan Bailey <nathan.bailey@arm.com> | 2024-03-14 15:45:40 +0000 |
commit | 0b552d2ae47da4fb9c16d2a59d6ebe12c8307771 (patch) | |
tree | 09b40b939acbe0bcf02dcc77a7ed7ce4aba94322 /src/mlia/cli/main.py | |
parent | 09b272be6e88d84a30cb89fb71f3fc3c64d20d2e (diff) | |
download | mlia-0b552d2ae47da4fb9c16d2a59d6ebe12c8307771.tar.gz |
feat: Enable rewrite parameterisation
Enables user to provide a toml or default profile to change training settings for rewrite optimization
Resolves: MLIA-1004
Signed-off-by: Nathan Bailey <nathan.bailey@arm.com>
Change-Id: I3bf9f44b9a2062fb71ef36eb32c9a69edcc48061
Diffstat (limited to 'src/mlia/cli/main.py')
-rw-r--r-- | src/mlia/cli/main.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 9e1b7cd..32d46a6 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2024, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI main entry point.""" from __future__ import annotations @@ -203,11 +203,17 @@ def run_command(args: argparse.Namespace) -> int: try: logger.info("ML Inference Advisor %s", __version__) - if copy_profile_file(ctx, func_args): + if copy_profile_file(ctx, func_args, "target_profile"): logger.info( "\nThe target profile (.toml) is copied to the output directory: %s", ctx.output_dir, ) + if copy_profile_file(ctx, func_args, "optimization_profile"): + logger.info( + "\nThe optimization profile (.toml) is copied to " + "the output directory: %s", + ctx.output_dir, + ) args.func(**func_args) return 0 except KeyboardInterrupt: @@ -278,11 +284,13 @@ def init_and_run(commands: list[CommandInfo], argv: list[str] | None = None) -> return run_command(args) -def copy_profile_file(ctx: ExecutionContext, func_args: dict) -> bool: - """If present, copy the target profile file to the output directory.""" - if func_args.get("target_profile"): +def copy_profile_file( + ctx: ExecutionContext, func_args: dict, profile_to_copy: str +) -> bool: + """If present, copy the selected profile file to the output directory.""" + if func_args.get(profile_to_copy): return copy_profile_file_to_output_dir( - func_args["target_profile"], ctx.output_dir + func_args[profile_to_copy], ctx.output_dir, profile_to_copy ) return False |