aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/cli/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/cli/main.py')
-rw-r--r--src/mlia/cli/main.py74
1 files changed, 19 insertions, 55 deletions
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
index 7ce7dc9..2b63124 100644
--- a/src/mlia/cli/main.py
+++ b/src/mlia/cli/main.py
@@ -27,6 +27,7 @@ from mlia.cli.options import add_debug_options
from mlia.cli.options import add_keras_model_options
from mlia.cli.options import add_model_options
from mlia.cli.options import add_multi_optimization_options
+from mlia.cli.options import add_output_directory
from mlia.cli.options import add_output_options
from mlia.cli.options import add_target_options
from mlia.cli.options import get_output_format
@@ -60,6 +61,7 @@ def get_commands() -> list[CommandInfo]:
check,
[],
[
+ add_output_directory,
add_model_options,
add_target_options,
add_backend_options,
@@ -72,6 +74,7 @@ def get_commands() -> list[CommandInfo]:
optimize,
[],
[
+ add_output_directory,
add_keras_model_options,
partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]),
partial(
@@ -86,7 +89,7 @@ def get_commands() -> list[CommandInfo]:
]
-def backend_commands() -> list[CommandInfo]:
+def get_backend_commands() -> list[CommandInfo]:
"""Return commands configuration."""
return [
CommandInfo(
@@ -118,14 +121,6 @@ def backend_commands() -> list[CommandInfo]:
]
-def get_default_command(commands: list[CommandInfo]) -> str | None:
- """Get name of the default command."""
- marked_as_default = [cmd.command_name for cmd in commands if cmd.is_default]
- assert len(marked_as_default) <= 1, "Only one command could be marked as default"
-
- return next(iter(marked_as_default), None)
-
-
def get_possible_command_names(commands: list[CommandInfo]) -> list[str]:
"""Get all possible command names including aliases."""
return [
@@ -164,12 +159,13 @@ def setup_context(
verbose="debug" in args and args.debug,
action_resolver=CLIActionResolver(vars(args)),
output_format=get_output_format(args),
+ output_dir=args.output_dir if "output_dir" in args else None,
)
setup_logging(ctx.logs_path, ctx.verbose, ctx.output_format)
# these parameters should not be passed into command function
- skipped_params = ["func", "command", "debug", "json"]
+ skipped_params = ["func", "command", "debug", "json", "output_dir"]
# pass these parameters only if command expects them
expected_params = [context_var_name]
@@ -198,7 +194,7 @@ def run_command(args: argparse.Namespace) -> int:
try:
logger.info(INFO_MESSAGE)
logger.info(
- "\nThis execution of MLIA uses working directory: %s", ctx.working_dir
+ "\nThis execution of MLIA uses output directory: %s", ctx.output_dir
)
args.func(**func_args)
return 0
@@ -231,25 +227,15 @@ def run_command(args: argparse.Namespace) -> int:
logger.error(err_advice_message)
finally:
- logger.info(
- "This execution of MLIA used working directory: %s", ctx.working_dir
- )
+ logger.info("This execution of MLIA used output directory: %s", ctx.output_dir)
return 1
-def init_common_parser() -> argparse.ArgumentParser:
- """Init common parser."""
- parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
-
- return parser
-
-
-def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.ArgumentParser:
+def init_parser(commands: list[CommandInfo]) -> argparse.ArgumentParser:
"""Init subcommand parser."""
parser = argparse.ArgumentParser(
description=INFO_MESSAGE,
formatter_class=argparse.RawDescriptionHelpFormatter,
- parents=[parent],
add_help=False,
allow_abbrev=False,
)
@@ -268,50 +254,28 @@ def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.Argument
help="Show program's version number and exit",
)
+ init_commands(parser, commands)
return parser
-def add_default_command_if_needed(
- args: list[str], input_commands: list[CommandInfo]
-) -> None:
- """Add default command to the list of the arguments if needed."""
- default_command = get_default_command(input_commands)
+def init_and_run(commands: list[CommandInfo], argv: list[str] | None = None) -> int:
+ """Init parser and run subcommand."""
+ parser = init_parser(commands)
+ args = parser.parse_args(argv)
- if default_command and len(args) > 0:
- commands = get_possible_command_names(input_commands)
- help_or_version = ["-h", "--help", "-v", "--version"]
-
- command_is_missing = args[0] not in [*commands, *help_or_version]
- if command_is_missing:
- args.insert(0, default_command)
-
-
-def generic_main(
- commands: list[CommandInfo], argv: list[str] | None = None
-) -> argparse.Namespace:
- """Enable multiple entry points."""
- common_parser = init_common_parser()
- subcommand_parser = init_subcommand_parser(common_parser)
- init_commands(subcommand_parser, commands)
-
- common_args, subcommand_args = common_parser.parse_known_args(argv)
-
- add_default_command_if_needed(subcommand_args, commands)
-
- args = subcommand_parser.parse_args(subcommand_args, common_args)
- return args
+ return run_command(args)
def main(argv: list[str] | None = None) -> int:
"""Entry point of the main application."""
- args = generic_main(get_commands(), argv)
- return run_command(args)
+ commands = get_commands()
+ return init_and_run(commands, argv)
def backend_main(argv: list[str] | None = None) -> int:
"""Entry point of the backend application."""
- args = generic_main(backend_commands(), argv)
- return run_command(args)
+ commands = get_backend_commands()
+ return init_and_run(commands, argv)
if __name__ == "__main__":