diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/conftest.py | 4 | ||||
-rw-r--r-- | tests/test_cli_main.py | 75 | ||||
-rw-r--r-- | tests/test_core_context.py | 8 | ||||
-rw-r--r-- | tests/test_core_workflow.py | 6 | ||||
-rw-r--r-- | tests/test_target_cortex_a_advice_generation.py | 2 | ||||
-rw-r--r-- | tests/test_target_cortex_a_data_collection.py | 4 | ||||
-rw-r--r-- | tests/test_target_ethos_u_advice_generation.py | 4 | ||||
-rw-r--r-- | tests/test_target_tosa_advice_generation.py | 2 | ||||
-rw-r--r-- | tests/test_target_tosa_data_collection.py | 4 |
9 files changed, 46 insertions, 63 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index e27acaf..b698a73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Pytest conf module.""" import shutil @@ -27,7 +27,7 @@ def fixture_test_resources_path() -> Path: @pytest.fixture(name="sample_context") def fixture_sample_context(tmpdir: str) -> ExecutionContext: """Return sample context fixture.""" - return ExecutionContext(working_dir=tmpdir) + return ExecutionContext(output_dir=tmpdir) @pytest.fixture(scope="session") diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index 9db5341..673031c 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse from functools import wraps +from pathlib import Path from typing import Any from typing import Callable from unittest.mock import ANY @@ -18,6 +19,8 @@ from mlia.backend.errors import BackendUnavailableError from mlia.cli.main import backend_main from mlia.cli.main import CommandInfo from mlia.cli.main import main +from mlia.cli.options import add_output_directory +from mlia.core.context import ExecutionContext from mlia.core.errors import InternalError from tests.utils.logging import clear_loggers @@ -44,59 +47,16 @@ def test_option_version(capfd: pytest.CaptureFixture) -> None: assert stderr == "" -@pytest.mark.parametrize( - "is_default, expected_command_help", - [(True, "Test command [default]"), (False, "Test command")], -) -def test_command_info(is_default: bool, expected_command_help: str) -> None: +def test_command_info() -> None: """Test properties of CommandInfo object.""" def test_command() -> None: """Test command.""" - command_info = CommandInfo(test_command, ["test"], [], is_default) + command_info = CommandInfo(test_command, ["test"], []) assert command_info.command_name == "test_command" assert command_info.command_name_and_aliases == ["test_command", "test"] - assert command_info.command_help == expected_command_help - - -def test_default_command(monkeypatch: pytest.MonkeyPatch) -> None: - """Test adding default command.""" - - def mock_command(func_mock: MagicMock, name: str) -> Callable[..., None]: - """Mock cli command.""" - - def sample_cmd_1(*args: Any, **kwargs: Any) -> None: - """Sample command.""" - func_mock(*args, **kwargs) - - ret_func = sample_cmd_1 - ret_func.__name__ = name - - return ret_func - - non_default_command = MagicMock() - - def non_default_command_params(parser: argparse.ArgumentParser) -> None: - """Add parameters for non default command.""" - parser.add_argument("--param") - - monkeypatch.setattr( - "mlia.cli.main.get_commands", - MagicMock( - return_value=[ - CommandInfo( - func=mock_command(non_default_command, "non_default_command"), - aliases=["command2"], - opt_groups=[non_default_command_params], - is_default=False, - ), - ] - ), - ) - - main(["command2", "--param", "test"]) - non_default_command.assert_called_once_with(param="test") + assert command_info.command_help == "Test command" def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable: @@ -293,6 +253,29 @@ def test_commands_execution( mock.assert_called_once_with(*expected_call.args, **expected_call.kwargs) +def test_passing_output_directory_parameter( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test passing parameter --output-dir.""" + passed_context: ExecutionContext | None = None + + def sample_command(ctx: ExecutionContext) -> None: + """Sample command.""" + nonlocal passed_context + passed_context = ctx + + monkeypatch.setattr( + "mlia.cli.main.get_commands", + lambda: [CommandInfo(sample_command, [], [add_output_directory])], + ) + + output_dir = tmp_path / "output" + main(["sample_command", "--output-dir", output_dir.as_posix()]) + + assert passed_context is not None + assert passed_context.output_dir == output_dir + + @pytest.mark.parametrize( "params, expected_call", [ diff --git a/tests/test_core_context.py b/tests/test_core_context.py index 0e7145f..814cb6a 100644 --- a/tests/test_core_context.py +++ b/tests/test_core_context.py @@ -52,7 +52,7 @@ def test_execution_context(tmpdir: str) -> None: context = ExecutionContext( advice_category=category, config_parameters={"param": "value"}, - working_dir=tmpdir, + output_dir=tmpdir, event_handlers=[], event_publisher=publisher, verbose=True, @@ -72,14 +72,14 @@ def test_execution_context(tmpdir: str) -> None: assert context.output_format == "json" assert str(context) == ( f"ExecutionContext: " - f"working_dir={tmpdir}, " + f"output_dir={tmpdir}, " "advice_category={'COMPATIBILITY'}, " "config_parameters={'param': 'value'}, " "verbose=True, " "output_format=json" ) - context_with_default_params = ExecutionContext(working_dir=tmpdir) + context_with_default_params = ExecutionContext(output_dir=tmpdir) assert context_with_default_params.advice_category == {AdviceCategory.COMPATIBILITY} assert context_with_default_params.config_parameters is None assert context_with_default_params.event_handlers is None @@ -94,7 +94,7 @@ def test_execution_context(tmpdir: str) -> None: assert context_with_default_params.output_format == "plain_text" expected_str = ( - f"ExecutionContext: working_dir={tmpdir}, " + f"ExecutionContext: output_dir={tmpdir}, " "advice_category={'COMPATIBILITY'}, " "config_parameters=None, " "verbose=False, " diff --git a/tests/test_core_workflow.py b/tests/test_core_workflow.py index 470e572..d21ced8 100644 --- a/tests/test_core_workflow.py +++ b/tests/test_core_workflow.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module workflow.""" from dataclasses import dataclass @@ -62,7 +62,7 @@ def test_workflow_executor(tmpdir: str) -> None: advice_producer_mock2.get_advice.return_value = [Advice(["Good advice!"])] context = ExecutionContext( - working_dir=tmpdir, + output_dir=tmpdir, event_handlers=[handler_mock], event_publisher=DefaultEventPublisher(), ) @@ -127,7 +127,7 @@ def test_workflow_executor_failed(tmpdir: str) -> None: handler_mock = MagicMock(spec=EventHandler) context = ExecutionContext( - working_dir=tmpdir, + output_dir=tmpdir, event_handlers=[handler_mock], event_publisher=DefaultEventPublisher(), ) diff --git a/tests/test_target_cortex_a_advice_generation.py b/tests/test_target_cortex_a_advice_generation.py index 1997c52..b9edbb5 100644 --- a/tests/test_target_cortex_a_advice_generation.py +++ b/tests/test_target_cortex_a_advice_generation.py @@ -189,7 +189,7 @@ def test_cortex_a_advice_producer( context = ExecutionContext( advice_category=advice_category, - working_dir=tmpdir, + output_dir=tmpdir, ) producer.set_context(context) diff --git a/tests/test_target_cortex_a_data_collection.py b/tests/test_target_cortex_a_data_collection.py index 7504166..d5f5a2d 100644 --- a/tests/test_target_cortex_a_data_collection.py +++ b/tests/test_target_cortex_a_data_collection.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for Cortex-A data collection module.""" from pathlib import Path @@ -22,7 +22,7 @@ def check_cortex_a_data_collection( MagicMock(return_value=CortexACompatibilityInfo(True, [])), ) - context = ExecutionContext(working_dir=tmpdir) + context = ExecutionContext(output_dir=tmpdir) collector = CortexAOperatorCompatibility(model) collector.set_context(context) diff --git a/tests/test_target_ethos_u_advice_generation.py b/tests/test_target_ethos_u_advice_generation.py index e93eeba..772fc56 100644 --- a/tests/test_target_ethos_u_advice_generation.py +++ b/tests/test_target_ethos_u_advice_generation.py @@ -370,7 +370,7 @@ def test_ethosu_advice_producer( context = ExecutionContext( advice_category=advice_category, - working_dir=tmpdir, + output_dir=tmpdir, action_resolver=action_resolver, ) @@ -475,7 +475,7 @@ def test_ethosu_static_advice_producer( context = ExecutionContext( advice_category=advice_category, - working_dir=tmpdir, + output_dir=tmpdir, action_resolver=action_resolver, ) producer.set_context(context) diff --git a/tests/test_target_tosa_advice_generation.py b/tests/test_target_tosa_advice_generation.py index d5ebbd7..61b74cd 100644 --- a/tests/test_target_tosa_advice_generation.py +++ b/tests/test_target_tosa_advice_generation.py @@ -47,7 +47,7 @@ def test_tosa_advice_producer( context = ExecutionContext( advice_category=advice_category, - working_dir=tmpdir, + output_dir=tmpdir, ) producer.set_context(context) diff --git a/tests/test_target_tosa_data_collection.py b/tests/test_target_tosa_data_collection.py index 9d590ca..75192fa 100644 --- a/tests/test_target_tosa_data_collection.py +++ b/tests/test_target_tosa_data_collection.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for TOSA data collection module.""" from pathlib import Path @@ -19,7 +19,7 @@ def test_tosa_data_collection( "mlia.target.tosa.data_collection.get_tosa_compatibility_info", MagicMock(return_value=TOSACompatibilityInfo(True, [])), ) - context = ExecutionContext(working_dir=tmpdir) + context = ExecutionContext(output_dir=tmpdir) collector = TOSAOperatorCompatibility(test_tflite_model) collector.set_context(context) |