aboutsummaryrefslogtreecommitdiff
path: root/tests/test_cli_main.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_cli_main.py')
-rw-r--r--tests/test_cli_main.py228
1 files changed, 118 insertions, 110 deletions
diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py
index 925f1e4..5a9c0c9 100644
--- a/tests/test_cli_main.py
+++ b/tests/test_cli_main.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for main module."""
from __future__ import annotations
@@ -19,7 +19,6 @@ from mlia.backend.errors import BackendUnavailableError
from mlia.cli.main import backend_main
from mlia.cli.main import CommandInfo
from mlia.cli.main import main
-from mlia.core.context import ExecutionContext
from mlia.core.errors import InternalError
from tests.utils.logging import clear_loggers
@@ -62,35 +61,23 @@ def test_command_info(is_default: bool, expected_command_help: str) -> None:
assert command_info.command_help == expected_command_help
-def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
+def test_default_command(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test adding default command."""
- def mock_command(
- func_mock: MagicMock, name: str, with_working_dir: bool
- ) -> Callable[..., None]:
+ def mock_command(func_mock: MagicMock, name: str) -> Callable[..., None]:
"""Mock cli command."""
def sample_cmd_1(*args: Any, **kwargs: Any) -> None:
"""Sample command."""
func_mock(*args, **kwargs)
- def sample_cmd_2(ctx: ExecutionContext, **kwargs: Any) -> None:
- """Another sample command."""
- func_mock(ctx=ctx, **kwargs)
-
- ret_func = sample_cmd_2 if with_working_dir else sample_cmd_1
+ ret_func = sample_cmd_1
ret_func.__name__ = name
- return ret_func # type: ignore
+ return ret_func
- default_command = MagicMock()
non_default_command = MagicMock()
- def default_command_params(parser: argparse.ArgumentParser) -> None:
- """Add parameters for default command."""
- parser.add_argument("--sample")
- parser.add_argument("--default_arg", default="123")
-
def non_default_command_params(parser: argparse.ArgumentParser) -> None:
"""Add parameters for non default command."""
parser.add_argument("--param")
@@ -100,15 +87,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non
MagicMock(
return_value=[
CommandInfo(
- func=mock_command(default_command, "default_command", True),
- aliases=["command1"],
- opt_groups=[default_command_params],
- is_default=True,
- ),
- CommandInfo(
- func=mock_command(
- non_default_command, "non_default_command", False
- ),
+ func=mock_command(non_default_command, "non_default_command"),
aliases=["command2"],
opt_groups=[non_default_command_params],
is_default=False,
@@ -117,11 +96,7 @@ def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Non
),
)
- tmp_working_dir = str(tmp_path)
- main(["--working-dir", tmp_working_dir, "--sample", "1"])
main(["command2", "--param", "test"])
-
- default_command.assert_called_once_with(ctx=ANY, sample="1", default_arg="123")
non_default_command.assert_called_once_with(param="test")
@@ -140,134 +115,168 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
"params, expected_call",
[
[
- ["operators", "sample_model.tflite"],
+ ["check", "sample_model.tflite", "--target-profile", "ethos-u55-256"],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.tflite",
+ compatibility=False,
+ performance=False,
output=None,
- supported_ops_report=False,
+ json=False,
+ backend=None,
),
],
[
- ["ops", "sample_model.tflite"],
- call(
- ctx=ANY,
- target_profile="ethos-u55-256",
- model="sample_model.tflite",
- output=None,
- supported_ops_report=False,
- ),
- ],
- [
- ["operators", "sample_model.tflite", "--target-profile", "ethos-u55-128"],
+ ["check", "sample_model.tflite", "--target-profile", "ethos-u55-128"],
call(
ctx=ANY,
target_profile="ethos-u55-128",
model="sample_model.tflite",
+ compatibility=False,
+ performance=False,
output=None,
- supported_ops_report=False,
+ json=False,
+ backend=None,
),
],
[
- ["operators"],
- call(
- ctx=ANY,
- target_profile="ethos-u55-256",
- model=None,
- output=None,
- supported_ops_report=False,
- ),
- ],
- [
- ["operators", "--supported-ops-report"],
+ [
+ "check",
+ "sample_model.h5",
+ "--performance",
+ "--compatibility",
+ "--target-profile",
+ "ethos-u55-256",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
- model=None,
+ model="sample_model.h5",
output=None,
- supported_ops_report=True,
+ json=False,
+ compatibility=True,
+ performance=True,
+ backend=None,
),
],
[
[
- "all_tests",
+ "check",
"sample_model.h5",
- "--optimization-type",
- "pruning",
- "--optimization-target",
- "0.5",
+ "--performance",
+ "--target-profile",
+ "ethos-u55-256",
+ "--output",
+ "result.json",
+ "--json",
],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- optimization_type="pruning",
- optimization_target="0.5",
- output=None,
- evaluate_on=["Vela"],
+ performance=True,
+ compatibility=False,
+ output=Path("result.json"),
+ json=True,
+ backend=None,
),
],
[
- ["sample_model.h5"],
+ [
+ "check",
+ "sample_model.h5",
+ "--performance",
+ "--target-profile",
+ "ethos-u55-128",
+ ],
call(
ctx=ANY,
- target_profile="ethos-u55-256",
+ target_profile="ethos-u55-128",
model="sample_model.h5",
- optimization_type="pruning,clustering",
- optimization_target="0.5,32",
+ compatibility=False,
+ performance=True,
output=None,
- evaluate_on=["Vela"],
+ json=False,
+ backend=None,
),
],
[
- ["performance", "sample_model.h5", "--output", "result.json"],
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--clustering",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- output="result.json",
- evaluate_on=["Vela"],
- ),
- ],
- [
- ["perf", "sample_model.h5", "--target-profile", "ethos-u55-128"],
- call(
- ctx=ANY,
- target_profile="ethos-u55-128",
- model="sample_model.h5",
+ pruning=True,
+ clustering=True,
+ pruning_target=None,
+ clustering_target=None,
output=None,
- evaluate_on=["Vela"],
+ json=False,
+ backend=None,
),
],
[
- ["optimization", "sample_model.h5"],
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--clustering",
+ "--pruning-target",
+ "0.5",
+ "--clustering-target",
+ "32",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- optimization_type="pruning,clustering",
- optimization_target="0.5,32",
+ pruning=True,
+ clustering=True,
+ pruning_target=0.5,
+ clustering_target=32,
output=None,
- evaluate_on=["Vela"],
+ json=False,
+ backend=None,
),
],
[
- ["optimization", "sample_model.h5", "--evaluate-on", "some_backend"],
+ [
+ "optimize",
+ "sample_model.h5",
+ "--target-profile",
+ "ethos-u55-256",
+ "--pruning",
+ "--backend",
+ "some_backend",
+ ],
call(
ctx=ANY,
target_profile="ethos-u55-256",
model="sample_model.h5",
- optimization_type="pruning,clustering",
- optimization_target="0.5,32",
+ pruning=True,
+ clustering=False,
+ pruning_target=None,
+ clustering_target=None,
output=None,
- evaluate_on=["some_backend"],
+ json=False,
+ backend=["some_backend"],
),
],
[
[
- "operators",
+ "check",
"sample_model.h5",
+ "--compatibility",
"--target-profile",
"cortex-a",
],
@@ -275,8 +284,11 @@ def wrap_mock_command(mock: MagicMock, command: Callable) -> Callable:
ctx=ANY,
target_profile="cortex-a",
model="sample_model.h5",
+ compatibility=True,
+ performance=False,
output=None,
- supported_ops_report=False,
+ json=False,
+ backend=None,
),
],
],
@@ -288,15 +300,11 @@ def test_commands_execution(
mock = MagicMock()
monkeypatch.setattr(
- "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"])
- )
-
- monkeypatch.setattr(
"mlia.cli.options.get_available_backends",
MagicMock(return_value=["Vela", "some_backend"]),
)
- for command in ["all_tests", "operators", "performance", "optimization"]:
+ for command in ["check", "optimize"]:
monkeypatch.setattr(
f"mlia.cli.main.{command}",
wrap_mock_command(mock, getattr(mlia.cli.main, command)),
@@ -335,15 +343,15 @@ def test_commands_execution_backend_main(
@pytest.mark.parametrize(
- "verbose, exc_mock, expected_output",
+ "debug, exc_mock, expected_output",
[
[
True,
MagicMock(side_effect=Exception("Error")),
[
"Execution finished with error: Error",
- f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
- "for more details",
+ "Please check the log files in the /tmp/mlia-",
+ "/logs for more details",
],
],
[
@@ -351,8 +359,8 @@ def test_commands_execution_backend_main(
MagicMock(side_effect=Exception("Error")),
[
"Execution finished with error: Error",
- f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
- "for more details, or enable verbose mode (--verbose)",
+ "Please check the log files in the /tmp/mlia-",
+ "/logs for more details, or enable debug mode (--debug)",
],
],
[
@@ -389,18 +397,18 @@ def test_commands_execution_backend_main(
],
],
)
-def test_verbose_output(
+def test_debug_output(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture,
- verbose: bool,
+ debug: bool,
exc_mock: MagicMock,
expected_output: list[str],
) -> None:
- """Test flag --verbose."""
+ """Test flag --debug."""
def command_params(parser: argparse.ArgumentParser) -> None:
"""Add parameters for non default command."""
- parser.add_argument("--verbose", action="store_true")
+ parser.add_argument("--debug", action="store_true")
def command() -> None:
"""Run test command."""
@@ -420,8 +428,8 @@ def test_verbose_output(
)
params = ["command"]
- if verbose:
- params.append("--verbose")
+ if debug:
+ params.append("--debug")
exit_code = main(params)
assert exit_code == 1