diff options
Diffstat (limited to 'tests/test_cli_options.py')
-rw-r--r-- | tests/test_cli_options.py | 179 |
1 files changed, 107 insertions, 72 deletions
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py index d75f7c0..a889a93 100644 --- a/tests/test_cli_options.py +++ b/tests/test_cli_options.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Tests for module options.""" from __future__ import annotations @@ -13,14 +13,19 @@ import pytest from mlia.cli.options import add_output_options from mlia.cli.options import get_target_profile_opts from mlia.cli.options import parse_optimization_parameters +from mlia.cli.options import parse_output_parameters +from mlia.core.common import FormattedFilePath @pytest.mark.parametrize( - "optimization_type, optimization_target, expected_error, expected_result", + "pruning, clustering, pruning_target, clustering_target, expected_error," + "expected_result", [ - ( - "pruning", - "0.5", + [ + False, + False, + None, + None, does_not_raise(), [ dict( @@ -29,39 +34,40 @@ from mlia.cli.options import parse_optimization_parameters layers_to_optimize=None, ) ], - ), - ( - "clustering", - "32", + ], + [ + True, + False, + None, + None, does_not_raise(), [ dict( - optimization_type="clustering", - optimization_target=32.0, + optimization_type="pruning", + optimization_target=0.5, layers_to_optimize=None, ) ], - ), - ( - "pruning,clustering", - "0.5,32", + ], + [ + False, + True, + None, + None, does_not_raise(), [ dict( - optimization_type="pruning", - optimization_target=0.5, - layers_to_optimize=None, - ), - dict( optimization_type="clustering", - optimization_target=32.0, + optimization_target=32, layers_to_optimize=None, - ), + ) ], - ), - ( - "pruning, clustering", - "0.5, 32", + ], + [ + True, + True, + None, + None, does_not_raise(), [ dict( @@ -71,50 +77,66 @@ from mlia.cli.options import parse_optimization_parameters ), dict( optimization_type="clustering", - optimization_target=32.0, + optimization_target=32, layers_to_optimize=None, ), ], - ), - ( - "pruning,clustering", - "0.5", - pytest.raises( - Exception, match="Wrong number of optimization targets and types" - ), - None, - ), - ( - "", - "0.5", - pytest.raises(Exception, match="Optimization type is not provided"), + ], + [ + False, + False, + 0.4, None, - ), - ( - "pruning,clustering", - "", - pytest.raises(Exception, match="Optimization target is not provided"), + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.4, + layers_to_optimize=None, + ) + ], + ], + [ + False, + False, None, - ), - ( - "pruning,", - "0.5,abc", + 32, pytest.raises( - Exception, match="Non numeric value for the optimization target" + argparse.ArgumentError, + match="To enable clustering optimization you need to include " + "the `--clustering` flag in your command.", ), None, - ), + ], + [ + False, + True, + None, + 32.2, + does_not_raise(), + [ + dict( + optimization_type="clustering", + optimization_target=32.2, + layers_to_optimize=None, + ) + ], + ], ], ) def test_parse_optimization_parameters( - optimization_type: str, - optimization_target: str, + pruning: bool, + clustering: bool, + pruning_target: float | None, + clustering_target: int | None, expected_error: Any, expected_result: Any, ) -> None: """Test function parse_optimization_parameters.""" with expected_error: - result = parse_optimization_parameters(optimization_type, optimization_target) + result = parse_optimization_parameters( + pruning, clustering, pruning_target, clustering_target + ) assert result == expected_result @@ -155,28 +177,41 @@ def test_output_options(output_parameters: list[str], expected_path: str) -> Non add_output_options(parser) args = parser.parse_args(output_parameters) - assert args.output == expected_path + assert str(args.output) == expected_path @pytest.mark.parametrize( - "output_filename", + "path, json, expected_error, output", [ - "report.txt", - "report.TXT", - "report", - "report.pdf", + [ + None, + True, + pytest.raises( + argparse.ArgumentError, + match=r"To enable JSON output you need to specify the output path. " + r"\(e.g. --output out.json --json\)", + ), + None, + ], + [None, False, does_not_raise(), None], + [ + Path("test_path"), + False, + does_not_raise(), + FormattedFilePath(Path("test_path"), "plain_text"), + ], + [ + Path("test_path"), + True, + does_not_raise(), + FormattedFilePath(Path("test_path"), "json"), + ], ], ) -def test_output_options_bad_parameters( - output_filename: str, capsys: pytest.CaptureFixture +def test_parse_output_parameters( + path: Path | None, json: bool, expected_error: Any, output: FormattedFilePath | None ) -> None: - """Test that args parsing should fail if format is not supported.""" - parser = argparse.ArgumentParser() - add_output_options(parser) - - with pytest.raises(SystemExit): - parser.parse_args(["--output", output_filename]) - - err_output = capsys.readouterr().err - suffix = Path(output_filename).suffix[1:] - assert f"Unsupported format '{suffix}'" in err_output + """Test parsing for output parameters.""" + with expected_error: + formatted_output = parse_output_parameters(path, json) + assert formatted_output == output |