aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_cli_main.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/mlia/test_cli_main.py')
-rw-r--r--tests/mlia/test_cli_main.py357
1 files changed, 357 insertions, 0 deletions
diff --git a/tests/mlia/test_cli_main.py b/tests/mlia/test_cli_main.py
new file mode 100644
index 0000000..a0937d5
--- /dev/null
+++ b/tests/mlia/test_cli_main.py
@@ -0,0 +1,357 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for main module."""
+import argparse
+from functools import wraps
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import List
+from unittest.mock import ANY
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+import mlia
+from mlia.cli.main import CommandInfo
+from mlia.cli.main import main
+from mlia.core.context import ExecutionContext
+from tests.mlia.utils.logging import clear_loggers
+
+
+def teardown_function() -> None:
+ """Perform action after test completion.
+
+ This function is launched automatically by pytest after each test
+ in this module.
+ """
+ clear_loggers()
+
+
+def test_option_version(capfd: pytest.CaptureFixture) -> None:
+ """Test --version."""
+ with pytest.raises(SystemExit) as ex:
+ main(["--version"])
+
+ assert ex.type == SystemExit
+ assert ex.value.code == 0
+
+ stdout, stderr = capfd.readouterr()
+ assert len(stdout.splitlines()) == 1
+ assert stderr == ""
+
+
+@pytest.mark.parametrize(
+ "is_default, expected_command_help",
+ [(True, "Test command [default]"), (False, "Test command")],
+)
+def test_command_info(is_default: bool, expected_command_help: str) -> None:
+ """Test properties of CommandInfo object."""
+
+ def test_command() -> None:
+ """Test command."""
+
+ command_info = CommandInfo(test_command, ["test"], [], is_default)
+ assert command_info.command_name == "test_command"
+ assert command_info.command_name_and_aliases == ["test_command", "test"]
+ assert command_info.command_help == expected_command_help
+
+
+def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
+ """Test adding default command."""
+
+ def mock_command(
+ func_mock: MagicMock, name: str, with_working_dir: bool
+ ) -> Callable[..., None]:
+ """Mock cli command."""
+
+ def sample_cmd_1(*args: Any, **kwargs: Any) -> None:
+ """Sample command."""
+ func_mock(*args, **kwargs)
+
+ def sample_cmd_2(ctx: ExecutionContext, **kwargs: Any) -> None:
+ """Another sample command."""
+ func_mock(ctx=ctx, **kwargs)
+
+ ret_func = sample_cmd_2 if with_working_dir else sample_cmd_1
+ ret_func.__name__ = name
+
+ return ret_func # type: ignore
+
+ default_command = MagicMock()
+ non_default_command = MagicMock()
+
+ def default_command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for default command."""
+ parser.add_argument("--sample")
+ parser.add_argument("--default_arg", default="123")
+
+ def non_default_command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for non default command."""
+ parser.add_argument("--param")
+
+ monkeypatch.setattr(
+ "mlia.cli.main.get_commands",
+ MagicMock(
+ return_value=[
+ CommandInfo(
+ func=mock_command(default_command, "default_command", True),
+ aliases=["command1"],
+ opt_groups=[default_command_params],
+ is_default=True,
+ ),
+ CommandInfo(
+ func=mock_command(
+ non_default_command, "non_default_command", False
+ ),
+ aliases=["command2"],
+ opt_groups=[non_default_command_params],
+ is_default=False,
+ ),
+ ]
+ ),
+ )
+
+ tmp_working_dir = str(tmp_path)
+ main(["--working-dir", tmp_working_dir, "--sample", "1"])
+ main(["command2", "--param", "test"])
+
+ default_command.assert_called_once_with(ctx=ANY, sample="1", default_arg="123")
+ non_default_command.assert_called_once_with(param="test")
+
+
+@pytest.mark.parametrize(
+ "params, expected_call",
+ [
+ [
+ ["operators", "sample_model.tflite"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["ops", "sample_model.tflite"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators", "sample_model.tflite", "--target-profile", "ethos-u55-128"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-128",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model=None,
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators", "--supported-ops-report"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model=None,
+ output=None,
+ supported_ops_report=True,
+ ),
+ ],
+ [
+ [
+ "all_tests",
+ "sample_model.h5",
+ "--optimization-type",
+ "pruning",
+ "--optimization-target",
+ "0.5",
+ ],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning",
+ optimization_target="0.5",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["sample_model.h5"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["performance", "sample_model.h5", "--output", "result.json"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ output="result.json",
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["perf", "sample_model.h5", "--target-profile", "ethos-u55-128"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-128",
+ model="sample_model.h5",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["optimization", "sample_model.h5"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["optimization", "sample_model.h5", "--evaluate-on", "some_backend"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["some_backend"],
+ ),
+ ],
+ ],
+)
+def test_commands_execution(
+ monkeypatch: pytest.MonkeyPatch, params: List[str], expected_call: Any
+) -> None:
+ """Test calling commands from the main function."""
+ mock = MagicMock()
+
+ def wrap_mock_command(command: Callable) -> Callable:
+ """Wrap the command with the mock."""
+
+ @wraps(command)
+ def mock_command(*args: Any, **kwargs: Any) -> Any:
+ """Mock the command."""
+ mock(*args, **kwargs)
+
+ return mock_command
+
+ monkeypatch.setattr(
+ "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"])
+ )
+
+ monkeypatch.setattr(
+ "mlia.cli.options.get_available_backends",
+ MagicMock(return_value=["Vela", "some_backend"]),
+ )
+
+ for command in ["all_tests", "operators", "performance", "optimization"]:
+ monkeypatch.setattr(
+ f"mlia.cli.main.{command}",
+ wrap_mock_command(getattr(mlia.cli.main, command)),
+ )
+
+ main(params)
+
+ mock.assert_called_once_with(*expected_call.args, **expected_call.kwargs)
+
+
+@pytest.mark.parametrize(
+ "verbose, exc_mock, expected_output",
+ [
+ [
+ True,
+ MagicMock(side_effect=Exception("Error")),
+ [
+ "Execution finished with error: Error",
+ f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
+ "for more details",
+ ],
+ ],
+ [
+ False,
+ MagicMock(side_effect=Exception("Error")),
+ [
+ "Execution finished with error: Error",
+ f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
+ "for more details, or enable verbose mode",
+ ],
+ ],
+ [
+ False,
+ MagicMock(side_effect=KeyboardInterrupt()),
+ ["Execution has been interrupted"],
+ ],
+ ],
+)
+def test_verbose_output(
+ monkeypatch: pytest.MonkeyPatch,
+ capsys: pytest.CaptureFixture,
+ verbose: bool,
+ exc_mock: MagicMock,
+ expected_output: List[str],
+) -> None:
+ """Test flag --verbose."""
+
+ def command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for non default command."""
+ parser.add_argument("--verbose", action="store_true")
+
+ def command() -> None:
+ """Run test command."""
+ exc_mock()
+
+ monkeypatch.setattr(
+ "mlia.cli.main.get_commands",
+ MagicMock(
+ return_value=[
+ CommandInfo(
+ func=command,
+ aliases=["command"],
+ opt_groups=[command_params],
+ ),
+ ]
+ ),
+ )
+
+ params = ["command"]
+ if verbose:
+ params.append("--verbose")
+
+ exit_code = main(params)
+ assert exit_code == 1
+
+ stdout, _ = capsys.readouterr()
+ for expected_message in expected_output:
+ assert expected_message in stdout