diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2023-02-02 09:07:02 +0000 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2023-02-06 11:51:52 +0000 |
commit | c6cfc78d5245c550016b9709686d2b32ab3fcd5b (patch) | |
tree | 42786a3755370d0fefea659d7e532a5483955239 /src/mlia/cli/main.py | |
parent | f1eaff3c9790464bed3183ff76555cf815166f50 (diff) | |
download | mlia-c6cfc78d5245c550016b9709686d2b32ab3fcd5b.tar.gz |
MLIA-461 Add parameter for the output directory
- Add CLI parameter --output-dir
- Rename ExecutionContext property working_dir into output_dir
- Remove logic for default command as it is no longer needed
Change-Id: I6387f6b688520ba1fc69a80167587297353620f6
Diffstat (limited to 'src/mlia/cli/main.py')
-rw-r--r-- | src/mlia/cli/main.py | 74 |
1 files changed, 19 insertions, 55 deletions
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 7ce7dc9..2b63124 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -27,6 +27,7 @@ from mlia.cli.options import add_debug_options from mlia.cli.options import add_keras_model_options from mlia.cli.options import add_model_options from mlia.cli.options import add_multi_optimization_options +from mlia.cli.options import add_output_directory from mlia.cli.options import add_output_options from mlia.cli.options import add_target_options from mlia.cli.options import get_output_format @@ -60,6 +61,7 @@ def get_commands() -> list[CommandInfo]: check, [], [ + add_output_directory, add_model_options, add_target_options, add_backend_options, @@ -72,6 +74,7 @@ def get_commands() -> list[CommandInfo]: optimize, [], [ + add_output_directory, add_keras_model_options, partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]), partial( @@ -86,7 +89,7 @@ def get_commands() -> list[CommandInfo]: ] -def backend_commands() -> list[CommandInfo]: +def get_backend_commands() -> list[CommandInfo]: """Return commands configuration.""" return [ CommandInfo( @@ -118,14 +121,6 @@ def backend_commands() -> list[CommandInfo]: ] -def get_default_command(commands: list[CommandInfo]) -> str | None: - """Get name of the default command.""" - marked_as_default = [cmd.command_name for cmd in commands if cmd.is_default] - assert len(marked_as_default) <= 1, "Only one command could be marked as default" - - return next(iter(marked_as_default), None) - - def get_possible_command_names(commands: list[CommandInfo]) -> list[str]: """Get all possible command names including aliases.""" return [ @@ -164,12 +159,13 @@ def setup_context( verbose="debug" in args and args.debug, action_resolver=CLIActionResolver(vars(args)), output_format=get_output_format(args), + output_dir=args.output_dir if "output_dir" in args else None, ) setup_logging(ctx.logs_path, ctx.verbose, ctx.output_format) # these parameters should not be passed into command function - skipped_params = ["func", "command", "debug", "json"] + skipped_params = ["func", "command", "debug", "json", "output_dir"] # pass these parameters only if command expects them expected_params = [context_var_name] @@ -198,7 +194,7 @@ def run_command(args: argparse.Namespace) -> int: try: logger.info(INFO_MESSAGE) logger.info( - "\nThis execution of MLIA uses working directory: %s", ctx.working_dir + "\nThis execution of MLIA uses output directory: %s", ctx.output_dir ) args.func(**func_args) return 0 @@ -231,25 +227,15 @@ def run_command(args: argparse.Namespace) -> int: logger.error(err_advice_message) finally: - logger.info( - "This execution of MLIA used working directory: %s", ctx.working_dir - ) + logger.info("This execution of MLIA used output directory: %s", ctx.output_dir) return 1 -def init_common_parser() -> argparse.ArgumentParser: - """Init common parser.""" - parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) - - return parser - - -def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.ArgumentParser: +def init_parser(commands: list[CommandInfo]) -> argparse.ArgumentParser: """Init subcommand parser.""" parser = argparse.ArgumentParser( description=INFO_MESSAGE, formatter_class=argparse.RawDescriptionHelpFormatter, - parents=[parent], add_help=False, allow_abbrev=False, ) @@ -268,50 +254,28 @@ def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.Argument help="Show program's version number and exit", ) + init_commands(parser, commands) return parser -def add_default_command_if_needed( - args: list[str], input_commands: list[CommandInfo] -) -> None: - """Add default command to the list of the arguments if needed.""" - default_command = get_default_command(input_commands) +def init_and_run(commands: list[CommandInfo], argv: list[str] | None = None) -> int: + """Init parser and run subcommand.""" + parser = init_parser(commands) + args = parser.parse_args(argv) - if default_command and len(args) > 0: - commands = get_possible_command_names(input_commands) - help_or_version = ["-h", "--help", "-v", "--version"] - - command_is_missing = args[0] not in [*commands, *help_or_version] - if command_is_missing: - args.insert(0, default_command) - - -def generic_main( - commands: list[CommandInfo], argv: list[str] | None = None -) -> argparse.Namespace: - """Enable multiple entry points.""" - common_parser = init_common_parser() - subcommand_parser = init_subcommand_parser(common_parser) - init_commands(subcommand_parser, commands) - - common_args, subcommand_args = common_parser.parse_known_args(argv) - - add_default_command_if_needed(subcommand_args, commands) - - args = subcommand_parser.parse_args(subcommand_args, common_args) - return args + return run_command(args) def main(argv: list[str] | None = None) -> int: """Entry point of the main application.""" - args = generic_main(get_commands(), argv) - return run_command(args) + commands = get_commands() + return init_and_run(commands, argv) def backend_main(argv: list[str] | None = None) -> int: """Entry point of the backend application.""" - args = generic_main(backend_commands(), argv) - return run_command(args) + commands = get_backend_commands() + return init_and_run(commands, argv) if __name__ == "__main__": |