aboutsummaryrefslogtreecommitdiff
path: root/tests/test_cli_options.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_cli_options.py')
-rw-r--r--tests/test_cli_options.py179
1 files changed, 107 insertions, 72 deletions
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py
index d75f7c0..a889a93 100644
--- a/tests/test_cli_options.py
+++ b/tests/test_cli_options.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for module options."""
from __future__ import annotations
@@ -13,14 +13,19 @@ import pytest
from mlia.cli.options import add_output_options
from mlia.cli.options import get_target_profile_opts
from mlia.cli.options import parse_optimization_parameters
+from mlia.cli.options import parse_output_parameters
+from mlia.core.common import FormattedFilePath
@pytest.mark.parametrize(
- "optimization_type, optimization_target, expected_error, expected_result",
+ "pruning, clustering, pruning_target, clustering_target, expected_error,"
+ "expected_result",
[
- (
- "pruning",
- "0.5",
+ [
+ False,
+ False,
+ None,
+ None,
does_not_raise(),
[
dict(
@@ -29,39 +34,40 @@ from mlia.cli.options import parse_optimization_parameters
layers_to_optimize=None,
)
],
- ),
- (
- "clustering",
- "32",
+ ],
+ [
+ True,
+ False,
+ None,
+ None,
does_not_raise(),
[
dict(
- optimization_type="clustering",
- optimization_target=32.0,
+ optimization_type="pruning",
+ optimization_target=0.5,
layers_to_optimize=None,
)
],
- ),
- (
- "pruning,clustering",
- "0.5,32",
+ ],
+ [
+ False,
+ True,
+ None,
+ None,
does_not_raise(),
[
dict(
- optimization_type="pruning",
- optimization_target=0.5,
- layers_to_optimize=None,
- ),
- dict(
optimization_type="clustering",
- optimization_target=32.0,
+ optimization_target=32,
layers_to_optimize=None,
- ),
+ )
],
- ),
- (
- "pruning, clustering",
- "0.5, 32",
+ ],
+ [
+ True,
+ True,
+ None,
+ None,
does_not_raise(),
[
dict(
@@ -71,50 +77,66 @@ from mlia.cli.options import parse_optimization_parameters
),
dict(
optimization_type="clustering",
- optimization_target=32.0,
+ optimization_target=32,
layers_to_optimize=None,
),
],
- ),
- (
- "pruning,clustering",
- "0.5",
- pytest.raises(
- Exception, match="Wrong number of optimization targets and types"
- ),
- None,
- ),
- (
- "",
- "0.5",
- pytest.raises(Exception, match="Optimization type is not provided"),
+ ],
+ [
+ False,
+ False,
+ 0.4,
None,
- ),
- (
- "pruning,clustering",
- "",
- pytest.raises(Exception, match="Optimization target is not provided"),
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.4,
+ layers_to_optimize=None,
+ )
+ ],
+ ],
+ [
+ False,
+ False,
None,
- ),
- (
- "pruning,",
- "0.5,abc",
+ 32,
pytest.raises(
- Exception, match="Non numeric value for the optimization target"
+ argparse.ArgumentError,
+ match="To enable clustering optimization you need to include "
+ "the `--clustering` flag in your command.",
),
None,
- ),
+ ],
+ [
+ False,
+ True,
+ None,
+ 32.2,
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.2,
+ layers_to_optimize=None,
+ )
+ ],
+ ],
],
)
def test_parse_optimization_parameters(
- optimization_type: str,
- optimization_target: str,
+ pruning: bool,
+ clustering: bool,
+ pruning_target: float | None,
+ clustering_target: int | None,
expected_error: Any,
expected_result: Any,
) -> None:
"""Test function parse_optimization_parameters."""
with expected_error:
- result = parse_optimization_parameters(optimization_type, optimization_target)
+ result = parse_optimization_parameters(
+ pruning, clustering, pruning_target, clustering_target
+ )
assert result == expected_result
@@ -155,28 +177,41 @@ def test_output_options(output_parameters: list[str], expected_path: str) -> Non
add_output_options(parser)
args = parser.parse_args(output_parameters)
- assert args.output == expected_path
+ assert str(args.output) == expected_path
@pytest.mark.parametrize(
- "output_filename",
+ "path, json, expected_error, output",
[
- "report.txt",
- "report.TXT",
- "report",
- "report.pdf",
+ [
+ None,
+ True,
+ pytest.raises(
+ argparse.ArgumentError,
+ match=r"To enable JSON output you need to specify the output path. "
+ r"\(e.g. --output out.json --json\)",
+ ),
+ None,
+ ],
+ [None, False, does_not_raise(), None],
+ [
+ Path("test_path"),
+ False,
+ does_not_raise(),
+ FormattedFilePath(Path("test_path"), "plain_text"),
+ ],
+ [
+ Path("test_path"),
+ True,
+ does_not_raise(),
+ FormattedFilePath(Path("test_path"), "json"),
+ ],
],
)
-def test_output_options_bad_parameters(
- output_filename: str, capsys: pytest.CaptureFixture
+def test_parse_output_parameters(
+ path: Path | None, json: bool, expected_error: Any, output: FormattedFilePath | None
) -> None:
- """Test that args parsing should fail if format is not supported."""
- parser = argparse.ArgumentParser()
- add_output_options(parser)
-
- with pytest.raises(SystemExit):
- parser.parse_args(["--output", output_filename])
-
- err_output = capsys.readouterr().err
- suffix = Path(output_filename).suffix[1:]
- assert f"Unsupported format '{suffix}'" in err_output
+ """Test parsing for output parameters."""
+ with expected_error:
+ formatted_output = parse_output_parameters(path, json)
+ assert formatted_output == output