From 5d81f37de09efe10f90512e50252be9c36925fcf Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Mon, 11 Jul 2022 12:33:42 +0100 Subject: MLIA-551 Rework remains of AIET architecture Re-factoring the code base to further merge the old AIET code into MLIA. - Remove last traces of the backend type 'tool' - Controlled systems removed, including SSH protocol, controller, RunningCommand, locks etc. - Build command / build dir and deploy functionality removed from Applications and Systems - Moving working_dir() - Replace module 'output_parser' with new module 'output_consumer' and merge Base64 parsing into it - Change the output consumption to optionally remove (i.e. actually consume) lines - Use Base64 parsing in GenericInferenceOutputParser, replacing the regex-based parsing and remove the now unused regex parsing - Remove AIET reporting - Pre-install applications by moving them to src/mlia/resources/backends - Rename aiet-config.json to backend-config.json - Move tests from tests/mlia/ to tests/ - Adapt unit tests to code changes - Dependencies removed: paramiko, filelock, psutil - Fix bug in corstone.py: The wrong resource directory was used which broke the functionality to download backends. - Use f-string formatting. - Use logging instead of print. Change-Id: I768bc3bb6b2eda57d219ad01be4a8e0a74167d76 --- tests/test_cli_options.py | 186 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 tests/test_cli_options.py (limited to 'tests/test_cli_options.py') diff --git a/tests/test_cli_options.py b/tests/test_cli_options.py new file mode 100644 index 0000000..a441e58 --- /dev/null +++ b/tests/test_cli_options.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module options.""" +import argparse +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import pytest + +from mlia.cli.options import add_output_options +from mlia.cli.options import get_target_profile_opts +from mlia.cli.options import parse_optimization_parameters + + +@pytest.mark.parametrize( + "optimization_type, optimization_target, expected_error, expected_result", + [ + ( + "pruning", + "0.5", + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ) + ], + ), + ( + "clustering", + "32", + does_not_raise(), + [ + dict( + optimization_type="clustering", + optimization_target=32.0, + layers_to_optimize=None, + ) + ], + ), + ( + "pruning,clustering", + "0.5,32", + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + dict( + optimization_type="clustering", + optimization_target=32.0, + layers_to_optimize=None, + ), + ], + ), + ( + "pruning, clustering", + "0.5, 32", + does_not_raise(), + [ + dict( + optimization_type="pruning", + optimization_target=0.5, + layers_to_optimize=None, + ), + dict( + optimization_type="clustering", + optimization_target=32.0, + layers_to_optimize=None, + ), + ], + ), + ( + "pruning,clustering", + "0.5", + pytest.raises( + Exception, match="Wrong number of optimization targets and types" + ), + None, + ), + ( + "", + "0.5", + pytest.raises(Exception, match="Optimization type is not provided"), + None, + ), + ( + "pruning,clustering", + "", + pytest.raises(Exception, match="Optimization target is not provided"), + None, + ), + ( + "pruning,", + "0.5,abc", + pytest.raises( + Exception, match="Non numeric value for the optimization target" + ), + None, + ), + ], +) +def test_parse_optimization_parameters( + optimization_type: str, + optimization_target: str, + expected_error: Any, + expected_result: Any, +) -> None: + """Test function parse_optimization_parameters.""" + with expected_error: + result = parse_optimization_parameters(optimization_type, optimization_target) + assert result == expected_result + + +@pytest.mark.parametrize( + "args, expected_opts", + [ + [ + {}, + [], + ], + [ + {"target_profile": "profile"}, + ["--target-profile", "profile"], + ], + [ + # for the default profile empty list should be returned + {"target": "ethos-u55-256"}, + [], + ], + ], +) +def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None: + """Test getting target options.""" + assert get_target_profile_opts(args) == expected_opts + + +@pytest.mark.parametrize( + "output_parameters, expected_path", + [ + [["--output", "report.json"], "report.json"], + [["--output", "REPORT.JSON"], "REPORT.JSON"], + [["--output", "some_folder/report.json"], "some_folder/report.json"], + [["--output", "report.csv"], "report.csv"], + [["--output", "REPORT.CSV"], "REPORT.CSV"], + [["--output", "some_folder/report.csv"], "some_folder/report.csv"], + ], +) +def test_output_options(output_parameters: List[str], expected_path: str) -> None: + """Test output options resolving.""" + parser = argparse.ArgumentParser() + add_output_options(parser) + + args = parser.parse_args(output_parameters) + assert args.output == expected_path + + +@pytest.mark.parametrize( + "output_filename", + [ + "report.txt", + "report.TXT", + "report", + "report.pdf", + ], +) +def test_output_options_bad_parameters( + output_filename: str, capsys: pytest.CaptureFixture +) -> None: + """Test that args parsing should fail if format is not supported.""" + parser = argparse.ArgumentParser() + add_output_options(parser) + + with pytest.raises(SystemExit): + parser.parse_args(["--output", output_filename]) + + err_output = capsys.readouterr().err + suffix = Path(output_filename).suffix[1:] + assert f"Unsupported format '{suffix}'" in err_output -- cgit v1.2.1