diff options
author | Diego Russo <diego.russo@arm.com> | 2022-05-30 13:34:14 +0100 |
---|---|---|
committer | Diego Russo <diego.russo@arm.com> | 2022-05-30 13:34:14 +0100 |
commit | 0efca3cadbad5517a59884576ddb90cfe7ac30f8 (patch) | |
tree | abed6cb6fbf3c439fc8d947f505b6a53d5daeb1e /tests/aiet | |
parent | 0777092695c143c3a54680b5748287d40c914c35 (diff) | |
download | mlia-0efca3cadbad5517a59884576ddb90cfe7ac30f8.tar.gz |
Add MLIA codebase0.3.0-rc.1
Add MLIA codebase including sources and tests.
Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd
Diffstat (limited to 'tests/aiet')
65 files changed, 6296 insertions, 0 deletions
diff --git a/tests/aiet/__init__.py b/tests/aiet/__init__.py new file mode 100644 index 0000000..873a7df --- /dev/null +++ b/tests/aiet/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""AIET tests module.""" diff --git a/tests/aiet/conftest.py b/tests/aiet/conftest.py new file mode 100644 index 0000000..cab3dc2 --- /dev/null +++ b/tests/aiet/conftest.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=redefined-outer-name +"""conftest for pytest.""" +import shutil +import tarfile +from pathlib import Path +from typing import Any + +import pytest +from click.testing import CliRunner + +from aiet.backend.common import get_backend_configs + + +@pytest.fixture(scope="session") +def test_systems_path(test_resources_path: Path) -> Path: + """Return test systems path in a pytest fixture.""" + return test_resources_path / "systems" + + +@pytest.fixture(scope="session") +def test_applications_path(test_resources_path: Path) -> Path: + """Return test applications path in a pytest fixture.""" + return test_resources_path / "applications" + + +@pytest.fixture(scope="session") +def test_tools_path(test_resources_path: Path) -> Path: + """Return test tools path in a pytest fixture.""" + return test_resources_path / "tools" + + +@pytest.fixture(scope="session") +def test_resources_path() -> Path: + """Return test resources path in a pytest fixture.""" + current_path = Path(__file__).parent.absolute() + return current_path / "test_resources" + + +@pytest.fixture(scope="session") +def non_optimised_input_model_file(test_tflite_model: Path) -> Path: + """Provide the path to a quantized dummy model file.""" + return test_tflite_model + + +@pytest.fixture(scope="session") +def optimised_input_model_file(test_tflite_vela_model: Path) -> Path: + """Provide path to Vela-optimised dummy model file.""" + return test_tflite_vela_model + + +@pytest.fixture(scope="session") +def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path: + """Provide the path to an invalid dummy model file.""" + return test_tflite_invalid_model + + +@pytest.fixture(autouse=True) +def test_resources(monkeypatch: pytest.MonkeyPatch, test_resources_path: Path) -> Any: + """Force using test resources as middleware's repository.""" + + def get_test_resources() -> Path: + """Return path to the test resources.""" + return test_resources_path + + monkeypatch.setattr("aiet.utils.fs.get_aiet_resources", get_test_resources) + yield + + +@pytest.fixture(scope="session", autouse=True) +def add_tools(test_resources_path: Path) -> Any: + """Symlink the tools from the original resources path to the test resources path.""" + # tool_dirs = get_available_tool_directory_names() + tool_dirs = [cfg.parent for cfg in get_backend_configs("tools")] + + links = { + src_dir: (test_resources_path / "tools" / src_dir.name) for src_dir in tool_dirs + } + for src_dir, dst_dir in links.items(): + if not dst_dir.exists(): + dst_dir.symlink_to(src_dir, target_is_directory=True) + yield + # Remove symlinks + for dst_dir in links.values(): + if dst_dir.is_symlink(): + dst_dir.unlink() + + +def create_archive( + archive_name: str, source: Path, destination: Path, with_root_folder: bool = False +) -> None: + """Create archive from directory source.""" + with tarfile.open(destination / archive_name, mode="w:gz") as tar: + for item in source.iterdir(): + item_name = item.name + if with_root_folder: + item_name = f"{source.name}/{item_name}" + tar.add(item, item_name) + + +def process_directory(source: Path, destination: Path) -> None: + """Process resource directory.""" + destination.mkdir() + + for item in source.iterdir(): + if item.is_dir(): + create_archive(f"{item.name}.tar.gz", item, destination) + create_archive(f"{item.name}_dir.tar.gz", item, destination, True) + + +@pytest.fixture(scope="session", autouse=True) +def add_archives( + test_resources_path: Path, tmp_path_factory: pytest.TempPathFactory +) -> Any: + """Generate archives of the test resources.""" + tmp_path = tmp_path_factory.mktemp("archives") + + archives_path = tmp_path / "archives" + archives_path.mkdir() + + if (archives_path_link := test_resources_path / "archives").is_symlink(): + archives_path.unlink() + + archives_path_link.symlink_to(archives_path, target_is_directory=True) + + for item in ["applications", "systems"]: + process_directory(test_resources_path / item, archives_path / item) + + yield + + archives_path_link.unlink() + shutil.rmtree(tmp_path) + + +@pytest.fixture(scope="module") +def cli_runner() -> CliRunner: + """Return CliRunner instance in a pytest fixture.""" + return CliRunner() diff --git a/tests/aiet/test_backend_application.py b/tests/aiet/test_backend_application.py new file mode 100644 index 0000000..abfab00 --- /dev/null +++ b/tests/aiet/test_backend_application.py @@ -0,0 +1,452 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Tests for the application backend.""" +from collections import Counter +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import List +from unittest.mock import MagicMock + +import pytest + +from aiet.backend.application import Application +from aiet.backend.application import get_application +from aiet.backend.application import get_available_application_directory_names +from aiet.backend.application import get_available_applications +from aiet.backend.application import get_unique_application_names +from aiet.backend.application import install_application +from aiet.backend.application import load_applications +from aiet.backend.application import remove_application +from aiet.backend.common import Command +from aiet.backend.common import DataPaths +from aiet.backend.common import Param +from aiet.backend.common import UserParamConfig +from aiet.backend.config import ApplicationConfig +from aiet.backend.config import ExtendedApplicationConfig +from aiet.backend.config import NamedExecutionConfig + + +def test_get_available_application_directory_names() -> None: + """Test get_available_applicationss mocking get_resources.""" + directory_names = get_available_application_directory_names() + assert Counter(directory_names) == Counter( + ["application1", "application2", "application4", "application5"] + ) + + +def test_get_available_applications() -> None: + """Test get_available_applicationss mocking get_resources.""" + available_applications = get_available_applications() + + assert all(isinstance(s, Application) for s in available_applications) + assert all(s != 42 for s in available_applications) + assert len(available_applications) == 9 + # application_5 has multiply items with multiply supported systems + assert [str(s) for s in available_applications] == [ + "application_1", + "application_2", + "application_4", + "application_5", + "application_5", + "application_5A", + "application_5A", + "application_5B", + "application_5B", + ] + + +def test_get_unique_application_names() -> None: + """Test get_unique_application_names.""" + unique_names = get_unique_application_names() + + assert all(isinstance(s, str) for s in unique_names) + assert all(s for s in unique_names) + assert sorted(unique_names) == [ + "application_1", + "application_2", + "application_4", + "application_5", + "application_5A", + "application_5B", + ] + + +def test_get_application() -> None: + """Test get_application mocking get_resoures.""" + application = get_application("application_1") + if len(application) != 1: + pytest.fail("Unable to get application") + assert application[0].name == "application_1" + + application = get_application("unknown application") + assert len(application) == 0 + + +@pytest.mark.parametrize( + "source, call_count, expected_exception", + ( + ( + "archives/applications/application1.tar.gz", + 0, + pytest.raises( + Exception, match=r"Applications \[application_1\] are already installed" + ), + ), + ( + "various/applications/application_with_empty_config", + 0, + pytest.raises(Exception, match="No application definition found"), + ), + ( + "various/applications/application_with_wrong_config1", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ( + "various/applications/application_with_wrong_config2", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ( + "various/applications/application_with_wrong_config3", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ("various/applications/application_with_valid_config", 1, does_not_raise()), + ( + "archives/applications/application3.tar.gz", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ( + "applications/application1", + 0, + pytest.raises( + Exception, match=r"Applications \[application_1\] are already installed" + ), + ), + ( + "applications/application3", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ), +) +def test_install_application( + monkeypatch: Any, + test_resources_path: Path, + source: str, + call_count: int, + expected_exception: Any, +) -> None: + """Test application install from archive.""" + mock_create_destination_and_install = MagicMock() + monkeypatch.setattr( + "aiet.backend.application.create_destination_and_install", + mock_create_destination_and_install, + ) + + with expected_exception: + install_application(test_resources_path / source) + assert mock_create_destination_and_install.call_count == call_count + + +def test_remove_application(monkeypatch: Any) -> None: + """Test application removal.""" + mock_remove_backend = MagicMock() + monkeypatch.setattr("aiet.backend.application.remove_backend", mock_remove_backend) + + remove_application("some_application_directory") + mock_remove_backend.assert_called_once() + + +def test_application_config_without_commands() -> None: + """Test application config without commands.""" + config = ApplicationConfig(name="application") + application = Application(config) + # pylint: disable=use-implicit-booleaness-not-comparison + assert application.commands == {} + + +class TestApplication: + """Test for application class methods.""" + + def test___eq__(self) -> None: + """Test overloaded __eq__ method.""" + config = ApplicationConfig( + # Application + supported_systems=["system1", "system2"], + build_dir="build_dir", + # inherited from Backend + name="name", + description="description", + commands={}, + ) + application1 = Application(config) + application2 = Application(config) # Identical + assert application1 == application2 + + application3 = Application(config) # changed + # Change one single attribute so not equal, but same Type + setattr(application3, "supported_systems", ["somewhere/else"]) + assert application1 != application3 + + # different Type + application4 = "Not the Application you are looking for" + assert application1 != application4 + + application5 = Application(config) + # supported systems could be in any order + setattr(application5, "supported_systems", ["system2", "system1"]) + assert application1 == application5 + + def test_can_run_on(self) -> None: + """Test Application can run on.""" + config = ApplicationConfig(name="application", supported_systems=["System-A"]) + + application = Application(config) + assert application.can_run_on("System-A") + assert not application.can_run_on("System-B") + + applications = get_application("application_1", "System 1") + assert len(applications) == 1 + assert applications[0].can_run_on("System 1") + + def test_get_deploy_data(self, tmp_path: Path) -> None: + """Test Application can run on.""" + src, dest = "src", "dest" + config = ApplicationConfig( + name="application", deploy_data=[(src, dest)], config_location=tmp_path + ) + src_path = tmp_path / src + src_path.mkdir() + application = Application(config) + assert application.get_deploy_data() == [DataPaths(src_path, dest)] + + def test_get_deploy_data_no_config_location(self) -> None: + """Test that getting deploy data fails if no config location provided.""" + with pytest.raises( + Exception, match="Unable to get application .* config location" + ): + Application(ApplicationConfig(name="application")).get_deploy_data() + + def test_unable_to_create_application_without_name(self) -> None: + """Test that it is not possible to create application without name.""" + with pytest.raises(Exception, match="Name is empty"): + Application(ApplicationConfig()) + + def test_application_config_without_commands(self) -> None: + """Test application config without commands.""" + config = ApplicationConfig(name="application") + application = Application(config) + # pylint: disable=use-implicit-booleaness-not-comparison + assert application.commands == {} + + @pytest.mark.parametrize( + "config, expected_params", + ( + ( + ApplicationConfig( + name="application", + commands={"command": ["cmd {user_params:0} {user_params:1}"]}, + user_params={ + "command": [ + UserParamConfig( + name="--param1", description="param1", alias="param1" + ), + UserParamConfig( + name="--param2", description="param2", alias="param2" + ), + ] + }, + ), + [Param("--param1", "param1"), Param("--param2", "param2")], + ), + ( + ApplicationConfig( + name="application", + commands={"command": ["cmd {user_params:param1} {user_params:1}"]}, + user_params={ + "command": [ + UserParamConfig( + name="--param1", description="param1", alias="param1" + ), + UserParamConfig( + name="--param2", description="param2", alias="param2" + ), + ] + }, + ), + [Param("--param1", "param1"), Param("--param2", "param2")], + ), + ( + ApplicationConfig( + name="application", + commands={"command": ["cmd {user_params:param1}"]}, + user_params={ + "command": [ + UserParamConfig( + name="--param1", description="param1", alias="param1" + ), + UserParamConfig( + name="--param2", description="param2", alias="param2" + ), + ] + }, + ), + [Param("--param1", "param1")], + ), + ), + ) + def test_remove_unused_params( + self, config: ApplicationConfig, expected_params: List[Param] + ) -> None: + """Test mod remove_unused_parameter.""" + application = Application(config) + application.remove_unused_params() + assert application.commands["command"].params == expected_params + + +@pytest.mark.parametrize( + "config, expected_error", + ( + ( + ExtendedApplicationConfig(name="application"), + pytest.raises(Exception, match="No supported systems definition provided"), + ), + ( + ExtendedApplicationConfig( + name="application", supported_systems=[NamedExecutionConfig(name="")] + ), + pytest.raises( + Exception, + match="Unable to read supported system definition, name is missed", + ), + ), + ( + ExtendedApplicationConfig( + name="application", + supported_systems=[ + NamedExecutionConfig( + name="system", + commands={"command": ["cmd"]}, + user_params={"command": [UserParamConfig(name="param")]}, + ) + ], + commands={"command": ["cmd {user_params:0}"]}, + user_params={"command": [UserParamConfig(name="param")]}, + ), + pytest.raises( + Exception, match="Default parameters for command .* should have aliases" + ), + ), + ( + ExtendedApplicationConfig( + name="application", + supported_systems=[ + NamedExecutionConfig( + name="system", + commands={"command": ["cmd"]}, + user_params={"command": [UserParamConfig(name="param")]}, + ) + ], + commands={"command": ["cmd {user_params:0}"]}, + user_params={"command": [UserParamConfig(name="param", alias="param")]}, + ), + pytest.raises( + Exception, match="system parameters for command .* should have aliases" + ), + ), + ), +) +def test_load_application_exceptional_cases( + config: ExtendedApplicationConfig, expected_error: Any +) -> None: + """Test exceptional cases for application load function.""" + with expected_error: + load_applications(config) + + +def test_load_application() -> None: + """Test application load function. + + The main purpose of this test is to test configuration for application + for different systems. All configuration should be correctly + overridden if needed. + """ + application_5 = get_application("application_5") + assert len(application_5) == 2 + + default_commands = { + "build": Command(["default build command"]), + "run": Command(["default run command"]), + } + default_variables = {"var1": "value1", "var2": "value2"} + + application_5_0 = application_5[0] + assert application_5_0.build_dir == "default_build_dir" + assert application_5_0.supported_systems == ["System 1"] + assert application_5_0.commands == default_commands + assert application_5_0.variables == default_variables + assert application_5_0.lock is False + + application_5_1 = application_5[1] + assert application_5_1.build_dir == application_5_0.build_dir + assert application_5_1.supported_systems == ["System 2"] + assert application_5_1.commands == application_5_1.commands + assert application_5_1.variables == default_variables + + application_5a = get_application("application_5A") + assert len(application_5a) == 2 + + application_5a_0 = application_5a[0] + assert application_5a_0.supported_systems == ["System 1"] + assert application_5a_0.build_dir == "build_5A" + assert application_5a_0.commands == default_commands + assert application_5a_0.variables == {"var1": "new value1", "var2": "value2"} + assert application_5a_0.lock is False + + application_5a_1 = application_5a[1] + assert application_5a_1.supported_systems == ["System 2"] + assert application_5a_1.build_dir == "build" + assert application_5a_1.commands == { + "build": Command(["default build command"]), + "run": Command(["run command on system 2"]), + } + assert application_5a_1.variables == {"var1": "value1", "var2": "new value2"} + assert application_5a_1.lock is True + + application_5b = get_application("application_5B") + assert len(application_5b) == 2 + + application_5b_0 = application_5b[0] + assert application_5b_0.build_dir == "build_5B" + assert application_5b_0.supported_systems == ["System 1"] + assert application_5b_0.commands == { + "build": Command(["default build command with value for var1 System1"], []), + "run": Command(["default run command with value for var2 System1"]), + } + assert "non_used_command" not in application_5b_0.commands + + application_5b_1 = application_5b[1] + assert application_5b_1.build_dir == "build" + assert application_5b_1.supported_systems == ["System 2"] + assert application_5b_1.commands == { + "build": Command( + [ + "build command on system 2 with value" + " for var1 System2 {user_params:param1}" + ], + [ + Param( + "--param", + "Sample command param", + ["value1", "value2", "value3"], + "value1", + ) + ], + ), + "run": Command(["run command on system 2"], []), + } diff --git a/tests/aiet/test_backend_common.py b/tests/aiet/test_backend_common.py new file mode 100644 index 0000000..12c30ec --- /dev/null +++ b/tests/aiet/test_backend_common.py @@ -0,0 +1,486 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use,protected-access +"""Tests for the common backend module.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import cast +from typing import Dict +from typing import IO +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +from unittest.mock import MagicMock + +import pytest + +from aiet.backend.application import Application +from aiet.backend.common import Backend +from aiet.backend.common import BaseBackendConfig +from aiet.backend.common import Command +from aiet.backend.common import ConfigurationException +from aiet.backend.common import load_config +from aiet.backend.common import Param +from aiet.backend.common import parse_raw_parameter +from aiet.backend.common import remove_backend +from aiet.backend.config import ApplicationConfig +from aiet.backend.config import UserParamConfig +from aiet.backend.execution import ExecutionContext +from aiet.backend.execution import ParamResolver +from aiet.backend.system import System + + +@pytest.mark.parametrize( + "directory_name, expected_exception", + ( + ("some_dir", does_not_raise()), + (None, pytest.raises(Exception, match="No directory name provided")), + ), +) +def test_remove_backend( + monkeypatch: Any, directory_name: str, expected_exception: Any +) -> None: + """Test remove_backend function.""" + mock_remove_resource = MagicMock() + monkeypatch.setattr("aiet.backend.common.remove_resource", mock_remove_resource) + + with expected_exception: + remove_backend(directory_name, "applications") + + +@pytest.mark.parametrize( + "filename, expected_exception", + ( + ("application_config.json", does_not_raise()), + (None, pytest.raises(Exception, match="Unable to read config")), + ), +) +def test_load_config( + filename: str, expected_exception: Any, test_resources_path: Path, monkeypatch: Any +) -> None: + """Test load_config.""" + with expected_exception: + configs: List[Optional[Union[Path, IO[bytes]]]] = ( + [None] + if not filename + else [ + # Ignore pylint warning as 'with' can't be used inside of a + # generator expression. + # pylint: disable=consider-using-with + open(test_resources_path / filename, "rb"), + test_resources_path / filename, + ] + ) + for config in configs: + json_mock = MagicMock() + monkeypatch.setattr("aiet.backend.common.json.load", json_mock) + load_config(config) + json_mock.assert_called_once() + + +class TestBackend: + """Test Backend class.""" + + def test___repr__(self) -> None: + """Test the representation of Backend instance.""" + backend = Backend( + BaseBackendConfig(name="Testing name", description="Testing description") + ) + assert str(backend) == "Testing name" + + def test__eq__(self) -> None: + """Test equality method with different cases.""" + backend1 = Backend(BaseBackendConfig(name="name", description="description")) + backend1.commands = {"command": Command(["command"])} + + backend2 = Backend(BaseBackendConfig(name="name", description="description")) + backend2.commands = {"command": Command(["command"])} + + backend3 = Backend( + BaseBackendConfig( + name="Ben", description="This is not the Backend you are looking for" + ) + ) + backend3.commands = {"wave": Command(["wave hand"])} + + backend4 = "Foo" # checking not isinstance(backend4, Backend) + + assert backend1 == backend2 + assert backend1 != backend3 + assert backend1 != backend4 + + @pytest.mark.parametrize( + "parameter, valid", + [ + ("--choice-param dummy_value_1", True), + ("--choice-param wrong_value", False), + ("--open-param something", True), + ("--wrong-param value", False), + ], + ) + def test_validate_parameter( + self, parameter: str, valid: bool, test_resources_path: Path + ) -> None: + """Test validate_parameter.""" + config = cast( + List[ApplicationConfig], + load_config(test_resources_path / "hello_world.json"), + ) + # The application configuration is a list of configurations so we need + # only the first one + # Exercise the validate_parameter test using the Application classe which + # inherits from Backend. + application = Application(config[0]) + assert application.validate_parameter("run", parameter) == valid + + def test_validate_parameter_with_invalid_command( + self, test_resources_path: Path + ) -> None: + """Test validate_parameter with an invalid command_name.""" + config = cast( + List[ApplicationConfig], + load_config(test_resources_path / "hello_world.json"), + ) + application = Application(config[0]) + with pytest.raises(AttributeError) as err: + # command foo does not exist, so raise an error + application.validate_parameter("foo", "bar") + assert "Unknown command: 'foo'" in str(err.value) + + def test_build_command(self, monkeypatch: Any) -> None: + """Test command building.""" + config = { + "name": "test", + "commands": { + "build": ["build {user_params:0} {user_params:1}"], + "run": ["run {user_params:0}"], + "post_run": ["post_run {application_params:0} on {system_params:0}"], + "some_command": ["Command with {variables:var_A}"], + "empty_command": [""], + }, + "user_params": { + "build": [ + { + "name": "choice_param_0=", + "values": [1, 2, 3], + "default_value": 1, + }, + {"name": "choice_param_1", "values": [3, 4, 5], "default_value": 3}, + {"name": "choice_param_3", "values": [6, 7, 8]}, + ], + "run": [{"name": "flag_param_0"}], + }, + "variables": {"var_A": "value for variable A"}, + } + + monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock()) + application, system = Application(config), System(config) # type: ignore + context = ExecutionContext( + app=application, + app_params=[], + system=system, + system_params=[], + custom_deploy_data=[], + ) + + param_resolver = ParamResolver(context) + + cmd = application.build_command( + "build", ["choice_param_0=2", "choice_param_1=4"], param_resolver + ) + assert cmd == ["build choice_param_0=2 choice_param_1 4"] + + cmd = application.build_command("build", ["choice_param_0=2"], param_resolver) + assert cmd == ["build choice_param_0=2 choice_param_1 3"] + + cmd = application.build_command( + "build", ["choice_param_0=2", "choice_param_3=7"], param_resolver + ) + assert cmd == ["build choice_param_0=2 choice_param_1 3"] + + with pytest.raises( + ConfigurationException, match="Command 'foo' could not be found." + ): + application.build_command("foo", [""], param_resolver) + + cmd = application.build_command("some_command", [], param_resolver) + assert cmd == ["Command with value for variable A"] + + cmd = application.build_command("empty_command", [], param_resolver) + assert cmd == [""] + + @pytest.mark.parametrize("class_", [Application, System]) + def test_build_command_unknown_variable(self, class_: type) -> None: + """Test that unable to construct backend with unknown variable.""" + with pytest.raises(Exception, match="Unknown variable var1"): + config = {"name": "test", "commands": {"run": ["run {variables:var1}"]}} + class_(config) + + @pytest.mark.parametrize( + "class_, config, expected_output", + [ + ( + Application, + { + "name": "test", + "commands": { + "build": ["build {user_params:0} {user_params:1}"], + "run": ["run {user_params:0}"], + }, + "user_params": { + "build": [ + { + "name": "choice_param_0=", + "values": ["a", "b", "c"], + "default_value": "a", + "alias": "param_1", + }, + { + "name": "choice_param_1", + "values": ["a", "b", "c"], + "default_value": "a", + "alias": "param_2", + }, + {"name": "choice_param_3", "values": ["a", "b", "c"]}, + ], + "run": [{"name": "flag_param_0"}], + }, + }, + [ + ( + "b", + Param( + name="choice_param_0=", + description="", + values=["a", "b", "c"], + default_value="a", + alias="param_1", + ), + ), + ( + "a", + Param( + name="choice_param_1", + description="", + values=["a", "b", "c"], + default_value="a", + alias="param_2", + ), + ), + ( + "c", + Param( + name="choice_param_3", + description="", + values=["a", "b", "c"], + ), + ), + ], + ), + (System, {"name": "test"}, []), + ], + ) + def test_resolved_parameters( + self, + monkeypatch: Any, + class_: type, + config: Dict, + expected_output: List[Tuple[Optional[str], Param]], + ) -> None: + """Test command building.""" + monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock()) + backend = class_(config) + + params = backend.resolved_parameters( + "build", ["choice_param_0=b", "choice_param_3=c"] + ) + assert params == expected_output + + @pytest.mark.parametrize( + ["param_name", "user_param", "expected_value"], + [ + ( + "test_name", + "test_name=1234", + "1234", + ), # optional parameter using '=' + ( + "test_name", + "test_name 1234", + "1234", + ), # optional parameter using ' ' + ("test_name", "test_name", None), # flag + (None, "test_name=1234", "1234"), # positional parameter + ], + ) + def test_resolved_user_parameters( + self, param_name: str, user_param: str, expected_value: str + ) -> None: + """Test different variants to provide user parameters.""" + # A dummy config providing one backend config + config = { + "name": "test_backend", + "commands": { + "test": ["user_param:test_param"], + }, + "user_params": { + "test": [UserParamConfig(name=param_name, alias="test_name")], + }, + } + backend = Backend(cast(BaseBackendConfig, config)) + params = backend.resolved_parameters( + command_name="test", user_params=[user_param] + ) + assert len(params) == 1 + value, param = params[0] + assert param_name == param.name + assert expected_value == value + + @pytest.mark.parametrize( + "input_param,expected", + [ + ("--param=1", ("--param", "1")), + ("--param 1", ("--param", "1")), + ("--flag", ("--flag", None)), + ], + ) + def test__parse_raw_parameter( + self, input_param: str, expected: Tuple[str, Optional[str]] + ) -> None: + """Test internal method of parsing a single raw parameter.""" + assert parse_raw_parameter(input_param) == expected + + +class TestParam: + """Test Param class.""" + + def test__eq__(self) -> None: + """Test equality method with different cases.""" + param1 = Param(name="test", description="desc", values=["values"]) + param2 = Param(name="test", description="desc", values=["values"]) + param3 = Param(name="test1", description="desc", values=["values"]) + param4 = object() + + assert param1 == param2 + assert param1 != param3 + assert param1 != param4 + + def test_get_details(self) -> None: + """Test get_details() method.""" + param1 = Param(name="test", description="desc", values=["values"]) + assert param1.get_details() == { + "name": "test", + "values": ["values"], + "description": "desc", + } + + def test_invalid(self) -> None: + """Test invalid use cases for the Param class.""" + with pytest.raises( + ConfigurationException, + match="Either name, alias or both must be set to identify a parameter.", + ): + Param(name=None, description="desc", values=["values"]) + + +class TestCommand: + """Test Command class.""" + + def test_get_details(self) -> None: + """Test get_details() method.""" + param1 = Param(name="test", description="desc", values=["values"]) + command1 = Command(command_strings=["echo test"], params=[param1]) + assert command1.get_details() == { + "command_strings": ["echo test"], + "user_params": [ + {"name": "test", "values": ["values"], "description": "desc"} + ], + } + + def test__eq__(self) -> None: + """Test equality method with different cases.""" + param1 = Param("test", "desc", ["values"]) + param2 = Param("test1", "desc1", ["values1"]) + command1 = Command(command_strings=["echo test"], params=[param1]) + command2 = Command(command_strings=["echo test"], params=[param1]) + command3 = Command(command_strings=["echo test"]) + command4 = Command(command_strings=["echo test"], params=[param2]) + command5 = object() + + assert command1 == command2 + assert command1 != command3 + assert command1 != command4 + assert command1 != command5 + + @pytest.mark.parametrize( + "params, expected_error", + [ + [[], does_not_raise()], + [[Param("param", "param description", [])], does_not_raise()], + [ + [ + Param("param", "param description", [], None, "alias"), + Param("param", "param description", [], None), + ], + does_not_raise(), + ], + [ + [ + Param("param1", "param1 description", [], None, "alias1"), + Param("param2", "param2 description", [], None, "alias2"), + ], + does_not_raise(), + ], + [ + [ + Param("param", "param description", [], None, "alias"), + Param("param", "param description", [], None, "alias"), + ], + pytest.raises(ConfigurationException, match="Non unique aliases alias"), + ], + [ + [ + Param("alias", "param description", [], None, "alias1"), + Param("param", "param description", [], None, "alias"), + ], + pytest.raises( + ConfigurationException, + match="Aliases .* could not be used as parameter name", + ), + ], + [ + [ + Param("alias", "param description", [], None, "alias"), + Param("param1", "param1 description", [], None, "alias1"), + ], + does_not_raise(), + ], + [ + [ + Param("alias", "param description", [], None, "alias"), + Param("alias", "param1 description", [], None, "alias1"), + ], + pytest.raises( + ConfigurationException, + match="Aliases .* could not be used as parameter name", + ), + ], + [ + [ + Param("param1", "param1 description", [], None, "alias1"), + Param("param2", "param2 description", [], None, "alias1"), + Param("param3", "param3 description", [], None, "alias2"), + Param("param4", "param4 description", [], None, "alias2"), + ], + pytest.raises( + ConfigurationException, match="Non unique aliases alias1, alias2" + ), + ], + ], + ) + def test_validate_params(self, params: List[Param], expected_error: Any) -> None: + """Test command validation function.""" + with expected_error: + Command([], params) diff --git a/tests/aiet/test_backend_controller.py b/tests/aiet/test_backend_controller.py new file mode 100644 index 0000000..8836ec5 --- /dev/null +++ b/tests/aiet/test_backend_controller.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for system controller.""" +import csv +import os +import time +from pathlib import Path +from typing import Any + +import psutil +import pytest + +from aiet.backend.common import ConfigurationException +from aiet.backend.controller import SystemController +from aiet.backend.controller import SystemControllerSingleInstance +from aiet.utils.proc import ShellCommand + + +def get_system_controller(**kwargs: Any) -> SystemController: + """Get service controller.""" + single_instance = kwargs.get("single_instance", False) + if single_instance: + pid_file_path = kwargs.get("pid_file_path") + return SystemControllerSingleInstance(pid_file_path) + + return SystemController() + + +def test_service_controller() -> None: + """Test service controller functionality.""" + service_controller = get_system_controller() + + assert service_controller.get_output() == ("", "") + with pytest.raises(ConfigurationException, match="Wrong working directory"): + service_controller.start(["sleep 100"], Path("unknown")) + + service_controller.start(["sleep 100"], Path.cwd()) + assert service_controller.is_running() + + service_controller.stop(True) + assert not service_controller.is_running() + assert service_controller.get_output() == ("", "") + + service_controller.stop() + + with pytest.raises( + ConfigurationException, match="System should have only one command to run" + ): + service_controller.start(["sleep 100", "sleep 101"], Path.cwd()) + + with pytest.raises(ConfigurationException, match="No startup command provided"): + service_controller.start([""], Path.cwd()) + + +def test_service_controller_bad_configuration() -> None: + """Test service controller functionality for bad configuration.""" + with pytest.raises(Exception, match="No pid file path presented"): + service_controller = get_system_controller( + single_instance=True, pid_file_path=None + ) + service_controller.start(["sleep 100"], Path.cwd()) + + +def test_service_controller_writes_process_info_correctly(tmpdir: Any) -> None: + """Test that controller writes process info correctly.""" + pid_file = Path(tmpdir) / "test.pid" + + service_controller = get_system_controller( + single_instance=True, pid_file_path=Path(tmpdir) / "test.pid" + ) + + service_controller.start(["sleep 100"], Path.cwd()) + assert service_controller.is_running() + assert pid_file.is_file() + + with open(pid_file, "r", encoding="utf-8") as file: + csv_reader = csv.reader(file) + rows = list(csv_reader) + assert len(rows) == 1 + + name, *_ = rows[0] + assert name == "sleep" + + service_controller.stop() + assert pid_file.exists() + + +def test_service_controller_does_not_write_process_info_if_process_finishes( + tmpdir: Any, +) -> None: + """Test that controller does not write process info if process already finished.""" + pid_file = Path(tmpdir) / "test.pid" + service_controller = get_system_controller( + single_instance=True, pid_file_path=pid_file + ) + service_controller.is_running = lambda: False # type: ignore + service_controller.start(["echo hello"], Path.cwd()) + + assert not pid_file.exists() + + +def test_service_controller_searches_for_previous_instances_correctly( + tmpdir: Any, +) -> None: + """Test that controller searches for previous instances correctly.""" + pid_file = Path(tmpdir) / "test.pid" + command = ShellCommand().run("sleep", "100") + assert command.is_alive() + + pid = command.process.pid + process = psutil.Process(pid) + with open(pid_file, "w", encoding="utf-8") as file: + csv_writer = csv.writer(file) + csv_writer.writerow(("some_process", "some_program", "some_cwd", os.getpid())) + csv_writer.writerow((process.name(), process.exe(), process.cwd(), process.pid)) + csv_writer.writerow(("some_old_process", "not_running", "from_nowhere", 77777)) + + service_controller = get_system_controller( + single_instance=True, pid_file_path=pid_file + ) + service_controller.start(["sleep 100"], Path.cwd()) + # controller should stop this process as it is currently running and + # mentioned in pid file + assert not command.is_alive() + + service_controller.stop() + + +@pytest.mark.parametrize( + "executable", ["test_backend_run_script.sh", "test_backend_run"] +) +def test_service_controller_run_shell_script( + executable: str, test_resources_path: Path +) -> None: + """Test controller's ability to run shell scripts.""" + script_path = test_resources_path / "scripts" + + service_controller = get_system_controller() + + service_controller.start([executable], script_path) + + assert service_controller.is_running() + # give time for the command to produce output + time.sleep(2) + service_controller.stop(wait=True) + assert not service_controller.is_running() + stdout, stderr = service_controller.get_output() + assert stdout == "Hello from script\n" + assert stderr == "Oops!\n" + + +def test_service_controller_does_nothing_if_not_started(tmpdir: Any) -> None: + """Test that nothing happened if controller is not started.""" + service_controller = get_system_controller( + single_instance=True, pid_file_path=Path(tmpdir) / "test.pid" + ) + + assert not service_controller.is_running() + service_controller.stop() + assert not service_controller.is_running() diff --git a/tests/aiet/test_backend_execution.py b/tests/aiet/test_backend_execution.py new file mode 100644 index 0000000..8aa45f1 --- /dev/null +++ b/tests/aiet/test_backend_execution.py @@ -0,0 +1,526 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Test backend context module.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Dict +from typing import Optional +from unittest import mock +from unittest.mock import MagicMock + +import pytest +from sh import CommandNotFound + +from aiet.backend.application import Application +from aiet.backend.application import get_application +from aiet.backend.common import ConfigurationException +from aiet.backend.common import DataPaths +from aiet.backend.common import UserParamConfig +from aiet.backend.config import ApplicationConfig +from aiet.backend.config import LocalProtocolConfig +from aiet.backend.config import SystemConfig +from aiet.backend.execution import deploy_data +from aiet.backend.execution import execute_commands_locally +from aiet.backend.execution import ExecutionContext +from aiet.backend.execution import get_application_and_system +from aiet.backend.execution import get_application_by_name_and_system +from aiet.backend.execution import get_file_lock_path +from aiet.backend.execution import get_tool_by_system +from aiet.backend.execution import ParamResolver +from aiet.backend.execution import Reporter +from aiet.backend.execution import wait +from aiet.backend.output_parser import OutputParser +from aiet.backend.system import get_system +from aiet.backend.system import load_system +from aiet.backend.tool import get_tool +from aiet.utils.proc import CommandFailedException + + +def test_context_param_resolver(tmpdir: Any) -> None: + """Test parameter resolving.""" + system_config_location = Path(tmpdir) / "system" + system_config_location.mkdir() + + application_config_location = Path(tmpdir) / "application" + application_config_location.mkdir() + + ctx = ExecutionContext( + app=Application( + ApplicationConfig( + name="test_application", + description="Test application", + config_location=application_config_location, + build_dir="build-{application.name}-{system.name}", + commands={ + "run": [ + "run_command1 {user_params:0}", + "run_command2 {user_params:1}", + ] + }, + variables={"var_1": "value for var_1"}, + user_params={ + "run": [ + UserParamConfig( + name="--param1", + description="Param 1", + default_value="123", + alias="param_1", + ), + UserParamConfig( + name="--param2", description="Param 2", default_value="456" + ), + UserParamConfig( + name="--param3", description="Param 3", alias="param_3" + ), + UserParamConfig( + name="--param4=", + description="Param 4", + default_value="456", + alias="param_4", + ), + UserParamConfig( + description="Param 5", + default_value="789", + alias="param_5", + ), + ] + }, + ) + ), + app_params=["--param2=789"], + system=load_system( + SystemConfig( + name="test_system", + description="Test system", + config_location=system_config_location, + build_dir="build", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={ + "build": ["build_command1 {user_params:0}"], + "run": ["run_command {application.commands.run:1}"], + }, + variables={"var_1": "value for var_1"}, + user_params={ + "build": [ + UserParamConfig( + name="--param1", description="Param 1", default_value="aaa" + ), + UserParamConfig(name="--param2", description="Param 2"), + ] + }, + ) + ), + system_params=["--param1=bbb"], + custom_deploy_data=[], + ) + + param_resolver = ParamResolver(ctx) + expected_values = { + "application.name": "test_application", + "application.description": "Test application", + "application.config_dir": str(application_config_location), + "application.build_dir": "{}/build-test_application-test_system".format( + application_config_location + ), + "application.commands.run:0": "run_command1 --param1 123", + "application.commands.run.params:0": "123", + "application.commands.run.params:param_1": "123", + "application.commands.run:1": "run_command2 --param2 789", + "application.commands.run.params:1": "789", + "application.variables:var_1": "value for var_1", + "system.name": "test_system", + "system.description": "Test system", + "system.config_dir": str(system_config_location), + "system.commands.build:0": "build_command1 --param1 bbb", + "system.commands.run:0": "run_command run_command2 --param2 789", + "system.commands.build.params:0": "bbb", + "system.variables:var_1": "value for var_1", + } + + for param, value in expected_values.items(): + assert param_resolver(param) == value + + assert ctx.build_dir() == Path( + "{}/build-test_application-test_system".format(application_config_location) + ) + + expected_errors = { + "application.variables:var_2": pytest.raises( + Exception, match="Unknown variable var_2" + ), + "application.commands.clean:0": pytest.raises( + Exception, match="Command clean not found" + ), + "application.commands.run:2": pytest.raises( + Exception, match="Invalid index 2 for command run" + ), + "application.commands.run.params:5": pytest.raises( + Exception, match="Invalid parameter index 5 for command run" + ), + "application.commands.run.params:param_2": pytest.raises( + Exception, + match="No value for parameter with index or alias param_2 of command run", + ), + "UNKNOWN": pytest.raises( + Exception, match="Unable to resolve parameter UNKNOWN" + ), + "system.commands.build.params:1": pytest.raises( + Exception, + match="No value for parameter with index or alias 1 of command build", + ), + "system.commands.build:A": pytest.raises( + Exception, match="Bad command index A" + ), + "system.variables:var_2": pytest.raises( + Exception, match="Unknown variable var_2" + ), + } + for param, error in expected_errors.items(): + with error: + param_resolver(param) + + resolved_params = ctx.app.resolved_parameters("run", []) + expected_user_params = { + "user_params:0": "--param1 123", + "user_params:param_1": "--param1 123", + "user_params:2": "--param3", + "user_params:param_3": "--param3", + "user_params:3": "--param4=456", + "user_params:param_4": "--param4=456", + "user_params:param_5": "789", + } + for param, expected_value in expected_user_params.items(): + assert param_resolver(param, "run", resolved_params) == expected_value + + with pytest.raises( + Exception, match="Invalid index 5 for user params of command run" + ): + param_resolver("user_params:5", "run", resolved_params) + + with pytest.raises( + Exception, match="No user parameter for command 'run' with alias 'param_2'." + ): + param_resolver("user_params:param_2", "run", resolved_params) + + with pytest.raises(Exception, match="Unable to resolve user params"): + param_resolver("user_params:0", "", resolved_params) + + bad_ctx = ExecutionContext( + app=Application( + ApplicationConfig( + name="test_application", + config_location=application_config_location, + build_dir="build-{user_params:0}", + ) + ), + app_params=["--param2=789"], + system=load_system( + SystemConfig( + name="test_system", + description="Test system", + config_location=system_config_location, + build_dir="build-{system.commands.run.params:123}", + data_transfer=LocalProtocolConfig(protocol="local"), + ) + ), + system_params=["--param1=bbb"], + custom_deploy_data=[], + ) + param_resolver = ParamResolver(bad_ctx) + with pytest.raises(Exception, match="Unable to resolve user params"): + bad_ctx.build_dir() + + +# pylint: disable=too-many-arguments +@pytest.mark.parametrize( + "application_name, soft_lock, sys_lock, lock_dir, expected_error, expected_path", + ( + ( + "test_application", + True, + True, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_application_test_system.lock"), + ), + ( + "$$test_application$!:", + True, + True, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_application_test_system.lock"), + ), + ( + "test_application", + True, + True, + Path("unknown"), + pytest.raises( + Exception, match="Invalid directory unknown for lock files provided" + ), + None, + ), + ( + "test_application", + False, + True, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_system.lock"), + ), + ( + "test_application", + True, + False, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_application.lock"), + ), + ( + "test_application", + False, + False, + Path("/tmp"), + pytest.raises(Exception, match="No filename for lock provided"), + None, + ), + ), +) +def test_get_file_lock_path( + application_name: str, + soft_lock: bool, + sys_lock: bool, + lock_dir: Path, + expected_error: Any, + expected_path: Path, +) -> None: + """Test get_file_lock_path function.""" + with expected_error: + ctx = ExecutionContext( + app=Application(ApplicationConfig(name=application_name, lock=soft_lock)), + app_params=[], + system=load_system( + SystemConfig( + name="test_system", + lock=sys_lock, + data_transfer=LocalProtocolConfig(protocol="local"), + ) + ), + system_params=[], + custom_deploy_data=[], + ) + path = get_file_lock_path(ctx, lock_dir) + assert path == expected_path + + +def test_get_application_by_name_and_system(monkeypatch: Any) -> None: + """Test exceptional case for get_application_by_name_and_system.""" + monkeypatch.setattr( + "aiet.backend.execution.get_application", + MagicMock(return_value=[MagicMock(), MagicMock()]), + ) + + with pytest.raises( + ValueError, + match="Error during getting application test_application for the " + "system test_system", + ): + get_application_by_name_and_system("test_application", "test_system") + + +def test_get_application_and_system(monkeypatch: Any) -> None: + """Test exceptional case for get_application_and_system.""" + monkeypatch.setattr( + "aiet.backend.execution.get_system", MagicMock(return_value=None) + ) + + with pytest.raises(ValueError, match="System test_system is not found"): + get_application_and_system("test_application", "test_system") + + +def test_wait_function(monkeypatch: Any) -> None: + """Test wait function.""" + sleep_mock = MagicMock() + monkeypatch.setattr("time.sleep", sleep_mock) + wait(0.1) + sleep_mock.assert_called_once() + + +def test_deployment_execution_context() -> None: + """Test property 'is_deploy_needed' of the ExecutionContext.""" + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=get_system("System 1"), + system_params=[], + ) + assert not ctx.is_deploy_needed + deploy_data(ctx) # should be a NOP + + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=get_system("System 1"), + system_params=[], + custom_deploy_data=[DataPaths(Path("README.md"), ".")], + ) + assert ctx.is_deploy_needed + + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=None, + system_params=[], + ) + assert not ctx.is_deploy_needed + with pytest.raises(AssertionError): + deploy_data(ctx) + + ctx = ExecutionContext( + app=get_tool("tool_1")[0], + app_params=[], + system=None, + system_params=[], + ) + assert not ctx.is_deploy_needed + deploy_data(ctx) # should be a NOP + + +@pytest.mark.parametrize( + ["tool_name", "system_name", "exception"], + [ + ("vela", "Corstone-300: Cortex-M55+Ethos-U65", None), + ("unknown tool", "Corstone-300: Cortex-M55+Ethos-U65", ConfigurationException), + ("vela", "unknown system", ConfigurationException), + ("vela", None, ConfigurationException), + ], +) +def test_get_tool_by_system( + tool_name: str, system_name: Optional[str], exception: Optional[Any] +) -> None: + """Test exceptions thrown by function get_tool_by_system().""" + + def test() -> None: + """Test call of get_tool_by_system().""" + tool = get_tool_by_system(tool_name, system_name) + assert tool is not None + + if exception is None: + test() + else: + with pytest.raises(exception): + test() + + +class TestExecuteCommandsLocally: + """Test execute_commands_locally() function.""" + + @pytest.mark.parametrize( + "first_command, exception, expected_output", + ( + ( + "echo 'hello'", + None, + "Running: echo 'hello'\nhello\nRunning: echo 'goodbye'\ngoodbye\n", + ), + ( + "non-existent-command", + CommandNotFound, + "Running: non-existent-command\n", + ), + ("false", CommandFailedException, "Running: false\n"), + ), + ids=( + "runs_multiple_commands", + "stops_executing_on_non_existent_command", + "stops_executing_when_command_exits_with_error_code", + ), + ) + def test_execution( + self, + first_command: str, + exception: Any, + expected_output: str, + test_resources_path: Path, + capsys: Any, + ) -> None: + """Test expected behaviour of the function.""" + commands = [first_command, "echo 'goodbye'"] + cwd = test_resources_path + if exception is None: + execute_commands_locally(commands, cwd) + else: + with pytest.raises(exception): + execute_commands_locally(commands, cwd) + + captured = capsys.readouterr() + assert captured.out == expected_output + + def test_stops_executing_on_exception( + self, monkeypatch: Any, test_resources_path: Path + ) -> None: + """Ensure commands following an error-exit-code command don't run.""" + # Mock execute_command() function + execute_command_mock = mock.MagicMock() + monkeypatch.setattr("aiet.utils.proc.execute_command", execute_command_mock) + + # Mock Command object and assign as return value to execute_command() + cmd_mock = mock.MagicMock() + execute_command_mock.return_value = cmd_mock + + # Mock the terminate_command (speed up test) + terminate_command_mock = mock.MagicMock() + monkeypatch.setattr("aiet.utils.proc.terminate_command", terminate_command_mock) + + # Mock a thrown Exception and assign to Command().exit_code + exit_code_mock = mock.PropertyMock(side_effect=Exception("Exception.")) + type(cmd_mock).exit_code = exit_code_mock + + with pytest.raises(Exception, match="Exception."): + execute_commands_locally( + ["command_1", "command_2"], cwd=test_resources_path + ) + + # Assert only "command_1" was executed + assert execute_command_mock.call_count == 1 + + +def test_reporter(tmpdir: Any) -> None: + """Test class 'Reporter'.""" + ctx = ExecutionContext( + app=get_application("application_4")[0], + app_params=["--app=TestApp"], + system=get_system("System 4"), + system_params=[], + ) + assert ctx.system is not None + + class MockParser(OutputParser): + """Mock implementation of an output parser.""" + + def __init__(self, metrics: Dict[str, Any]) -> None: + """Set up the MockParser.""" + super().__init__(name="test") + self.metrics = metrics + + def __call__(self, output: bytearray) -> Dict[str, Any]: + """Return mock metrics (ignoring the given output).""" + return self.metrics + + metrics = {"Metric": 123, "AnotherMetric": 456} + reporter = Reporter( + parsers=[MockParser(metrics={key: val}) for key, val in metrics.items()], + ) + reporter.parse(bytearray()) + report = reporter.report(ctx) + assert report["system"]["name"] == ctx.system.name + assert report["system"]["params"] == {} + assert report["application"]["name"] == ctx.app.name + assert report["application"]["params"] == {"--app": "TestApp"} + assert report["test"]["metrics"] == metrics + report_file = Path(tmpdir) / "report.json" + reporter.save(report, report_file) + assert report_file.is_file() diff --git a/tests/aiet/test_backend_output_parser.py b/tests/aiet/test_backend_output_parser.py new file mode 100644 index 0000000..d659812 --- /dev/null +++ b/tests/aiet/test_backend_output_parser.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the output parsing.""" +import base64 +import json +from typing import Any +from typing import Dict + +import pytest + +from aiet.backend.output_parser import Base64OutputParser +from aiet.backend.output_parser import OutputParser +from aiet.backend.output_parser import RegexOutputParser + + +OUTPUT_MATCH_ALL = bytearray( + """ +String1: My awesome string! +String2: STRINGS_ARE_GREAT!!! +Int: 12 +Float: 3.14 +""", + encoding="utf-8", +) + +OUTPUT_NO_MATCH = bytearray( + """ +This contains no matches... +Test1234567890!"£$%^&*()_+@~{}[]/.,<>?| +""", + encoding="utf-8", +) + +OUTPUT_PARTIAL_MATCH = bytearray( + "String1: My awesome string!", + encoding="utf-8", +) + +REGEX_CONFIG = { + "FirstString": {"pattern": r"String1.*: (.*)", "type": "str"}, + "SecondString": {"pattern": r"String2.*: (.*)!!!", "type": "str"}, + "IntegerValue": {"pattern": r"Int.*: (.*)", "type": "int"}, + "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"}, +} + +EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {} + +EXPECTED_METRICS_ALL = { + "FirstString": "My awesome string!", + "SecondString": "STRINGS_ARE_GREAT", + "IntegerValue": 12, + "FloatValue": 3.14, +} + +EXPECTED_METRICS_PARTIAL = { + "FirstString": "My awesome string!", +} + + +class TestRegexOutputParser: + """Collect tests for the RegexOutputParser.""" + + @staticmethod + @pytest.mark.parametrize( + ["output", "config", "expected_metrics"], + [ + (OUTPUT_MATCH_ALL, REGEX_CONFIG, EXPECTED_METRICS_ALL), + (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL), + (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL), + ( + OUTPUT_MATCH_ALL + OUTPUT_PARTIAL_MATCH, + REGEX_CONFIG, + EXPECTED_METRICS_ALL, + ), + (OUTPUT_NO_MATCH, REGEX_CONFIG, {}), + (OUTPUT_MATCH_ALL, EMPTY_REGEX_CONFIG, {}), + (bytearray(), EMPTY_REGEX_CONFIG, {}), + (bytearray(), REGEX_CONFIG, {}), + ], + ) + def test_parsing(output: bytearray, config: Dict, expected_metrics: Dict) -> None: + """ + Make sure the RegexOutputParser yields valid results. + + I.e. return an empty dict if either the input or the config is empty and + return the parsed metrics otherwise. + """ + parser = RegexOutputParser(name="Test", regex_config=config) + assert parser.name == "Test" + assert isinstance(parser, OutputParser) + res = parser(output) + assert res == expected_metrics + + @staticmethod + def test_unsupported_type() -> None: + """An unsupported type in the regex_config must raise an exception.""" + config = {"BrokenMetric": {"pattern": "(.*)", "type": "UNSUPPORTED_TYPE"}} + with pytest.raises(TypeError): + RegexOutputParser(name="Test", regex_config=config) + + @staticmethod + @pytest.mark.parametrize( + "config", + ( + {"TooManyGroups": {"pattern": r"(\w)(\d)", "type": "str"}}, + {"NoGroups": {"pattern": r"\W", "type": "str"}}, + ), + ) + def test_invalid_pattern(config: Dict) -> None: + """Exactly one capturing parenthesis is allowed in the regex pattern.""" + with pytest.raises(ValueError): + RegexOutputParser(name="Test", regex_config=config) + + +@pytest.mark.parametrize( + "expected_metrics", + [ + EXPECTED_METRICS_ALL, + EXPECTED_METRICS_PARTIAL, + ], +) +def test_base64_output_parser(expected_metrics: Dict) -> None: + """ + Make sure the Base64OutputParser yields valid results. + + I.e. return an empty dict if either the input or the config is empty and + return the parsed metrics otherwise. + """ + parser = Base64OutputParser(name="Test") + assert parser.name == "Test" + assert isinstance(parser, OutputParser) + + def create_base64_output(expected_metrics: Dict) -> bytearray: + json_str = json.dumps(expected_metrics, indent=4) + json_b64 = base64.b64encode(json_str.encode("utf-8")) + return ( + OUTPUT_MATCH_ALL # Should not be matched by the Base64OutputParser + + f"<{Base64OutputParser.TAG_NAME}>".encode("utf-8") + + bytearray(json_b64) + + f"</{Base64OutputParser.TAG_NAME}>".encode("utf-8") + + OUTPUT_NO_MATCH # Just to add some difficulty... + ) + + output = create_base64_output(expected_metrics) + res = parser(output) + assert len(res) == 1 + assert isinstance(res, dict) + for val in res.values(): + assert val == expected_metrics + + output = parser.filter_out_parsed_content(output) + assert output == (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH) diff --git a/tests/aiet/test_backend_protocol.py b/tests/aiet/test_backend_protocol.py new file mode 100644 index 0000000..2103238 --- /dev/null +++ b/tests/aiet/test_backend_protocol.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use,attribute-defined-outside-init,protected-access +"""Tests for the protocol backend module.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import paramiko +import pytest + +from aiet.backend.common import ConfigurationException +from aiet.backend.config import LocalProtocolConfig +from aiet.backend.protocol import CustomSFTPClient +from aiet.backend.protocol import LocalProtocol +from aiet.backend.protocol import ProtocolFactory +from aiet.backend.protocol import SSHProtocol + + +class TestProtocolFactory: + """Test ProtocolFactory class.""" + + @pytest.mark.parametrize( + "config, expected_class, exception", + [ + ( + { + "protocol": "ssh", + "username": "user", + "password": "pass", + "hostname": "hostname", + "port": "22", + }, + SSHProtocol, + does_not_raise(), + ), + ({"protocol": "local"}, LocalProtocol, does_not_raise()), + ( + {"protocol": "something"}, + None, + pytest.raises(Exception, match="Protocol not supported"), + ), + (None, None, pytest.raises(Exception, match="No protocol config provided")), + ], + ) + def test_get_protocol( + self, config: Any, expected_class: type, exception: Any + ) -> None: + """Test get_protocol method.""" + factory = ProtocolFactory() + with exception: + protocol = factory.get_protocol(config) + assert isinstance(protocol, expected_class) + + +class TestLocalProtocol: + """Test local protocol.""" + + def test_local_protocol_run_command(self) -> None: + """Test local protocol run command.""" + config = LocalProtocolConfig(protocol="local") + protocol = LocalProtocol(config, cwd=Path("/tmp")) + ret, stdout, stderr = protocol.run("pwd") + assert ret == 0 + assert stdout.decode("utf-8").strip() == "/tmp" + assert stderr.decode("utf-8") == "" + + def test_local_protocol_run_wrong_cwd(self) -> None: + """Execution should fail if wrong working directory provided.""" + config = LocalProtocolConfig(protocol="local") + protocol = LocalProtocol(config, cwd=Path("unknown_directory")) + with pytest.raises( + ConfigurationException, match="Wrong working directory unknown_directory" + ): + protocol.run("pwd") + + +class TestSSHProtocol: + """Test SSH protocol.""" + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch: Any) -> None: + """Set up protocol mocks.""" + self.mock_ssh_client = MagicMock(spec=paramiko.client.SSHClient) + + self.mock_ssh_channel = ( + self.mock_ssh_client.get_transport.return_value.open_session.return_value + ) + self.mock_ssh_channel.mock_add_spec(spec=paramiko.channel.Channel) + self.mock_ssh_channel.exit_status_ready.side_effect = [False, True] + self.mock_ssh_channel.recv_exit_status.return_value = True + self.mock_ssh_channel.recv_ready.side_effect = [False, True] + self.mock_ssh_channel.recv_stderr_ready.side_effect = [False, True] + + monkeypatch.setattr( + "aiet.backend.protocol.paramiko.client.SSHClient", + MagicMock(return_value=self.mock_ssh_client), + ) + + self.mock_sftp_client = MagicMock(spec=CustomSFTPClient) + monkeypatch.setattr( + "aiet.backend.protocol.CustomSFTPClient.from_transport", + MagicMock(return_value=self.mock_sftp_client), + ) + + ssh_config = { + "protocol": "ssh", + "username": "user", + "password": "pass", + "hostname": "hostname", + "port": "22", + } + self.protocol = SSHProtocol(ssh_config) + + def test_unable_create_ssh_client(self, monkeypatch: Any) -> None: + """Test that command should fail if unable to create ssh client instance.""" + monkeypatch.setattr( + "aiet.backend.protocol.paramiko.client.SSHClient", + MagicMock(side_effect=OSError("Error!")), + ) + + with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"): + self.protocol.run("command_example", retry=False) + + def test_ssh_protocol_run_command(self) -> None: + """Test that command run via ssh successfully.""" + self.protocol.run("command_example") + self.mock_ssh_channel.exec_command.assert_called_once() + + def test_ssh_protocol_run_command_connect_failed(self) -> None: + """Test that if connection is not possible then correct exception is raised.""" + self.mock_ssh_client.connect.side_effect = OSError("Unable to connect") + self.mock_ssh_client.close.side_effect = Exception("Error!") + + with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"): + self.protocol.run("command_example", retry=False) + + def test_ssh_protocol_run_command_bad_transport(self) -> None: + """Test that command should fail if unable to get transport.""" + self.mock_ssh_client.get_transport.return_value = None + + with pytest.raises(Exception, match="Unable to get transport"): + self.protocol.run("command_example", retry=False) + + def test_ssh_protocol_deploy_command_file( + self, test_applications_path: Path + ) -> None: + """Test that files could be deployed over ssh.""" + file_for_deploy = test_applications_path / "readme.txt" + dest = "/tmp/dest" + + self.protocol.deploy(file_for_deploy, dest) + self.mock_sftp_client.put.assert_called_once_with(str(file_for_deploy), dest) + + def test_ssh_protocol_deploy_command_unknown_file(self) -> None: + """Test that deploy will fail if file does not exist.""" + with pytest.raises(Exception, match="Deploy error: file type not supported"): + self.protocol.deploy(Path("unknown_file"), "/tmp/dest") + + def test_ssh_protocol_deploy_command_bad_transport(self) -> None: + """Test that deploy should fail if unable to get transport.""" + self.mock_ssh_client.get_transport.return_value = None + + with pytest.raises(Exception, match="Unable to get transport"): + self.protocol.deploy(Path("some_file"), "/tmp/dest") + + def test_ssh_protocol_deploy_command_directory( + self, test_resources_path: Path + ) -> None: + """Test that directory could be deployed over ssh.""" + directory_for_deploy = test_resources_path / "scripts" + dest = "/tmp/dest" + + self.protocol.deploy(directory_for_deploy, dest) + self.mock_sftp_client.put_dir.assert_called_once_with( + directory_for_deploy, dest + ) + + @pytest.mark.parametrize("establish_connection", (True, False)) + def test_ssh_protocol_close(self, establish_connection: bool) -> None: + """Test protocol close operation.""" + if establish_connection: + self.protocol.establish_connection() + self.protocol.close() + + call_count = 1 if establish_connection else 0 + assert self.mock_ssh_channel.exec_command.call_count == call_count + + def test_connection_details(self) -> None: + """Test getting connection details.""" + assert self.protocol.connection_details() == ("hostname", 22) + + +class TestCustomSFTPClient: + """Test CustomSFTPClient class.""" + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch: Any) -> None: + """Set up mocks for CustomSFTPClient instance.""" + self.mock_mkdir = MagicMock() + self.mock_put = MagicMock() + monkeypatch.setattr( + "aiet.backend.protocol.paramiko.SFTPClient.__init__", + MagicMock(return_value=None), + ) + monkeypatch.setattr( + "aiet.backend.protocol.paramiko.SFTPClient.mkdir", self.mock_mkdir + ) + monkeypatch.setattr( + "aiet.backend.protocol.paramiko.SFTPClient.put", self.mock_put + ) + + self.sftp_client = CustomSFTPClient(MagicMock()) + + def test_put_dir(self, test_systems_path: Path) -> None: + """Test deploying directory to remote host.""" + directory_for_deploy = test_systems_path / "system1" + + self.sftp_client.put_dir(directory_for_deploy, "/tmp/dest") + assert self.mock_put.call_count == 3 + assert self.mock_mkdir.call_count == 3 + + def test_mkdir(self) -> None: + """Test creating directory on remote host.""" + self.mock_mkdir.side_effect = IOError("Cannot create directory") + + self.sftp_client._mkdir("new_directory", ignore_existing=True) + + with pytest.raises(IOError, match="Cannot create directory"): + self.sftp_client._mkdir("new_directory", ignore_existing=False) diff --git a/tests/aiet/test_backend_source.py b/tests/aiet/test_backend_source.py new file mode 100644 index 0000000..13b2c6d --- /dev/null +++ b/tests/aiet/test_backend_source.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Tests for the source backend module.""" +from collections import Counter +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from aiet.backend.common import ConfigurationException +from aiet.backend.source import create_destination_and_install +from aiet.backend.source import DirectorySource +from aiet.backend.source import get_source +from aiet.backend.source import TarArchiveSource + + +def test_create_destination_and_install(test_systems_path: Path, tmpdir: Any) -> None: + """Test create_destination_and_install function.""" + system_directory = test_systems_path / "system1" + + dir_source = DirectorySource(system_directory) + resources = Path(tmpdir) + create_destination_and_install(dir_source, resources) + assert (resources / "system1").is_dir() + + +@patch("aiet.backend.source.DirectorySource.create_destination", return_value=False) +def test_create_destination_and_install_if_dest_creation_not_required( + mock_ds_create_destination: Any, tmpdir: Any +) -> None: + """Test create_destination_and_install function.""" + dir_source = DirectorySource(Path("unknown")) + resources = Path(tmpdir) + with pytest.raises(Exception): + create_destination_and_install(dir_source, resources) + + mock_ds_create_destination.assert_called_once() + + +def test_create_destination_and_install_if_installation_fails(tmpdir: Any) -> None: + """Test create_destination_and_install function if installation fails.""" + dir_source = DirectorySource(Path("unknown")) + resources = Path(tmpdir) + with pytest.raises(Exception, match="Directory .* does not exist"): + create_destination_and_install(dir_source, resources) + assert not (resources / "unknown").exists() + assert resources.exists() + + +def test_create_destination_and_install_if_name_is_empty() -> None: + """Test create_destination_and_install function fails if source name is empty.""" + source = MagicMock() + source.create_destination.return_value = True + source.name.return_value = None + + with pytest.raises(Exception, match="Unable to get source name"): + create_destination_and_install(source, Path("some_path")) + + source.install_into.assert_not_called() + + +@pytest.mark.parametrize( + "source_path, expected_class, expected_error", + [ + (Path("applications/application1/"), DirectorySource, does_not_raise()), + ( + Path("archives/applications/application1.tar.gz"), + TarArchiveSource, + does_not_raise(), + ), + ( + Path("doesnt/exist"), + None, + pytest.raises( + ConfigurationException, match="Unable to read .*doesnt/exist" + ), + ), + ], +) +def test_get_source( + source_path: Path, + expected_class: Any, + expected_error: Any, + test_resources_path: Path, +) -> None: + """Test get_source function.""" + with expected_error: + full_source_path = test_resources_path / source_path + source = get_source(full_source_path) + assert isinstance(source, expected_class) + + +class TestDirectorySource: + """Test DirectorySource class.""" + + @pytest.mark.parametrize( + "directory, name", + [ + (Path("/some/path/some_system"), "some_system"), + (Path("some_system"), "some_system"), + ], + ) + def test_name(self, directory: Path, name: str) -> None: + """Test getting source name.""" + assert DirectorySource(directory).name() == name + + def test_install_into(self, test_systems_path: Path, tmpdir: Any) -> None: + """Test install directory into destination.""" + system_directory = test_systems_path / "system1" + + dir_source = DirectorySource(system_directory) + with pytest.raises(Exception, match="Wrong destination .*"): + dir_source.install_into(Path("unknown_destination")) + + tmpdir_path = Path(tmpdir) + dir_source.install_into(tmpdir_path) + source_files = [f.name for f in system_directory.iterdir()] + dest_files = [f.name for f in tmpdir_path.iterdir()] + assert Counter(source_files) == Counter(dest_files) + + def test_install_into_unknown_source_directory(self, tmpdir: Any) -> None: + """Test install system from unknown directory.""" + with pytest.raises(Exception, match="Directory .* does not exist"): + DirectorySource(Path("unknown_directory")).install_into(Path(tmpdir)) + + +class TestTarArchiveSource: + """Test TarArchiveSource class.""" + + @pytest.mark.parametrize( + "archive, name", + [ + (Path("some_archive.tgz"), "some_archive"), + (Path("some_archive.tar.gz"), "some_archive"), + (Path("some_archive"), "some_archive"), + ("archives/systems/system1.tar.gz", "system1"), + ("archives/systems/system1_dir.tar.gz", "system1"), + ], + ) + def test_name(self, test_resources_path: Path, archive: Path, name: str) -> None: + """Test getting source name.""" + assert TarArchiveSource(test_resources_path / archive).name() == name + + def test_install_into(self, test_resources_path: Path, tmpdir: Any) -> None: + """Test install archive into destination.""" + system_archive = test_resources_path / "archives/systems/system1.tar.gz" + + tar_source = TarArchiveSource(system_archive) + with pytest.raises(Exception, match="Wrong destination .*"): + tar_source.install_into(Path("unknown_destination")) + + tmpdir_path = Path(tmpdir) + tar_source.install_into(tmpdir_path) + source_files = [ + "aiet-config.json.license", + "aiet-config.json", + "system_artifact", + ] + dest_files = [f.name for f in tmpdir_path.iterdir()] + assert Counter(source_files) == Counter(dest_files) + + def test_install_into_unknown_source_archive(self, tmpdir: Any) -> None: + """Test install unknown source archive.""" + with pytest.raises(Exception, match="File .* does not exist"): + TarArchiveSource(Path("unknown.tar.gz")).install_into(Path(tmpdir)) + + def test_install_into_unsupported_source_archive(self, tmpdir: Any) -> None: + """Test install unsupported file type.""" + plain_text_file = Path(tmpdir) / "test_file" + plain_text_file.write_text("Not a system config") + + with pytest.raises(Exception, match="Unsupported archive type .*"): + TarArchiveSource(plain_text_file).install_into(Path(tmpdir)) + + def test_lazy_property_init(self, test_resources_path: Path) -> None: + """Test that class properties initialized correctly.""" + system_archive = test_resources_path / "archives/systems/system1.tar.gz" + + tar_source = TarArchiveSource(system_archive) + assert tar_source.name() == "system1" + assert tar_source.config() is not None + assert tar_source.create_destination() + + tar_source = TarArchiveSource(system_archive) + assert tar_source.config() is not None + assert tar_source.create_destination() + assert tar_source.name() == "system1" + + def test_create_destination_property(self, test_resources_path: Path) -> None: + """Test create_destination property filled correctly for different archives.""" + system_archive1 = test_resources_path / "archives/systems/system1.tar.gz" + system_archive2 = test_resources_path / "archives/systems/system1_dir.tar.gz" + + assert TarArchiveSource(system_archive1).create_destination() + assert not TarArchiveSource(system_archive2).create_destination() diff --git a/tests/aiet/test_backend_system.py b/tests/aiet/test_backend_system.py new file mode 100644 index 0000000..a581547 --- /dev/null +++ b/tests/aiet/test_backend_system.py @@ -0,0 +1,536 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for system backend.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from unittest.mock import MagicMock + +import pytest + +from aiet.backend.common import Command +from aiet.backend.common import ConfigurationException +from aiet.backend.common import Param +from aiet.backend.common import UserParamConfig +from aiet.backend.config import LocalProtocolConfig +from aiet.backend.config import ProtocolConfig +from aiet.backend.config import SSHConfig +from aiet.backend.config import SystemConfig +from aiet.backend.controller import SystemController +from aiet.backend.controller import SystemControllerSingleInstance +from aiet.backend.protocol import LocalProtocol +from aiet.backend.protocol import SSHProtocol +from aiet.backend.protocol import SupportsClose +from aiet.backend.protocol import SupportsDeploy +from aiet.backend.system import ControlledSystem +from aiet.backend.system import get_available_systems +from aiet.backend.system import get_controller +from aiet.backend.system import get_system +from aiet.backend.system import install_system +from aiet.backend.system import load_system +from aiet.backend.system import remove_system +from aiet.backend.system import StandaloneSystem +from aiet.backend.system import System + + +def dummy_resolver( + values: Optional[Dict[str, str]] = None +) -> Callable[[str, str, List[Tuple[Optional[str], Param]]], str]: + """Return dummy parameter resolver implementation.""" + # pylint: disable=unused-argument + def resolver( + param: str, cmd: str, param_values: List[Tuple[Optional[str], Param]] + ) -> str: + """Implement dummy parameter resolver.""" + return values.get(param, "") if values else "" + + return resolver + + +def test_get_available_systems() -> None: + """Test get_available_systems mocking get_resources.""" + available_systems = get_available_systems() + assert all(isinstance(s, System) for s in available_systems) + assert len(available_systems) == 3 + assert [str(s) for s in available_systems] == ["System 1", "System 2", "System 4"] + + +def test_get_system() -> None: + """Test get_system.""" + system1 = get_system("System 1") + assert isinstance(system1, ControlledSystem) + assert system1.connectable is True + assert system1.connection_details() == ("localhost", 8021) + assert system1.name == "System 1" + + system2 = get_system("System 2") + # check that comparison with object of another type returns false + assert system1 != 42 + assert system1 != system2 + + system = get_system("Unknown system") + assert system is None + + +@pytest.mark.parametrize( + "source, call_count, exception_type", + ( + ( + "archives/systems/system1.tar.gz", + 0, + pytest.raises(Exception, match="Systems .* are already installed"), + ), + ( + "archives/systems/system3.tar.gz", + 0, + pytest.raises(Exception, match="Unable to read system definition"), + ), + ( + "systems/system1", + 0, + pytest.raises(Exception, match="Systems .* are already installed"), + ), + ( + "systems/system3", + 0, + pytest.raises(Exception, match="Unable to read system definition"), + ), + ("unknown_path", 0, pytest.raises(Exception, match="Unable to read")), + ( + "various/systems/system_with_empty_config", + 0, + pytest.raises(Exception, match="No system definition found"), + ), + ("various/systems/system_with_valid_config", 1, does_not_raise()), + ), +) +def test_install_system( + monkeypatch: Any, + test_resources_path: Path, + source: str, + call_count: int, + exception_type: Any, +) -> None: + """Test system installation from archive.""" + mock_create_destination_and_install = MagicMock() + monkeypatch.setattr( + "aiet.backend.system.create_destination_and_install", + mock_create_destination_and_install, + ) + + with exception_type: + install_system(test_resources_path / source) + + assert mock_create_destination_and_install.call_count == call_count + + +def test_remove_system(monkeypatch: Any) -> None: + """Test system removal.""" + mock_remove_backend = MagicMock() + monkeypatch.setattr("aiet.backend.system.remove_backend", mock_remove_backend) + remove_system("some_system_dir") + mock_remove_backend.assert_called_once() + + +def test_system(monkeypatch: Any) -> None: + """Test the System class.""" + config = SystemConfig(name="System 1") + monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock()) + system = System(config) + assert str(system) == "System 1" + assert system.name == "System 1" + + +def test_system_with_empty_parameter_name() -> None: + """Test that configuration fails if parameter name is empty.""" + bad_config = SystemConfig( + name="System 1", + commands={"run": ["run"]}, + user_params={"run": [{"name": "", "values": ["1", "2", "3"]}]}, + ) + with pytest.raises(Exception, match="Parameter has an empty 'name' attribute."): + System(bad_config) + + +def test_system_standalone_run() -> None: + """Test run operation for standalone system.""" + system = get_system("System 4") + assert isinstance(system, StandaloneSystem) + + with pytest.raises( + ConfigurationException, match="System .* does not support connections" + ): + system.connection_details() + + with pytest.raises( + ConfigurationException, match="System .* does not support connections" + ): + system.establish_connection() + + assert system.connectable is False + + system.run("echo 'application run'") + + +@pytest.mark.parametrize( + "system_name, expected_value", [("System 1", True), ("System 4", False)] +) +def test_system_supports_deploy(system_name: str, expected_value: bool) -> None: + """Test system property supports_deploy.""" + system = get_system(system_name) + if system is None: + pytest.fail("Unable to get system {}".format(system_name)) + assert system.supports_deploy == expected_value + + +@pytest.mark.parametrize( + "mock_protocol", + [ + MagicMock(spec=SSHProtocol), + MagicMock( + spec=SSHProtocol, + **{"close.side_effect": ValueError("Unable to close protocol")} + ), + MagicMock(spec=LocalProtocol), + ], +) +def test_system_start_and_stop(monkeypatch: Any, mock_protocol: MagicMock) -> None: + """Test system start, run commands and stop.""" + monkeypatch.setattr( + "aiet.backend.system.ProtocolFactory.get_protocol", + MagicMock(return_value=mock_protocol), + ) + + system = get_system("System 1") + if system is None: + pytest.fail("Unable to get system") + assert isinstance(system, ControlledSystem) + + with pytest.raises(Exception, match="System has not been started"): + system.stop() + + assert not system.is_running() + assert system.get_output() == ("", "") + system.start(["sleep 10"], False) + assert system.is_running() + system.stop(wait=True) + assert not system.is_running() + assert system.get_output() == ("", "") + + if isinstance(mock_protocol, SupportsClose): + mock_protocol.close.assert_called_once() + + if isinstance(mock_protocol, SSHProtocol): + system.establish_connection() + + +def test_system_start_no_config_location() -> None: + """Test that system without config location could not start.""" + system = load_system( + SystemConfig( + name="test", + data_transfer=SSHConfig( + protocol="ssh", + username="user", + password="user", + hostname="localhost", + port="123", + ), + ) + ) + + assert isinstance(system, ControlledSystem) + with pytest.raises( + ConfigurationException, match="System test has wrong config location" + ): + system.start(["sleep 100"]) + + +@pytest.mark.parametrize( + "config, expected_class, expected_error", + [ + ( + SystemConfig( + name="test", + data_transfer=SSHConfig( + protocol="ssh", + username="user", + password="user", + hostname="localhost", + port="123", + ), + ), + ControlledSystem, + does_not_raise(), + ), + ( + SystemConfig( + name="test", data_transfer=LocalProtocolConfig(protocol="local") + ), + StandaloneSystem, + does_not_raise(), + ), + ( + SystemConfig( + name="test", + data_transfer=ProtocolConfig(protocol="cool_protocol"), # type: ignore + ), + None, + pytest.raises( + Exception, match="Unsupported execution type for protocol cool_protocol" + ), + ), + ], +) +def test_load_system( + config: SystemConfig, expected_class: type, expected_error: Any +) -> None: + """Test load_system function.""" + if not expected_class: + with expected_error: + load_system(config) + else: + system = load_system(config) + assert isinstance(system, expected_class) + + +def test_load_system_populate_shared_params() -> None: + """Test shared parameters population.""" + with pytest.raises(Exception, match="All shared parameters should have aliases"): + load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + ) + ] + }, + ) + ) + + with pytest.raises( + Exception, match="All parameters for command run should have aliases" + ): + load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + alias="shared_param1", + ) + ], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + ) + ], + }, + ) + ) + system0 = load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["run_command"]}, + user_params={ + "shared": [], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + alias="run_param1", + ) + ], + }, + ) + ) + assert len(system0.commands) == 1 + run_command1 = system0.commands["run"] + assert run_command1 == Command( + ["run_command"], + [ + Param( + "--run_param1", + "Run specific parameter", + ["1", "2", "3"], + "2", + "run_param1", + ) + ], + ) + + system1 = load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + alias="shared_param1", + ) + ], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + alias="run_param1", + ) + ], + }, + ) + ) + assert len(system1.commands) == 2 + build_command1 = system1.commands["build"] + assert build_command1 == Command( + [], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ) + ], + ) + + run_command1 = system1.commands["run"] + assert run_command1 == Command( + [], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ), + Param( + "--run_param1", + "Run specific parameter", + ["1", "2", "3"], + "2", + "run_param1", + ), + ], + ) + + system2 = load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"build": ["build_command"]}, + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + alias="shared_param1", + ) + ], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + alias="run_param1", + ) + ], + }, + ) + ) + assert len(system2.commands) == 2 + build_command2 = system2.commands["build"] + assert build_command2 == Command( + ["build_command"], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ) + ], + ) + + run_command2 = system1.commands["run"] + assert run_command2 == Command( + [], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ), + Param( + "--run_param1", + "Run specific parameter", + ["1", "2", "3"], + "2", + "run_param1", + ), + ], + ) + + +@pytest.mark.parametrize( + "mock_protocol, expected_call_count", + [(MagicMock(spec=SupportsDeploy), 1), (MagicMock(), 0)], +) +def test_system_deploy_data( + monkeypatch: Any, mock_protocol: MagicMock, expected_call_count: int +) -> None: + """Test deploy data functionality.""" + monkeypatch.setattr( + "aiet.backend.system.ProtocolFactory.get_protocol", + MagicMock(return_value=mock_protocol), + ) + + system = ControlledSystem(SystemConfig(name="test")) + system.deploy(Path("some_file"), "some_dest") + + assert mock_protocol.deploy.call_count == expected_call_count + + +@pytest.mark.parametrize( + "single_instance, controller_class", + ((False, SystemController), (True, SystemControllerSingleInstance)), +) +def test_get_controller(single_instance: bool, controller_class: type) -> None: + """Test function get_controller.""" + controller = get_controller(single_instance) + assert isinstance(controller, controller_class) diff --git a/tests/aiet/test_backend_tool.py b/tests/aiet/test_backend_tool.py new file mode 100644 index 0000000..fd5960d --- /dev/null +++ b/tests/aiet/test_backend_tool.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Tests for the tool backend.""" +from collections import Counter + +import pytest + +from aiet.backend.common import ConfigurationException +from aiet.backend.config import ToolConfig +from aiet.backend.tool import get_available_tool_directory_names +from aiet.backend.tool import get_available_tools +from aiet.backend.tool import get_tool +from aiet.backend.tool import Tool + + +def test_get_available_tool_directory_names() -> None: + """Test get_available_tools mocking get_resources.""" + directory_names = get_available_tool_directory_names() + assert Counter(directory_names) == Counter(["tool1", "tool2", "vela"]) + + +def test_get_available_tools() -> None: + """Test get_available_tools mocking get_resources.""" + available_tools = get_available_tools() + expected_tool_names = sorted( + [ + "tool_1", + "tool_2", + "vela", + "vela", + "vela", + ] + ) + + assert all(isinstance(s, Tool) for s in available_tools) + assert all(s != 42 for s in available_tools) + assert any(s == available_tools[0] for s in available_tools) + assert len(available_tools) == len(expected_tool_names) + available_tool_names = sorted(str(s) for s in available_tools) + assert available_tool_names == expected_tool_names + + +def test_get_tool() -> None: + """Test get_tool mocking get_resoures.""" + tools = get_tool("tool_1") + assert len(tools) == 1 + tool = tools[0] + assert tool is not None + assert isinstance(tool, Tool) + assert tool.name == "tool_1" + + tools = get_tool("unknown tool") + assert not tools + + +def test_tool_creation() -> None: + """Test edge cases when creating a Tool instance.""" + with pytest.raises(ConfigurationException): + Tool(ToolConfig(name="test", commands={"test": []})) # no 'run' command diff --git a/tests/aiet/test_check_model.py b/tests/aiet/test_check_model.py new file mode 100644 index 0000000..4eafe59 --- /dev/null +++ b/tests/aiet/test_check_model.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=redefined-outer-name,no-self-use +"""Module for testing check_model.py script.""" +from pathlib import Path +from typing import Any + +import pytest +from ethosu.vela.tflite.Model import Model +from ethosu.vela.tflite.OperatorCode import OperatorCode + +from aiet.cli.common import InvalidTFLiteFileError +from aiet.cli.common import ModelOptimisedException +from aiet.resources.tools.vela.check_model import check_custom_codes_for_ethosu +from aiet.resources.tools.vela.check_model import check_model +from aiet.resources.tools.vela.check_model import get_custom_codes_from_operators +from aiet.resources.tools.vela.check_model import get_model_from_file +from aiet.resources.tools.vela.check_model import get_operators_from_model +from aiet.resources.tools.vela.check_model import is_vela_optimised + + +@pytest.fixture(scope="session") +def optimised_tflite_model( + optimised_input_model_file: Path, +) -> Model: + """Return Model instance read from a Vela-optimised TFLite file.""" + return get_model_from_file(optimised_input_model_file) + + +@pytest.fixture(scope="session") +def non_optimised_tflite_model( + non_optimised_input_model_file: Path, +) -> Model: + """Return Model instance read from a Vela-optimised TFLite file.""" + return get_model_from_file(non_optimised_input_model_file) + + +class TestIsVelaOptimised: + """Test class for is_vela_optimised() function.""" + + def test_return_true_when_input_is_optimised( + self, + optimised_tflite_model: Model, + ) -> None: + """Verify True returned when input is optimised model.""" + output = is_vela_optimised(optimised_tflite_model) + + assert output is True + + def test_return_false_when_input_is_not_optimised( + self, + non_optimised_tflite_model: Model, + ) -> None: + """Verify False returned when input is non-optimised model.""" + output = is_vela_optimised(non_optimised_tflite_model) + + assert output is False + + +def test_get_operator_list_returns_correct_instances( + optimised_tflite_model: Model, +) -> None: + """Verify list of OperatorCode instances returned by get_operator_list().""" + operator_list = get_operators_from_model(optimised_tflite_model) + + assert all(isinstance(operator, OperatorCode) for operator in operator_list) + + +class TestGetCustomCodesFromOperators: + """Test the get_custom_codes_from_operators() function.""" + + def test_returns_empty_list_when_input_operators_have_no_custom_codes( + self, monkeypatch: Any + ) -> None: + """Verify function returns empty list when operators have no custom codes.""" + # Mock OperatorCode.CustomCode() function to return None + monkeypatch.setattr( + "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", lambda _: None + ) + + operators = [OperatorCode()] * 3 + + custom_codes = get_custom_codes_from_operators(operators) + + assert custom_codes == [] + + def test_returns_custom_codes_when_input_operators_have_custom_codes( + self, monkeypatch: Any + ) -> None: + """Verify list of bytes objects returned representing the CustomCodes.""" + # Mock OperatorCode.CustomCode() function to return a byte string + monkeypatch.setattr( + "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", + lambda _: b"custom-code", + ) + + operators = [OperatorCode()] * 3 + + custom_codes = get_custom_codes_from_operators(operators) + + assert custom_codes == [b"custom-code", b"custom-code", b"custom-code"] + + +@pytest.mark.parametrize( + "custom_codes, expected_output", + [ + ([b"ethos-u", b"something else"], True), + ([b"custom-code-1", b"custom-code-2"], False), + ], +) +def test_check_list_for_ethosu(custom_codes: list, expected_output: bool) -> None: + """Verify function detects 'ethos-u' bytes in the input list.""" + output = check_custom_codes_for_ethosu(custom_codes) + assert output is expected_output + + +class TestGetModelFromFile: + """Test the get_model_from_file() function.""" + + def test_error_raised_when_input_is_invalid_model_file( + self, + invalid_input_model_file: Path, + ) -> None: + """Verify error thrown when an invalid model file is given.""" + with pytest.raises(InvalidTFLiteFileError): + get_model_from_file(invalid_input_model_file) + + def test_model_instance_returned_when_input_is_valid_model_file( + self, + optimised_input_model_file: Path, + ) -> None: + """Verify file is read successfully and returns model instance.""" + tflite_model = get_model_from_file(optimised_input_model_file) + + assert isinstance(tflite_model, Model) + + +class TestCheckModel: + """Test the check_model() function.""" + + def test_check_model_with_non_optimised_input( + self, + non_optimised_input_model_file: Path, + ) -> None: + """Verify no error occurs for a valid input file.""" + check_model(non_optimised_input_model_file) + + def test_check_model_with_optimised_input( + self, + optimised_input_model_file: Path, + ) -> None: + """Verify that the right exception is raised with already optimised input.""" + with pytest.raises(ModelOptimisedException): + check_model(optimised_input_model_file) + + def test_check_model_with_invalid_input( + self, + invalid_input_model_file: Path, + ) -> None: + """Verify that an exception is raised with invalid input.""" + with pytest.raises(Exception): + check_model(invalid_input_model_file) diff --git a/tests/aiet/test_cli.py b/tests/aiet/test_cli.py new file mode 100644 index 0000000..e8589fa --- /dev/null +++ b/tests/aiet/test_cli.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for testing CLI top command.""" +from typing import Any +from unittest.mock import ANY +from unittest.mock import MagicMock + +from click.testing import CliRunner + +from aiet.cli import cli + + +def test_cli(cli_runner: CliRunner) -> None: + """Test CLI top level command.""" + result = cli_runner.invoke(cli) + assert result.exit_code == 0 + assert "system" in cli.commands + assert "application" in cli.commands + + +def test_cli_version(cli_runner: CliRunner) -> None: + """Test version option.""" + result = cli_runner.invoke(cli, ["--version"]) + assert result.exit_code == 0 + assert "version" in result.output + + +def test_cli_verbose(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test verbose option.""" + with monkeypatch.context() as mock_context: + mock = MagicMock() + # params[1] is the verbose option and we need to replace the + # callback with a mock object + mock_context.setattr(cli.params[1], "callback", mock) + cli_runner.invoke(cli, ["-vvvv"]) + # 4 is the number -v called earlier + mock.assert_called_once_with(ANY, ANY, 4) diff --git a/tests/aiet/test_cli_application.py b/tests/aiet/test_cli_application.py new file mode 100644 index 0000000..f1ccc44 --- /dev/null +++ b/tests/aiet/test_cli_application.py @@ -0,0 +1,1153 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=attribute-defined-outside-init,no-member,line-too-long,too-many-arguments,too-many-locals,redefined-outer-name,too-many-lines +"""Module for testing CLI application subcommand.""" +import base64 +import json +import re +import time +from contextlib import contextmanager +from contextlib import ExitStack +from pathlib import Path +from typing import Any +from typing import Generator +from typing import IO +from typing import List +from typing import Optional +from typing import TypedDict +from unittest.mock import MagicMock + +import click +import pytest +from click.testing import CliRunner +from filelock import FileLock + +from aiet.backend.application import Application +from aiet.backend.config import ApplicationConfig +from aiet.backend.config import LocalProtocolConfig +from aiet.backend.config import SSHConfig +from aiet.backend.config import SystemConfig +from aiet.backend.config import UserParamConfig +from aiet.backend.output_parser import Base64OutputParser +from aiet.backend.protocol import SSHProtocol +from aiet.backend.system import load_system +from aiet.cli.application import application_cmd +from aiet.cli.application import details_cmd +from aiet.cli.application import execute_cmd +from aiet.cli.application import install_cmd +from aiet.cli.application import list_cmd +from aiet.cli.application import parse_payload_run_config +from aiet.cli.application import remove_cmd +from aiet.cli.application import run_cmd +from aiet.cli.common import MiddlewareExitCode + + +def test_application_cmd() -> None: + """Test application commands.""" + commands = ["list", "details", "install", "remove", "execute", "run"] + assert all(command in application_cmd.commands for command in commands) + + +@pytest.mark.parametrize("format_", ["json", "cli"]) +def test_application_cmd_context(cli_runner: CliRunner, format_: str) -> None: + """Test setting command context parameters.""" + result = cli_runner.invoke(application_cmd, ["--format", format_]) + # command should fail if no subcommand provided + assert result.exit_code == 2 + + result = cli_runner.invoke(application_cmd, ["--format", format_, "list"]) + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "format_, system_name, expected_output", + [ + ( + "json", + None, + '{"type": "application", "available": ["application_1", "application_2"]}\n', + ), + ( + "json", + "system_1", + '{"type": "application", "available": ["application_1"]}\n', + ), + ("cli", None, "Available applications:\n\napplication_1\napplication_2\n"), + ("cli", "system_1", "Available applications:\n\napplication_1\n"), + ], +) +def test_list_cmd( + cli_runner: CliRunner, + monkeypatch: Any, + format_: str, + system_name: str, + expected_output: str, +) -> None: + """Test available applications commands.""" + # Mock some applications + mock_application_1 = MagicMock(spec=Application) + mock_application_1.name = "application_1" + mock_application_1.can_run_on.return_value = system_name == "system_1" + mock_application_2 = MagicMock(spec=Application) + mock_application_2.name = "application_2" + mock_application_2.can_run_on.return_value = system_name == "system_2" + + # Monkey patch the call get_available_applications + mock_available_applications = MagicMock() + mock_available_applications.return_value = [mock_application_1, mock_application_2] + + monkeypatch.setattr( + "aiet.backend.application.get_available_applications", + mock_available_applications, + ) + + obj = {"format": format_} + args = [] + if system_name: + list_cmd.params[0].type = click.Choice([system_name]) + args = ["--system", system_name] + result = cli_runner.invoke(list_cmd, obj=obj, args=args) + assert result.output == expected_output + + +def get_test_application() -> Application: + """Return test system details.""" + config = ApplicationConfig( + name="application", + description="test", + build_dir="", + supported_systems=[], + deploy_data=[], + user_params={}, + commands={ + "clean": ["clean"], + "build": ["build"], + "run": ["run"], + "post_run": ["post_run"], + }, + ) + + return Application(config) + + +def get_details_cmd_json_output() -> str: + """Get JSON output for details command.""" + json_output = """ +[ + { + "type": "application", + "name": "application", + "description": "test", + "supported_systems": [], + "commands": { + "clean": { + "command_strings": [ + "clean" + ], + "user_params": [] + }, + "build": { + "command_strings": [ + "build" + ], + "user_params": [] + }, + "run": { + "command_strings": [ + "run" + ], + "user_params": [] + }, + "post_run": { + "command_strings": [ + "post_run" + ], + "user_params": [] + } + } + } +]""" + return json.dumps(json.loads(json_output)) + "\n" + + +def get_details_cmd_console_output() -> str: + """Get console output for details command.""" + return ( + 'Application "application" details' + + "\nDescription: test" + + "\n\nSupported systems: " + + "\n\nclean commands:" + + "\nCommands: ['clean']" + + "\n\nbuild commands:" + + "\nCommands: ['build']" + + "\n\nrun commands:" + + "\nCommands: ['run']" + + "\n\npost_run commands:" + + "\nCommands: ['post_run']" + + "\n" + ) + + +@pytest.mark.parametrize( + "application_name,format_, expected_output", + [ + ("application", "json", get_details_cmd_json_output()), + ("application", "cli", get_details_cmd_console_output()), + ], +) +def test_details_cmd( + cli_runner: CliRunner, + monkeypatch: Any, + application_name: str, + format_: str, + expected_output: str, +) -> None: + """Test application details command.""" + monkeypatch.setattr( + "aiet.cli.application.get_application", + MagicMock(return_value=[get_test_application()]), + ) + + details_cmd.params[0].type = click.Choice(["application"]) + result = cli_runner.invoke( + details_cmd, obj={"format": format_}, args=["--name", application_name] + ) + assert result.exception is None + assert result.output == expected_output + + +def test_details_cmd_wrong_system(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test details command fails if application is not supported by the system.""" + monkeypatch.setattr( + "aiet.backend.execution.get_application", MagicMock(return_value=[]) + ) + + details_cmd.params[0].type = click.Choice(["application"]) + details_cmd.params[1].type = click.Choice(["system"]) + result = cli_runner.invoke( + details_cmd, args=["--name", "application", "--system", "system"] + ) + assert result.exit_code == 2 + assert ( + "Application 'application' doesn't support the system 'system'" in result.stdout + ) + + +def test_install_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test install application command.""" + mock_install_application = MagicMock() + monkeypatch.setattr( + "aiet.cli.application.install_application", mock_install_application + ) + + args = ["--source", "test"] + cli_runner.invoke(install_cmd, args=args) + mock_install_application.assert_called_once_with(Path("test")) + + +def test_remove_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test remove application command.""" + mock_remove_application = MagicMock() + monkeypatch.setattr( + "aiet.cli.application.remove_application", mock_remove_application + ) + remove_cmd.params[0].type = click.Choice(["test"]) + + args = ["--directory_name", "test"] + cli_runner.invoke(remove_cmd, args=args) + mock_remove_application.assert_called_once_with("test") + + +class ExecutionCase(TypedDict, total=False): + """Execution case.""" + + args: List[str] + lock_path: str + can_establish_connection: bool + establish_connection_delay: int + app_exit_code: int + exit_code: int + output: str + + +@pytest.mark.parametrize( + "application_config, system_config, executions", + [ + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + config_location=Path("wrong_location"), + commands={"build": ["echo build {application.name}"]}, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + config_location=Path("wrong_location"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "build"], + exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, + output="Error: Application test_application has wrong config location\n", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + build_dir="build", + deploy_data=[("sample_file", "/tmp/sample_file")], + commands={"build": ["echo build {application.name}"]}, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "run"], + exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, + output="Error: System test_system does not support data deploy\n", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + commands={"build": ["echo build {application.name}"]}, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "build"], + exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, + output="Error: No build directory defined for the app test_application\n", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["new_system"], + build_dir="build", + commands={ + "build": ["echo build {application.name} with {user_params:0}"] + }, + user_params={ + "build": [ + UserParamConfig( + name="param", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "build"], + exit_code=1, + output="Error: Application 'test_application' doesn't support the system 'test_system'\n", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + build_dir="build", + commands={"build": ["false"]}, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "build"], + exit_code=MiddlewareExitCode.BACKEND_ERROR, + output="""Running: false +Error: Execution failed. Please check output for the details.\n""", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + lock=True, + build_dir="build", + commands={ + "build": ["echo build {application.name} with {user_params:0}"] + }, + user_params={ + "build": [ + UserParamConfig( + name="param", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + lock=True, + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=["-c", "build"], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Running: echo build test_application with param default +build test_application with param default\n""", + ), + ExecutionCase( + args=["-c", "build"], + lock_path="/tmp/middleware_test_application_test_system.lock", + exit_code=MiddlewareExitCode.CONCURRENT_ERROR, + output="Error: Another instance of the system is running\n", + ), + ExecutionCase( + args=["-c", "build", "--param=param=val3"], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Running: echo build test_application with param val3 +build test_application with param val3\n""", + ), + ExecutionCase( + args=["-c", "build", "--param=param=newval"], + exit_code=1, + output="Error: Application parameter 'param=newval' not valid for command 'build'\n", + ), + ExecutionCase( + args=["-c", "some_command"], + exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, + output="Error: Unsupported command some_command\n", + ), + ExecutionCase( + args=["-c", "run"], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Generating commands to execute +Running: echo run test_application on test_system +run test_application on test_system\n""", + ), + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + deploy_data=[("sample_file", "/tmp/sample_file")], + commands={ + "run": [ + "echo run {application.name} with {user_params:param} on {system.name}" + ] + }, + user_params={ + "run": [ + UserParamConfig( + name="param=", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + alias="param", + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + lock=True, + data_transfer=SSHConfig( + protocol="ssh", + username="username", + password="password", + hostname="localhost", + port="8022", + ), + commands={"run": ["sleep 100"]}, + ), + [ + ExecutionCase( + args=["-c", "run"], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . +Deploying {application.config_location}/sample_file onto /tmp/sample_file +Running: echo run test_application with param=default on test_system +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully.\n""", + ), + ExecutionCase( + args=["-c", "run"], + lock_path="/tmp/middleware_test_system.lock", + exit_code=MiddlewareExitCode.CONCURRENT_ERROR, + output="Error: Another instance of the system is running\n", + ), + ExecutionCase( + args=[ + "-c", + "run", + "--deploy={application.config_location}/sample_file:/tmp/sample_file", + ], + exit_code=0, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . +Deploying {application.config_location}/sample_file onto /tmp/sample_file +Deploying {application.config_location}/sample_file onto /tmp/sample_file +Running: echo run test_application with param=default on test_system +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully.\n""", + ), + ExecutionCase( + args=["-c", "run"], + app_exit_code=1, + exit_code=0, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . +Deploying {application.config_location}/sample_file onto /tmp/sample_file +Running: echo run test_application with param=default on test_system +Application exited with exit code 1 +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully.\n""", + ), + ExecutionCase( + args=["-c", "run"], + exit_code=MiddlewareExitCode.CONNECTION_ERROR, + can_establish_connection=False, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .......................................................................................... +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully. +Error: Couldn't connect to 'localhost:8022'.\n""", + ), + ExecutionCase( + args=["-c", "run", "--deploy=bad_format"], + exit_code=1, + output="Error: Invalid deploy parameter 'bad_format' for command run\n", + ), + ExecutionCase( + args=["-c", "run", "--deploy=:"], + exit_code=1, + output="Error: Invalid deploy parameter ':' for command run\n", + ), + ExecutionCase( + args=["-c", "run", "--deploy= : "], + exit_code=1, + output="Error: Invalid deploy parameter ' : ' for command run\n", + ), + ExecutionCase( + args=["-c", "run", "--deploy=some_src_file:"], + exit_code=1, + output="Error: Invalid deploy parameter 'some_src_file:' for command run\n", + ), + ExecutionCase( + args=["-c", "run", "--deploy=:some_dst_file"], + exit_code=1, + output="Error: Invalid deploy parameter ':some_dst_file' for command run\n", + ), + ExecutionCase( + args=["-c", "run", "--deploy=unknown_file:/tmp/dest"], + exit_code=1, + output="Error: Path unknown_file does not exist\n", + ), + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + commands={ + "run": [ + "echo run {application.name} with {user_params:param} on {system.name}" + ] + }, + user_params={ + "run": [ + UserParamConfig( + name="param=", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + alias="param", + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=SSHConfig( + protocol="ssh", + username="username", + password="password", + hostname="localhost", + port="8022", + ), + commands={"run": ["echo Unable to start system"]}, + ), + [ + ExecutionCase( + args=["-c", "run"], + exit_code=4, + can_establish_connection=False, + establish_connection_delay=1, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . + +---------- test_system execution failed ---------- +Unable to start system + + + +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully. +Error: Execution failed. Please check output for the details.\n""", + ) + ], + ], + ], +) +def test_application_command_execution( + application_config: ApplicationConfig, + system_config: SystemConfig, + executions: List[ExecutionCase], + tmpdir: Any, + cli_runner: CliRunner, + monkeypatch: Any, +) -> None: + """Test application command execution.""" + + @contextmanager + def lock_execution(lock_path: str) -> Generator[None, None, None]: + lock = FileLock(lock_path) + lock.acquire(timeout=1) + + try: + yield + finally: + lock.release() + + def replace_vars(str_val: str) -> str: + """Replace variables.""" + application_config_location = str( + application_config["config_location"].absolute() + ) + + return str_val.replace( + "{application.config_location}", application_config_location + ) + + for execution in executions: + init_execution_test( + monkeypatch, + tmpdir, + application_config, + system_config, + can_establish_connection=execution.get("can_establish_connection", True), + establish_conection_delay=execution.get("establish_connection_delay", 0), + remote_app_exit_code=execution.get("app_exit_code", 0), + ) + + lock_path = execution.get("lock_path") + + with ExitStack() as stack: + if lock_path: + stack.enter_context(lock_execution(lock_path)) + + args = [replace_vars(arg) for arg in execution["args"]] + + result = cli_runner.invoke( + execute_cmd, + args=["-n", application_config["name"], "-s", system_config["name"]] + + args, + ) + output = replace_vars(execution["output"]) + assert result.exit_code == execution["exit_code"] + assert result.stdout == output + + +@pytest.fixture(params=[False, True], ids=["run-cli", "run-json"]) +def payload_path_or_none(request: Any, tmp_path_factory: Any) -> Optional[Path]: + """Drives tests for run command so that it executes them both to use a json file, and to use CLI.""" + if request.param: + ret: Path = tmp_path_factory.getbasetemp() / "system_config_payload_file.json" + return ret + return None + + +def write_system_payload_config( + payload_file: IO[str], + application_config: ApplicationConfig, + system_config: SystemConfig, +) -> None: + """Write a json payload file for the given test configuration.""" + payload_dict = { + "id": system_config["name"], + "arguments": { + "application": application_config["name"], + }, + } + json.dump(payload_dict, payload_file) + + +@pytest.mark.parametrize( + "application_config, system_config, executions", + [ + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + build_dir="build", + commands={ + "build": ["echo build {application.name} with {user_params:0}"] + }, + user_params={ + "build": [ + UserParamConfig( + name="param", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ), + [ + ExecutionCase( + args=[], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Running: echo build test_application with param default +build test_application with param default +Generating commands to execute +Running: echo run test_application on test_system +run test_application on test_system\n""", + ) + ], + ], + [ + ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + commands={ + "run": [ + "echo run {application.name} with {user_params:param} on {system.name}" + ] + }, + user_params={ + "run": [ + UserParamConfig( + name="param=", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + alias="param", + ) + ] + }, + ), + SystemConfig( + name="test_system", + description="Test system", + data_transfer=SSHConfig( + protocol="ssh", + username="username", + password="password", + hostname="localhost", + port="8022", + ), + commands={"run": ["sleep 100"]}, + ), + [ + ExecutionCase( + args=[], + exit_code=MiddlewareExitCode.SUCCESS, + output="""Generating commands to execute +Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . +Running: echo run test_application with param=default on test_system +Shutting down sequence... +Stopping test_system... (It could take few seconds) +test_system stopped successfully.\n""", + ) + ], + ], + ], +) +def test_application_run( + application_config: ApplicationConfig, + system_config: SystemConfig, + executions: List[ExecutionCase], + tmpdir: Any, + cli_runner: CliRunner, + monkeypatch: Any, + payload_path_or_none: Path, +) -> None: + """Test application command execution.""" + for execution in executions: + init_execution_test(monkeypatch, tmpdir, application_config, system_config) + + if payload_path_or_none: + with open(payload_path_or_none, "w", encoding="utf-8") as payload_file: + write_system_payload_config( + payload_file, application_config, system_config + ) + + result = cli_runner.invoke( + run_cmd, + args=["--config", str(payload_path_or_none)], + ) + else: + result = cli_runner.invoke( + run_cmd, + args=["-n", application_config["name"], "-s", system_config["name"]] + + execution["args"], + ) + + assert result.stdout == execution["output"] + assert result.exit_code == execution["exit_code"] + + +@pytest.mark.parametrize( + "cmdline,error_pattern", + [ + [ + "--config {payload} -s test_system", + "when --config is set, the following parameters should not be provided", + ], + [ + "--config {payload} -n test_application", + "when --config is set, the following parameters should not be provided", + ], + [ + "--config {payload} -p mypar:3", + "when --config is set, the following parameters should not be provided", + ], + [ + "-p mypar:3", + "when --config is not set, the following parameters are required", + ], + ["-s test_system", "when --config is not set, --name is required"], + ["-n test_application", "when --config is not set, --system is required"], + ], +) +def test_application_run_invalid_param_combinations( + cmdline: str, + error_pattern: str, + cli_runner: CliRunner, + monkeypatch: Any, + tmp_path: Any, + tmpdir: Any, +) -> None: + """Test that invalid combinations arguments result in error as expected.""" + application_config = ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + build_dir="build", + commands={"build": ["echo build {application.name} with {user_params:0}"]}, + user_params={ + "build": [ + UserParamConfig( + name="param", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + ) + ] + }, + ) + system_config = SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["echo run {application.name} on {system.name}"]}, + ) + + init_execution_test(monkeypatch, tmpdir, application_config, system_config) + + payload_file = tmp_path / "payload.json" + payload_file.write_text("dummy") + result = cli_runner.invoke( + run_cmd, + args=cmdline.format(payload=payload_file).split(), + ) + found = re.search(error_pattern, result.stdout) + assert found, f"Cannot find pattern: [{error_pattern}] in \n[\n{result.stdout}\n]" + + +@pytest.mark.parametrize( + "payload,expected", + [ + pytest.param( + {"arguments": {}}, + None, + marks=pytest.mark.xfail(reason="no system 'id''", strict=True), + ), + pytest.param( + {"id": "testsystem"}, + None, + marks=pytest.mark.xfail(reason="no arguments object", strict=True), + ), + ( + {"id": "testsystem", "arguments": {"application": "testapp"}}, + ("testsystem", "testapp", [], [], [], None), + ), + ( + { + "id": "testsystem", + "arguments": {"application": "testapp", "par1": "val1"}, + }, + ("testsystem", "testapp", ["par1=val1"], [], [], None), + ), + ( + { + "id": "testsystem", + "arguments": {"application": "testapp", "application/par1": "val1"}, + }, + ("testsystem", "testapp", ["par1=val1"], [], [], None), + ), + ( + { + "id": "testsystem", + "arguments": {"application": "testapp", "system/par1": "val1"}, + }, + ("testsystem", "testapp", [], ["par1=val1"], [], None), + ), + ( + { + "id": "testsystem", + "arguments": {"application": "testapp", "deploy/par1": "val1"}, + }, + ("testsystem", "testapp", [], [], ["par1"], None), + ), + ( + { + "id": "testsystem", + "arguments": { + "application": "testapp", + "appar1": "val1", + "application/appar2": "val2", + "system/syspar1": "val3", + "deploy/depploypar1": "val4", + "application/appar3": "val5", + "system/syspar2": "val6", + "deploy/depploypar2": "val7", + }, + }, + ( + "testsystem", + "testapp", + ["appar1=val1", "appar2=val2", "appar3=val5"], + ["syspar1=val3", "syspar2=val6"], + ["depploypar1", "depploypar2"], + None, + ), + ), + ], +) +def test_parse_payload_run_config(payload: dict, expected: tuple) -> None: + """Test parsing of the JSON payload for the run_config command.""" + assert parse_payload_run_config(payload) == expected + + +def test_application_run_report( + tmpdir: Any, + cli_runner: CliRunner, + monkeypatch: Any, +) -> None: + """Test flag '--report' of command 'application run'.""" + app_metrics = {"app_metric": 3.14} + app_metrics_b64 = base64.b64encode(json.dumps(app_metrics).encode("utf-8")) + application_config = ApplicationConfig( + name="test_application", + description="Test application", + supported_systems=["test_system"], + build_dir="build", + commands={"build": ["echo build {application.name} with {user_params:0}"]}, + user_params={ + "build": [ + UserParamConfig( + name="param", + description="sample parameter", + default_value="default", + values=["val1", "val2", "val3"], + ), + UserParamConfig( + name="p2", + description="another parameter, not overridden", + default_value="the-right-choice", + values=["the-right-choice", "the-bad-choice"], + ), + ] + }, + ) + system_config = SystemConfig( + name="test_system", + description="Test system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={ + "run": [ + "echo run {application.name} on {system.name}", + f"echo build <{Base64OutputParser.TAG_NAME}>{app_metrics_b64.decode('utf-8')}</{Base64OutputParser.TAG_NAME}>", + ] + }, + reporting={ + "regex": { + "app_name": { + "pattern": r"run (.\S*) ", + "type": "str", + }, + "sys_name": { + "pattern": r"on (.\S*)", + "type": "str", + }, + } + }, + ) + report_file = Path(tmpdir) / "test_report.json" + param_val = "param=val1" + exit_code = MiddlewareExitCode.SUCCESS + + init_execution_test(monkeypatch, tmpdir, application_config, system_config) + + result = cli_runner.invoke( + run_cmd, + args=[ + "-n", + application_config["name"], + "-s", + system_config["name"], + "--report", + str(report_file), + "--param", + param_val, + ], + ) + assert result.exit_code == exit_code + assert report_file.is_file() + with open(report_file, "r", encoding="utf-8") as file: + report = json.load(file) + + assert report == { + "application": { + "metrics": {"0": {"app_metric": 3.14}}, + "name": "test_application", + "params": {"param": "val1", "p2": "the-right-choice"}, + }, + "system": { + "metrics": {"app_name": "test_application", "sys_name": "test_system"}, + "name": "test_system", + "params": {}, + }, + } + + +def init_execution_test( + monkeypatch: Any, + tmpdir: Any, + application_config: ApplicationConfig, + system_config: SystemConfig, + can_establish_connection: bool = True, + establish_conection_delay: float = 0, + remote_app_exit_code: int = 0, +) -> None: + """Init execution test.""" + application_name = application_config["name"] + system_name = system_config["name"] + + execute_cmd.params[0].type = click.Choice([application_name]) + execute_cmd.params[1].type = click.Choice([system_name]) + execute_cmd.params[2].type = click.Choice(["build", "run", "some_command"]) + + run_cmd.params[0].type = click.Choice([application_name]) + run_cmd.params[1].type = click.Choice([system_name]) + + if "config_location" not in application_config: + application_path = Path(tmpdir) / "application" + application_path.mkdir() + application_config["config_location"] = application_path + + # this file could be used as deploy parameter value or + # as deploy parameter in application configuration + sample_file = application_path / "sample_file" + sample_file.touch() + monkeypatch.setattr( + "aiet.backend.application.get_available_applications", + MagicMock(return_value=[Application(application_config)]), + ) + + ssh_protocol_mock = MagicMock(spec=SSHProtocol) + + def mock_establish_connection() -> bool: + """Mock establish connection function.""" + # give some time for the system to start + time.sleep(establish_conection_delay) + return can_establish_connection + + ssh_protocol_mock.establish_connection.side_effect = mock_establish_connection + ssh_protocol_mock.connection_details.return_value = ("localhost", 8022) + ssh_protocol_mock.run.return_value = ( + remote_app_exit_code, + bytearray(), + bytearray(), + ) + monkeypatch.setattr( + "aiet.backend.protocol.SSHProtocol", MagicMock(return_value=ssh_protocol_mock) + ) + + if "config_location" not in system_config: + system_path = Path(tmpdir) / "system" + system_path.mkdir() + system_config["config_location"] = system_path + monkeypatch.setattr( + "aiet.backend.system.get_available_systems", + MagicMock(return_value=[load_system(system_config)]), + ) + + monkeypatch.setattr("aiet.backend.execution.wait", MagicMock()) diff --git a/tests/aiet/test_cli_common.py b/tests/aiet/test_cli_common.py new file mode 100644 index 0000000..d018e44 --- /dev/null +++ b/tests/aiet/test_cli_common.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Test for cli common module.""" +from typing import Any + +import pytest + +from aiet.cli.common import print_command_details +from aiet.cli.common import raise_exception_at_signal + + +def test_print_command_details(capsys: Any) -> None: + """Test print_command_details function.""" + command = { + "command_strings": ["echo test"], + "user_params": [ + {"name": "param_name", "description": "param_description"}, + { + "name": "param_name2", + "description": "param_description2", + "alias": "alias2", + }, + ], + } + print_command_details(command) + captured = capsys.readouterr() + assert "echo test" in captured.out + assert "param_name" in captured.out + assert "alias2" in captured.out + + +def test_raise_exception_at_signal() -> None: + """Test raise_exception_at_signal graceful shutdown.""" + with pytest.raises(Exception) as err: + raise_exception_at_signal(1, "") + + assert str(err.value) == "Middleware shutdown requested" diff --git a/tests/aiet/test_cli_system.py b/tests/aiet/test_cli_system.py new file mode 100644 index 0000000..fd39f31 --- /dev/null +++ b/tests/aiet/test_cli_system.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for testing CLI system subcommand.""" +import json +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union +from unittest.mock import MagicMock + +import click +import pytest +from click.testing import CliRunner + +from aiet.backend.config import SystemConfig +from aiet.backend.system import load_system +from aiet.backend.system import System +from aiet.cli.system import details_cmd +from aiet.cli.system import install_cmd +from aiet.cli.system import list_cmd +from aiet.cli.system import remove_cmd +from aiet.cli.system import system_cmd + + +def test_system_cmd() -> None: + """Test system commands.""" + commands = ["list", "details", "install", "remove"] + assert all(command in system_cmd.commands for command in commands) + + +@pytest.mark.parametrize("format_", ["json", "cli"]) +def test_system_cmd_context(cli_runner: CliRunner, format_: str) -> None: + """Test setting command context parameters.""" + result = cli_runner.invoke(system_cmd, ["--format", format_]) + # command should fail if no subcommand provided + assert result.exit_code == 2 + + result = cli_runner.invoke(system_cmd, ["--format", format_, "list"]) + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "format_,expected_output", + [ + ("json", '{"type": "system", "available": ["system1", "system2"]}\n'), + ("cli", "Available systems:\n\nsystem1\nsystem2\n"), + ], +) +def test_list_cmd_with_format( + cli_runner: CliRunner, monkeypatch: Any, format_: str, expected_output: str +) -> None: + """Test available systems command with different formats output.""" + # Mock some systems + mock_system1 = MagicMock() + mock_system1.name = "system1" + mock_system2 = MagicMock() + mock_system2.name = "system2" + + # Monkey patch the call get_available_systems + mock_available_systems = MagicMock() + mock_available_systems.return_value = [mock_system1, mock_system2] + monkeypatch.setattr("aiet.cli.system.get_available_systems", mock_available_systems) + + obj = {"format": format_} + result = cli_runner.invoke(list_cmd, obj=obj) + assert result.output == expected_output + + +def get_test_system( + annotations: Optional[Dict[str, Union[str, List[str]]]] = None +) -> System: + """Return test system details.""" + config = SystemConfig( + name="system", + description="test", + data_transfer={ + "protocol": "ssh", + "username": "root", + "password": "root", + "hostname": "localhost", + "port": "8022", + }, + commands={ + "clean": ["clean"], + "build": ["build"], + "run": ["run"], + "post_run": ["post_run"], + }, + annotations=annotations or {}, + ) + + return load_system(config) + + +def get_details_cmd_json_output( + annotations: Optional[Dict[str, Union[str, List[str]]]] = None +) -> str: + """Test JSON output for details command.""" + ann_str = "" + if annotations is not None: + ann_str = '"annotations":{},'.format(json.dumps(annotations)) + + json_output = ( + """ +{ + "type": "system", + "name": "system", + "description": "test", + "data_transfer_protocol": "ssh", + "commands": { + "clean": + { + "command_strings": ["clean"], + "user_params": [] + }, + "build": + { + "command_strings": ["build"], + "user_params": [] + }, + "run": + { + "command_strings": ["run"], + "user_params": [] + }, + "post_run": + { + "command_strings": ["post_run"], + "user_params": [] + } + }, +""" + + ann_str + + """ + "available_application" : [] + } +""" + ) + return json.dumps(json.loads(json_output)) + "\n" + + +def get_details_cmd_console_output( + annotations: Optional[Dict[str, Union[str, List[str]]]] = None +) -> str: + """Test console output for details command.""" + ann_str = "" + if annotations: + val_str = "".join( + "\n\t{}: {}".format(ann_name, ann_value) + for ann_name, ann_value in annotations.items() + ) + ann_str = "\nAnnotations:{}".format(val_str) + return ( + 'System "system" details' + + "\nDescription: test" + + "\nData Transfer Protocol: ssh" + + "\nAvailable Applications: " + + ann_str + + "\n\nclean commands:" + + "\nCommands: ['clean']" + + "\n\nbuild commands:" + + "\nCommands: ['build']" + + "\n\nrun commands:" + + "\nCommands: ['run']" + + "\n\npost_run commands:" + + "\nCommands: ['post_run']" + + "\n" + ) + + +@pytest.mark.parametrize( + "format_,system,expected_output", + [ + ( + "json", + get_test_system(annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}), + get_details_cmd_json_output( + annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]} + ), + ), + ( + "cli", + get_test_system(annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}), + get_details_cmd_console_output( + annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]} + ), + ), + ( + "json", + get_test_system(annotations={}), + get_details_cmd_json_output(annotations={}), + ), + ( + "cli", + get_test_system(annotations={}), + get_details_cmd_console_output(annotations={}), + ), + ], +) +def test_details_cmd( + cli_runner: CliRunner, + monkeypatch: Any, + format_: str, + system: System, + expected_output: str, +) -> None: + """Test details command with different formats output.""" + mock_get_system = MagicMock() + mock_get_system.return_value = system + monkeypatch.setattr("aiet.cli.system.get_system", mock_get_system) + + args = ["--name", "system"] + obj = {"format": format_} + details_cmd.params[0].type = click.Choice(["system"]) + + result = cli_runner.invoke(details_cmd, args=args, obj=obj) + assert result.output == expected_output + + +def test_install_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test install system command.""" + mock_install_system = MagicMock() + monkeypatch.setattr("aiet.cli.system.install_system", mock_install_system) + + args = ["--source", "test"] + cli_runner.invoke(install_cmd, args=args) + mock_install_system.assert_called_once_with(Path("test")) + + +def test_remove_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: + """Test remove system command.""" + mock_remove_system = MagicMock() + monkeypatch.setattr("aiet.cli.system.remove_system", mock_remove_system) + remove_cmd.params[0].type = click.Choice(["test"]) + + args = ["--directory_name", "test"] + cli_runner.invoke(remove_cmd, args=args) + mock_remove_system.assert_called_once_with("test") diff --git a/tests/aiet/test_cli_tool.py b/tests/aiet/test_cli_tool.py new file mode 100644 index 0000000..45d45c8 --- /dev/null +++ b/tests/aiet/test_cli_tool.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=attribute-defined-outside-init,no-member,line-too-long,too-many-arguments,too-many-locals +"""Module for testing CLI tool subcommand.""" +import json +from pathlib import Path +from typing import Any +from typing import List +from typing import Optional +from typing import Sequence +from unittest.mock import MagicMock + +import click +import pytest +from click.testing import CliRunner +from click.testing import Result + +from aiet.backend.tool import get_unique_tool_names +from aiet.backend.tool import Tool +from aiet.cli.tool import details_cmd +from aiet.cli.tool import execute_cmd +from aiet.cli.tool import list_cmd +from aiet.cli.tool import tool_cmd + + +def test_tool_cmd() -> None: + """Test tool commands.""" + commands = ["list", "details", "execute"] + assert all(command in tool_cmd.commands for command in commands) + + +@pytest.mark.parametrize("format_", ["json", "cli"]) +def test_tool_cmd_context(cli_runner: CliRunner, format_: str) -> None: + """Test setting command context parameters.""" + result = cli_runner.invoke(tool_cmd, ["--format", format_]) + # command should fail if no subcommand provided + assert result.exit_code == 2 + + result = cli_runner.invoke(tool_cmd, ["--format", format_, "list"]) + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "format_, expected_output", + [ + ( + "json", + '{"type": "tool", "available": ["tool_1", "tool_2"]}\n', + ), + ("cli", "Available tools:\n\ntool_1\ntool_2\n"), + ], +) +def test_list_cmd( + cli_runner: CliRunner, + monkeypatch: Any, + format_: str, + expected_output: str, +) -> None: + """Test available tool commands.""" + # Mock some tools + mock_tool_1 = MagicMock(spec=Tool) + mock_tool_1.name = "tool_1" + mock_tool_2 = MagicMock(spec=Tool) + mock_tool_2.name = "tool_2" + + # Monkey patch the call get_available_tools + mock_available_tools = MagicMock() + mock_available_tools.return_value = [mock_tool_1, mock_tool_2] + + monkeypatch.setattr("aiet.backend.tool.get_available_tools", mock_available_tools) + + obj = {"format": format_} + args: Sequence[str] = [] + result = cli_runner.invoke(list_cmd, obj=obj, args=args) + assert result.output == expected_output + + +def get_details_cmd_json_output() -> List[dict]: + """Get JSON output for details command.""" + json_output = [ + { + "type": "tool", + "name": "tool_1", + "description": "This is tool 1", + "supported_systems": ["System 1"], + "commands": { + "clean": {"command_strings": ["echo 'clean'"], "user_params": []}, + "build": {"command_strings": ["echo 'build'"], "user_params": []}, + "run": {"command_strings": ["echo 'run'"], "user_params": []}, + "post_run": {"command_strings": ["echo 'post_run'"], "user_params": []}, + }, + } + ] + + return json_output + + +def get_details_cmd_console_output() -> str: + """Get console output for details command.""" + return ( + 'Tool "tool_1" details' + "\nDescription: This is tool 1" + "\n\nSupported systems: System 1" + "\n\nclean commands:" + "\nCommands: [\"echo 'clean'\"]" + "\n\nbuild commands:" + "\nCommands: [\"echo 'build'\"]" + "\n\nrun commands:\nCommands: [\"echo 'run'\"]" + "\n\npost_run commands:" + "\nCommands: [\"echo 'post_run'\"]" + "\n" + ) + + +@pytest.mark.parametrize( + [ + "tool_name", + "format_", + "expected_success", + "expected_output", + ], + [ + ("tool_1", "json", True, get_details_cmd_json_output()), + ("tool_1", "cli", True, get_details_cmd_console_output()), + ("non-existent tool", "json", False, None), + ("non-existent tool", "cli", False, None), + ], +) +def test_details_cmd( + cli_runner: CliRunner, + tool_name: str, + format_: str, + expected_success: bool, + expected_output: str, +) -> None: + """Test tool details command.""" + details_cmd.params[0].type = click.Choice(["tool_1", "tool_2", "vela"]) + result = cli_runner.invoke( + details_cmd, obj={"format": format_}, args=["--name", tool_name] + ) + success = result.exit_code == 0 + assert success == expected_success, result.output + if expected_success: + assert result.exception is None + output = json.loads(result.output) if format_ == "json" else result.output + assert output == expected_output + + +@pytest.mark.parametrize( + "system_name", + [ + "", + "Corstone-300: Cortex-M55+Ethos-U55", + "Corstone-300: Cortex-M55+Ethos-U65", + "Corstone-310: Cortex-M85+Ethos-U55", + ], +) +def test_details_cmd_vela(cli_runner: CliRunner, system_name: str) -> None: + """Test tool details command for Vela.""" + details_cmd.params[0].type = click.Choice(get_unique_tool_names()) + details_cmd.params[1].type = click.Choice([system_name]) + args = ["--name", "vela"] + if system_name: + args += ["--system", system_name] + result = cli_runner.invoke(details_cmd, obj={"format": "json"}, args=args) + success = result.exit_code == 0 + assert success, result.output + result_json = json.loads(result.output) + assert result_json + if system_name: + assert len(result_json) == 1 + tool = result_json[0] + assert len(tool["supported_systems"]) == 1 + assert system_name == tool["supported_systems"][0] + else: # no system specified => list details for all systems + assert len(result_json) == 3 + assert all(len(tool["supported_systems"]) == 1 for tool in result_json) + + +@pytest.fixture(scope="session") +def input_model_file(non_optimised_input_model_file: Path) -> Path: + """Provide the path to a quantized dummy model file in the test_resources_path.""" + return non_optimised_input_model_file + + +def execute_vela( + cli_runner: CliRunner, + tool_name: str = "vela", + system_name: Optional[str] = None, + input_model: Optional[Path] = None, + output_model: Optional[Path] = None, + mac: Optional[int] = None, + format_: str = "cli", +) -> Result: + """Run Vela with different parameters.""" + execute_cmd.params[0].type = click.Choice(get_unique_tool_names()) + execute_cmd.params[2].type = click.Choice([system_name or "dummy_system"]) + args = ["--name", tool_name] + if system_name is not None: + args += ["--system", system_name] + if input_model is not None: + args += ["--param", "input={}".format(input_model)] + if output_model is not None: + args += ["--param", "output={}".format(output_model)] + if mac is not None: + args += ["--param", "mac={}".format(mac)] + result = cli_runner.invoke( + execute_cmd, + args=args, + obj={"format": format_}, + ) + return result + + +@pytest.mark.parametrize("format_", ["cli, json"]) +@pytest.mark.parametrize( + ["tool_name", "system_name", "mac", "expected_success", "expected_output"], + [ + ("vela", "System 1", 32, False, None), # system not supported + ("vela", "NON-EXISTENT SYSTEM", 128, False, None), # system does not exist + ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 32, True, None), + ("NON-EXISTENT TOOL", "Corstone-300: Cortex-M55+Ethos-U55", 32, False, None), + ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 64, True, None), + ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 128, True, None), + ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 256, True, None), + ( + "vela", + "Corstone-300: Cortex-M55+Ethos-U55", + 512, + False, + None, + ), # mac not supported + ( + "vela", + "Corstone-300: Cortex-M55+Ethos-U65", + 32, + False, + None, + ), # mac not supported + ("vela", "Corstone-300: Cortex-M55+Ethos-U65", 256, True, None), + ("vela", "Corstone-300: Cortex-M55+Ethos-U65", 512, True, None), + ( + "vela", + None, + 512, + False, + "Error: Please specify the system for tool vela.", + ), # no system specified + ( + "NON-EXISTENT TOOL", + "Corstone-300: Cortex-M55+Ethos-U65", + 512, + False, + None, + ), # tool does not exist + ("vela", "Corstone-310: Cortex-M85+Ethos-U55", 128, True, None), + ], +) +def test_vela_run( + cli_runner: CliRunner, + format_: str, + input_model_file: Path, # pylint: disable=redefined-outer-name + tool_name: str, + system_name: Optional[str], + mac: int, + expected_success: bool, + expected_output: Optional[str], + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """Test the execution of the Vela command.""" + monkeypatch.chdir(tmp_path) + + output_file = Path("vela_output.tflite") + + result = execute_vela( + cli_runner, + tool_name=tool_name, + system_name=system_name, + input_model=input_model_file, + output_model=output_file, + mac=mac, + format_=format_, + ) + + success = result.exit_code == 0 + assert success == expected_success + if success: + # Check output file + output_file = output_file.resolve() + assert output_file.is_file() + if expected_output: + assert result.output.strip() == expected_output + + +@pytest.mark.parametrize("include_input_model", [True, False]) +@pytest.mark.parametrize("include_output_model", [True, False]) +@pytest.mark.parametrize("include_mac", [True, False]) +def test_vela_run_missing_params( + cli_runner: CliRunner, + input_model_file: Path, # pylint: disable=redefined-outer-name + include_input_model: bool, + include_output_model: bool, + include_mac: bool, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """Test the execution of the Vela command with missing user parameters.""" + monkeypatch.chdir(tmp_path) + + output_model_file = Path("output_model.tflite") + system_name = "Corstone-300: Cortex-M55+Ethos-U65" + mac = 256 + # input_model is a required parameters, but mac and output_model have default values. + expected_success = include_input_model + + result = execute_vela( + cli_runner, + tool_name="vela", + system_name=system_name, + input_model=input_model_file if include_input_model else None, + output_model=output_model_file if include_output_model else None, + mac=mac if include_mac else None, + ) + + success = result.exit_code == 0 + assert success == expected_success, ( + f"Success is {success}, but expected {expected_success}. " + f"Included params: [" + f"input_model={include_input_model}, " + f"output_model={include_output_model}, " + f"mac={include_mac}]" + ) diff --git a/tests/aiet/test_main.py b/tests/aiet/test_main.py new file mode 100644 index 0000000..f2ebae2 --- /dev/null +++ b/tests/aiet/test_main.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for testing AIET main.py.""" +from typing import Any +from unittest.mock import MagicMock + +from aiet import main + + +def test_main(monkeypatch: Any) -> None: + """Test main entry point function.""" + with monkeypatch.context() as mock_context: + mock = MagicMock() + mock_context.setattr(main, "cli", mock) + main.main() + mock.assert_called_once() diff --git a/tests/aiet/test_resources/application_config.json b/tests/aiet/test_resources/application_config.json new file mode 100644 index 0000000..2dfcfec --- /dev/null +++ b/tests/aiet/test_resources/application_config.json @@ -0,0 +1,96 @@ +[ + { + "name": "application_1", + "description": "application number one", + "supported_systems": [ + "system_1", + "system_2" + ], + "build_dir": "build_dir_11", + "commands": { + "clean": [ + "clean_cmd_11" + ], + "build": [ + "build_cmd_11" + ], + "run": [ + "run_cmd_11" + ], + "post_run": [ + "post_run_cmd_11" + ] + }, + "user_params": { + "run": [ + { + "name": "run_param_11", + "values": [], + "description": "run param number one" + } + ], + "build": [ + { + "name": "build_param_11", + "values": [], + "description": "build param number one" + }, + { + "name": "build_param_12", + "values": [], + "description": "build param number two" + }, + { + "name": "build_param_13", + "values": [ + "value_1" + ], + "description": "build param number three with some value" + } + ] + } + }, + { + "name": "application_2", + "description": "application number two", + "supported_systems": [ + "system_2" + ], + "build_dir": "build_dir_21", + "commands": { + "clean": [ + "clean_cmd_21" + ], + "build": [ + "build_cmd_21", + "build_cmd_22" + ], + "run": [ + "run_cmd_21" + ], + "post_run": [ + "post_run_cmd_21" + ] + }, + "user_params": { + "build": [ + { + "name": "build_param_21", + "values": [], + "description": "build param number one" + }, + { + "name": "build_param_22", + "values": [], + "description": "build param number two" + }, + { + "name": "build_param_23", + "values": [], + "description": "build param number three" + } + ], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/application_config.json.license b/tests/aiet/test_resources/application_config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/application_config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application1/aiet-config.json b/tests/aiet/test_resources/applications/application1/aiet-config.json new file mode 100644 index 0000000..97f0401 --- /dev/null +++ b/tests/aiet/test_resources/applications/application1/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "application_1", + "description": "This is application 1", + "supported_systems": [ + { + "name": "System 1" + } + ], + "build_dir": "build", + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/applications/application1/aiet-config.json.license b/tests/aiet/test_resources/applications/application1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/applications/application1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application2/aiet-config.json b/tests/aiet/test_resources/applications/application2/aiet-config.json new file mode 100644 index 0000000..e9122d3 --- /dev/null +++ b/tests/aiet/test_resources/applications/application2/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "application_2", + "description": "This is application 2", + "supported_systems": [ + { + "name": "System 2" + } + ], + "build_dir": "build", + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/applications/application2/aiet-config.json.license b/tests/aiet/test_resources/applications/application2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/applications/application2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application3/readme.txt b/tests/aiet/test_resources/applications/application3/readme.txt new file mode 100644 index 0000000..8c72c05 --- /dev/null +++ b/tests/aiet/test_resources/applications/application3/readme.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +This application does not have json configuration file diff --git a/tests/aiet/test_resources/applications/application4/aiet-config.json b/tests/aiet/test_resources/applications/application4/aiet-config.json new file mode 100644 index 0000000..34dc780 --- /dev/null +++ b/tests/aiet/test_resources/applications/application4/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "application_4", + "description": "This is application 4", + "build_dir": "build", + "supported_systems": [ + { + "name": "System 4" + } + ], + "commands": { + "build": [ + "cp ../hello_app.txt . # {user_params:0}" + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/applications/application4/aiet-config.json.license b/tests/aiet/test_resources/applications/application4/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/applications/application4/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application4/hello_app.txt b/tests/aiet/test_resources/applications/application4/hello_app.txt new file mode 100644 index 0000000..2ec0d1d --- /dev/null +++ b/tests/aiet/test_resources/applications/application4/hello_app.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +Hello from APP! diff --git a/tests/aiet/test_resources/applications/application5/aiet-config.json b/tests/aiet/test_resources/applications/application5/aiet-config.json new file mode 100644 index 0000000..5269409 --- /dev/null +++ b/tests/aiet/test_resources/applications/application5/aiet-config.json @@ -0,0 +1,160 @@ +[ + { + "name": "application_5", + "description": "This is application 5", + "build_dir": "default_build_dir", + "supported_systems": [ + { + "name": "System 1", + "lock": false + }, + { + "name": "System 2" + } + ], + "variables": { + "var1": "value1", + "var2": "value2" + }, + "lock": true, + "commands": { + "build": [ + "default build command" + ], + "run": [ + "default run command" + ] + }, + "user_params": { + "build": [], + "run": [] + } + }, + { + "name": "application_5A", + "description": "This is application 5A", + "supported_systems": [ + { + "name": "System 1", + "build_dir": "build_5A", + "variables": { + "var1": "new value1" + } + }, + { + "name": "System 2", + "variables": { + "var2": "new value2" + }, + "lock": true, + "commands": { + "run": [ + "run command on system 2" + ] + } + } + ], + "variables": { + "var1": "value1", + "var2": "value2" + }, + "build_dir": "build", + "commands": { + "build": [ + "default build command" + ], + "run": [ + "default run command" + ] + }, + "user_params": { + "build": [], + "run": [] + } + }, + { + "name": "application_5B", + "description": "This is application 5B", + "supported_systems": [ + { + "name": "System 1", + "build_dir": "build_5B", + "variables": { + "var1": "value for var1 System1", + "var2": "value for var2 System1" + }, + "user_params": { + "build": [ + { + "name": "--param_5B", + "description": "Sample command param", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value1", + "alias": "param1" + } + ] + } + }, + { + "name": "System 2", + "variables": { + "var1": "value for var1 System2", + "var2": "value for var2 System2" + }, + "commands": { + "build": [ + "build command on system 2 with {variables:var1} {user_params:param1}" + ], + "run": [ + "run command on system 2" + ] + }, + "user_params": { + "run": [] + } + } + ], + "build_dir": "build", + "commands": { + "build": [ + "default build command with {variables:var1}" + ], + "run": [ + "default run command with {variables:var2}" + ] + }, + "user_params": { + "build": [ + { + "name": "--param", + "description": "Sample command param", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value1", + "alias": "param1" + } + ], + "run": [], + "non_used_command": [ + { + "name": "--not-used", + "description": "Not used param anywhere", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value1", + "alias": "param1" + } + ] + } + } +] diff --git a/tests/aiet/test_resources/applications/application5/aiet-config.json.license b/tests/aiet/test_resources/applications/application5/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/applications/application5/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/readme.txt b/tests/aiet/test_resources/applications/readme.txt new file mode 100644 index 0000000..a1f8209 --- /dev/null +++ b/tests/aiet/test_resources/applications/readme.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +Dummy file for test purposes diff --git a/tests/aiet/test_resources/hello_world.json b/tests/aiet/test_resources/hello_world.json new file mode 100644 index 0000000..8a9a448 --- /dev/null +++ b/tests/aiet/test_resources/hello_world.json @@ -0,0 +1,54 @@ +[ + { + "name": "Hello world", + "description": "Dummy application that displays 'Hello world!'", + "supported_systems": [ + "Dummy System" + ], + "build_dir": "build", + "deploy_data": [ + [ + "src", + "/tmp/" + ], + [ + "README", + "/tmp/README.md" + ] + ], + "commands": { + "clean": [], + "build": [], + "run": [ + "echo 'Hello world!'", + "ls -l /tmp" + ], + "post_run": [] + }, + "user_params": { + "run": [ + { + "name": "--choice-param", + "values": [ + "dummy_value_1", + "dummy_value_2" + ], + "default_value": "dummy_value_1", + "description": "Choice param" + }, + { + "name": "--open-param", + "values": [], + "default_value": "dummy_value_4", + "description": "Open param" + }, + { + "name": "--enable-flag", + "default_value": "dummy_value_4", + "description": "Flag param" + } + ], + "build": [] + } + } +] diff --git a/tests/aiet/test_resources/hello_world.json.license b/tests/aiet/test_resources/hello_world.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/hello_world.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/scripts/test_backend_run b/tests/aiet/test_resources/scripts/test_backend_run new file mode 100755 index 0000000..548f577 --- /dev/null +++ b/tests/aiet/test_resources/scripts/test_backend_run @@ -0,0 +1,8 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +echo "Hello from script" +>&2 echo "Oops!" +sleep 100 diff --git a/tests/aiet/test_resources/scripts/test_backend_run_script.sh b/tests/aiet/test_resources/scripts/test_backend_run_script.sh new file mode 100644 index 0000000..548f577 --- /dev/null +++ b/tests/aiet/test_resources/scripts/test_backend_run_script.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +echo "Hello from script" +>&2 echo "Oops!" +sleep 100 diff --git a/tests/aiet/test_resources/systems/system1/aiet-config.json b/tests/aiet/test_resources/systems/system1/aiet-config.json new file mode 100644 index 0000000..4b5dd19 --- /dev/null +++ b/tests/aiet/test_resources/systems/system1/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "System 1", + "description": "This is system 1", + "build_dir": "build", + "data_transfer": { + "protocol": "ssh", + "username": "root", + "password": "root", + "hostname": "localhost", + "port": "8021" + }, + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ], + "deploy": [ + "echo 'deploy'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/systems/system1/aiet-config.json.license b/tests/aiet/test_resources/systems/system1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/systems/system1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt b/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt new file mode 100644 index 0000000..487e9d8 --- /dev/null +++ b/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt @@ -0,0 +1,2 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/systems/system2/aiet-config.json b/tests/aiet/test_resources/systems/system2/aiet-config.json new file mode 100644 index 0000000..a9e0eb3 --- /dev/null +++ b/tests/aiet/test_resources/systems/system2/aiet-config.json @@ -0,0 +1,32 @@ +[ + { + "name": "System 2", + "description": "This is system 2", + "build_dir": "build", + "data_transfer": { + "protocol": "ssh", + "username": "root", + "password": "root", + "hostname": "localhost", + "port": "8021" + }, + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/systems/system2/aiet-config.json.license b/tests/aiet/test_resources/systems/system2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/systems/system2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/systems/system3/readme.txt b/tests/aiet/test_resources/systems/system3/readme.txt new file mode 100644 index 0000000..aba5a9c --- /dev/null +++ b/tests/aiet/test_resources/systems/system3/readme.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +This system does not have the json configuration file diff --git a/tests/aiet/test_resources/systems/system4/aiet-config.json b/tests/aiet/test_resources/systems/system4/aiet-config.json new file mode 100644 index 0000000..295e00f --- /dev/null +++ b/tests/aiet/test_resources/systems/system4/aiet-config.json @@ -0,0 +1,19 @@ +[ + { + "name": "System 4", + "description": "This is system 4", + "build_dir": "build", + "data_transfer": { + "protocol": "local" + }, + "commands": { + "run": [ + "echo {application.name}", + "cat {application.commands.run:0}" + ] + }, + "user_params": { + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/systems/system4/aiet-config.json.license b/tests/aiet/test_resources/systems/system4/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/systems/system4/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/tools/tool1/aiet-config.json b/tests/aiet/test_resources/tools/tool1/aiet-config.json new file mode 100644 index 0000000..067ef7e --- /dev/null +++ b/tests/aiet/test_resources/tools/tool1/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "tool_1", + "description": "This is tool 1", + "build_dir": "build", + "supported_systems": [ + { + "name": "System 1" + } + ], + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/tools/tool1/aiet-config.json.license b/tests/aiet/test_resources/tools/tool1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/tools/tool1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/tools/tool2/aiet-config.json b/tests/aiet/test_resources/tools/tool2/aiet-config.json new file mode 100644 index 0000000..6eee9a6 --- /dev/null +++ b/tests/aiet/test_resources/tools/tool2/aiet-config.json @@ -0,0 +1,26 @@ +[ + { + "name": "tool_2", + "description": "This is tool 2 with no supported systems", + "build_dir": "build", + "supported_systems": [], + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/tools/tool2/aiet-config.json.license b/tests/aiet/test_resources/tools/tool2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/tools/tool2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json @@ -0,0 +1 @@ +[] diff --git a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json new file mode 100644 index 0000000..ff1cf1a --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "test_application", + "description": "This is test_application", + "build_dir": "build", + "supported_systems": [ + { + "name": "System 4" + } + ], + "commands": { + "build": [ + "cp ../hello_app.txt ." + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json new file mode 100644 index 0000000..724b31b --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json @@ -0,0 +1,2 @@ +This is not valid json file +{ diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json new file mode 100644 index 0000000..1ebb29c --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "test_application", + "description": "This is test_application", + "build_dir": "build", + "commands": { + "build": [ + "cp ../hello_app.txt ." + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json new file mode 100644 index 0000000..410d12d --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "test_application", + "description": "This is test_application", + "build_dir": "build", + "supported_systems": [ + { + "anme": "System 4" + } + ], + "commands": { + "build": [ + "cp ../hello_app.txt ." + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json @@ -0,0 +1 @@ +[] diff --git a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json new file mode 100644 index 0000000..20142e9 --- /dev/null +++ b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json @@ -0,0 +1,16 @@ +[ + { + "name": "Test system", + "description": "This is a test system", + "build_dir": "build", + "data_transfer": { + "protocol": "local" + }, + "commands": { + "run": [] + }, + "user_params": { + "run": [] + } + } +] diff --git a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_run_vela_script.py b/tests/aiet/test_run_vela_script.py new file mode 100644 index 0000000..971856e --- /dev/null +++ b/tests/aiet/test_run_vela_script.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=redefined-outer-name,no-self-use +"""Module for testing run_vela.py script.""" +from pathlib import Path +from typing import Any +from typing import List + +import pytest +from click.testing import CliRunner + +from aiet.cli.common import MiddlewareExitCode +from aiet.resources.tools.vela.check_model import get_model_from_file +from aiet.resources.tools.vela.check_model import is_vela_optimised +from aiet.resources.tools.vela.run_vela import run_vela + + +@pytest.fixture(scope="session") +def vela_config_path(test_tools_path: Path) -> Path: + """Return test systems path in a pytest fixture.""" + return test_tools_path / "vela" / "vela.ini" + + +@pytest.fixture( + params=[ + ["ethos-u65-256", "Ethos_U65_High_End", "U65_Shared_Sram"], + ["ethos-u55-32", "Ethos_U55_High_End_Embedded", "U55_Shared_Sram"], + ] +) +def ethos_config(request: Any) -> Any: + """Fixture to provide different configuration for Ethos-U optimization with Vela.""" + return request.param + + +# pylint: disable=too-many-arguments +def generate_args( + input_: Path, + output: Path, + cfg: Path, + acc_config: str, + system_config: str, + memory_mode: str, +) -> List[str]: + """Generate arguments that can be passed to script 'run_vela'.""" + return [ + "-i", + str(input_), + "-o", + str(output), + "--config", + str(cfg), + "--accelerator-config", + acc_config, + "--system-config", + system_config, + "--memory-mode", + memory_mode, + "--optimise", + "Performance", + ] + + +def check_run_vela( + cli_runner: CliRunner, args: List, expected_success: bool, output_file: Path +) -> None: + """Run Vela with the given arguments and check the result.""" + result = cli_runner.invoke(run_vela, args) + success = result.exit_code == MiddlewareExitCode.SUCCESS + assert success == expected_success + if success: + model = get_model_from_file(output_file) + assert is_vela_optimised(model) + + +def run_vela_script( + cli_runner: CliRunner, + input_model_file: Path, + output_model_file: Path, + vela_config: Path, + expected_success: bool, + acc_config: str, + system_config: str, + memory_mode: str, +) -> None: + """Run the command 'run_vela' on the command line.""" + args = generate_args( + input_model_file, + output_model_file, + vela_config, + acc_config, + system_config, + memory_mode, + ) + check_run_vela(cli_runner, args, expected_success, output_model_file) + + +class TestRunVelaCli: + """Test the command-line execution of the run_vela command.""" + + def test_non_optimised_model( + self, + cli_runner: CliRunner, + non_optimised_input_model_file: Path, + tmp_path: Path, + vela_config_path: Path, + ethos_config: List, + ) -> None: + """Verify Vela is run correctly on an unoptimised model.""" + run_vela_script( + cli_runner, + non_optimised_input_model_file, + tmp_path / "test.tflite", + vela_config_path, + True, + *ethos_config, + ) + + def test_optimised_model( + self, + cli_runner: CliRunner, + optimised_input_model_file: Path, + tmp_path: Path, + vela_config_path: Path, + ethos_config: List, + ) -> None: + """Verify Vela is run correctly on an already optimised model.""" + run_vela_script( + cli_runner, + optimised_input_model_file, + tmp_path / "test.tflite", + vela_config_path, + True, + *ethos_config, + ) + + def test_invalid_model( + self, + cli_runner: CliRunner, + invalid_input_model_file: Path, + tmp_path: Path, + vela_config_path: Path, + ethos_config: List, + ) -> None: + """Verify an error is raised when the input model is not valid.""" + run_vela_script( + cli_runner, + invalid_input_model_file, + tmp_path / "test.tflite", + vela_config_path, + False, + *ethos_config, + ) diff --git a/tests/aiet/test_utils_fs.py b/tests/aiet/test_utils_fs.py new file mode 100644 index 0000000..46d276e --- /dev/null +++ b/tests/aiet/test_utils_fs.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Module for testing fs.py.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Union +from unittest.mock import MagicMock + +import pytest + +from aiet.utils.fs import get_resources +from aiet.utils.fs import read_file_as_bytearray +from aiet.utils.fs import read_file_as_string +from aiet.utils.fs import recreate_directory +from aiet.utils.fs import remove_directory +from aiet.utils.fs import remove_resource +from aiet.utils.fs import ResourceType +from aiet.utils.fs import valid_for_filename + + +@pytest.mark.parametrize( + "resource_name,expected_path", + [ + ("systems", does_not_raise()), + ("applications", does_not_raise()), + ("whaaat", pytest.raises(ResourceWarning)), + (None, pytest.raises(ResourceWarning)), + ], +) +def test_get_resources(resource_name: ResourceType, expected_path: Any) -> None: + """Test get_resources() with multiple parameters.""" + with expected_path: + resource_path = get_resources(resource_name) + assert resource_path.exists() + + +def test_remove_resource_wrong_directory( + monkeypatch: Any, test_applications_path: Path +) -> None: + """Test removing resource with wrong directory.""" + mock_get_resources = MagicMock(return_value=test_applications_path) + monkeypatch.setattr("aiet.utils.fs.get_resources", mock_get_resources) + + mock_shutil_rmtree = MagicMock() + monkeypatch.setattr("aiet.utils.fs.shutil.rmtree", mock_shutil_rmtree) + + with pytest.raises(Exception, match="Resource .* does not exist"): + remove_resource("unknown", "applications") + mock_shutil_rmtree.assert_not_called() + + with pytest.raises(Exception, match="Wrong resource .*"): + remove_resource("readme.txt", "applications") + mock_shutil_rmtree.assert_not_called() + + +def test_remove_resource(monkeypatch: Any, test_applications_path: Path) -> None: + """Test removing resource data.""" + mock_get_resources = MagicMock(return_value=test_applications_path) + monkeypatch.setattr("aiet.utils.fs.get_resources", mock_get_resources) + + mock_shutil_rmtree = MagicMock() + monkeypatch.setattr("aiet.utils.fs.shutil.rmtree", mock_shutil_rmtree) + + remove_resource("application1", "applications") + mock_shutil_rmtree.assert_called_once() + + +def test_remove_directory(tmpdir: Any) -> None: + """Test directory removal.""" + tmpdir_path = Path(tmpdir) + tmpfile = tmpdir_path / "temp.txt" + + for item in [None, tmpfile]: + with pytest.raises(Exception, match="No directory path provided"): + remove_directory(item) + + newdir = tmpdir_path / "newdir" + newdir.mkdir() + + assert newdir.is_dir() + remove_directory(newdir) + assert not newdir.exists() + + +def test_recreate_directory(tmpdir: Any) -> None: + """Test directory recreation.""" + with pytest.raises(Exception, match="No directory path provided"): + recreate_directory(None) + + tmpdir_path = Path(tmpdir) + tmpfile = tmpdir_path / "temp.txt" + tmpfile.touch() + with pytest.raises(Exception, match="Path .* does exist and it is not a directory"): + recreate_directory(tmpfile) + + newdir = tmpdir_path / "newdir" + newdir.mkdir() + newfile = newdir / "newfile" + newfile.touch() + assert list(newdir.iterdir()) == [newfile] + recreate_directory(newdir) + assert not list(newdir.iterdir()) + + newdir2 = tmpdir_path / "newdir2" + assert not newdir2.exists() + recreate_directory(newdir2) + assert newdir2.is_dir() + + +def write_to_file( + write_directory: Any, write_mode: str, write_text: Union[str, bytes] +) -> Path: + """Write some text to a temporary test file.""" + tmpdir_path = Path(write_directory) + tmpfile = tmpdir_path / "file_name.txt" + with open(tmpfile, write_mode) as file: # pylint: disable=unspecified-encoding + file.write(write_text) + return tmpfile + + +class TestReadFileAsString: + """Test read_file_as_string() function.""" + + def test_returns_text_from_valid_file(self, tmpdir: Any) -> None: + """Ensure the string written to a file read correctly.""" + file_path = write_to_file(tmpdir, "w", "hello") + assert read_file_as_string(file_path) == "hello" + + def test_output_is_empty_string_when_input_file_non_existent( + self, tmpdir: Any + ) -> None: + """Ensure empty string returned when reading from non-existent file.""" + file_path = Path(tmpdir / "non-existent.txt") + assert read_file_as_string(file_path) == "" + + +class TestReadFileAsByteArray: + """Test read_file_as_bytearray() function.""" + + def test_returns_bytes_from_valid_file(self, tmpdir: Any) -> None: + """Ensure the bytes written to a file read correctly.""" + file_path = write_to_file(tmpdir, "wb", b"hello bytes") + assert read_file_as_bytearray(file_path) == b"hello bytes" + + def test_output_is_empty_bytearray_when_input_file_non_existent( + self, tmpdir: Any + ) -> None: + """Ensure empty bytearray returned when reading from non-existent file.""" + file_path = Path(tmpdir / "non-existent.txt") + assert read_file_as_bytearray(file_path) == bytearray() + + +@pytest.mark.parametrize( + "value, replacement, expected_result", + [ + ["", "", ""], + ["123", "", "123"], + ["123", "_", "123"], + ["/some_folder/some_script.sh", "", "some_foldersome_script.sh"], + ["/some_folder/some_script.sh", "_", "_some_folder_some_script.sh"], + ["!;'some_name$%^!", "_", "___some_name____"], + ], +) +def test_valid_for_filename(value: str, replacement: str, expected_result: str) -> None: + """Test function valid_for_filename.""" + assert valid_for_filename(value, replacement) == expected_result diff --git a/tests/aiet/test_utils_helpers.py b/tests/aiet/test_utils_helpers.py new file mode 100644 index 0000000..bbe03fc --- /dev/null +++ b/tests/aiet/test_utils_helpers.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for testing helpers.py.""" +import logging +from typing import Any +from typing import List +from unittest.mock import call +from unittest.mock import MagicMock + +import pytest + +from aiet.utils.helpers import set_verbosity + + +@pytest.mark.parametrize( + "verbosity,expected_calls", + [(0, []), (1, [call(logging.INFO)]), (2, [call(logging.DEBUG)])], +) +def test_set_verbosity( + verbosity: int, expected_calls: List[Any], monkeypatch: Any +) -> None: + """Test set_verbosity() with different verbsosity levels.""" + with monkeypatch.context() as mock_context: + logging_mock = MagicMock() + mock_context.setattr(logging.getLogger(), "setLevel", logging_mock) + set_verbosity(None, None, verbosity) + logging_mock.assert_has_calls(expected_calls) diff --git a/tests/aiet/test_utils_proc.py b/tests/aiet/test_utils_proc.py new file mode 100644 index 0000000..9fb48dd --- /dev/null +++ b/tests/aiet/test_utils_proc.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=attribute-defined-outside-init,no-self-use,not-callable +"""Pytests for testing aiet/utils/proc.py.""" +from pathlib import Path +from typing import Any +from unittest import mock + +import psutil +import pytest +from sh import ErrorReturnCode + +from aiet.utils.proc import Command +from aiet.utils.proc import CommandFailedException +from aiet.utils.proc import CommandNotFound +from aiet.utils.proc import parse_command +from aiet.utils.proc import print_command_stdout +from aiet.utils.proc import run_and_wait +from aiet.utils.proc import save_process_info +from aiet.utils.proc import ShellCommand +from aiet.utils.proc import terminate_command +from aiet.utils.proc import terminate_external_process + + +class TestShellCommand: + """Sample class for collecting tests.""" + + def test_shellcommand_default_value(self) -> None: + """Test the instantiation of the class ShellCommand with no parameter.""" + shell_command = ShellCommand() + assert shell_command.base_log_path == "/tmp" + + @pytest.mark.parametrize( + "base_log_path,expected", [("/test", "/test"), ("/asd", "/asd")] + ) + def test_shellcommand_with_param(self, base_log_path: str, expected: str) -> None: + """Test init ShellCommand with different parameters.""" + shell_command = ShellCommand(base_log_path) + assert shell_command.base_log_path == expected + + def test_run_ls(self, monkeypatch: Any) -> None: + """Test a simple ls command.""" + mock_command = mock.MagicMock() + monkeypatch.setattr(Command, "bake", mock_command) + + mock_get_stdout_stderr_paths = mock.MagicMock() + mock_get_stdout_stderr_paths.return_value = ("/tmp/std.out", "/tmp/std.err") + monkeypatch.setattr( + ShellCommand, "get_stdout_stderr_paths", mock_get_stdout_stderr_paths + ) + + shell_command = ShellCommand() + shell_command.run("ls", "-l") + assert mock_command.mock_calls[0] == mock.call(("-l",)) + assert mock_command.mock_calls[1] == mock.call()( + _bg=True, _err="/tmp/std.err", _out="/tmp/std.out", _tee=True, _bg_exc=False + ) + + def test_run_command_not_found(self) -> None: + """Test whe the command doesn't exist.""" + shell_command = ShellCommand() + with pytest.raises(CommandNotFound): + shell_command.run("lsl", "-l") + + def test_get_stdout_stderr_paths_valid_path(self) -> None: + """Test the method to get files to store stdout and stderr.""" + valid_path = "/tmp" + shell_command = ShellCommand(valid_path) + out, err = shell_command.get_stdout_stderr_paths(valid_path, "cmd") + assert out.exists() and out.is_file() + assert err.exists() and err.is_file() + assert "cmd" in out.name + assert "cmd" in err.name + + def test_get_stdout_stderr_paths_not_invalid_path(self) -> None: + """Test the method to get output files with an invalid path.""" + invalid_path = "/invalid/foo/bar" + shell_command = ShellCommand(invalid_path) + with pytest.raises(FileNotFoundError): + shell_command.get_stdout_stderr_paths(invalid_path, "cmd") + + +@mock.patch("builtins.print") +def test_print_command_stdout_alive(mock_print: Any) -> None: + """Test the print command stdout with an alive (running) process.""" + mock_command = mock.MagicMock() + mock_command.is_alive.return_value = True + mock_command.next.side_effect = ["test1", "test2", StopIteration] + + print_command_stdout(mock_command) + + mock_command.assert_has_calls( + [mock.call.is_alive(), mock.call.next(), mock.call.next()] + ) + mock_print.assert_has_calls( + [mock.call("test1", end=""), mock.call("test2", end="")] + ) + + +@mock.patch("builtins.print") +def test_print_command_stdout_not_alive(mock_print: Any) -> None: + """Test the print command stdout with a not alive (exited) process.""" + mock_command = mock.MagicMock() + mock_command.is_alive.return_value = False + mock_command.stdout = "test" + + print_command_stdout(mock_command) + mock_command.assert_has_calls([mock.call.is_alive()]) + mock_print.assert_called_once_with("test") + + +def test_terminate_external_process_no_process(capsys: Any) -> None: + """Test that non existed process could be terminated.""" + mock_command = mock.MagicMock() + mock_command.terminate.side_effect = psutil.Error("Error!") + + terminate_external_process(mock_command) + captured = capsys.readouterr() + assert captured.out == "Unable to terminate process\n" + + +def test_terminate_external_process_case1() -> None: + """Test when process terminated immediately.""" + mock_command = mock.MagicMock() + mock_command.is_running.return_value = False + + terminate_external_process(mock_command) + mock_command.terminate.assert_called_once() + mock_command.is_running.assert_called_once() + + +def test_terminate_external_process_case2() -> None: + """Test when process termination takes time.""" + mock_command = mock.MagicMock() + mock_command.is_running.side_effect = [True, True, False] + + terminate_external_process(mock_command) + mock_command.terminate.assert_called_once() + assert mock_command.is_running.call_count == 3 + + +def test_terminate_external_process_case3() -> None: + """Test when process termination takes more time.""" + mock_command = mock.MagicMock() + mock_command.is_running.side_effect = [True, True, True] + + terminate_external_process( + mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1 + ) + assert mock_command.is_running.call_count == 3 + assert mock_command.terminate.call_count == 2 + + +def test_terminate_external_process_case4() -> None: + """Test when process termination takes more time.""" + mock_command = mock.MagicMock() + mock_command.is_running.side_effect = [True, True, False] + + terminate_external_process( + mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1 + ) + mock_command.terminate.assert_called_once() + assert mock_command.is_running.call_count == 3 + assert mock_command.terminate.call_count == 1 + + +def test_terminate_command_no_process() -> None: + """Test command termination when process does not exist.""" + mock_command = mock.MagicMock() + mock_command.process.signal_group.side_effect = ProcessLookupError() + + terminate_command(mock_command) + mock_command.process.signal_group.assert_called_once() + mock_command.is_alive.assert_not_called() + + +def test_terminate_command() -> None: + """Test command termination.""" + mock_command = mock.MagicMock() + mock_command.is_alive.return_value = False + + terminate_command(mock_command) + mock_command.process.signal_group.assert_called_once() + + +def test_terminate_command_case1() -> None: + """Test command termination when it takes time..""" + mock_command = mock.MagicMock() + mock_command.is_alive.side_effect = [True, True, False] + + terminate_command(mock_command, wait_period=0.1) + mock_command.process.signal_group.assert_called_once() + assert mock_command.is_alive.call_count == 3 + + +def test_terminate_command_case2() -> None: + """Test command termination when it takes much time..""" + mock_command = mock.MagicMock() + mock_command.is_alive.side_effect = [True, True, True] + + terminate_command(mock_command, number_of_attempts=3, wait_period=0.1) + assert mock_command.is_alive.call_count == 3 + assert mock_command.process.signal_group.call_count == 2 + + +class TestRunAndWait: + """Test run_and_wait function.""" + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch: Any) -> None: + """Init test method.""" + self.execute_command_mock = mock.MagicMock() + monkeypatch.setattr( + "aiet.utils.proc.execute_command", self.execute_command_mock + ) + + self.terminate_command_mock = mock.MagicMock() + monkeypatch.setattr( + "aiet.utils.proc.terminate_command", self.terminate_command_mock + ) + + def test_if_execute_command_raises_exception(self) -> None: + """Test if execute_command fails.""" + self.execute_command_mock.side_effect = Exception("Error!") + with pytest.raises(Exception, match="Error!"): + run_and_wait("command", Path.cwd()) + + def test_if_command_finishes_with_error(self) -> None: + """Test if command finishes with error.""" + cmd_mock = mock.MagicMock() + self.execute_command_mock.return_value = cmd_mock + exit_code_mock = mock.PropertyMock( + side_effect=ErrorReturnCode("cmd", bytearray(), bytearray()) + ) + type(cmd_mock).exit_code = exit_code_mock + + with pytest.raises(CommandFailedException): + run_and_wait("command", Path.cwd()) + + @pytest.mark.parametrize("terminate_on_error, call_count", ((False, 0), (True, 1))) + def test_if_command_finishes_with_exception( + self, terminate_on_error: bool, call_count: int + ) -> None: + """Test if command finishes with error.""" + cmd_mock = mock.MagicMock() + self.execute_command_mock.return_value = cmd_mock + exit_code_mock = mock.PropertyMock(side_effect=Exception("Error!")) + type(cmd_mock).exit_code = exit_code_mock + + with pytest.raises(Exception, match="Error!"): + run_and_wait("command", Path.cwd(), terminate_on_error=terminate_on_error) + + assert self.terminate_command_mock.call_count == call_count + + +def test_save_process_info_no_process(monkeypatch: Any, tmpdir: Any) -> None: + """Test save_process_info function.""" + mock_process = mock.MagicMock() + monkeypatch.setattr("psutil.Process", mock.MagicMock(return_value=mock_process)) + mock_process.children.side_effect = psutil.NoSuchProcess(555) + + pid_file_path = Path(tmpdir) / "test.pid" + save_process_info(555, pid_file_path) + assert not pid_file_path.exists() + + +def test_parse_command() -> None: + """Test parse_command function.""" + assert parse_command("1.sh") == ["bash", "1.sh"] + assert parse_command("1.sh", shell="sh") == ["sh", "1.sh"] + assert parse_command("command") == ["command"] + assert parse_command("command 123 --param=1") == ["command", "123", "--param=1"] |