From c6cfc78d5245c550016b9709686d2b32ab3fcd5b Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 2 Feb 2023 09:07:02 +0000 Subject: MLIA-461 Add parameter for the output directory - Add CLI parameter --output-dir - Rename ExecutionContext property working_dir into output_dir - Remove logic for default command as it is no longer needed Change-Id: I6387f6b688520ba1fc69a80167587297353620f6 --- src/mlia/cli/common.py | 10 +--- src/mlia/cli/main.py | 74 +++++++----------------- src/mlia/cli/options.py | 12 ++++ src/mlia/core/context.py | 37 ++++++------ tests/conftest.py | 4 +- tests/test_cli_main.py | 75 ++++++++++--------------- tests/test_core_context.py | 8 +-- tests/test_core_workflow.py | 6 +- tests/test_target_cortex_a_advice_generation.py | 2 +- tests/test_target_cortex_a_data_collection.py | 4 +- tests/test_target_ethos_u_advice_generation.py | 4 +- tests/test_target_tosa_advice_generation.py | 2 +- tests/test_target_tosa_data_collection.py | 4 +- tests_e2e/test_e2e.py | 11 +--- 14 files changed, 101 insertions(+), 152 deletions(-) diff --git a/src/mlia/cli/common.py b/src/mlia/cli/common.py index 077f456..f45dc65 100644 --- a/src/mlia/cli/common.py +++ b/src/mlia/cli/common.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """CLI common module.""" from __future__ import annotations @@ -15,7 +15,6 @@ class CommandInfo: func: Callable aliases: list[str] opt_groups: list[Callable[[argparse.ArgumentParser], None]] - is_default: bool = False name: str | None = None @property @@ -32,9 +31,4 @@ class CommandInfo: def command_help(self) -> str: """Return help message for the command.""" assert self.func.__doc__, "Command function does not have a docstring" - func_help = self.func.__doc__.splitlines()[0].rstrip(".") - - if self.is_default: - func_help = f"{func_help} [default]" - - return func_help + return self.func.__doc__.splitlines()[0].rstrip(".") diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py index 7ce7dc9..2b63124 100644 --- a/src/mlia/cli/main.py +++ b/src/mlia/cli/main.py @@ -27,6 +27,7 @@ from mlia.cli.options import add_debug_options from mlia.cli.options import add_keras_model_options from mlia.cli.options import add_model_options from mlia.cli.options import add_multi_optimization_options +from mlia.cli.options import add_output_directory from mlia.cli.options import add_output_options from mlia.cli.options import add_target_options from mlia.cli.options import get_output_format @@ -60,6 +61,7 @@ def get_commands() -> list[CommandInfo]: check, [], [ + add_output_directory, add_model_options, add_target_options, add_backend_options, @@ -72,6 +74,7 @@ def get_commands() -> list[CommandInfo]: optimize, [], [ + add_output_directory, add_keras_model_options, partial(add_target_options, profiles_to_skip=["tosa", "cortex-a"]), partial( @@ -86,7 +89,7 @@ def get_commands() -> list[CommandInfo]: ] -def backend_commands() -> list[CommandInfo]: +def get_backend_commands() -> list[CommandInfo]: """Return commands configuration.""" return [ CommandInfo( @@ -118,14 +121,6 @@ def backend_commands() -> list[CommandInfo]: ] -def get_default_command(commands: list[CommandInfo]) -> str | None: - """Get name of the default command.""" - marked_as_default = [cmd.command_name for cmd in commands if cmd.is_default] - assert len(marked_as_default) <= 1, "Only one command could be marked as default" - - return next(iter(marked_as_default), None) - - def get_possible_command_names(commands: list[CommandInfo]) -> list[str]: """Get all possible command names including aliases.""" return [ @@ -164,12 +159,13 @@ def setup_context( verbose="debug" in args and args.debug, action_resolver=CLIActionResolver(vars(args)), output_format=get_output_format(args), + output_dir=args.output_dir if "output_dir" in args else None, ) setup_logging(ctx.logs_path, ctx.verbose, ctx.output_format) # these parameters should not be passed into command function - skipped_params = ["func", "command", "debug", "json"] + skipped_params = ["func", "command", "debug", "json", "output_dir"] # pass these parameters only if command expects them expected_params = [context_var_name] @@ -198,7 +194,7 @@ def run_command(args: argparse.Namespace) -> int: try: logger.info(INFO_MESSAGE) logger.info( - "\nThis execution of MLIA uses working directory: %s", ctx.working_dir + "\nThis execution of MLIA uses output directory: %s", ctx.output_dir ) args.func(**func_args) return 0 @@ -231,25 +227,15 @@ def run_command(args: argparse.Namespace) -> int: logger.error(err_advice_message) finally: - logger.info( - "This execution of MLIA used working directory: %s", ctx.working_dir - ) + logger.info("This execution of MLIA used output directory: %s", ctx.output_dir) return 1 -def init_common_parser() -> argparse.ArgumentParser: - """Init common parser.""" - parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) - - return parser - - -def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.ArgumentParser: +def init_parser(commands: list[CommandInfo]) -> argparse.ArgumentParser: """Init subcommand parser.""" parser = argparse.ArgumentParser( description=INFO_MESSAGE, formatter_class=argparse.RawDescriptionHelpFormatter, - parents=[parent], add_help=False, allow_abbrev=False, ) @@ -268,50 +254,28 @@ def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.Argument help="Show program's version number and exit", ) + init_commands(parser, commands) return parser -def add_default_command_if_needed( - args: list[str], input_commands: list[CommandInfo] -) -> None: - """Add default command to the list of the arguments if needed.""" - default_command = get_default_command(input_commands) +def init_and_run(commands: list[CommandInfo], argv: list[str] | None = None) -> int: + """Init parser and run subcommand.""" + parser = init_parser(commands) + args = parser.parse_args(argv) - if default_command and len(args) > 0: - commands = get_possible_command_names(input_commands) - help_or_version = ["-h", "--help", "-v", "--version"] - - command_is_missing = args[0] not in [*commands, *help_or_version] - if command_is_missing: - args.insert(0, default_command) - - -def generic_main( - commands: list[CommandInfo], argv: list[str] | None = None -) -> argparse.Namespace: - """Enable multiple entry points.""" - common_parser = init_common_parser() - subcommand_parser = init_subcommand_parser(common_parser) - init_commands(subcommand_parser, commands) - - common_args, subcommand_args = common_parser.parse_known_args(argv) - - add_default_command_if_needed(subcommand_args, commands) - - args = subcommand_parser.parse_args(subcommand_args, common_args) - return args + return run_command(args) def main(argv: list[str] | None = None) -> int: """Entry point of the main application.""" - args = generic_main(get_commands(), argv) - return run_command(args) + commands = get_commands() + return init_and_run(commands, argv) def backend_main(argv: list[str] | None = None) -> int: """Entry point of the backend application.""" - args = generic_main(backend_commands(), argv) - return run_command(args) + commands = get_backend_commands() + return init_and_run(commands, argv) if __name__ == "__main__": diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py index e01f107..d154646 100644 --- a/src/mlia/cli/options.py +++ b/src/mlia/cli/options.py @@ -202,6 +202,18 @@ def add_backend_options( ) +def add_output_directory(parser: argparse.ArgumentParser) -> None: + """Add parameter for the output directory.""" + parser.add_argument( + "--output-dir", + type=Path, + help="Path to the directory where MLIA " + "stores artifacts, e.g. logs, target profiles and model files. " + "If not specified then MLIA will use temporary " + "directory instead.", + ) + + def parse_optimization_parameters( pruning: bool = False, clustering: bool = False, diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py index f8442a3..c0267f1 100644 --- a/src/mlia/core/context.py +++ b/src/mlia/core/context.py @@ -105,7 +105,7 @@ class ExecutionContext(Context): *, advice_category: set[AdviceCategory] = None, config_parameters: Mapping[str, Any] | None = None, - working_dir: str | Path | None = None, + output_dir: str | Path | None = None, event_handlers: list[EventHandler] | None = None, event_publisher: EventPublisher | None = None, verbose: bool = False, @@ -118,16 +118,16 @@ class ExecutionContext(Context): :param advice_category: requested advice categories :param config_parameters: dictionary like object with input parameters - :param working_dir: path to the directory that will be used as a place + :param output_dir: path to the directory that will be used as a place to store temporary files, logs, models. If not provided then - current working directory will be used instead + temporary directory will be used instead :param event_handlers: optional list of event handlers :param event_publisher: optional event publisher instance. If not provided then default implementation of event publisher will be used :param verbose: enable verbose output - :param logs_dir: name of the directory inside working directory where + :param logs_dir: name of the directory inside output directory where log files will be stored - :param models_dir: name of the directory inside working directory where + :param models_dir: name of the directory inside output directory where temporary models will be stored :param action_resolver: instance of the action resolver that could make advice actionable @@ -135,10 +135,11 @@ class ExecutionContext(Context): self._advice_category = advice_category or {AdviceCategory.COMPATIBILITY} self._config_parameters = config_parameters - if working_dir: - self._working_dir_path = Path(working_dir) + if output_dir: + self._output_dir_path = Path(output_dir) + self._output_dir_path.mkdir(exist_ok=True) else: - self._working_dir_path = generate_temp_workdir() + self._output_dir_path = generate_temp_output_dir() self._event_handlers = event_handlers self._event_publisher = event_publisher or DefaultEventPublisher() @@ -149,9 +150,9 @@ class ExecutionContext(Context): self._output_format = output_format @property - def working_dir(self) -> Path: - """Return working dir path.""" - return self._working_dir_path + def output_dir(self) -> Path: + """Return output dir path.""" + return self._output_dir_path @property def advice_category(self) -> set[AdviceCategory]: @@ -195,7 +196,7 @@ class ExecutionContext(Context): def get_model_path(self, model_filename: str) -> Path: """Return path for the model.""" - models_dir_path = self._working_dir_path / self.models_dir + models_dir_path = self._output_dir_path / self.models_dir models_dir_path.mkdir(exist_ok=True) return models_dir_path / model_filename @@ -203,7 +204,7 @@ class ExecutionContext(Context): @property def logs_path(self) -> Path: """Return path to the logs directory.""" - return self._working_dir_path / self.logs_dir + return self._output_dir_path / self.logs_dir @property def output_format(self) -> OutputFormat: @@ -231,7 +232,7 @@ class ExecutionContext(Context): ) return ( - f"ExecutionContext: working_dir={self._working_dir_path}, " + f"ExecutionContext: output_dir={self._output_dir_path}, " f"advice_category={category}, " f"config_parameters={self.config_parameters}, " f"verbose={self.verbose}, " @@ -239,7 +240,7 @@ class ExecutionContext(Context): ) -def generate_temp_workdir() -> Path: - """Generate a temporary working dir and returns the path.""" - working_dir = tempfile.mkdtemp(suffix=None, prefix="mlia-", dir=None) - return Path(working_dir) +def generate_temp_output_dir() -> Path: + """Generate a temporary output dir and returns the path.""" + output_dir = tempfile.mkdtemp(suffix=None, prefix="mlia-", dir=None) + return Path(output_dir) diff --git a/tests/conftest.py b/tests/conftest.py index e27acaf..b698a73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Pytest conf module.""" import shutil @@ -27,7 +27,7 @@ def fixture_test_resources_path() -> Path: @pytest.fixture(name="sample_context") def fixture_sample_context(tmpdir: str) -> ExecutionContext: """Return sample context fixture.""" - return ExecutionContext(working_dir=tmpdir) + return ExecutionContext(output_dir=tmpdir) @pytest.fixture(scope="session") diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index 9db5341..673031c 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse from functools import wraps +from pathlib import Path from typing import Any from typing import Callable from unittest.mock import ANY @@ -18,6 +19,8 @@ from mlia.backend.errors import BackendUnavailableError from mlia.cli.main import backend_main from mlia.cli.main import CommandInfo from mlia.cli.main import main +from mlia.cli.options import add_output_directory +from mlia.core.context import ExecutionContext from mlia.core.errors import InternalError from tests.utils.logging import clear_loggers @@ -44,59 +47,16 @@ def test_option_version(capfd: pytest.CaptureFixture) -> None: assert stderr == "" -@pytest.mark.parametrize( - "is_default, expected_command_help", - [(True, "Test command [default]"), (False, "Test command")], -) -def test_command_info(is_default: bool, expected_command_help: str) -> None: +def test_command_info() -> None: """Test properties of CommandInfo object.""" def test_command() -> None: """Test command.""" - command_info = CommandInfo(test_command, ["test"], [], is_default) + command_info = CommandInfo(test_command, ["test"], []) assert command_info.command_name == "test_command" assert command_info.command_name_and_aliases == ["test_command", "test"] - assert command_info.command_help == expected_command_help - - -def test_default_command(monkeypatch: pytest.MonkeyPatch) -> None: - """Test adding default command.""" - - def mock_command(func_mock: MagicMock, name: str) -> Callable[..., None]: - """Mock cli command.""" - - def sample_cmd_1(*args: Any, **kwargs: Any) -> None: - """Sample command.""" - func_mock(*args, **kwargs) - - ret_func = sample_cmd_1 - ret_func.__name__ = name - - return ret_func - - non_default_command = MagicMock() - - def non_default_command_params(parser: argparse.ArgumentParser) -> None: - """Add parameters for non default command.""" - parser.add_argument("--param") - - monkeypatch.setattr( - "mlia.cli.main.get_commands", - MagicMock( - return_value=[ - CommandInfo( - func=mock_command(non_default_command, "non_default_command"), - aliases=["command2"], - opt_groups=[non_default_command_params], - is_default=False, - ), - ] - ), - ) - - main(["command2", "--param", "test"]) - non_default_command.assert_called_once_with(param="test") + assert command_info.command_help == "Test command" def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: @@ -293,6 +253,29 @@ def test_commands_execution( mock.assert_called_once_with(*expected_call.args, **expected_call.kwargs) +def test_passing_output_directory_parameter( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test passing parameter --output-dir.""" + passed_context: ExecutionContext | None = None + + def sample_command(ctx: ExecutionContext) -> None: + """Sample command.""" + nonlocal passed_context + passed_context = ctx + + monkeypatch.setattr( + "mlia.cli.main.get_commands", + lambda: [CommandInfo(sample_command, [], [add_output_directory])], + ) + + output_dir = tmp_path / "output" + main(["sample_command", "--output-dir", output_dir.as_posix()]) + + assert passed_context is not None + assert passed_context.output_dir == output_dir + + @pytest.mark.parametrize( "params, expected_call", [ diff --git a/tests/test_core_context.py b/tests/test_core_context.py index 0e7145f..814cb6a 100644 --- a/tests/test_core_context.py +++ b/tests/test_core_context.py @@ -52,7 +52,7 @@ def test_execution_context(tmpdir: str) -> None: context = ExecutionContext( advice_category=category, config_parameters={"param": "value"}, - working_dir=tmpdir, + output_dir=tmpdir, event_handlers=[], event_publisher=publisher, verbose=True, @@ -72,14 +72,14 @@ def test_execution_context(tmpdir: str) -> None: assert context.output_format == "json" assert str(context) == ( f"ExecutionContext: " - f"working_dir={tmpdir}, " + f"output_dir={tmpdir}, " "advice_category={'COMPATIBILITY'}, " "config_parameters={'param': 'value'}, " "verbose=True, " "output_format=json" ) - context_with_default_params = ExecutionContext(working_dir=tmpdir) + context_with_default_params = ExecutionContext(output_dir=tmpdir) assert context_with_default_params.advice_category == {AdviceCategory.COMPATIBILITY} assert context_with_default_params.config_parameters is None assert context_with_default_params.event_handlers is None @@ -94,7 +94,7 @@ def test_execution_context(tmpdir: str) -> None: assert context_with_default_params.output_format == "plain_text" expected_str = ( - f"ExecutionContext: working_dir={tmpdir}, " + f"ExecutionContext: output_dir={tmpdir}, " "advice_category={'COMPATIBILITY'}, " "config_parameters=None, " "verbose=False, " diff --git a/tests/test_core_workflow.py b/tests/test_core_workflow.py index 470e572..d21ced8 100644 --- a/tests/test_core_workflow.py +++ b/tests/test_core_workflow.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module workflow.""" from dataclasses import dataclass @@ -62,7 +62,7 @@ def test_workflow_executor(tmpdir: str) -> None: advice_producer_mock2.get_advice.return_value = [Advice(["Good advice!"])] context = ExecutionContext( - working_dir=tmpdir, + output_dir=tmpdir, event_handlers=[handler_mock], event_publisher=DefaultEventPublisher(), ) @@ -127,7 +127,7 @@ def test_workflow_executor_failed(tmpdir: str) -> None: handler_mock = MagicMock(spec=EventHandler) context = ExecutionContext( - working_dir=tmpdir, + output_dir=tmpdir, event_handlers=[handler_mock], event_publisher=DefaultEventPublisher(), ) diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py index 1997c52..b9edbb5 100644 --- a/tests/test_target_cortex_a_advice_generation.py +++ b/tests/test_target_cortex_a_advice_generation.py @@ -189,7 +189,7 @@ def test_cortex_a_advice_producer( context = ExecutionContext( advice_category=advice_category, - working_dir=tmpdir, + output_dir=tmpdir, ) producer.set_context(context) diff --git a/tests/test_target_cortex_a_data_collection.py b/tests/test_target_cortex_a_data_collection.py index 7504166..d5f5a2d 100644 --- a/tests/test_target_cortex_a_data_collection.py +++ b/tests/test_target_cortex_a_data_collection.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Cortex-A data collection module.""" from pathlib import Path @@ -22,7 +22,7 @@ def check_cortex_a_data_collection( MagicMock(return_value=CortexACompatibilityInfo(True, [])), ) - context = ExecutionContext(working_dir=tmpdir) + context = ExecutionContext(output_dir=tmpdir) collector = CortexAOperatorCompatibility(model) collector.set_context(context) diff --git a/tests/test_target_ethos_u_advice_generation.py b/tests/test_target_ethos_u_advice_generation.py index e93eeba..772fc56 100644 --- a/tests/test_target_ethos_u_advice_generation.py +++ b/tests/test_target_ethos_u_advice_generation.py @@ -370,7 +370,7 @@ def test_ethosu_advice_producer( context = ExecutionContext( advice_category=advice_category, - working_dir=tmpdir, + output_dir=tmpdir, action_resolver=action_resolver, ) @@ -475,7 +475,7 @@ def test_ethosu_static_advice_producer( context = ExecutionContext( advice_category=advice_category, - working_dir=tmpdir, + output_dir=tmpdir, action_resolver=action_resolver, ) producer.set_context(context) diff --git a/tests/test_target_tosa_advice_generation.py b/tests/test_target_tosa_advice_generation.py index d5ebbd7..61b74cd 100644 --- a/tests/test_target_tosa_advice_generation.py +++ b/tests/test_target_tosa_advice_generation.py @@ -47,7 +47,7 @@ def test_tosa_advice_producer( context = ExecutionContext( advice_category=advice_category, - working_dir=tmpdir, + output_dir=tmpdir, ) producer.set_context(context) diff --git a/tests/test_target_tosa_data_collection.py b/tests/test_target_tosa_data_collection.py index 9d590ca..75192fa 100644 --- a/tests/test_target_tosa_data_collection.py +++ b/tests/test_target_tosa_data_collection.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for TOSA data collection module.""" from pathlib import Path @@ -19,7 +19,7 @@ def test_tosa_data_collection( "mlia.target.tosa.data_collection.get_tosa_compatibility_info", MagicMock(return_value=TOSACompatibilityInfo(True, [])), ) - context = ExecutionContext(working_dir=tmpdir) + context = ExecutionContext(output_dir=tmpdir) collector = TOSAOperatorCompatibility(test_tflite_model) collector.set_context(context) diff --git a/tests_e2e/test_e2e.py b/tests_e2e/test_e2e.py index beddaed..4de640b 100644 --- a/tests_e2e/test_e2e.py +++ b/tests_e2e/test_e2e.py @@ -23,9 +23,7 @@ import pytest from mlia.cli.config import get_available_backends from mlia.cli.main import get_commands from mlia.cli.main import get_possible_command_names -from mlia.cli.main import init_commands -from mlia.cli.main import init_common_parser -from mlia.cli.main import init_subcommand_parser +from mlia.cli.main import init_parser from mlia.utils.filesystem import get_supported_profile_names from mlia.utils.types import is_list_of @@ -155,11 +153,8 @@ def get_config_file() -> Path: def get_args_parser() -> Any: """Return MLIA argument parser.""" - common_parser = init_common_parser() - subcommand_parser = init_subcommand_parser(common_parser) - init_commands(subcommand_parser, get_commands()) - - return subcommand_parser + commands = get_commands() + return init_parser(commands) def replace_element(params: list[str], idx: int, value: str) -> list[str]: -- cgit v1.2.1