diff options
author | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-07-11 12:33:42 +0100 |
---|---|---|
committer | Benjamin Klimczak <benjamin.klimczak@arm.com> | 2022-07-26 14:08:21 +0100 |
commit | 5d81f37de09efe10f90512e50252be9c36925fcf (patch) | |
tree | b4d7cdfd051da0a6e882bdfcf280fd7ca7b39e57 /tests/test_cli_options.py | |
parent | 7899b908c1fe6d86b92a80f3827ddd0ac05b674b (diff) | |
download | mlia-5d81f37de09efe10f90512e50252be9c36925fcf.tar.gz |
MLIA-551 Rework remains of AIET architecture
Re-factoring the code base to further merge the old AIET code into MLIA.
- Remove last traces of the backend type 'tool'
- Controlled systems removed, including SSH protocol, controller,
RunningCommand, locks etc.
- Build command / build dir and deploy functionality removed from
Applications and Systems
- Moving working_dir()
- Replace module 'output_parser' with new module 'output_consumer' and
merge Base64 parsing into it
- Change the output consumption to optionally remove (i.e. actually
consume) lines
- Use Base64 parsing in GenericInferenceOutputParser, replacing the
regex-based parsing and remove the now unused regex parsing
- Remove AIET reporting
- Pre-install applications by moving them to src/mlia/resources/backends
- Rename aiet-config.json to backend-config.json
- Move tests from tests/mlia/ to tests/
- Adapt unit tests to code changes
- Dependencies removed: paramiko, filelock, psutil
- Fix bug in corstone.py: The wrong resource directory was used which
broke the functionality to download backends.
- Use f-string formatting.
- Use logging instead of print.
Change-Id: I768bc3bb6b2eda57d219ad01be4a8e0a74167d76
Diffstat (limited to 'tests/test_cli_options.py')
-rw-r--r-- | tests/test_cli_options.py | 186 |
1 files changed, 186 insertions, 0 deletions
diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py new file mode 100644 index 0000000..a441e58 --- /dev/null +++ b/tests/test_cli_options.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module options.""" +import argparse +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import pytest + +from mlia.cli.options import add_output_options +from mlia.cli.options import get_target_profile_opts +from mlia.cli.options import parse_optimization_parameters + + +@pytest.mark.parametrize( + "optimization_type, optimization_target, expected_error, expected_result", + [ + ( + "pruning", + "0.5", + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ) + ], + ), + ( + "clustering", + "32", + does_not_raise(), + [ + dict( + optimization_type="clustering", + optimization_target=32.0, + layers_to_optimize=None, + ) + ], + ), + ( + "pruning,clustering", + "0.5,32", + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + dict( + optimization_type="clustering", + optimization_target=32.0, + layers_to_optimize=None, + ), + ], + ), + ( + "pruning, clustering", + "0.5, 32", + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + dict( + optimization_type="clustering", + optimization_target=32.0, + layers_to_optimize=None, + ), + ], + ), + ( + "pruning,clustering", + "0.5", + pytest.raises( + Exception, match="Wrong number of optimization targets and types" + ), + None, + ), + ( + "", + "0.5", + pytest.raises(Exception, match="Optimization type is not provided"), + None, + ), + ( + "pruning,clustering", + "", + pytest.raises(Exception, match="Optimization target is not provided"), + None, + ), + ( + "pruning,", + "0.5,abc", + pytest.raises( + Exception, match="Non numeric value for the optimization target" + ), + None, + ), + ], +) +def test_parse_optimization_parameters( + optimization_type: str, + optimization_target: str, + expected_error: Any, + expected_result: Any, +) -> None: + """Test function parse_optimization_parameters.""" + with expected_error: + result = parse_optimization_parameters(optimization_type, optimization_target) + assert result == expected_result + + +@pytest.mark.parametrize( + "args, expected_opts", + [ + [ + {}, + [], + ], + [ + {"target_profile": "profile"}, + ["--target-profile", "profile"], + ], + [ + # for the default profile empty list should be returned + {"target": "ethos-u55-256"}, + [], + ], + ], +) +def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None: + """Test getting target options.""" + assert get_target_profile_opts(args) == expected_opts + + +@pytest.mark.parametrize( + "output_parameters, expected_path", + [ + [["--output", "report.json"], "report.json"], + [["--output", "REPORT.JSON"], "REPORT.JSON"], + [["--output", "some_folder/report.json"], "some_folder/report.json"], + [["--output", "report.csv"], "report.csv"], + [["--output", "REPORT.CSV"], "REPORT.CSV"], + [["--output", "some_folder/report.csv"], "some_folder/report.csv"], + ], +) +def test_output_options(output_parameters: List[str], expected_path: str) -> None: + """Test output options resolving.""" + parser = argparse.ArgumentParser() + add_output_options(parser) + + args = parser.parse_args(output_parameters) + assert args.output == expected_path + + +@pytest.mark.parametrize( + "output_filename", + [ + "report.txt", + "report.TXT", + "report", + "report.pdf", + ], +) +def test_output_options_bad_parameters( + output_filename: str, capsys: pytest.CaptureFixture +) -> None: + """Test that args parsing should fail if format is not supported.""" + parser = argparse.ArgumentParser() + add_output_options(parser) + + with pytest.raises(SystemExit): + parser.parse_args(["--output", output_filename]) + + err_output = capsys.readouterr().err + suffix = Path(output_filename).suffix[1:] + assert f"Unsupported format '{suffix}'" in err_output |