aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py3
-rw-r--r--tests/aiet/__init__.py3
-rw-r--r--tests/aiet/conftest.py139
-rw-r--r--tests/aiet/test_backend_application.py452
-rw-r--r--tests/aiet/test_backend_common.py486
-rw-r--r--tests/aiet/test_backend_controller.py160
-rw-r--r--tests/aiet/test_backend_execution.py526
-rw-r--r--tests/aiet/test_backend_output_parser.py152
-rw-r--r--tests/aiet/test_backend_protocol.py231
-rw-r--r--tests/aiet/test_backend_source.py199
-rw-r--r--tests/aiet/test_backend_system.py536
-rw-r--r--tests/aiet/test_backend_tool.py60
-rw-r--r--tests/aiet/test_check_model.py162
-rw-r--r--tests/aiet/test_cli.py37
-rw-r--r--tests/aiet/test_cli_application.py1153
-rw-r--r--tests/aiet/test_cli_common.py37
-rw-r--r--tests/aiet/test_cli_system.py240
-rw-r--r--tests/aiet/test_cli_tool.py333
-rw-r--r--tests/aiet/test_main.py16
-rw-r--r--tests/aiet/test_resources/application_config.json96
-rw-r--r--tests/aiet/test_resources/application_config.json.license3
-rw-r--r--tests/aiet/test_resources/applications/application1/aiet-config.json30
-rw-r--r--tests/aiet/test_resources/applications/application1/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/applications/application2/aiet-config.json30
-rw-r--r--tests/aiet/test_resources/applications/application2/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/applications/application3/readme.txt4
-rw-r--r--tests/aiet/test_resources/applications/application4/aiet-config.json35
-rw-r--r--tests/aiet/test_resources/applications/application4/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/applications/application4/hello_app.txt4
-rw-r--r--tests/aiet/test_resources/applications/application5/aiet-config.json160
-rw-r--r--tests/aiet/test_resources/applications/application5/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/applications/readme.txt4
-rw-r--r--tests/aiet/test_resources/hello_world.json54
-rw-r--r--tests/aiet/test_resources/hello_world.json.license3
-rwxr-xr-xtests/aiet/test_resources/scripts/test_backend_run8
-rw-r--r--tests/aiet/test_resources/scripts/test_backend_run_script.sh8
-rw-r--r--tests/aiet/test_resources/systems/system1/aiet-config.json35
-rw-r--r--tests/aiet/test_resources/systems/system1/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt2
-rw-r--r--tests/aiet/test_resources/systems/system2/aiet-config.json32
-rw-r--r--tests/aiet/test_resources/systems/system2/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/systems/system3/readme.txt4
-rw-r--r--tests/aiet/test_resources/systems/system4/aiet-config.json19
-rw-r--r--tests/aiet/test_resources/systems/system4/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/tools/tool1/aiet-config.json30
-rw-r--r--tests/aiet/test_resources/tools/tool1/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/tools/tool2/aiet-config.json26
-rw-r--r--tests/aiet/test_resources/tools/tool2/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json1
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json35
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json2
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json30
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json35
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json1
-rw-r--r--tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json16
-rw-r--r--tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license3
-rw-r--r--tests/aiet/test_run_vela_script.py152
-rw-r--r--tests/aiet/test_utils_fs.py168
-rw-r--r--tests/aiet/test_utils_helpers.py27
-rw-r--r--tests/aiet/test_utils_proc.py272
-rw-r--r--tests/conftest.py95
-rw-r--r--tests/mlia/__init__.py3
-rw-r--r--tests/mlia/conftest.py20
-rw-r--r--tests/mlia/test_api.py96
-rw-r--r--tests/mlia/test_cli_commands.py204
-rw-r--r--tests/mlia/test_cli_config.py49
-rw-r--r--tests/mlia/test_cli_helpers.py165
-rw-r--r--tests/mlia/test_cli_logging.py104
-rw-r--r--tests/mlia/test_cli_main.py357
-rw-r--r--tests/mlia/test_cli_options.py186
-rw-r--r--tests/mlia/test_core_advice_generation.py71
-rw-r--r--tests/mlia/test_core_advisor.py40
-rw-r--r--tests/mlia/test_core_context.py62
-rw-r--r--tests/mlia/test_core_data_analysis.py31
-rw-r--r--tests/mlia/test_core_events.py155
-rw-r--r--tests/mlia/test_core_helpers.py17
-rw-r--r--tests/mlia/test_core_mixins.py99
-rw-r--r--tests/mlia/test_core_performance.py29
-rw-r--r--tests/mlia/test_core_reporting.py413
-rw-r--r--tests/mlia/test_core_workflow.py164
-rw-r--r--tests/mlia/test_devices_ethosu_advice_generation.py483
-rw-r--r--tests/mlia/test_devices_ethosu_advisor.py9
-rw-r--r--tests/mlia/test_devices_ethosu_config.py124
-rw-r--r--tests/mlia/test_devices_ethosu_data_analysis.py147
-rw-r--r--tests/mlia/test_devices_ethosu_data_collection.py151
-rw-r--r--tests/mlia/test_devices_ethosu_performance.py28
-rw-r--r--tests/mlia/test_devices_ethosu_reporters.py434
-rw-r--r--tests/mlia/test_nn_tensorflow_config.py72
-rw-r--r--tests/mlia/test_nn_tensorflow_optimizations_clustering.py131
-rw-r--r--tests/mlia/test_nn_tensorflow_optimizations_pruning.py117
-rw-r--r--tests/mlia/test_nn_tensorflow_optimizations_select.py240
-rw-r--r--tests/mlia/test_nn_tensorflow_tflite_metrics.py137
-rw-r--r--tests/mlia/test_nn_tensorflow_utils.py81
-rw-r--r--tests/mlia/test_resources/vela/sample_vela.ini47
-rw-r--r--tests/mlia/test_tools_aiet_wrapper.py760
-rw-r--r--tests/mlia/test_tools_metadata_common.py196
-rw-r--r--tests/mlia/test_tools_metadata_corstone.py419
-rw-r--r--tests/mlia/test_tools_vela_wrapper.py285
-rw-r--r--tests/mlia/test_utils_console.py100
-rw-r--r--tests/mlia/test_utils_download.py147
-rw-r--r--tests/mlia/test_utils_filesystem.py166
-rw-r--r--tests/mlia/test_utils_logging.py63
-rw-r--r--tests/mlia/test_utils_misc.py25
-rw-r--r--tests/mlia/test_utils_proc.py149
-rw-r--r--tests/mlia/test_utils_types.py77
-rw-r--r--tests/mlia/utils/__init__.py3
-rw-r--r--tests/mlia/utils/common.py32
-rw-r--r--tests/mlia/utils/logging.py13
114 files changed, 13295 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..4a1e153
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests module."""
diff --git a/tests/aiet/__init__.py b/tests/aiet/__init__.py
new file mode 100644
index 0000000..873a7df
--- /dev/null
+++ b/tests/aiet/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""AIET tests module."""
diff --git a/tests/aiet/conftest.py b/tests/aiet/conftest.py
new file mode 100644
index 0000000..cab3dc2
--- /dev/null
+++ b/tests/aiet/conftest.py
@@ -0,0 +1,139 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=redefined-outer-name
+"""conftest for pytest."""
+import shutil
+import tarfile
+from pathlib import Path
+from typing import Any
+
+import pytest
+from click.testing import CliRunner
+
+from aiet.backend.common import get_backend_configs
+
+
+@pytest.fixture(scope="session")
+def test_systems_path(test_resources_path: Path) -> Path:
+ """Return test systems path in a pytest fixture."""
+ return test_resources_path / "systems"
+
+
+@pytest.fixture(scope="session")
+def test_applications_path(test_resources_path: Path) -> Path:
+ """Return test applications path in a pytest fixture."""
+ return test_resources_path / "applications"
+
+
+@pytest.fixture(scope="session")
+def test_tools_path(test_resources_path: Path) -> Path:
+ """Return test tools path in a pytest fixture."""
+ return test_resources_path / "tools"
+
+
+@pytest.fixture(scope="session")
+def test_resources_path() -> Path:
+ """Return test resources path in a pytest fixture."""
+ current_path = Path(__file__).parent.absolute()
+ return current_path / "test_resources"
+
+
+@pytest.fixture(scope="session")
+def non_optimised_input_model_file(test_tflite_model: Path) -> Path:
+ """Provide the path to a quantized dummy model file."""
+ return test_tflite_model
+
+
+@pytest.fixture(scope="session")
+def optimised_input_model_file(test_tflite_vela_model: Path) -> Path:
+ """Provide path to Vela-optimised dummy model file."""
+ return test_tflite_vela_model
+
+
+@pytest.fixture(scope="session")
+def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path:
+ """Provide the path to an invalid dummy model file."""
+ return test_tflite_invalid_model
+
+
+@pytest.fixture(autouse=True)
+def test_resources(monkeypatch: pytest.MonkeyPatch, test_resources_path: Path) -> Any:
+ """Force using test resources as middleware's repository."""
+
+ def get_test_resources() -> Path:
+ """Return path to the test resources."""
+ return test_resources_path
+
+ monkeypatch.setattr("aiet.utils.fs.get_aiet_resources", get_test_resources)
+ yield
+
+
+@pytest.fixture(scope="session", autouse=True)
+def add_tools(test_resources_path: Path) -> Any:
+ """Symlink the tools from the original resources path to the test resources path."""
+ # tool_dirs = get_available_tool_directory_names()
+ tool_dirs = [cfg.parent for cfg in get_backend_configs("tools")]
+
+ links = {
+ src_dir: (test_resources_path / "tools" / src_dir.name) for src_dir in tool_dirs
+ }
+ for src_dir, dst_dir in links.items():
+ if not dst_dir.exists():
+ dst_dir.symlink_to(src_dir, target_is_directory=True)
+ yield
+ # Remove symlinks
+ for dst_dir in links.values():
+ if dst_dir.is_symlink():
+ dst_dir.unlink()
+
+
+def create_archive(
+ archive_name: str, source: Path, destination: Path, with_root_folder: bool = False
+) -> None:
+ """Create archive from directory source."""
+ with tarfile.open(destination / archive_name, mode="w:gz") as tar:
+ for item in source.iterdir():
+ item_name = item.name
+ if with_root_folder:
+ item_name = f"{source.name}/{item_name}"
+ tar.add(item, item_name)
+
+
+def process_directory(source: Path, destination: Path) -> None:
+ """Process resource directory."""
+ destination.mkdir()
+
+ for item in source.iterdir():
+ if item.is_dir():
+ create_archive(f"{item.name}.tar.gz", item, destination)
+ create_archive(f"{item.name}_dir.tar.gz", item, destination, True)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def add_archives(
+ test_resources_path: Path, tmp_path_factory: pytest.TempPathFactory
+) -> Any:
+ """Generate archives of the test resources."""
+ tmp_path = tmp_path_factory.mktemp("archives")
+
+ archives_path = tmp_path / "archives"
+ archives_path.mkdir()
+
+ if (archives_path_link := test_resources_path / "archives").is_symlink():
+ archives_path.unlink()
+
+ archives_path_link.symlink_to(archives_path, target_is_directory=True)
+
+ for item in ["applications", "systems"]:
+ process_directory(test_resources_path / item, archives_path / item)
+
+ yield
+
+ archives_path_link.unlink()
+ shutil.rmtree(tmp_path)
+
+
+@pytest.fixture(scope="module")
+def cli_runner() -> CliRunner:
+ """Return CliRunner instance in a pytest fixture."""
+ return CliRunner()
diff --git a/tests/aiet/test_backend_application.py b/tests/aiet/test_backend_application.py
new file mode 100644
index 0000000..abfab00
--- /dev/null
+++ b/tests/aiet/test_backend_application.py
@@ -0,0 +1,452 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use
+"""Tests for the application backend."""
+from collections import Counter
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import List
+from unittest.mock import MagicMock
+
+import pytest
+
+from aiet.backend.application import Application
+from aiet.backend.application import get_application
+from aiet.backend.application import get_available_application_directory_names
+from aiet.backend.application import get_available_applications
+from aiet.backend.application import get_unique_application_names
+from aiet.backend.application import install_application
+from aiet.backend.application import load_applications
+from aiet.backend.application import remove_application
+from aiet.backend.common import Command
+from aiet.backend.common import DataPaths
+from aiet.backend.common import Param
+from aiet.backend.common import UserParamConfig
+from aiet.backend.config import ApplicationConfig
+from aiet.backend.config import ExtendedApplicationConfig
+from aiet.backend.config import NamedExecutionConfig
+
+
+def test_get_available_application_directory_names() -> None:
+ """Test get_available_applicationss mocking get_resources."""
+ directory_names = get_available_application_directory_names()
+ assert Counter(directory_names) == Counter(
+ ["application1", "application2", "application4", "application5"]
+ )
+
+
+def test_get_available_applications() -> None:
+ """Test get_available_applicationss mocking get_resources."""
+ available_applications = get_available_applications()
+
+ assert all(isinstance(s, Application) for s in available_applications)
+ assert all(s != 42 for s in available_applications)
+ assert len(available_applications) == 9
+ # application_5 has multiply items with multiply supported systems
+ assert [str(s) for s in available_applications] == [
+ "application_1",
+ "application_2",
+ "application_4",
+ "application_5",
+ "application_5",
+ "application_5A",
+ "application_5A",
+ "application_5B",
+ "application_5B",
+ ]
+
+
+def test_get_unique_application_names() -> None:
+ """Test get_unique_application_names."""
+ unique_names = get_unique_application_names()
+
+ assert all(isinstance(s, str) for s in unique_names)
+ assert all(s for s in unique_names)
+ assert sorted(unique_names) == [
+ "application_1",
+ "application_2",
+ "application_4",
+ "application_5",
+ "application_5A",
+ "application_5B",
+ ]
+
+
+def test_get_application() -> None:
+ """Test get_application mocking get_resoures."""
+ application = get_application("application_1")
+ if len(application) != 1:
+ pytest.fail("Unable to get application")
+ assert application[0].name == "application_1"
+
+ application = get_application("unknown application")
+ assert len(application) == 0
+
+
+@pytest.mark.parametrize(
+ "source, call_count, expected_exception",
+ (
+ (
+ "archives/applications/application1.tar.gz",
+ 0,
+ pytest.raises(
+ Exception, match=r"Applications \[application_1\] are already installed"
+ ),
+ ),
+ (
+ "various/applications/application_with_empty_config",
+ 0,
+ pytest.raises(Exception, match="No application definition found"),
+ ),
+ (
+ "various/applications/application_with_wrong_config1",
+ 0,
+ pytest.raises(Exception, match="Unable to read application definition"),
+ ),
+ (
+ "various/applications/application_with_wrong_config2",
+ 0,
+ pytest.raises(Exception, match="Unable to read application definition"),
+ ),
+ (
+ "various/applications/application_with_wrong_config3",
+ 0,
+ pytest.raises(Exception, match="Unable to read application definition"),
+ ),
+ ("various/applications/application_with_valid_config", 1, does_not_raise()),
+ (
+ "archives/applications/application3.tar.gz",
+ 0,
+ pytest.raises(Exception, match="Unable to read application definition"),
+ ),
+ (
+ "applications/application1",
+ 0,
+ pytest.raises(
+ Exception, match=r"Applications \[application_1\] are already installed"
+ ),
+ ),
+ (
+ "applications/application3",
+ 0,
+ pytest.raises(Exception, match="Unable to read application definition"),
+ ),
+ ),
+)
+def test_install_application(
+ monkeypatch: Any,
+ test_resources_path: Path,
+ source: str,
+ call_count: int,
+ expected_exception: Any,
+) -> None:
+ """Test application install from archive."""
+ mock_create_destination_and_install = MagicMock()
+ monkeypatch.setattr(
+ "aiet.backend.application.create_destination_and_install",
+ mock_create_destination_and_install,
+ )
+
+ with expected_exception:
+ install_application(test_resources_path / source)
+ assert mock_create_destination_and_install.call_count == call_count
+
+
+def test_remove_application(monkeypatch: Any) -> None:
+ """Test application removal."""
+ mock_remove_backend = MagicMock()
+ monkeypatch.setattr("aiet.backend.application.remove_backend", mock_remove_backend)
+
+ remove_application("some_application_directory")
+ mock_remove_backend.assert_called_once()
+
+
+def test_application_config_without_commands() -> None:
+ """Test application config without commands."""
+ config = ApplicationConfig(name="application")
+ application = Application(config)
+ # pylint: disable=use-implicit-booleaness-not-comparison
+ assert application.commands == {}
+
+
+class TestApplication:
+ """Test for application class methods."""
+
+ def test___eq__(self) -> None:
+ """Test overloaded __eq__ method."""
+ config = ApplicationConfig(
+ # Application
+ supported_systems=["system1", "system2"],
+ build_dir="build_dir",
+ # inherited from Backend
+ name="name",
+ description="description",
+ commands={},
+ )
+ application1 = Application(config)
+ application2 = Application(config) # Identical
+ assert application1 == application2
+
+ application3 = Application(config) # changed
+ # Change one single attribute so not equal, but same Type
+ setattr(application3, "supported_systems", ["somewhere/else"])
+ assert application1 != application3
+
+ # different Type
+ application4 = "Not the Application you are looking for"
+ assert application1 != application4
+
+ application5 = Application(config)
+ # supported systems could be in any order
+ setattr(application5, "supported_systems", ["system2", "system1"])
+ assert application1 == application5
+
+ def test_can_run_on(self) -> None:
+ """Test Application can run on."""
+ config = ApplicationConfig(name="application", supported_systems=["System-A"])
+
+ application = Application(config)
+ assert application.can_run_on("System-A")
+ assert not application.can_run_on("System-B")
+
+ applications = get_application("application_1", "System 1")
+ assert len(applications) == 1
+ assert applications[0].can_run_on("System 1")
+
+ def test_get_deploy_data(self, tmp_path: Path) -> None:
+ """Test Application can run on."""
+ src, dest = "src", "dest"
+ config = ApplicationConfig(
+ name="application", deploy_data=[(src, dest)], config_location=tmp_path
+ )
+ src_path = tmp_path / src
+ src_path.mkdir()
+ application = Application(config)
+ assert application.get_deploy_data() == [DataPaths(src_path, dest)]
+
+ def test_get_deploy_data_no_config_location(self) -> None:
+ """Test that getting deploy data fails if no config location provided."""
+ with pytest.raises(
+ Exception, match="Unable to get application .* config location"
+ ):
+ Application(ApplicationConfig(name="application")).get_deploy_data()
+
+ def test_unable_to_create_application_without_name(self) -> None:
+ """Test that it is not possible to create application without name."""
+ with pytest.raises(Exception, match="Name is empty"):
+ Application(ApplicationConfig())
+
+ def test_application_config_without_commands(self) -> None:
+ """Test application config without commands."""
+ config = ApplicationConfig(name="application")
+ application = Application(config)
+ # pylint: disable=use-implicit-booleaness-not-comparison
+ assert application.commands == {}
+
+ @pytest.mark.parametrize(
+ "config, expected_params",
+ (
+ (
+ ApplicationConfig(
+ name="application",
+ commands={"command": ["cmd {user_params:0} {user_params:1}"]},
+ user_params={
+ "command": [
+ UserParamConfig(
+ name="--param1", description="param1", alias="param1"
+ ),
+ UserParamConfig(
+ name="--param2", description="param2", alias="param2"
+ ),
+ ]
+ },
+ ),
+ [Param("--param1", "param1"), Param("--param2", "param2")],
+ ),
+ (
+ ApplicationConfig(
+ name="application",
+ commands={"command": ["cmd {user_params:param1} {user_params:1}"]},
+ user_params={
+ "command": [
+ UserParamConfig(
+ name="--param1", description="param1", alias="param1"
+ ),
+ UserParamConfig(
+ name="--param2", description="param2", alias="param2"
+ ),
+ ]
+ },
+ ),
+ [Param("--param1", "param1"), Param("--param2", "param2")],
+ ),
+ (
+ ApplicationConfig(
+ name="application",
+ commands={"command": ["cmd {user_params:param1}"]},
+ user_params={
+ "command": [
+ UserParamConfig(
+ name="--param1", description="param1", alias="param1"
+ ),
+ UserParamConfig(
+ name="--param2", description="param2", alias="param2"
+ ),
+ ]
+ },
+ ),
+ [Param("--param1", "param1")],
+ ),
+ ),
+ )
+ def test_remove_unused_params(
+ self, config: ApplicationConfig, expected_params: List[Param]
+ ) -> None:
+ """Test mod remove_unused_parameter."""
+ application = Application(config)
+ application.remove_unused_params()
+ assert application.commands["command"].params == expected_params
+
+
+@pytest.mark.parametrize(
+ "config, expected_error",
+ (
+ (
+ ExtendedApplicationConfig(name="application"),
+ pytest.raises(Exception, match="No supported systems definition provided"),
+ ),
+ (
+ ExtendedApplicationConfig(
+ name="application", supported_systems=[NamedExecutionConfig(name="")]
+ ),
+ pytest.raises(
+ Exception,
+ match="Unable to read supported system definition, name is missed",
+ ),
+ ),
+ (
+ ExtendedApplicationConfig(
+ name="application",
+ supported_systems=[
+ NamedExecutionConfig(
+ name="system",
+ commands={"command": ["cmd"]},
+ user_params={"command": [UserParamConfig(name="param")]},
+ )
+ ],
+ commands={"command": ["cmd {user_params:0}"]},
+ user_params={"command": [UserParamConfig(name="param")]},
+ ),
+ pytest.raises(
+ Exception, match="Default parameters for command .* should have aliases"
+ ),
+ ),
+ (
+ ExtendedApplicationConfig(
+ name="application",
+ supported_systems=[
+ NamedExecutionConfig(
+ name="system",
+ commands={"command": ["cmd"]},
+ user_params={"command": [UserParamConfig(name="param")]},
+ )
+ ],
+ commands={"command": ["cmd {user_params:0}"]},
+ user_params={"command": [UserParamConfig(name="param", alias="param")]},
+ ),
+ pytest.raises(
+ Exception, match="system parameters for command .* should have aliases"
+ ),
+ ),
+ ),
+)
+def test_load_application_exceptional_cases(
+ config: ExtendedApplicationConfig, expected_error: Any
+) -> None:
+ """Test exceptional cases for application load function."""
+ with expected_error:
+ load_applications(config)
+
+
+def test_load_application() -> None:
+ """Test application load function.
+
+ The main purpose of this test is to test configuration for application
+ for different systems. All configuration should be correctly
+ overridden if needed.
+ """
+ application_5 = get_application("application_5")
+ assert len(application_5) == 2
+
+ default_commands = {
+ "build": Command(["default build command"]),
+ "run": Command(["default run command"]),
+ }
+ default_variables = {"var1": "value1", "var2": "value2"}
+
+ application_5_0 = application_5[0]
+ assert application_5_0.build_dir == "default_build_dir"
+ assert application_5_0.supported_systems == ["System 1"]
+ assert application_5_0.commands == default_commands
+ assert application_5_0.variables == default_variables
+ assert application_5_0.lock is False
+
+ application_5_1 = application_5[1]
+ assert application_5_1.build_dir == application_5_0.build_dir
+ assert application_5_1.supported_systems == ["System 2"]
+ assert application_5_1.commands == application_5_1.commands
+ assert application_5_1.variables == default_variables
+
+ application_5a = get_application("application_5A")
+ assert len(application_5a) == 2
+
+ application_5a_0 = application_5a[0]
+ assert application_5a_0.supported_systems == ["System 1"]
+ assert application_5a_0.build_dir == "build_5A"
+ assert application_5a_0.commands == default_commands
+ assert application_5a_0.variables == {"var1": "new value1", "var2": "value2"}
+ assert application_5a_0.lock is False
+
+ application_5a_1 = application_5a[1]
+ assert application_5a_1.supported_systems == ["System 2"]
+ assert application_5a_1.build_dir == "build"
+ assert application_5a_1.commands == {
+ "build": Command(["default build command"]),
+ "run": Command(["run command on system 2"]),
+ }
+ assert application_5a_1.variables == {"var1": "value1", "var2": "new value2"}
+ assert application_5a_1.lock is True
+
+ application_5b = get_application("application_5B")
+ assert len(application_5b) == 2
+
+ application_5b_0 = application_5b[0]
+ assert application_5b_0.build_dir == "build_5B"
+ assert application_5b_0.supported_systems == ["System 1"]
+ assert application_5b_0.commands == {
+ "build": Command(["default build command with value for var1 System1"], []),
+ "run": Command(["default run command with value for var2 System1"]),
+ }
+ assert "non_used_command" not in application_5b_0.commands
+
+ application_5b_1 = application_5b[1]
+ assert application_5b_1.build_dir == "build"
+ assert application_5b_1.supported_systems == ["System 2"]
+ assert application_5b_1.commands == {
+ "build": Command(
+ [
+ "build command on system 2 with value"
+ " for var1 System2 {user_params:param1}"
+ ],
+ [
+ Param(
+ "--param",
+ "Sample command param",
+ ["value1", "value2", "value3"],
+ "value1",
+ )
+ ],
+ ),
+ "run": Command(["run command on system 2"], []),
+ }
diff --git a/tests/aiet/test_backend_common.py b/tests/aiet/test_backend_common.py
new file mode 100644
index 0000000..12c30ec
--- /dev/null
+++ b/tests/aiet/test_backend_common.py
@@ -0,0 +1,486 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use,protected-access
+"""Tests for the common backend module."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import IO
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+from unittest.mock import MagicMock
+
+import pytest
+
+from aiet.backend.application import Application
+from aiet.backend.common import Backend
+from aiet.backend.common import BaseBackendConfig
+from aiet.backend.common import Command
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import load_config
+from aiet.backend.common import Param
+from aiet.backend.common import parse_raw_parameter
+from aiet.backend.common import remove_backend
+from aiet.backend.config import ApplicationConfig
+from aiet.backend.config import UserParamConfig
+from aiet.backend.execution import ExecutionContext
+from aiet.backend.execution import ParamResolver
+from aiet.backend.system import System
+
+
+@pytest.mark.parametrize(
+ "directory_name, expected_exception",
+ (
+ ("some_dir", does_not_raise()),
+ (None, pytest.raises(Exception, match="No directory name provided")),
+ ),
+)
+def test_remove_backend(
+ monkeypatch: Any, directory_name: str, expected_exception: Any
+) -> None:
+ """Test remove_backend function."""
+ mock_remove_resource = MagicMock()
+ monkeypatch.setattr("aiet.backend.common.remove_resource", mock_remove_resource)
+
+ with expected_exception:
+ remove_backend(directory_name, "applications")
+
+
+@pytest.mark.parametrize(
+ "filename, expected_exception",
+ (
+ ("application_config.json", does_not_raise()),
+ (None, pytest.raises(Exception, match="Unable to read config")),
+ ),
+)
+def test_load_config(
+ filename: str, expected_exception: Any, test_resources_path: Path, monkeypatch: Any
+) -> None:
+ """Test load_config."""
+ with expected_exception:
+ configs: List[Optional[Union[Path, IO[bytes]]]] = (
+ [None]
+ if not filename
+ else [
+ # Ignore pylint warning as 'with' can't be used inside of a
+ # generator expression.
+ # pylint: disable=consider-using-with
+ open(test_resources_path / filename, "rb"),
+ test_resources_path / filename,
+ ]
+ )
+ for config in configs:
+ json_mock = MagicMock()
+ monkeypatch.setattr("aiet.backend.common.json.load", json_mock)
+ load_config(config)
+ json_mock.assert_called_once()
+
+
+class TestBackend:
+ """Test Backend class."""
+
+ def test___repr__(self) -> None:
+ """Test the representation of Backend instance."""
+ backend = Backend(
+ BaseBackendConfig(name="Testing name", description="Testing description")
+ )
+ assert str(backend) == "Testing name"
+
+ def test__eq__(self) -> None:
+ """Test equality method with different cases."""
+ backend1 = Backend(BaseBackendConfig(name="name", description="description"))
+ backend1.commands = {"command": Command(["command"])}
+
+ backend2 = Backend(BaseBackendConfig(name="name", description="description"))
+ backend2.commands = {"command": Command(["command"])}
+
+ backend3 = Backend(
+ BaseBackendConfig(
+ name="Ben", description="This is not the Backend you are looking for"
+ )
+ )
+ backend3.commands = {"wave": Command(["wave hand"])}
+
+ backend4 = "Foo" # checking not isinstance(backend4, Backend)
+
+ assert backend1 == backend2
+ assert backend1 != backend3
+ assert backend1 != backend4
+
+ @pytest.mark.parametrize(
+ "parameter, valid",
+ [
+ ("--choice-param dummy_value_1", True),
+ ("--choice-param wrong_value", False),
+ ("--open-param something", True),
+ ("--wrong-param value", False),
+ ],
+ )
+ def test_validate_parameter(
+ self, parameter: str, valid: bool, test_resources_path: Path
+ ) -> None:
+ """Test validate_parameter."""
+ config = cast(
+ List[ApplicationConfig],
+ load_config(test_resources_path / "hello_world.json"),
+ )
+ # The application configuration is a list of configurations so we need
+ # only the first one
+ # Exercise the validate_parameter test using the Application classe which
+ # inherits from Backend.
+ application = Application(config[0])
+ assert application.validate_parameter("run", parameter) == valid
+
+ def test_validate_parameter_with_invalid_command(
+ self, test_resources_path: Path
+ ) -> None:
+ """Test validate_parameter with an invalid command_name."""
+ config = cast(
+ List[ApplicationConfig],
+ load_config(test_resources_path / "hello_world.json"),
+ )
+ application = Application(config[0])
+ with pytest.raises(AttributeError) as err:
+ # command foo does not exist, so raise an error
+ application.validate_parameter("foo", "bar")
+ assert "Unknown command: 'foo'" in str(err.value)
+
+ def test_build_command(self, monkeypatch: Any) -> None:
+ """Test command building."""
+ config = {
+ "name": "test",
+ "commands": {
+ "build": ["build {user_params:0} {user_params:1}"],
+ "run": ["run {user_params:0}"],
+ "post_run": ["post_run {application_params:0} on {system_params:0}"],
+ "some_command": ["Command with {variables:var_A}"],
+ "empty_command": [""],
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "choice_param_0=",
+ "values": [1, 2, 3],
+ "default_value": 1,
+ },
+ {"name": "choice_param_1", "values": [3, 4, 5], "default_value": 3},
+ {"name": "choice_param_3", "values": [6, 7, 8]},
+ ],
+ "run": [{"name": "flag_param_0"}],
+ },
+ "variables": {"var_A": "value for variable A"},
+ }
+
+ monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock())
+ application, system = Application(config), System(config) # type: ignore
+ context = ExecutionContext(
+ app=application,
+ app_params=[],
+ system=system,
+ system_params=[],
+ custom_deploy_data=[],
+ )
+
+ param_resolver = ParamResolver(context)
+
+ cmd = application.build_command(
+ "build", ["choice_param_0=2", "choice_param_1=4"], param_resolver
+ )
+ assert cmd == ["build choice_param_0=2 choice_param_1 4"]
+
+ cmd = application.build_command("build", ["choice_param_0=2"], param_resolver)
+ assert cmd == ["build choice_param_0=2 choice_param_1 3"]
+
+ cmd = application.build_command(
+ "build", ["choice_param_0=2", "choice_param_3=7"], param_resolver
+ )
+ assert cmd == ["build choice_param_0=2 choice_param_1 3"]
+
+ with pytest.raises(
+ ConfigurationException, match="Command 'foo' could not be found."
+ ):
+ application.build_command("foo", [""], param_resolver)
+
+ cmd = application.build_command("some_command", [], param_resolver)
+ assert cmd == ["Command with value for variable A"]
+
+ cmd = application.build_command("empty_command", [], param_resolver)
+ assert cmd == [""]
+
+ @pytest.mark.parametrize("class_", [Application, System])
+ def test_build_command_unknown_variable(self, class_: type) -> None:
+ """Test that unable to construct backend with unknown variable."""
+ with pytest.raises(Exception, match="Unknown variable var1"):
+ config = {"name": "test", "commands": {"run": ["run {variables:var1}"]}}
+ class_(config)
+
+ @pytest.mark.parametrize(
+ "class_, config, expected_output",
+ [
+ (
+ Application,
+ {
+ "name": "test",
+ "commands": {
+ "build": ["build {user_params:0} {user_params:1}"],
+ "run": ["run {user_params:0}"],
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "choice_param_0=",
+ "values": ["a", "b", "c"],
+ "default_value": "a",
+ "alias": "param_1",
+ },
+ {
+ "name": "choice_param_1",
+ "values": ["a", "b", "c"],
+ "default_value": "a",
+ "alias": "param_2",
+ },
+ {"name": "choice_param_3", "values": ["a", "b", "c"]},
+ ],
+ "run": [{"name": "flag_param_0"}],
+ },
+ },
+ [
+ (
+ "b",
+ Param(
+ name="choice_param_0=",
+ description="",
+ values=["a", "b", "c"],
+ default_value="a",
+ alias="param_1",
+ ),
+ ),
+ (
+ "a",
+ Param(
+ name="choice_param_1",
+ description="",
+ values=["a", "b", "c"],
+ default_value="a",
+ alias="param_2",
+ ),
+ ),
+ (
+ "c",
+ Param(
+ name="choice_param_3",
+ description="",
+ values=["a", "b", "c"],
+ ),
+ ),
+ ],
+ ),
+ (System, {"name": "test"}, []),
+ ],
+ )
+ def test_resolved_parameters(
+ self,
+ monkeypatch: Any,
+ class_: type,
+ config: Dict,
+ expected_output: List[Tuple[Optional[str], Param]],
+ ) -> None:
+ """Test command building."""
+ monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock())
+ backend = class_(config)
+
+ params = backend.resolved_parameters(
+ "build", ["choice_param_0=b", "choice_param_3=c"]
+ )
+ assert params == expected_output
+
+ @pytest.mark.parametrize(
+ ["param_name", "user_param", "expected_value"],
+ [
+ (
+ "test_name",
+ "test_name=1234",
+ "1234",
+ ), # optional parameter using '='
+ (
+ "test_name",
+ "test_name 1234",
+ "1234",
+ ), # optional parameter using ' '
+ ("test_name", "test_name", None), # flag
+ (None, "test_name=1234", "1234"), # positional parameter
+ ],
+ )
+ def test_resolved_user_parameters(
+ self, param_name: str, user_param: str, expected_value: str
+ ) -> None:
+ """Test different variants to provide user parameters."""
+ # A dummy config providing one backend config
+ config = {
+ "name": "test_backend",
+ "commands": {
+ "test": ["user_param:test_param"],
+ },
+ "user_params": {
+ "test": [UserParamConfig(name=param_name, alias="test_name")],
+ },
+ }
+ backend = Backend(cast(BaseBackendConfig, config))
+ params = backend.resolved_parameters(
+ command_name="test", user_params=[user_param]
+ )
+ assert len(params) == 1
+ value, param = params[0]
+ assert param_name == param.name
+ assert expected_value == value
+
+ @pytest.mark.parametrize(
+ "input_param,expected",
+ [
+ ("--param=1", ("--param", "1")),
+ ("--param 1", ("--param", "1")),
+ ("--flag", ("--flag", None)),
+ ],
+ )
+ def test__parse_raw_parameter(
+ self, input_param: str, expected: Tuple[str, Optional[str]]
+ ) -> None:
+ """Test internal method of parsing a single raw parameter."""
+ assert parse_raw_parameter(input_param) == expected
+
+
+class TestParam:
+ """Test Param class."""
+
+ def test__eq__(self) -> None:
+ """Test equality method with different cases."""
+ param1 = Param(name="test", description="desc", values=["values"])
+ param2 = Param(name="test", description="desc", values=["values"])
+ param3 = Param(name="test1", description="desc", values=["values"])
+ param4 = object()
+
+ assert param1 == param2
+ assert param1 != param3
+ assert param1 != param4
+
+ def test_get_details(self) -> None:
+ """Test get_details() method."""
+ param1 = Param(name="test", description="desc", values=["values"])
+ assert param1.get_details() == {
+ "name": "test",
+ "values": ["values"],
+ "description": "desc",
+ }
+
+ def test_invalid(self) -> None:
+ """Test invalid use cases for the Param class."""
+ with pytest.raises(
+ ConfigurationException,
+ match="Either name, alias or both must be set to identify a parameter.",
+ ):
+ Param(name=None, description="desc", values=["values"])
+
+
+class TestCommand:
+ """Test Command class."""
+
+ def test_get_details(self) -> None:
+ """Test get_details() method."""
+ param1 = Param(name="test", description="desc", values=["values"])
+ command1 = Command(command_strings=["echo test"], params=[param1])
+ assert command1.get_details() == {
+ "command_strings": ["echo test"],
+ "user_params": [
+ {"name": "test", "values": ["values"], "description": "desc"}
+ ],
+ }
+
+ def test__eq__(self) -> None:
+ """Test equality method with different cases."""
+ param1 = Param("test", "desc", ["values"])
+ param2 = Param("test1", "desc1", ["values1"])
+ command1 = Command(command_strings=["echo test"], params=[param1])
+ command2 = Command(command_strings=["echo test"], params=[param1])
+ command3 = Command(command_strings=["echo test"])
+ command4 = Command(command_strings=["echo test"], params=[param2])
+ command5 = object()
+
+ assert command1 == command2
+ assert command1 != command3
+ assert command1 != command4
+ assert command1 != command5
+
+ @pytest.mark.parametrize(
+ "params, expected_error",
+ [
+ [[], does_not_raise()],
+ [[Param("param", "param description", [])], does_not_raise()],
+ [
+ [
+ Param("param", "param description", [], None, "alias"),
+ Param("param", "param description", [], None),
+ ],
+ does_not_raise(),
+ ],
+ [
+ [
+ Param("param1", "param1 description", [], None, "alias1"),
+ Param("param2", "param2 description", [], None, "alias2"),
+ ],
+ does_not_raise(),
+ ],
+ [
+ [
+ Param("param", "param description", [], None, "alias"),
+ Param("param", "param description", [], None, "alias"),
+ ],
+ pytest.raises(ConfigurationException, match="Non unique aliases alias"),
+ ],
+ [
+ [
+ Param("alias", "param description", [], None, "alias1"),
+ Param("param", "param description", [], None, "alias"),
+ ],
+ pytest.raises(
+ ConfigurationException,
+ match="Aliases .* could not be used as parameter name",
+ ),
+ ],
+ [
+ [
+ Param("alias", "param description", [], None, "alias"),
+ Param("param1", "param1 description", [], None, "alias1"),
+ ],
+ does_not_raise(),
+ ],
+ [
+ [
+ Param("alias", "param description", [], None, "alias"),
+ Param("alias", "param1 description", [], None, "alias1"),
+ ],
+ pytest.raises(
+ ConfigurationException,
+ match="Aliases .* could not be used as parameter name",
+ ),
+ ],
+ [
+ [
+ Param("param1", "param1 description", [], None, "alias1"),
+ Param("param2", "param2 description", [], None, "alias1"),
+ Param("param3", "param3 description", [], None, "alias2"),
+ Param("param4", "param4 description", [], None, "alias2"),
+ ],
+ pytest.raises(
+ ConfigurationException, match="Non unique aliases alias1, alias2"
+ ),
+ ],
+ ],
+ )
+ def test_validate_params(self, params: List[Param], expected_error: Any) -> None:
+ """Test command validation function."""
+ with expected_error:
+ Command([], params)
diff --git a/tests/aiet/test_backend_controller.py b/tests/aiet/test_backend_controller.py
new file mode 100644
index 0000000..8836ec5
--- /dev/null
+++ b/tests/aiet/test_backend_controller.py
@@ -0,0 +1,160 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for system controller."""
+import csv
+import os
+import time
+from pathlib import Path
+from typing import Any
+
+import psutil
+import pytest
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.controller import SystemController
+from aiet.backend.controller import SystemControllerSingleInstance
+from aiet.utils.proc import ShellCommand
+
+
+def get_system_controller(**kwargs: Any) -> SystemController:
+ """Get service controller."""
+ single_instance = kwargs.get("single_instance", False)
+ if single_instance:
+ pid_file_path = kwargs.get("pid_file_path")
+ return SystemControllerSingleInstance(pid_file_path)
+
+ return SystemController()
+
+
+def test_service_controller() -> None:
+ """Test service controller functionality."""
+ service_controller = get_system_controller()
+
+ assert service_controller.get_output() == ("", "")
+ with pytest.raises(ConfigurationException, match="Wrong working directory"):
+ service_controller.start(["sleep 100"], Path("unknown"))
+
+ service_controller.start(["sleep 100"], Path.cwd())
+ assert service_controller.is_running()
+
+ service_controller.stop(True)
+ assert not service_controller.is_running()
+ assert service_controller.get_output() == ("", "")
+
+ service_controller.stop()
+
+ with pytest.raises(
+ ConfigurationException, match="System should have only one command to run"
+ ):
+ service_controller.start(["sleep 100", "sleep 101"], Path.cwd())
+
+ with pytest.raises(ConfigurationException, match="No startup command provided"):
+ service_controller.start([""], Path.cwd())
+
+
+def test_service_controller_bad_configuration() -> None:
+ """Test service controller functionality for bad configuration."""
+ with pytest.raises(Exception, match="No pid file path presented"):
+ service_controller = get_system_controller(
+ single_instance=True, pid_file_path=None
+ )
+ service_controller.start(["sleep 100"], Path.cwd())
+
+
+def test_service_controller_writes_process_info_correctly(tmpdir: Any) -> None:
+ """Test that controller writes process info correctly."""
+ pid_file = Path(tmpdir) / "test.pid"
+
+ service_controller = get_system_controller(
+ single_instance=True, pid_file_path=Path(tmpdir) / "test.pid"
+ )
+
+ service_controller.start(["sleep 100"], Path.cwd())
+ assert service_controller.is_running()
+ assert pid_file.is_file()
+
+ with open(pid_file, "r", encoding="utf-8") as file:
+ csv_reader = csv.reader(file)
+ rows = list(csv_reader)
+ assert len(rows) == 1
+
+ name, *_ = rows[0]
+ assert name == "sleep"
+
+ service_controller.stop()
+ assert pid_file.exists()
+
+
+def test_service_controller_does_not_write_process_info_if_process_finishes(
+ tmpdir: Any,
+) -> None:
+ """Test that controller does not write process info if process already finished."""
+ pid_file = Path(tmpdir) / "test.pid"
+ service_controller = get_system_controller(
+ single_instance=True, pid_file_path=pid_file
+ )
+ service_controller.is_running = lambda: False # type: ignore
+ service_controller.start(["echo hello"], Path.cwd())
+
+ assert not pid_file.exists()
+
+
+def test_service_controller_searches_for_previous_instances_correctly(
+ tmpdir: Any,
+) -> None:
+ """Test that controller searches for previous instances correctly."""
+ pid_file = Path(tmpdir) / "test.pid"
+ command = ShellCommand().run("sleep", "100")
+ assert command.is_alive()
+
+ pid = command.process.pid
+ process = psutil.Process(pid)
+ with open(pid_file, "w", encoding="utf-8") as file:
+ csv_writer = csv.writer(file)
+ csv_writer.writerow(("some_process", "some_program", "some_cwd", os.getpid()))
+ csv_writer.writerow((process.name(), process.exe(), process.cwd(), process.pid))
+ csv_writer.writerow(("some_old_process", "not_running", "from_nowhere", 77777))
+
+ service_controller = get_system_controller(
+ single_instance=True, pid_file_path=pid_file
+ )
+ service_controller.start(["sleep 100"], Path.cwd())
+ # controller should stop this process as it is currently running and
+ # mentioned in pid file
+ assert not command.is_alive()
+
+ service_controller.stop()
+
+
+@pytest.mark.parametrize(
+ "executable", ["test_backend_run_script.sh", "test_backend_run"]
+)
+def test_service_controller_run_shell_script(
+ executable: str, test_resources_path: Path
+) -> None:
+ """Test controller's ability to run shell scripts."""
+ script_path = test_resources_path / "scripts"
+
+ service_controller = get_system_controller()
+
+ service_controller.start([executable], script_path)
+
+ assert service_controller.is_running()
+ # give time for the command to produce output
+ time.sleep(2)
+ service_controller.stop(wait=True)
+ assert not service_controller.is_running()
+ stdout, stderr = service_controller.get_output()
+ assert stdout == "Hello from script\n"
+ assert stderr == "Oops!\n"
+
+
+def test_service_controller_does_nothing_if_not_started(tmpdir: Any) -> None:
+ """Test that nothing happened if controller is not started."""
+ service_controller = get_system_controller(
+ single_instance=True, pid_file_path=Path(tmpdir) / "test.pid"
+ )
+
+ assert not service_controller.is_running()
+ service_controller.stop()
+ assert not service_controller.is_running()
diff --git a/tests/aiet/test_backend_execution.py b/tests/aiet/test_backend_execution.py
new file mode 100644
index 0000000..8aa45f1
--- /dev/null
+++ b/tests/aiet/test_backend_execution.py
@@ -0,0 +1,526 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use
+"""Test backend context module."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import Optional
+from unittest import mock
+from unittest.mock import MagicMock
+
+import pytest
+from sh import CommandNotFound
+
+from aiet.backend.application import Application
+from aiet.backend.application import get_application
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import DataPaths
+from aiet.backend.common import UserParamConfig
+from aiet.backend.config import ApplicationConfig
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.config import SystemConfig
+from aiet.backend.execution import deploy_data
+from aiet.backend.execution import execute_commands_locally
+from aiet.backend.execution import ExecutionContext
+from aiet.backend.execution import get_application_and_system
+from aiet.backend.execution import get_application_by_name_and_system
+from aiet.backend.execution import get_file_lock_path
+from aiet.backend.execution import get_tool_by_system
+from aiet.backend.execution import ParamResolver
+from aiet.backend.execution import Reporter
+from aiet.backend.execution import wait
+from aiet.backend.output_parser import OutputParser
+from aiet.backend.system import get_system
+from aiet.backend.system import load_system
+from aiet.backend.tool import get_tool
+from aiet.utils.proc import CommandFailedException
+
+
+def test_context_param_resolver(tmpdir: Any) -> None:
+ """Test parameter resolving."""
+ system_config_location = Path(tmpdir) / "system"
+ system_config_location.mkdir()
+
+ application_config_location = Path(tmpdir) / "application"
+ application_config_location.mkdir()
+
+ ctx = ExecutionContext(
+ app=Application(
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ config_location=application_config_location,
+ build_dir="build-{application.name}-{system.name}",
+ commands={
+ "run": [
+ "run_command1 {user_params:0}",
+ "run_command2 {user_params:1}",
+ ]
+ },
+ variables={"var_1": "value for var_1"},
+ user_params={
+ "run": [
+ UserParamConfig(
+ name="--param1",
+ description="Param 1",
+ default_value="123",
+ alias="param_1",
+ ),
+ UserParamConfig(
+ name="--param2", description="Param 2", default_value="456"
+ ),
+ UserParamConfig(
+ name="--param3", description="Param 3", alias="param_3"
+ ),
+ UserParamConfig(
+ name="--param4=",
+ description="Param 4",
+ default_value="456",
+ alias="param_4",
+ ),
+ UserParamConfig(
+ description="Param 5",
+ default_value="789",
+ alias="param_5",
+ ),
+ ]
+ },
+ )
+ ),
+ app_params=["--param2=789"],
+ system=load_system(
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ config_location=system_config_location,
+ build_dir="build",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={
+ "build": ["build_command1 {user_params:0}"],
+ "run": ["run_command {application.commands.run:1}"],
+ },
+ variables={"var_1": "value for var_1"},
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="--param1", description="Param 1", default_value="aaa"
+ ),
+ UserParamConfig(name="--param2", description="Param 2"),
+ ]
+ },
+ )
+ ),
+ system_params=["--param1=bbb"],
+ custom_deploy_data=[],
+ )
+
+ param_resolver = ParamResolver(ctx)
+ expected_values = {
+ "application.name": "test_application",
+ "application.description": "Test application",
+ "application.config_dir": str(application_config_location),
+ "application.build_dir": "{}/build-test_application-test_system".format(
+ application_config_location
+ ),
+ "application.commands.run:0": "run_command1 --param1 123",
+ "application.commands.run.params:0": "123",
+ "application.commands.run.params:param_1": "123",
+ "application.commands.run:1": "run_command2 --param2 789",
+ "application.commands.run.params:1": "789",
+ "application.variables:var_1": "value for var_1",
+ "system.name": "test_system",
+ "system.description": "Test system",
+ "system.config_dir": str(system_config_location),
+ "system.commands.build:0": "build_command1 --param1 bbb",
+ "system.commands.run:0": "run_command run_command2 --param2 789",
+ "system.commands.build.params:0": "bbb",
+ "system.variables:var_1": "value for var_1",
+ }
+
+ for param, value in expected_values.items():
+ assert param_resolver(param) == value
+
+ assert ctx.build_dir() == Path(
+ "{}/build-test_application-test_system".format(application_config_location)
+ )
+
+ expected_errors = {
+ "application.variables:var_2": pytest.raises(
+ Exception, match="Unknown variable var_2"
+ ),
+ "application.commands.clean:0": pytest.raises(
+ Exception, match="Command clean not found"
+ ),
+ "application.commands.run:2": pytest.raises(
+ Exception, match="Invalid index 2 for command run"
+ ),
+ "application.commands.run.params:5": pytest.raises(
+ Exception, match="Invalid parameter index 5 for command run"
+ ),
+ "application.commands.run.params:param_2": pytest.raises(
+ Exception,
+ match="No value for parameter with index or alias param_2 of command run",
+ ),
+ "UNKNOWN": pytest.raises(
+ Exception, match="Unable to resolve parameter UNKNOWN"
+ ),
+ "system.commands.build.params:1": pytest.raises(
+ Exception,
+ match="No value for parameter with index or alias 1 of command build",
+ ),
+ "system.commands.build:A": pytest.raises(
+ Exception, match="Bad command index A"
+ ),
+ "system.variables:var_2": pytest.raises(
+ Exception, match="Unknown variable var_2"
+ ),
+ }
+ for param, error in expected_errors.items():
+ with error:
+ param_resolver(param)
+
+ resolved_params = ctx.app.resolved_parameters("run", [])
+ expected_user_params = {
+ "user_params:0": "--param1 123",
+ "user_params:param_1": "--param1 123",
+ "user_params:2": "--param3",
+ "user_params:param_3": "--param3",
+ "user_params:3": "--param4=456",
+ "user_params:param_4": "--param4=456",
+ "user_params:param_5": "789",
+ }
+ for param, expected_value in expected_user_params.items():
+ assert param_resolver(param, "run", resolved_params) == expected_value
+
+ with pytest.raises(
+ Exception, match="Invalid index 5 for user params of command run"
+ ):
+ param_resolver("user_params:5", "run", resolved_params)
+
+ with pytest.raises(
+ Exception, match="No user parameter for command 'run' with alias 'param_2'."
+ ):
+ param_resolver("user_params:param_2", "run", resolved_params)
+
+ with pytest.raises(Exception, match="Unable to resolve user params"):
+ param_resolver("user_params:0", "", resolved_params)
+
+ bad_ctx = ExecutionContext(
+ app=Application(
+ ApplicationConfig(
+ name="test_application",
+ config_location=application_config_location,
+ build_dir="build-{user_params:0}",
+ )
+ ),
+ app_params=["--param2=789"],
+ system=load_system(
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ config_location=system_config_location,
+ build_dir="build-{system.commands.run.params:123}",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ )
+ ),
+ system_params=["--param1=bbb"],
+ custom_deploy_data=[],
+ )
+ param_resolver = ParamResolver(bad_ctx)
+ with pytest.raises(Exception, match="Unable to resolve user params"):
+ bad_ctx.build_dir()
+
+
+# pylint: disable=too-many-arguments
+@pytest.mark.parametrize(
+ "application_name, soft_lock, sys_lock, lock_dir, expected_error, expected_path",
+ (
+ (
+ "test_application",
+ True,
+ True,
+ Path("/tmp"),
+ does_not_raise(),
+ Path("/tmp/middleware_test_application_test_system.lock"),
+ ),
+ (
+ "$$test_application$!:",
+ True,
+ True,
+ Path("/tmp"),
+ does_not_raise(),
+ Path("/tmp/middleware_test_application_test_system.lock"),
+ ),
+ (
+ "test_application",
+ True,
+ True,
+ Path("unknown"),
+ pytest.raises(
+ Exception, match="Invalid directory unknown for lock files provided"
+ ),
+ None,
+ ),
+ (
+ "test_application",
+ False,
+ True,
+ Path("/tmp"),
+ does_not_raise(),
+ Path("/tmp/middleware_test_system.lock"),
+ ),
+ (
+ "test_application",
+ True,
+ False,
+ Path("/tmp"),
+ does_not_raise(),
+ Path("/tmp/middleware_test_application.lock"),
+ ),
+ (
+ "test_application",
+ False,
+ False,
+ Path("/tmp"),
+ pytest.raises(Exception, match="No filename for lock provided"),
+ None,
+ ),
+ ),
+)
+def test_get_file_lock_path(
+ application_name: str,
+ soft_lock: bool,
+ sys_lock: bool,
+ lock_dir: Path,
+ expected_error: Any,
+ expected_path: Path,
+) -> None:
+ """Test get_file_lock_path function."""
+ with expected_error:
+ ctx = ExecutionContext(
+ app=Application(ApplicationConfig(name=application_name, lock=soft_lock)),
+ app_params=[],
+ system=load_system(
+ SystemConfig(
+ name="test_system",
+ lock=sys_lock,
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ )
+ ),
+ system_params=[],
+ custom_deploy_data=[],
+ )
+ path = get_file_lock_path(ctx, lock_dir)
+ assert path == expected_path
+
+
+def test_get_application_by_name_and_system(monkeypatch: Any) -> None:
+ """Test exceptional case for get_application_by_name_and_system."""
+ monkeypatch.setattr(
+ "aiet.backend.execution.get_application",
+ MagicMock(return_value=[MagicMock(), MagicMock()]),
+ )
+
+ with pytest.raises(
+ ValueError,
+ match="Error during getting application test_application for the "
+ "system test_system",
+ ):
+ get_application_by_name_and_system("test_application", "test_system")
+
+
+def test_get_application_and_system(monkeypatch: Any) -> None:
+ """Test exceptional case for get_application_and_system."""
+ monkeypatch.setattr(
+ "aiet.backend.execution.get_system", MagicMock(return_value=None)
+ )
+
+ with pytest.raises(ValueError, match="System test_system is not found"):
+ get_application_and_system("test_application", "test_system")
+
+
+def test_wait_function(monkeypatch: Any) -> None:
+ """Test wait function."""
+ sleep_mock = MagicMock()
+ monkeypatch.setattr("time.sleep", sleep_mock)
+ wait(0.1)
+ sleep_mock.assert_called_once()
+
+
+def test_deployment_execution_context() -> None:
+ """Test property 'is_deploy_needed' of the ExecutionContext."""
+ ctx = ExecutionContext(
+ app=get_application("application_1")[0],
+ app_params=[],
+ system=get_system("System 1"),
+ system_params=[],
+ )
+ assert not ctx.is_deploy_needed
+ deploy_data(ctx) # should be a NOP
+
+ ctx = ExecutionContext(
+ app=get_application("application_1")[0],
+ app_params=[],
+ system=get_system("System 1"),
+ system_params=[],
+ custom_deploy_data=[DataPaths(Path("README.md"), ".")],
+ )
+ assert ctx.is_deploy_needed
+
+ ctx = ExecutionContext(
+ app=get_application("application_1")[0],
+ app_params=[],
+ system=None,
+ system_params=[],
+ )
+ assert not ctx.is_deploy_needed
+ with pytest.raises(AssertionError):
+ deploy_data(ctx)
+
+ ctx = ExecutionContext(
+ app=get_tool("tool_1")[0],
+ app_params=[],
+ system=None,
+ system_params=[],
+ )
+ assert not ctx.is_deploy_needed
+ deploy_data(ctx) # should be a NOP
+
+
+@pytest.mark.parametrize(
+ ["tool_name", "system_name", "exception"],
+ [
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U65", None),
+ ("unknown tool", "Corstone-300: Cortex-M55+Ethos-U65", ConfigurationException),
+ ("vela", "unknown system", ConfigurationException),
+ ("vela", None, ConfigurationException),
+ ],
+)
+def test_get_tool_by_system(
+ tool_name: str, system_name: Optional[str], exception: Optional[Any]
+) -> None:
+ """Test exceptions thrown by function get_tool_by_system()."""
+
+ def test() -> None:
+ """Test call of get_tool_by_system()."""
+ tool = get_tool_by_system(tool_name, system_name)
+ assert tool is not None
+
+ if exception is None:
+ test()
+ else:
+ with pytest.raises(exception):
+ test()
+
+
+class TestExecuteCommandsLocally:
+ """Test execute_commands_locally() function."""
+
+ @pytest.mark.parametrize(
+ "first_command, exception, expected_output",
+ (
+ (
+ "echo 'hello'",
+ None,
+ "Running: echo 'hello'\nhello\nRunning: echo 'goodbye'\ngoodbye\n",
+ ),
+ (
+ "non-existent-command",
+ CommandNotFound,
+ "Running: non-existent-command\n",
+ ),
+ ("false", CommandFailedException, "Running: false\n"),
+ ),
+ ids=(
+ "runs_multiple_commands",
+ "stops_executing_on_non_existent_command",
+ "stops_executing_when_command_exits_with_error_code",
+ ),
+ )
+ def test_execution(
+ self,
+ first_command: str,
+ exception: Any,
+ expected_output: str,
+ test_resources_path: Path,
+ capsys: Any,
+ ) -> None:
+ """Test expected behaviour of the function."""
+ commands = [first_command, "echo 'goodbye'"]
+ cwd = test_resources_path
+ if exception is None:
+ execute_commands_locally(commands, cwd)
+ else:
+ with pytest.raises(exception):
+ execute_commands_locally(commands, cwd)
+
+ captured = capsys.readouterr()
+ assert captured.out == expected_output
+
+ def test_stops_executing_on_exception(
+ self, monkeypatch: Any, test_resources_path: Path
+ ) -> None:
+ """Ensure commands following an error-exit-code command don't run."""
+ # Mock execute_command() function
+ execute_command_mock = mock.MagicMock()
+ monkeypatch.setattr("aiet.utils.proc.execute_command", execute_command_mock)
+
+ # Mock Command object and assign as return value to execute_command()
+ cmd_mock = mock.MagicMock()
+ execute_command_mock.return_value = cmd_mock
+
+ # Mock the terminate_command (speed up test)
+ terminate_command_mock = mock.MagicMock()
+ monkeypatch.setattr("aiet.utils.proc.terminate_command", terminate_command_mock)
+
+ # Mock a thrown Exception and assign to Command().exit_code
+ exit_code_mock = mock.PropertyMock(side_effect=Exception("Exception."))
+ type(cmd_mock).exit_code = exit_code_mock
+
+ with pytest.raises(Exception, match="Exception."):
+ execute_commands_locally(
+ ["command_1", "command_2"], cwd=test_resources_path
+ )
+
+ # Assert only "command_1" was executed
+ assert execute_command_mock.call_count == 1
+
+
+def test_reporter(tmpdir: Any) -> None:
+ """Test class 'Reporter'."""
+ ctx = ExecutionContext(
+ app=get_application("application_4")[0],
+ app_params=["--app=TestApp"],
+ system=get_system("System 4"),
+ system_params=[],
+ )
+ assert ctx.system is not None
+
+ class MockParser(OutputParser):
+ """Mock implementation of an output parser."""
+
+ def __init__(self, metrics: Dict[str, Any]) -> None:
+ """Set up the MockParser."""
+ super().__init__(name="test")
+ self.metrics = metrics
+
+ def __call__(self, output: bytearray) -> Dict[str, Any]:
+ """Return mock metrics (ignoring the given output)."""
+ return self.metrics
+
+ metrics = {"Metric": 123, "AnotherMetric": 456}
+ reporter = Reporter(
+ parsers=[MockParser(metrics={key: val}) for key, val in metrics.items()],
+ )
+ reporter.parse(bytearray())
+ report = reporter.report(ctx)
+ assert report["system"]["name"] == ctx.system.name
+ assert report["system"]["params"] == {}
+ assert report["application"]["name"] == ctx.app.name
+ assert report["application"]["params"] == {"--app": "TestApp"}
+ assert report["test"]["metrics"] == metrics
+ report_file = Path(tmpdir) / "report.json"
+ reporter.save(report, report_file)
+ assert report_file.is_file()
diff --git a/tests/aiet/test_backend_output_parser.py b/tests/aiet/test_backend_output_parser.py
new file mode 100644
index 0000000..d659812
--- /dev/null
+++ b/tests/aiet/test_backend_output_parser.py
@@ -0,0 +1,152 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the output parsing."""
+import base64
+import json
+from typing import Any
+from typing import Dict
+
+import pytest
+
+from aiet.backend.output_parser import Base64OutputParser
+from aiet.backend.output_parser import OutputParser
+from aiet.backend.output_parser import RegexOutputParser
+
+
+OUTPUT_MATCH_ALL = bytearray(
+ """
+String1: My awesome string!
+String2: STRINGS_ARE_GREAT!!!
+Int: 12
+Float: 3.14
+""",
+ encoding="utf-8",
+)
+
+OUTPUT_NO_MATCH = bytearray(
+ """
+This contains no matches...
+Test1234567890!"£$%^&*()_+@~{}[]/.,<>?|
+""",
+ encoding="utf-8",
+)
+
+OUTPUT_PARTIAL_MATCH = bytearray(
+ "String1: My awesome string!",
+ encoding="utf-8",
+)
+
+REGEX_CONFIG = {
+ "FirstString": {"pattern": r"String1.*: (.*)", "type": "str"},
+ "SecondString": {"pattern": r"String2.*: (.*)!!!", "type": "str"},
+ "IntegerValue": {"pattern": r"Int.*: (.*)", "type": "int"},
+ "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"},
+}
+
+EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {}
+
+EXPECTED_METRICS_ALL = {
+ "FirstString": "My awesome string!",
+ "SecondString": "STRINGS_ARE_GREAT",
+ "IntegerValue": 12,
+ "FloatValue": 3.14,
+}
+
+EXPECTED_METRICS_PARTIAL = {
+ "FirstString": "My awesome string!",
+}
+
+
+class TestRegexOutputParser:
+ """Collect tests for the RegexOutputParser."""
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ ["output", "config", "expected_metrics"],
+ [
+ (OUTPUT_MATCH_ALL, REGEX_CONFIG, EXPECTED_METRICS_ALL),
+ (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL),
+ (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL),
+ (
+ OUTPUT_MATCH_ALL + OUTPUT_PARTIAL_MATCH,
+ REGEX_CONFIG,
+ EXPECTED_METRICS_ALL,
+ ),
+ (OUTPUT_NO_MATCH, REGEX_CONFIG, {}),
+ (OUTPUT_MATCH_ALL, EMPTY_REGEX_CONFIG, {}),
+ (bytearray(), EMPTY_REGEX_CONFIG, {}),
+ (bytearray(), REGEX_CONFIG, {}),
+ ],
+ )
+ def test_parsing(output: bytearray, config: Dict, expected_metrics: Dict) -> None:
+ """
+ Make sure the RegexOutputParser yields valid results.
+
+ I.e. return an empty dict if either the input or the config is empty and
+ return the parsed metrics otherwise.
+ """
+ parser = RegexOutputParser(name="Test", regex_config=config)
+ assert parser.name == "Test"
+ assert isinstance(parser, OutputParser)
+ res = parser(output)
+ assert res == expected_metrics
+
+ @staticmethod
+ def test_unsupported_type() -> None:
+ """An unsupported type in the regex_config must raise an exception."""
+ config = {"BrokenMetric": {"pattern": "(.*)", "type": "UNSUPPORTED_TYPE"}}
+ with pytest.raises(TypeError):
+ RegexOutputParser(name="Test", regex_config=config)
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "config",
+ (
+ {"TooManyGroups": {"pattern": r"(\w)(\d)", "type": "str"}},
+ {"NoGroups": {"pattern": r"\W", "type": "str"}},
+ ),
+ )
+ def test_invalid_pattern(config: Dict) -> None:
+ """Exactly one capturing parenthesis is allowed in the regex pattern."""
+ with pytest.raises(ValueError):
+ RegexOutputParser(name="Test", regex_config=config)
+
+
+@pytest.mark.parametrize(
+ "expected_metrics",
+ [
+ EXPECTED_METRICS_ALL,
+ EXPECTED_METRICS_PARTIAL,
+ ],
+)
+def test_base64_output_parser(expected_metrics: Dict) -> None:
+ """
+ Make sure the Base64OutputParser yields valid results.
+
+ I.e. return an empty dict if either the input or the config is empty and
+ return the parsed metrics otherwise.
+ """
+ parser = Base64OutputParser(name="Test")
+ assert parser.name == "Test"
+ assert isinstance(parser, OutputParser)
+
+ def create_base64_output(expected_metrics: Dict) -> bytearray:
+ json_str = json.dumps(expected_metrics, indent=4)
+ json_b64 = base64.b64encode(json_str.encode("utf-8"))
+ return (
+ OUTPUT_MATCH_ALL # Should not be matched by the Base64OutputParser
+ + f"<{Base64OutputParser.TAG_NAME}>".encode("utf-8")
+ + bytearray(json_b64)
+ + f"</{Base64OutputParser.TAG_NAME}>".encode("utf-8")
+ + OUTPUT_NO_MATCH # Just to add some difficulty...
+ )
+
+ output = create_base64_output(expected_metrics)
+ res = parser(output)
+ assert len(res) == 1
+ assert isinstance(res, dict)
+ for val in res.values():
+ assert val == expected_metrics
+
+ output = parser.filter_out_parsed_content(output)
+ assert output == (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH)
diff --git a/tests/aiet/test_backend_protocol.py b/tests/aiet/test_backend_protocol.py
new file mode 100644
index 0000000..2103238
--- /dev/null
+++ b/tests/aiet/test_backend_protocol.py
@@ -0,0 +1,231 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use,attribute-defined-outside-init,protected-access
+"""Tests for the protocol backend module."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from unittest.mock import MagicMock
+
+import paramiko
+import pytest
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.protocol import CustomSFTPClient
+from aiet.backend.protocol import LocalProtocol
+from aiet.backend.protocol import ProtocolFactory
+from aiet.backend.protocol import SSHProtocol
+
+
+class TestProtocolFactory:
+ """Test ProtocolFactory class."""
+
+ @pytest.mark.parametrize(
+ "config, expected_class, exception",
+ [
+ (
+ {
+ "protocol": "ssh",
+ "username": "user",
+ "password": "pass",
+ "hostname": "hostname",
+ "port": "22",
+ },
+ SSHProtocol,
+ does_not_raise(),
+ ),
+ ({"protocol": "local"}, LocalProtocol, does_not_raise()),
+ (
+ {"protocol": "something"},
+ None,
+ pytest.raises(Exception, match="Protocol not supported"),
+ ),
+ (None, None, pytest.raises(Exception, match="No protocol config provided")),
+ ],
+ )
+ def test_get_protocol(
+ self, config: Any, expected_class: type, exception: Any
+ ) -> None:
+ """Test get_protocol method."""
+ factory = ProtocolFactory()
+ with exception:
+ protocol = factory.get_protocol(config)
+ assert isinstance(protocol, expected_class)
+
+
+class TestLocalProtocol:
+ """Test local protocol."""
+
+ def test_local_protocol_run_command(self) -> None:
+ """Test local protocol run command."""
+ config = LocalProtocolConfig(protocol="local")
+ protocol = LocalProtocol(config, cwd=Path("/tmp"))
+ ret, stdout, stderr = protocol.run("pwd")
+ assert ret == 0
+ assert stdout.decode("utf-8").strip() == "/tmp"
+ assert stderr.decode("utf-8") == ""
+
+ def test_local_protocol_run_wrong_cwd(self) -> None:
+ """Execution should fail if wrong working directory provided."""
+ config = LocalProtocolConfig(protocol="local")
+ protocol = LocalProtocol(config, cwd=Path("unknown_directory"))
+ with pytest.raises(
+ ConfigurationException, match="Wrong working directory unknown_directory"
+ ):
+ protocol.run("pwd")
+
+
+class TestSSHProtocol:
+ """Test SSH protocol."""
+
+ @pytest.fixture(autouse=True)
+ def setup_method(self, monkeypatch: Any) -> None:
+ """Set up protocol mocks."""
+ self.mock_ssh_client = MagicMock(spec=paramiko.client.SSHClient)
+
+ self.mock_ssh_channel = (
+ self.mock_ssh_client.get_transport.return_value.open_session.return_value
+ )
+ self.mock_ssh_channel.mock_add_spec(spec=paramiko.channel.Channel)
+ self.mock_ssh_channel.exit_status_ready.side_effect = [False, True]
+ self.mock_ssh_channel.recv_exit_status.return_value = True
+ self.mock_ssh_channel.recv_ready.side_effect = [False, True]
+ self.mock_ssh_channel.recv_stderr_ready.side_effect = [False, True]
+
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.client.SSHClient",
+ MagicMock(return_value=self.mock_ssh_client),
+ )
+
+ self.mock_sftp_client = MagicMock(spec=CustomSFTPClient)
+ monkeypatch.setattr(
+ "aiet.backend.protocol.CustomSFTPClient.from_transport",
+ MagicMock(return_value=self.mock_sftp_client),
+ )
+
+ ssh_config = {
+ "protocol": "ssh",
+ "username": "user",
+ "password": "pass",
+ "hostname": "hostname",
+ "port": "22",
+ }
+ self.protocol = SSHProtocol(ssh_config)
+
+ def test_unable_create_ssh_client(self, monkeypatch: Any) -> None:
+ """Test that command should fail if unable to create ssh client instance."""
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.client.SSHClient",
+ MagicMock(side_effect=OSError("Error!")),
+ )
+
+ with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"):
+ self.protocol.run("command_example", retry=False)
+
+ def test_ssh_protocol_run_command(self) -> None:
+ """Test that command run via ssh successfully."""
+ self.protocol.run("command_example")
+ self.mock_ssh_channel.exec_command.assert_called_once()
+
+ def test_ssh_protocol_run_command_connect_failed(self) -> None:
+ """Test that if connection is not possible then correct exception is raised."""
+ self.mock_ssh_client.connect.side_effect = OSError("Unable to connect")
+ self.mock_ssh_client.close.side_effect = Exception("Error!")
+
+ with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"):
+ self.protocol.run("command_example", retry=False)
+
+ def test_ssh_protocol_run_command_bad_transport(self) -> None:
+ """Test that command should fail if unable to get transport."""
+ self.mock_ssh_client.get_transport.return_value = None
+
+ with pytest.raises(Exception, match="Unable to get transport"):
+ self.protocol.run("command_example", retry=False)
+
+ def test_ssh_protocol_deploy_command_file(
+ self, test_applications_path: Path
+ ) -> None:
+ """Test that files could be deployed over ssh."""
+ file_for_deploy = test_applications_path / "readme.txt"
+ dest = "/tmp/dest"
+
+ self.protocol.deploy(file_for_deploy, dest)
+ self.mock_sftp_client.put.assert_called_once_with(str(file_for_deploy), dest)
+
+ def test_ssh_protocol_deploy_command_unknown_file(self) -> None:
+ """Test that deploy will fail if file does not exist."""
+ with pytest.raises(Exception, match="Deploy error: file type not supported"):
+ self.protocol.deploy(Path("unknown_file"), "/tmp/dest")
+
+ def test_ssh_protocol_deploy_command_bad_transport(self) -> None:
+ """Test that deploy should fail if unable to get transport."""
+ self.mock_ssh_client.get_transport.return_value = None
+
+ with pytest.raises(Exception, match="Unable to get transport"):
+ self.protocol.deploy(Path("some_file"), "/tmp/dest")
+
+ def test_ssh_protocol_deploy_command_directory(
+ self, test_resources_path: Path
+ ) -> None:
+ """Test that directory could be deployed over ssh."""
+ directory_for_deploy = test_resources_path / "scripts"
+ dest = "/tmp/dest"
+
+ self.protocol.deploy(directory_for_deploy, dest)
+ self.mock_sftp_client.put_dir.assert_called_once_with(
+ directory_for_deploy, dest
+ )
+
+ @pytest.mark.parametrize("establish_connection", (True, False))
+ def test_ssh_protocol_close(self, establish_connection: bool) -> None:
+ """Test protocol close operation."""
+ if establish_connection:
+ self.protocol.establish_connection()
+ self.protocol.close()
+
+ call_count = 1 if establish_connection else 0
+ assert self.mock_ssh_channel.exec_command.call_count == call_count
+
+ def test_connection_details(self) -> None:
+ """Test getting connection details."""
+ assert self.protocol.connection_details() == ("hostname", 22)
+
+
+class TestCustomSFTPClient:
+ """Test CustomSFTPClient class."""
+
+ @pytest.fixture(autouse=True)
+ def setup_method(self, monkeypatch: Any) -> None:
+ """Set up mocks for CustomSFTPClient instance."""
+ self.mock_mkdir = MagicMock()
+ self.mock_put = MagicMock()
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.SFTPClient.__init__",
+ MagicMock(return_value=None),
+ )
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.SFTPClient.mkdir", self.mock_mkdir
+ )
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.SFTPClient.put", self.mock_put
+ )
+
+ self.sftp_client = CustomSFTPClient(MagicMock())
+
+ def test_put_dir(self, test_systems_path: Path) -> None:
+ """Test deploying directory to remote host."""
+ directory_for_deploy = test_systems_path / "system1"
+
+ self.sftp_client.put_dir(directory_for_deploy, "/tmp/dest")
+ assert self.mock_put.call_count == 3
+ assert self.mock_mkdir.call_count == 3
+
+ def test_mkdir(self) -> None:
+ """Test creating directory on remote host."""
+ self.mock_mkdir.side_effect = IOError("Cannot create directory")
+
+ self.sftp_client._mkdir("new_directory", ignore_existing=True)
+
+ with pytest.raises(IOError, match="Cannot create directory"):
+ self.sftp_client._mkdir("new_directory", ignore_existing=False)
diff --git a/tests/aiet/test_backend_source.py b/tests/aiet/test_backend_source.py
new file mode 100644
index 0000000..13b2c6d
--- /dev/null
+++ b/tests/aiet/test_backend_source.py
@@ -0,0 +1,199 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use
+"""Tests for the source backend module."""
+from collections import Counter
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+import pytest
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.source import create_destination_and_install
+from aiet.backend.source import DirectorySource
+from aiet.backend.source import get_source
+from aiet.backend.source import TarArchiveSource
+
+
+def test_create_destination_and_install(test_systems_path: Path, tmpdir: Any) -> None:
+ """Test create_destination_and_install function."""
+ system_directory = test_systems_path / "system1"
+
+ dir_source = DirectorySource(system_directory)
+ resources = Path(tmpdir)
+ create_destination_and_install(dir_source, resources)
+ assert (resources / "system1").is_dir()
+
+
+@patch("aiet.backend.source.DirectorySource.create_destination", return_value=False)
+def test_create_destination_and_install_if_dest_creation_not_required(
+ mock_ds_create_destination: Any, tmpdir: Any
+) -> None:
+ """Test create_destination_and_install function."""
+ dir_source = DirectorySource(Path("unknown"))
+ resources = Path(tmpdir)
+ with pytest.raises(Exception):
+ create_destination_and_install(dir_source, resources)
+
+ mock_ds_create_destination.assert_called_once()
+
+
+def test_create_destination_and_install_if_installation_fails(tmpdir: Any) -> None:
+ """Test create_destination_and_install function if installation fails."""
+ dir_source = DirectorySource(Path("unknown"))
+ resources = Path(tmpdir)
+ with pytest.raises(Exception, match="Directory .* does not exist"):
+ create_destination_and_install(dir_source, resources)
+ assert not (resources / "unknown").exists()
+ assert resources.exists()
+
+
+def test_create_destination_and_install_if_name_is_empty() -> None:
+ """Test create_destination_and_install function fails if source name is empty."""
+ source = MagicMock()
+ source.create_destination.return_value = True
+ source.name.return_value = None
+
+ with pytest.raises(Exception, match="Unable to get source name"):
+ create_destination_and_install(source, Path("some_path"))
+
+ source.install_into.assert_not_called()
+
+
+@pytest.mark.parametrize(
+ "source_path, expected_class, expected_error",
+ [
+ (Path("applications/application1/"), DirectorySource, does_not_raise()),
+ (
+ Path("archives/applications/application1.tar.gz"),
+ TarArchiveSource,
+ does_not_raise(),
+ ),
+ (
+ Path("doesnt/exist"),
+ None,
+ pytest.raises(
+ ConfigurationException, match="Unable to read .*doesnt/exist"
+ ),
+ ),
+ ],
+)
+def test_get_source(
+ source_path: Path,
+ expected_class: Any,
+ expected_error: Any,
+ test_resources_path: Path,
+) -> None:
+ """Test get_source function."""
+ with expected_error:
+ full_source_path = test_resources_path / source_path
+ source = get_source(full_source_path)
+ assert isinstance(source, expected_class)
+
+
+class TestDirectorySource:
+ """Test DirectorySource class."""
+
+ @pytest.mark.parametrize(
+ "directory, name",
+ [
+ (Path("/some/path/some_system"), "some_system"),
+ (Path("some_system"), "some_system"),
+ ],
+ )
+ def test_name(self, directory: Path, name: str) -> None:
+ """Test getting source name."""
+ assert DirectorySource(directory).name() == name
+
+ def test_install_into(self, test_systems_path: Path, tmpdir: Any) -> None:
+ """Test install directory into destination."""
+ system_directory = test_systems_path / "system1"
+
+ dir_source = DirectorySource(system_directory)
+ with pytest.raises(Exception, match="Wrong destination .*"):
+ dir_source.install_into(Path("unknown_destination"))
+
+ tmpdir_path = Path(tmpdir)
+ dir_source.install_into(tmpdir_path)
+ source_files = [f.name for f in system_directory.iterdir()]
+ dest_files = [f.name for f in tmpdir_path.iterdir()]
+ assert Counter(source_files) == Counter(dest_files)
+
+ def test_install_into_unknown_source_directory(self, tmpdir: Any) -> None:
+ """Test install system from unknown directory."""
+ with pytest.raises(Exception, match="Directory .* does not exist"):
+ DirectorySource(Path("unknown_directory")).install_into(Path(tmpdir))
+
+
+class TestTarArchiveSource:
+ """Test TarArchiveSource class."""
+
+ @pytest.mark.parametrize(
+ "archive, name",
+ [
+ (Path("some_archive.tgz"), "some_archive"),
+ (Path("some_archive.tar.gz"), "some_archive"),
+ (Path("some_archive"), "some_archive"),
+ ("archives/systems/system1.tar.gz", "system1"),
+ ("archives/systems/system1_dir.tar.gz", "system1"),
+ ],
+ )
+ def test_name(self, test_resources_path: Path, archive: Path, name: str) -> None:
+ """Test getting source name."""
+ assert TarArchiveSource(test_resources_path / archive).name() == name
+
+ def test_install_into(self, test_resources_path: Path, tmpdir: Any) -> None:
+ """Test install archive into destination."""
+ system_archive = test_resources_path / "archives/systems/system1.tar.gz"
+
+ tar_source = TarArchiveSource(system_archive)
+ with pytest.raises(Exception, match="Wrong destination .*"):
+ tar_source.install_into(Path("unknown_destination"))
+
+ tmpdir_path = Path(tmpdir)
+ tar_source.install_into(tmpdir_path)
+ source_files = [
+ "aiet-config.json.license",
+ "aiet-config.json",
+ "system_artifact",
+ ]
+ dest_files = [f.name for f in tmpdir_path.iterdir()]
+ assert Counter(source_files) == Counter(dest_files)
+
+ def test_install_into_unknown_source_archive(self, tmpdir: Any) -> None:
+ """Test install unknown source archive."""
+ with pytest.raises(Exception, match="File .* does not exist"):
+ TarArchiveSource(Path("unknown.tar.gz")).install_into(Path(tmpdir))
+
+ def test_install_into_unsupported_source_archive(self, tmpdir: Any) -> None:
+ """Test install unsupported file type."""
+ plain_text_file = Path(tmpdir) / "test_file"
+ plain_text_file.write_text("Not a system config")
+
+ with pytest.raises(Exception, match="Unsupported archive type .*"):
+ TarArchiveSource(plain_text_file).install_into(Path(tmpdir))
+
+ def test_lazy_property_init(self, test_resources_path: Path) -> None:
+ """Test that class properties initialized correctly."""
+ system_archive = test_resources_path / "archives/systems/system1.tar.gz"
+
+ tar_source = TarArchiveSource(system_archive)
+ assert tar_source.name() == "system1"
+ assert tar_source.config() is not None
+ assert tar_source.create_destination()
+
+ tar_source = TarArchiveSource(system_archive)
+ assert tar_source.config() is not None
+ assert tar_source.create_destination()
+ assert tar_source.name() == "system1"
+
+ def test_create_destination_property(self, test_resources_path: Path) -> None:
+ """Test create_destination property filled correctly for different archives."""
+ system_archive1 = test_resources_path / "archives/systems/system1.tar.gz"
+ system_archive2 = test_resources_path / "archives/systems/system1_dir.tar.gz"
+
+ assert TarArchiveSource(system_archive1).create_destination()
+ assert not TarArchiveSource(system_archive2).create_destination()
diff --git a/tests/aiet/test_backend_system.py b/tests/aiet/test_backend_system.py
new file mode 100644
index 0000000..a581547
--- /dev/null
+++ b/tests/aiet/test_backend_system.py
@@ -0,0 +1,536 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for system backend."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from unittest.mock import MagicMock
+
+import pytest
+
+from aiet.backend.common import Command
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import Param
+from aiet.backend.common import UserParamConfig
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.config import ProtocolConfig
+from aiet.backend.config import SSHConfig
+from aiet.backend.config import SystemConfig
+from aiet.backend.controller import SystemController
+from aiet.backend.controller import SystemControllerSingleInstance
+from aiet.backend.protocol import LocalProtocol
+from aiet.backend.protocol import SSHProtocol
+from aiet.backend.protocol import SupportsClose
+from aiet.backend.protocol import SupportsDeploy
+from aiet.backend.system import ControlledSystem
+from aiet.backend.system import get_available_systems
+from aiet.backend.system import get_controller
+from aiet.backend.system import get_system
+from aiet.backend.system import install_system
+from aiet.backend.system import load_system
+from aiet.backend.system import remove_system
+from aiet.backend.system import StandaloneSystem
+from aiet.backend.system import System
+
+
+def dummy_resolver(
+ values: Optional[Dict[str, str]] = None
+) -> Callable[[str, str, List[Tuple[Optional[str], Param]]], str]:
+ """Return dummy parameter resolver implementation."""
+ # pylint: disable=unused-argument
+ def resolver(
+ param: str, cmd: str, param_values: List[Tuple[Optional[str], Param]]
+ ) -> str:
+ """Implement dummy parameter resolver."""
+ return values.get(param, "") if values else ""
+
+ return resolver
+
+
+def test_get_available_systems() -> None:
+ """Test get_available_systems mocking get_resources."""
+ available_systems = get_available_systems()
+ assert all(isinstance(s, System) for s in available_systems)
+ assert len(available_systems) == 3
+ assert [str(s) for s in available_systems] == ["System 1", "System 2", "System 4"]
+
+
+def test_get_system() -> None:
+ """Test get_system."""
+ system1 = get_system("System 1")
+ assert isinstance(system1, ControlledSystem)
+ assert system1.connectable is True
+ assert system1.connection_details() == ("localhost", 8021)
+ assert system1.name == "System 1"
+
+ system2 = get_system("System 2")
+ # check that comparison with object of another type returns false
+ assert system1 != 42
+ assert system1 != system2
+
+ system = get_system("Unknown system")
+ assert system is None
+
+
+@pytest.mark.parametrize(
+ "source, call_count, exception_type",
+ (
+ (
+ "archives/systems/system1.tar.gz",
+ 0,
+ pytest.raises(Exception, match="Systems .* are already installed"),
+ ),
+ (
+ "archives/systems/system3.tar.gz",
+ 0,
+ pytest.raises(Exception, match="Unable to read system definition"),
+ ),
+ (
+ "systems/system1",
+ 0,
+ pytest.raises(Exception, match="Systems .* are already installed"),
+ ),
+ (
+ "systems/system3",
+ 0,
+ pytest.raises(Exception, match="Unable to read system definition"),
+ ),
+ ("unknown_path", 0, pytest.raises(Exception, match="Unable to read")),
+ (
+ "various/systems/system_with_empty_config",
+ 0,
+ pytest.raises(Exception, match="No system definition found"),
+ ),
+ ("various/systems/system_with_valid_config", 1, does_not_raise()),
+ ),
+)
+def test_install_system(
+ monkeypatch: Any,
+ test_resources_path: Path,
+ source: str,
+ call_count: int,
+ exception_type: Any,
+) -> None:
+ """Test system installation from archive."""
+ mock_create_destination_and_install = MagicMock()
+ monkeypatch.setattr(
+ "aiet.backend.system.create_destination_and_install",
+ mock_create_destination_and_install,
+ )
+
+ with exception_type:
+ install_system(test_resources_path / source)
+
+ assert mock_create_destination_and_install.call_count == call_count
+
+
+def test_remove_system(monkeypatch: Any) -> None:
+ """Test system removal."""
+ mock_remove_backend = MagicMock()
+ monkeypatch.setattr("aiet.backend.system.remove_backend", mock_remove_backend)
+ remove_system("some_system_dir")
+ mock_remove_backend.assert_called_once()
+
+
+def test_system(monkeypatch: Any) -> None:
+ """Test the System class."""
+ config = SystemConfig(name="System 1")
+ monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock())
+ system = System(config)
+ assert str(system) == "System 1"
+ assert system.name == "System 1"
+
+
+def test_system_with_empty_parameter_name() -> None:
+ """Test that configuration fails if parameter name is empty."""
+ bad_config = SystemConfig(
+ name="System 1",
+ commands={"run": ["run"]},
+ user_params={"run": [{"name": "", "values": ["1", "2", "3"]}]},
+ )
+ with pytest.raises(Exception, match="Parameter has an empty 'name' attribute."):
+ System(bad_config)
+
+
+def test_system_standalone_run() -> None:
+ """Test run operation for standalone system."""
+ system = get_system("System 4")
+ assert isinstance(system, StandaloneSystem)
+
+ with pytest.raises(
+ ConfigurationException, match="System .* does not support connections"
+ ):
+ system.connection_details()
+
+ with pytest.raises(
+ ConfigurationException, match="System .* does not support connections"
+ ):
+ system.establish_connection()
+
+ assert system.connectable is False
+
+ system.run("echo 'application run'")
+
+
+@pytest.mark.parametrize(
+ "system_name, expected_value", [("System 1", True), ("System 4", False)]
+)
+def test_system_supports_deploy(system_name: str, expected_value: bool) -> None:
+ """Test system property supports_deploy."""
+ system = get_system(system_name)
+ if system is None:
+ pytest.fail("Unable to get system {}".format(system_name))
+ assert system.supports_deploy == expected_value
+
+
+@pytest.mark.parametrize(
+ "mock_protocol",
+ [
+ MagicMock(spec=SSHProtocol),
+ MagicMock(
+ spec=SSHProtocol,
+ **{"close.side_effect": ValueError("Unable to close protocol")}
+ ),
+ MagicMock(spec=LocalProtocol),
+ ],
+)
+def test_system_start_and_stop(monkeypatch: Any, mock_protocol: MagicMock) -> None:
+ """Test system start, run commands and stop."""
+ monkeypatch.setattr(
+ "aiet.backend.system.ProtocolFactory.get_protocol",
+ MagicMock(return_value=mock_protocol),
+ )
+
+ system = get_system("System 1")
+ if system is None:
+ pytest.fail("Unable to get system")
+ assert isinstance(system, ControlledSystem)
+
+ with pytest.raises(Exception, match="System has not been started"):
+ system.stop()
+
+ assert not system.is_running()
+ assert system.get_output() == ("", "")
+ system.start(["sleep 10"], False)
+ assert system.is_running()
+ system.stop(wait=True)
+ assert not system.is_running()
+ assert system.get_output() == ("", "")
+
+ if isinstance(mock_protocol, SupportsClose):
+ mock_protocol.close.assert_called_once()
+
+ if isinstance(mock_protocol, SSHProtocol):
+ system.establish_connection()
+
+
+def test_system_start_no_config_location() -> None:
+ """Test that system without config location could not start."""
+ system = load_system(
+ SystemConfig(
+ name="test",
+ data_transfer=SSHConfig(
+ protocol="ssh",
+ username="user",
+ password="user",
+ hostname="localhost",
+ port="123",
+ ),
+ )
+ )
+
+ assert isinstance(system, ControlledSystem)
+ with pytest.raises(
+ ConfigurationException, match="System test has wrong config location"
+ ):
+ system.start(["sleep 100"])
+
+
+@pytest.mark.parametrize(
+ "config, expected_class, expected_error",
+ [
+ (
+ SystemConfig(
+ name="test",
+ data_transfer=SSHConfig(
+ protocol="ssh",
+ username="user",
+ password="user",
+ hostname="localhost",
+ port="123",
+ ),
+ ),
+ ControlledSystem,
+ does_not_raise(),
+ ),
+ (
+ SystemConfig(
+ name="test", data_transfer=LocalProtocolConfig(protocol="local")
+ ),
+ StandaloneSystem,
+ does_not_raise(),
+ ),
+ (
+ SystemConfig(
+ name="test",
+ data_transfer=ProtocolConfig(protocol="cool_protocol"), # type: ignore
+ ),
+ None,
+ pytest.raises(
+ Exception, match="Unsupported execution type for protocol cool_protocol"
+ ),
+ ),
+ ],
+)
+def test_load_system(
+ config: SystemConfig, expected_class: type, expected_error: Any
+) -> None:
+ """Test load_system function."""
+ if not expected_class:
+ with expected_error:
+ load_system(config)
+ else:
+ system = load_system(config)
+ assert isinstance(system, expected_class)
+
+
+def test_load_system_populate_shared_params() -> None:
+ """Test shared parameters population."""
+ with pytest.raises(Exception, match="All shared parameters should have aliases"):
+ load_system(
+ SystemConfig(
+ name="test_system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ user_params={
+ "shared": [
+ UserParamConfig(
+ name="--shared_param1",
+ description="Shared parameter",
+ values=["1", "2", "3"],
+ default_value="1",
+ )
+ ]
+ },
+ )
+ )
+
+ with pytest.raises(
+ Exception, match="All parameters for command run should have aliases"
+ ):
+ load_system(
+ SystemConfig(
+ name="test_system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ user_params={
+ "shared": [
+ UserParamConfig(
+ name="--shared_param1",
+ description="Shared parameter",
+ values=["1", "2", "3"],
+ default_value="1",
+ alias="shared_param1",
+ )
+ ],
+ "run": [
+ UserParamConfig(
+ name="--run_param1",
+ description="Run specific parameter",
+ values=["1", "2", "3"],
+ default_value="2",
+ )
+ ],
+ },
+ )
+ )
+ system0 = load_system(
+ SystemConfig(
+ name="test_system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["run_command"]},
+ user_params={
+ "shared": [],
+ "run": [
+ UserParamConfig(
+ name="--run_param1",
+ description="Run specific parameter",
+ values=["1", "2", "3"],
+ default_value="2",
+ alias="run_param1",
+ )
+ ],
+ },
+ )
+ )
+ assert len(system0.commands) == 1
+ run_command1 = system0.commands["run"]
+ assert run_command1 == Command(
+ ["run_command"],
+ [
+ Param(
+ "--run_param1",
+ "Run specific parameter",
+ ["1", "2", "3"],
+ "2",
+ "run_param1",
+ )
+ ],
+ )
+
+ system1 = load_system(
+ SystemConfig(
+ name="test_system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ user_params={
+ "shared": [
+ UserParamConfig(
+ name="--shared_param1",
+ description="Shared parameter",
+ values=["1", "2", "3"],
+ default_value="1",
+ alias="shared_param1",
+ )
+ ],
+ "run": [
+ UserParamConfig(
+ name="--run_param1",
+ description="Run specific parameter",
+ values=["1", "2", "3"],
+ default_value="2",
+ alias="run_param1",
+ )
+ ],
+ },
+ )
+ )
+ assert len(system1.commands) == 2
+ build_command1 = system1.commands["build"]
+ assert build_command1 == Command(
+ [],
+ [
+ Param(
+ "--shared_param1",
+ "Shared parameter",
+ ["1", "2", "3"],
+ "1",
+ "shared_param1",
+ )
+ ],
+ )
+
+ run_command1 = system1.commands["run"]
+ assert run_command1 == Command(
+ [],
+ [
+ Param(
+ "--shared_param1",
+ "Shared parameter",
+ ["1", "2", "3"],
+ "1",
+ "shared_param1",
+ ),
+ Param(
+ "--run_param1",
+ "Run specific parameter",
+ ["1", "2", "3"],
+ "2",
+ "run_param1",
+ ),
+ ],
+ )
+
+ system2 = load_system(
+ SystemConfig(
+ name="test_system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"build": ["build_command"]},
+ user_params={
+ "shared": [
+ UserParamConfig(
+ name="--shared_param1",
+ description="Shared parameter",
+ values=["1", "2", "3"],
+ default_value="1",
+ alias="shared_param1",
+ )
+ ],
+ "run": [
+ UserParamConfig(
+ name="--run_param1",
+ description="Run specific parameter",
+ values=["1", "2", "3"],
+ default_value="2",
+ alias="run_param1",
+ )
+ ],
+ },
+ )
+ )
+ assert len(system2.commands) == 2
+ build_command2 = system2.commands["build"]
+ assert build_command2 == Command(
+ ["build_command"],
+ [
+ Param(
+ "--shared_param1",
+ "Shared parameter",
+ ["1", "2", "3"],
+ "1",
+ "shared_param1",
+ )
+ ],
+ )
+
+ run_command2 = system1.commands["run"]
+ assert run_command2 == Command(
+ [],
+ [
+ Param(
+ "--shared_param1",
+ "Shared parameter",
+ ["1", "2", "3"],
+ "1",
+ "shared_param1",
+ ),
+ Param(
+ "--run_param1",
+ "Run specific parameter",
+ ["1", "2", "3"],
+ "2",
+ "run_param1",
+ ),
+ ],
+ )
+
+
+@pytest.mark.parametrize(
+ "mock_protocol, expected_call_count",
+ [(MagicMock(spec=SupportsDeploy), 1), (MagicMock(), 0)],
+)
+def test_system_deploy_data(
+ monkeypatch: Any, mock_protocol: MagicMock, expected_call_count: int
+) -> None:
+ """Test deploy data functionality."""
+ monkeypatch.setattr(
+ "aiet.backend.system.ProtocolFactory.get_protocol",
+ MagicMock(return_value=mock_protocol),
+ )
+
+ system = ControlledSystem(SystemConfig(name="test"))
+ system.deploy(Path("some_file"), "some_dest")
+
+ assert mock_protocol.deploy.call_count == expected_call_count
+
+
+@pytest.mark.parametrize(
+ "single_instance, controller_class",
+ ((False, SystemController), (True, SystemControllerSingleInstance)),
+)
+def test_get_controller(single_instance: bool, controller_class: type) -> None:
+ """Test function get_controller."""
+ controller = get_controller(single_instance)
+ assert isinstance(controller, controller_class)
diff --git a/tests/aiet/test_backend_tool.py b/tests/aiet/test_backend_tool.py
new file mode 100644
index 0000000..fd5960d
--- /dev/null
+++ b/tests/aiet/test_backend_tool.py
@@ -0,0 +1,60 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use
+"""Tests for the tool backend."""
+from collections import Counter
+
+import pytest
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.config import ToolConfig
+from aiet.backend.tool import get_available_tool_directory_names
+from aiet.backend.tool import get_available_tools
+from aiet.backend.tool import get_tool
+from aiet.backend.tool import Tool
+
+
+def test_get_available_tool_directory_names() -> None:
+ """Test get_available_tools mocking get_resources."""
+ directory_names = get_available_tool_directory_names()
+ assert Counter(directory_names) == Counter(["tool1", "tool2", "vela"])
+
+
+def test_get_available_tools() -> None:
+ """Test get_available_tools mocking get_resources."""
+ available_tools = get_available_tools()
+ expected_tool_names = sorted(
+ [
+ "tool_1",
+ "tool_2",
+ "vela",
+ "vela",
+ "vela",
+ ]
+ )
+
+ assert all(isinstance(s, Tool) for s in available_tools)
+ assert all(s != 42 for s in available_tools)
+ assert any(s == available_tools[0] for s in available_tools)
+ assert len(available_tools) == len(expected_tool_names)
+ available_tool_names = sorted(str(s) for s in available_tools)
+ assert available_tool_names == expected_tool_names
+
+
+def test_get_tool() -> None:
+ """Test get_tool mocking get_resoures."""
+ tools = get_tool("tool_1")
+ assert len(tools) == 1
+ tool = tools[0]
+ assert tool is not None
+ assert isinstance(tool, Tool)
+ assert tool.name == "tool_1"
+
+ tools = get_tool("unknown tool")
+ assert not tools
+
+
+def test_tool_creation() -> None:
+ """Test edge cases when creating a Tool instance."""
+ with pytest.raises(ConfigurationException):
+ Tool(ToolConfig(name="test", commands={"test": []})) # no 'run' command
diff --git a/tests/aiet/test_check_model.py b/tests/aiet/test_check_model.py
new file mode 100644
index 0000000..4eafe59
--- /dev/null
+++ b/tests/aiet/test_check_model.py
@@ -0,0 +1,162 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=redefined-outer-name,no-self-use
+"""Module for testing check_model.py script."""
+from pathlib import Path
+from typing import Any
+
+import pytest
+from ethosu.vela.tflite.Model import Model
+from ethosu.vela.tflite.OperatorCode import OperatorCode
+
+from aiet.cli.common import InvalidTFLiteFileError
+from aiet.cli.common import ModelOptimisedException
+from aiet.resources.tools.vela.check_model import check_custom_codes_for_ethosu
+from aiet.resources.tools.vela.check_model import check_model
+from aiet.resources.tools.vela.check_model import get_custom_codes_from_operators
+from aiet.resources.tools.vela.check_model import get_model_from_file
+from aiet.resources.tools.vela.check_model import get_operators_from_model
+from aiet.resources.tools.vela.check_model import is_vela_optimised
+
+
+@pytest.fixture(scope="session")
+def optimised_tflite_model(
+ optimised_input_model_file: Path,
+) -> Model:
+ """Return Model instance read from a Vela-optimised TFLite file."""
+ return get_model_from_file(optimised_input_model_file)
+
+
+@pytest.fixture(scope="session")
+def non_optimised_tflite_model(
+ non_optimised_input_model_file: Path,
+) -> Model:
+ """Return Model instance read from a Vela-optimised TFLite file."""
+ return get_model_from_file(non_optimised_input_model_file)
+
+
+class TestIsVelaOptimised:
+ """Test class for is_vela_optimised() function."""
+
+ def test_return_true_when_input_is_optimised(
+ self,
+ optimised_tflite_model: Model,
+ ) -> None:
+ """Verify True returned when input is optimised model."""
+ output = is_vela_optimised(optimised_tflite_model)
+
+ assert output is True
+
+ def test_return_false_when_input_is_not_optimised(
+ self,
+ non_optimised_tflite_model: Model,
+ ) -> None:
+ """Verify False returned when input is non-optimised model."""
+ output = is_vela_optimised(non_optimised_tflite_model)
+
+ assert output is False
+
+
+def test_get_operator_list_returns_correct_instances(
+ optimised_tflite_model: Model,
+) -> None:
+ """Verify list of OperatorCode instances returned by get_operator_list()."""
+ operator_list = get_operators_from_model(optimised_tflite_model)
+
+ assert all(isinstance(operator, OperatorCode) for operator in operator_list)
+
+
+class TestGetCustomCodesFromOperators:
+ """Test the get_custom_codes_from_operators() function."""
+
+ def test_returns_empty_list_when_input_operators_have_no_custom_codes(
+ self, monkeypatch: Any
+ ) -> None:
+ """Verify function returns empty list when operators have no custom codes."""
+ # Mock OperatorCode.CustomCode() function to return None
+ monkeypatch.setattr(
+ "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", lambda _: None
+ )
+
+ operators = [OperatorCode()] * 3
+
+ custom_codes = get_custom_codes_from_operators(operators)
+
+ assert custom_codes == []
+
+ def test_returns_custom_codes_when_input_operators_have_custom_codes(
+ self, monkeypatch: Any
+ ) -> None:
+ """Verify list of bytes objects returned representing the CustomCodes."""
+ # Mock OperatorCode.CustomCode() function to return a byte string
+ monkeypatch.setattr(
+ "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode",
+ lambda _: b"custom-code",
+ )
+
+ operators = [OperatorCode()] * 3
+
+ custom_codes = get_custom_codes_from_operators(operators)
+
+ assert custom_codes == [b"custom-code", b"custom-code", b"custom-code"]
+
+
+@pytest.mark.parametrize(
+ "custom_codes, expected_output",
+ [
+ ([b"ethos-u", b"something else"], True),
+ ([b"custom-code-1", b"custom-code-2"], False),
+ ],
+)
+def test_check_list_for_ethosu(custom_codes: list, expected_output: bool) -> None:
+ """Verify function detects 'ethos-u' bytes in the input list."""
+ output = check_custom_codes_for_ethosu(custom_codes)
+ assert output is expected_output
+
+
+class TestGetModelFromFile:
+ """Test the get_model_from_file() function."""
+
+ def test_error_raised_when_input_is_invalid_model_file(
+ self,
+ invalid_input_model_file: Path,
+ ) -> None:
+ """Verify error thrown when an invalid model file is given."""
+ with pytest.raises(InvalidTFLiteFileError):
+ get_model_from_file(invalid_input_model_file)
+
+ def test_model_instance_returned_when_input_is_valid_model_file(
+ self,
+ optimised_input_model_file: Path,
+ ) -> None:
+ """Verify file is read successfully and returns model instance."""
+ tflite_model = get_model_from_file(optimised_input_model_file)
+
+ assert isinstance(tflite_model, Model)
+
+
+class TestCheckModel:
+ """Test the check_model() function."""
+
+ def test_check_model_with_non_optimised_input(
+ self,
+ non_optimised_input_model_file: Path,
+ ) -> None:
+ """Verify no error occurs for a valid input file."""
+ check_model(non_optimised_input_model_file)
+
+ def test_check_model_with_optimised_input(
+ self,
+ optimised_input_model_file: Path,
+ ) -> None:
+ """Verify that the right exception is raised with already optimised input."""
+ with pytest.raises(ModelOptimisedException):
+ check_model(optimised_input_model_file)
+
+ def test_check_model_with_invalid_input(
+ self,
+ invalid_input_model_file: Path,
+ ) -> None:
+ """Verify that an exception is raised with invalid input."""
+ with pytest.raises(Exception):
+ check_model(invalid_input_model_file)
diff --git a/tests/aiet/test_cli.py b/tests/aiet/test_cli.py
new file mode 100644
index 0000000..e8589fa
--- /dev/null
+++ b/tests/aiet/test_cli.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for testing CLI top command."""
+from typing import Any
+from unittest.mock import ANY
+from unittest.mock import MagicMock
+
+from click.testing import CliRunner
+
+from aiet.cli import cli
+
+
+def test_cli(cli_runner: CliRunner) -> None:
+ """Test CLI top level command."""
+ result = cli_runner.invoke(cli)
+ assert result.exit_code == 0
+ assert "system" in cli.commands
+ assert "application" in cli.commands
+
+
+def test_cli_version(cli_runner: CliRunner) -> None:
+ """Test version option."""
+ result = cli_runner.invoke(cli, ["--version"])
+ assert result.exit_code == 0
+ assert "version" in result.output
+
+
+def test_cli_verbose(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test verbose option."""
+ with monkeypatch.context() as mock_context:
+ mock = MagicMock()
+ # params[1] is the verbose option and we need to replace the
+ # callback with a mock object
+ mock_context.setattr(cli.params[1], "callback", mock)
+ cli_runner.invoke(cli, ["-vvvv"])
+ # 4 is the number -v called earlier
+ mock.assert_called_once_with(ANY, ANY, 4)
diff --git a/tests/aiet/test_cli_application.py b/tests/aiet/test_cli_application.py
new file mode 100644
index 0000000..f1ccc44
--- /dev/null
+++ b/tests/aiet/test_cli_application.py
@@ -0,0 +1,1153 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=attribute-defined-outside-init,no-member,line-too-long,too-many-arguments,too-many-locals,redefined-outer-name,too-many-lines
+"""Module for testing CLI application subcommand."""
+import base64
+import json
+import re
+import time
+from contextlib import contextmanager
+from contextlib import ExitStack
+from pathlib import Path
+from typing import Any
+from typing import Generator
+from typing import IO
+from typing import List
+from typing import Optional
+from typing import TypedDict
+from unittest.mock import MagicMock
+
+import click
+import pytest
+from click.testing import CliRunner
+from filelock import FileLock
+
+from aiet.backend.application import Application
+from aiet.backend.config import ApplicationConfig
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.config import SSHConfig
+from aiet.backend.config import SystemConfig
+from aiet.backend.config import UserParamConfig
+from aiet.backend.output_parser import Base64OutputParser
+from aiet.backend.protocol import SSHProtocol
+from aiet.backend.system import load_system
+from aiet.cli.application import application_cmd
+from aiet.cli.application import details_cmd
+from aiet.cli.application import execute_cmd
+from aiet.cli.application import install_cmd
+from aiet.cli.application import list_cmd
+from aiet.cli.application import parse_payload_run_config
+from aiet.cli.application import remove_cmd
+from aiet.cli.application import run_cmd
+from aiet.cli.common import MiddlewareExitCode
+
+
+def test_application_cmd() -> None:
+ """Test application commands."""
+ commands = ["list", "details", "install", "remove", "execute", "run"]
+ assert all(command in application_cmd.commands for command in commands)
+
+
+@pytest.mark.parametrize("format_", ["json", "cli"])
+def test_application_cmd_context(cli_runner: CliRunner, format_: str) -> None:
+ """Test setting command context parameters."""
+ result = cli_runner.invoke(application_cmd, ["--format", format_])
+ # command should fail if no subcommand provided
+ assert result.exit_code == 2
+
+ result = cli_runner.invoke(application_cmd, ["--format", format_, "list"])
+ assert result.exit_code == 0
+
+
+@pytest.mark.parametrize(
+ "format_, system_name, expected_output",
+ [
+ (
+ "json",
+ None,
+ '{"type": "application", "available": ["application_1", "application_2"]}\n',
+ ),
+ (
+ "json",
+ "system_1",
+ '{"type": "application", "available": ["application_1"]}\n',
+ ),
+ ("cli", None, "Available applications:\n\napplication_1\napplication_2\n"),
+ ("cli", "system_1", "Available applications:\n\napplication_1\n"),
+ ],
+)
+def test_list_cmd(
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ format_: str,
+ system_name: str,
+ expected_output: str,
+) -> None:
+ """Test available applications commands."""
+ # Mock some applications
+ mock_application_1 = MagicMock(spec=Application)
+ mock_application_1.name = "application_1"
+ mock_application_1.can_run_on.return_value = system_name == "system_1"
+ mock_application_2 = MagicMock(spec=Application)
+ mock_application_2.name = "application_2"
+ mock_application_2.can_run_on.return_value = system_name == "system_2"
+
+ # Monkey patch the call get_available_applications
+ mock_available_applications = MagicMock()
+ mock_available_applications.return_value = [mock_application_1, mock_application_2]
+
+ monkeypatch.setattr(
+ "aiet.backend.application.get_available_applications",
+ mock_available_applications,
+ )
+
+ obj = {"format": format_}
+ args = []
+ if system_name:
+ list_cmd.params[0].type = click.Choice([system_name])
+ args = ["--system", system_name]
+ result = cli_runner.invoke(list_cmd, obj=obj, args=args)
+ assert result.output == expected_output
+
+
+def get_test_application() -> Application:
+ """Return test system details."""
+ config = ApplicationConfig(
+ name="application",
+ description="test",
+ build_dir="",
+ supported_systems=[],
+ deploy_data=[],
+ user_params={},
+ commands={
+ "clean": ["clean"],
+ "build": ["build"],
+ "run": ["run"],
+ "post_run": ["post_run"],
+ },
+ )
+
+ return Application(config)
+
+
+def get_details_cmd_json_output() -> str:
+ """Get JSON output for details command."""
+ json_output = """
+[
+ {
+ "type": "application",
+ "name": "application",
+ "description": "test",
+ "supported_systems": [],
+ "commands": {
+ "clean": {
+ "command_strings": [
+ "clean"
+ ],
+ "user_params": []
+ },
+ "build": {
+ "command_strings": [
+ "build"
+ ],
+ "user_params": []
+ },
+ "run": {
+ "command_strings": [
+ "run"
+ ],
+ "user_params": []
+ },
+ "post_run": {
+ "command_strings": [
+ "post_run"
+ ],
+ "user_params": []
+ }
+ }
+ }
+]"""
+ return json.dumps(json.loads(json_output)) + "\n"
+
+
+def get_details_cmd_console_output() -> str:
+ """Get console output for details command."""
+ return (
+ 'Application "application" details'
+ + "\nDescription: test"
+ + "\n\nSupported systems: "
+ + "\n\nclean commands:"
+ + "\nCommands: ['clean']"
+ + "\n\nbuild commands:"
+ + "\nCommands: ['build']"
+ + "\n\nrun commands:"
+ + "\nCommands: ['run']"
+ + "\n\npost_run commands:"
+ + "\nCommands: ['post_run']"
+ + "\n"
+ )
+
+
+@pytest.mark.parametrize(
+ "application_name,format_, expected_output",
+ [
+ ("application", "json", get_details_cmd_json_output()),
+ ("application", "cli", get_details_cmd_console_output()),
+ ],
+)
+def test_details_cmd(
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ application_name: str,
+ format_: str,
+ expected_output: str,
+) -> None:
+ """Test application details command."""
+ monkeypatch.setattr(
+ "aiet.cli.application.get_application",
+ MagicMock(return_value=[get_test_application()]),
+ )
+
+ details_cmd.params[0].type = click.Choice(["application"])
+ result = cli_runner.invoke(
+ details_cmd, obj={"format": format_}, args=["--name", application_name]
+ )
+ assert result.exception is None
+ assert result.output == expected_output
+
+
+def test_details_cmd_wrong_system(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test details command fails if application is not supported by the system."""
+ monkeypatch.setattr(
+ "aiet.backend.execution.get_application", MagicMock(return_value=[])
+ )
+
+ details_cmd.params[0].type = click.Choice(["application"])
+ details_cmd.params[1].type = click.Choice(["system"])
+ result = cli_runner.invoke(
+ details_cmd, args=["--name", "application", "--system", "system"]
+ )
+ assert result.exit_code == 2
+ assert (
+ "Application 'application' doesn't support the system 'system'" in result.stdout
+ )
+
+
+def test_install_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test install application command."""
+ mock_install_application = MagicMock()
+ monkeypatch.setattr(
+ "aiet.cli.application.install_application", mock_install_application
+ )
+
+ args = ["--source", "test"]
+ cli_runner.invoke(install_cmd, args=args)
+ mock_install_application.assert_called_once_with(Path("test"))
+
+
+def test_remove_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test remove application command."""
+ mock_remove_application = MagicMock()
+ monkeypatch.setattr(
+ "aiet.cli.application.remove_application", mock_remove_application
+ )
+ remove_cmd.params[0].type = click.Choice(["test"])
+
+ args = ["--directory_name", "test"]
+ cli_runner.invoke(remove_cmd, args=args)
+ mock_remove_application.assert_called_once_with("test")
+
+
+class ExecutionCase(TypedDict, total=False):
+ """Execution case."""
+
+ args: List[str]
+ lock_path: str
+ can_establish_connection: bool
+ establish_connection_delay: int
+ app_exit_code: int
+ exit_code: int
+ output: str
+
+
+@pytest.mark.parametrize(
+ "application_config, system_config, executions",
+ [
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ config_location=Path("wrong_location"),
+ commands={"build": ["echo build {application.name}"]},
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ config_location=Path("wrong_location"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "build"],
+ exit_code=MiddlewareExitCode.CONFIGURATION_ERROR,
+ output="Error: Application test_application has wrong config location\n",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ build_dir="build",
+ deploy_data=[("sample_file", "/tmp/sample_file")],
+ commands={"build": ["echo build {application.name}"]},
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "run"],
+ exit_code=MiddlewareExitCode.CONFIGURATION_ERROR,
+ output="Error: System test_system does not support data deploy\n",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ commands={"build": ["echo build {application.name}"]},
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "build"],
+ exit_code=MiddlewareExitCode.CONFIGURATION_ERROR,
+ output="Error: No build directory defined for the app test_application\n",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["new_system"],
+ build_dir="build",
+ commands={
+ "build": ["echo build {application.name} with {user_params:0}"]
+ },
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="param",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "build"],
+ exit_code=1,
+ output="Error: Application 'test_application' doesn't support the system 'test_system'\n",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ build_dir="build",
+ commands={"build": ["false"]},
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "build"],
+ exit_code=MiddlewareExitCode.BACKEND_ERROR,
+ output="""Running: false
+Error: Execution failed. Please check output for the details.\n""",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ lock=True,
+ build_dir="build",
+ commands={
+ "build": ["echo build {application.name} with {user_params:0}"]
+ },
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="param",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ lock=True,
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "build"],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Running: echo build test_application with param default
+build test_application with param default\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "build"],
+ lock_path="/tmp/middleware_test_application_test_system.lock",
+ exit_code=MiddlewareExitCode.CONCURRENT_ERROR,
+ output="Error: Another instance of the system is running\n",
+ ),
+ ExecutionCase(
+ args=["-c", "build", "--param=param=val3"],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Running: echo build test_application with param val3
+build test_application with param val3\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "build", "--param=param=newval"],
+ exit_code=1,
+ output="Error: Application parameter 'param=newval' not valid for command 'build'\n",
+ ),
+ ExecutionCase(
+ args=["-c", "some_command"],
+ exit_code=MiddlewareExitCode.CONFIGURATION_ERROR,
+ output="Error: Unsupported command some_command\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run"],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Generating commands to execute
+Running: echo run test_application on test_system
+run test_application on test_system\n""",
+ ),
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ deploy_data=[("sample_file", "/tmp/sample_file")],
+ commands={
+ "run": [
+ "echo run {application.name} with {user_params:param} on {system.name}"
+ ]
+ },
+ user_params={
+ "run": [
+ UserParamConfig(
+ name="param=",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ alias="param",
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ lock=True,
+ data_transfer=SSHConfig(
+ protocol="ssh",
+ username="username",
+ password="password",
+ hostname="localhost",
+ port="8022",
+ ),
+ commands={"run": ["sleep 100"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "run"],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .
+Deploying {application.config_location}/sample_file onto /tmp/sample_file
+Running: echo run test_application with param=default on test_system
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "run"],
+ lock_path="/tmp/middleware_test_system.lock",
+ exit_code=MiddlewareExitCode.CONCURRENT_ERROR,
+ output="Error: Another instance of the system is running\n",
+ ),
+ ExecutionCase(
+ args=[
+ "-c",
+ "run",
+ "--deploy={application.config_location}/sample_file:/tmp/sample_file",
+ ],
+ exit_code=0,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .
+Deploying {application.config_location}/sample_file onto /tmp/sample_file
+Deploying {application.config_location}/sample_file onto /tmp/sample_file
+Running: echo run test_application with param=default on test_system
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "run"],
+ app_exit_code=1,
+ exit_code=0,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .
+Deploying {application.config_location}/sample_file onto /tmp/sample_file
+Running: echo run test_application with param=default on test_system
+Application exited with exit code 1
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "run"],
+ exit_code=MiddlewareExitCode.CONNECTION_ERROR,
+ can_establish_connection=False,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds ..........................................................................................
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.
+Error: Couldn't connect to 'localhost:8022'.\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy=bad_format"],
+ exit_code=1,
+ output="Error: Invalid deploy parameter 'bad_format' for command run\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy=:"],
+ exit_code=1,
+ output="Error: Invalid deploy parameter ':' for command run\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy= : "],
+ exit_code=1,
+ output="Error: Invalid deploy parameter ' : ' for command run\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy=some_src_file:"],
+ exit_code=1,
+ output="Error: Invalid deploy parameter 'some_src_file:' for command run\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy=:some_dst_file"],
+ exit_code=1,
+ output="Error: Invalid deploy parameter ':some_dst_file' for command run\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy=unknown_file:/tmp/dest"],
+ exit_code=1,
+ output="Error: Path unknown_file does not exist\n",
+ ),
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ commands={
+ "run": [
+ "echo run {application.name} with {user_params:param} on {system.name}"
+ ]
+ },
+ user_params={
+ "run": [
+ UserParamConfig(
+ name="param=",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ alias="param",
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=SSHConfig(
+ protocol="ssh",
+ username="username",
+ password="password",
+ hostname="localhost",
+ port="8022",
+ ),
+ commands={"run": ["echo Unable to start system"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "run"],
+ exit_code=4,
+ can_establish_connection=False,
+ establish_connection_delay=1,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .
+
+---------- test_system execution failed ----------
+Unable to start system
+
+
+
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.
+Error: Execution failed. Please check output for the details.\n""",
+ )
+ ],
+ ],
+ ],
+)
+def test_application_command_execution(
+ application_config: ApplicationConfig,
+ system_config: SystemConfig,
+ executions: List[ExecutionCase],
+ tmpdir: Any,
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+) -> None:
+ """Test application command execution."""
+
+ @contextmanager
+ def lock_execution(lock_path: str) -> Generator[None, None, None]:
+ lock = FileLock(lock_path)
+ lock.acquire(timeout=1)
+
+ try:
+ yield
+ finally:
+ lock.release()
+
+ def replace_vars(str_val: str) -> str:
+ """Replace variables."""
+ application_config_location = str(
+ application_config["config_location"].absolute()
+ )
+
+ return str_val.replace(
+ "{application.config_location}", application_config_location
+ )
+
+ for execution in executions:
+ init_execution_test(
+ monkeypatch,
+ tmpdir,
+ application_config,
+ system_config,
+ can_establish_connection=execution.get("can_establish_connection", True),
+ establish_conection_delay=execution.get("establish_connection_delay", 0),
+ remote_app_exit_code=execution.get("app_exit_code", 0),
+ )
+
+ lock_path = execution.get("lock_path")
+
+ with ExitStack() as stack:
+ if lock_path:
+ stack.enter_context(lock_execution(lock_path))
+
+ args = [replace_vars(arg) for arg in execution["args"]]
+
+ result = cli_runner.invoke(
+ execute_cmd,
+ args=["-n", application_config["name"], "-s", system_config["name"]]
+ + args,
+ )
+ output = replace_vars(execution["output"])
+ assert result.exit_code == execution["exit_code"]
+ assert result.stdout == output
+
+
+@pytest.fixture(params=[False, True], ids=["run-cli", "run-json"])
+def payload_path_or_none(request: Any, tmp_path_factory: Any) -> Optional[Path]:
+ """Drives tests for run command so that it executes them both to use a json file, and to use CLI."""
+ if request.param:
+ ret: Path = tmp_path_factory.getbasetemp() / "system_config_payload_file.json"
+ return ret
+ return None
+
+
+def write_system_payload_config(
+ payload_file: IO[str],
+ application_config: ApplicationConfig,
+ system_config: SystemConfig,
+) -> None:
+ """Write a json payload file for the given test configuration."""
+ payload_dict = {
+ "id": system_config["name"],
+ "arguments": {
+ "application": application_config["name"],
+ },
+ }
+ json.dump(payload_dict, payload_file)
+
+
+@pytest.mark.parametrize(
+ "application_config, system_config, executions",
+ [
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ build_dir="build",
+ commands={
+ "build": ["echo build {application.name} with {user_params:0}"]
+ },
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="param",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=[],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Running: echo build test_application with param default
+build test_application with param default
+Generating commands to execute
+Running: echo run test_application on test_system
+run test_application on test_system\n""",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ commands={
+ "run": [
+ "echo run {application.name} with {user_params:param} on {system.name}"
+ ]
+ },
+ user_params={
+ "run": [
+ UserParamConfig(
+ name="param=",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ alias="param",
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=SSHConfig(
+ protocol="ssh",
+ username="username",
+ password="password",
+ hostname="localhost",
+ port="8022",
+ ),
+ commands={"run": ["sleep 100"]},
+ ),
+ [
+ ExecutionCase(
+ args=[],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .
+Running: echo run test_application with param=default on test_system
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.\n""",
+ )
+ ],
+ ],
+ ],
+)
+def test_application_run(
+ application_config: ApplicationConfig,
+ system_config: SystemConfig,
+ executions: List[ExecutionCase],
+ tmpdir: Any,
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ payload_path_or_none: Path,
+) -> None:
+ """Test application command execution."""
+ for execution in executions:
+ init_execution_test(monkeypatch, tmpdir, application_config, system_config)
+
+ if payload_path_or_none:
+ with open(payload_path_or_none, "w", encoding="utf-8") as payload_file:
+ write_system_payload_config(
+ payload_file, application_config, system_config
+ )
+
+ result = cli_runner.invoke(
+ run_cmd,
+ args=["--config", str(payload_path_or_none)],
+ )
+ else:
+ result = cli_runner.invoke(
+ run_cmd,
+ args=["-n", application_config["name"], "-s", system_config["name"]]
+ + execution["args"],
+ )
+
+ assert result.stdout == execution["output"]
+ assert result.exit_code == execution["exit_code"]
+
+
+@pytest.mark.parametrize(
+ "cmdline,error_pattern",
+ [
+ [
+ "--config {payload} -s test_system",
+ "when --config is set, the following parameters should not be provided",
+ ],
+ [
+ "--config {payload} -n test_application",
+ "when --config is set, the following parameters should not be provided",
+ ],
+ [
+ "--config {payload} -p mypar:3",
+ "when --config is set, the following parameters should not be provided",
+ ],
+ [
+ "-p mypar:3",
+ "when --config is not set, the following parameters are required",
+ ],
+ ["-s test_system", "when --config is not set, --name is required"],
+ ["-n test_application", "when --config is not set, --system is required"],
+ ],
+)
+def test_application_run_invalid_param_combinations(
+ cmdline: str,
+ error_pattern: str,
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ tmp_path: Any,
+ tmpdir: Any,
+) -> None:
+ """Test that invalid combinations arguments result in error as expected."""
+ application_config = ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ build_dir="build",
+ commands={"build": ["echo build {application.name} with {user_params:0}"]},
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="param",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ )
+ ]
+ },
+ )
+ system_config = SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ )
+
+ init_execution_test(monkeypatch, tmpdir, application_config, system_config)
+
+ payload_file = tmp_path / "payload.json"
+ payload_file.write_text("dummy")
+ result = cli_runner.invoke(
+ run_cmd,
+ args=cmdline.format(payload=payload_file).split(),
+ )
+ found = re.search(error_pattern, result.stdout)
+ assert found, f"Cannot find pattern: [{error_pattern}] in \n[\n{result.stdout}\n]"
+
+
+@pytest.mark.parametrize(
+ "payload,expected",
+ [
+ pytest.param(
+ {"arguments": {}},
+ None,
+ marks=pytest.mark.xfail(reason="no system 'id''", strict=True),
+ ),
+ pytest.param(
+ {"id": "testsystem"},
+ None,
+ marks=pytest.mark.xfail(reason="no arguments object", strict=True),
+ ),
+ (
+ {"id": "testsystem", "arguments": {"application": "testapp"}},
+ ("testsystem", "testapp", [], [], [], None),
+ ),
+ (
+ {
+ "id": "testsystem",
+ "arguments": {"application": "testapp", "par1": "val1"},
+ },
+ ("testsystem", "testapp", ["par1=val1"], [], [], None),
+ ),
+ (
+ {
+ "id": "testsystem",
+ "arguments": {"application": "testapp", "application/par1": "val1"},
+ },
+ ("testsystem", "testapp", ["par1=val1"], [], [], None),
+ ),
+ (
+ {
+ "id": "testsystem",
+ "arguments": {"application": "testapp", "system/par1": "val1"},
+ },
+ ("testsystem", "testapp", [], ["par1=val1"], [], None),
+ ),
+ (
+ {
+ "id": "testsystem",
+ "arguments": {"application": "testapp", "deploy/par1": "val1"},
+ },
+ ("testsystem", "testapp", [], [], ["par1"], None),
+ ),
+ (
+ {
+ "id": "testsystem",
+ "arguments": {
+ "application": "testapp",
+ "appar1": "val1",
+ "application/appar2": "val2",
+ "system/syspar1": "val3",
+ "deploy/depploypar1": "val4",
+ "application/appar3": "val5",
+ "system/syspar2": "val6",
+ "deploy/depploypar2": "val7",
+ },
+ },
+ (
+ "testsystem",
+ "testapp",
+ ["appar1=val1", "appar2=val2", "appar3=val5"],
+ ["syspar1=val3", "syspar2=val6"],
+ ["depploypar1", "depploypar2"],
+ None,
+ ),
+ ),
+ ],
+)
+def test_parse_payload_run_config(payload: dict, expected: tuple) -> None:
+ """Test parsing of the JSON payload for the run_config command."""
+ assert parse_payload_run_config(payload) == expected
+
+
+def test_application_run_report(
+ tmpdir: Any,
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+) -> None:
+ """Test flag '--report' of command 'application run'."""
+ app_metrics = {"app_metric": 3.14}
+ app_metrics_b64 = base64.b64encode(json.dumps(app_metrics).encode("utf-8"))
+ application_config = ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ build_dir="build",
+ commands={"build": ["echo build {application.name} with {user_params:0}"]},
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="param",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ ),
+ UserParamConfig(
+ name="p2",
+ description="another parameter, not overridden",
+ default_value="the-right-choice",
+ values=["the-right-choice", "the-bad-choice"],
+ ),
+ ]
+ },
+ )
+ system_config = SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={
+ "run": [
+ "echo run {application.name} on {system.name}",
+ f"echo build <{Base64OutputParser.TAG_NAME}>{app_metrics_b64.decode('utf-8')}</{Base64OutputParser.TAG_NAME}>",
+ ]
+ },
+ reporting={
+ "regex": {
+ "app_name": {
+ "pattern": r"run (.\S*) ",
+ "type": "str",
+ },
+ "sys_name": {
+ "pattern": r"on (.\S*)",
+ "type": "str",
+ },
+ }
+ },
+ )
+ report_file = Path(tmpdir) / "test_report.json"
+ param_val = "param=val1"
+ exit_code = MiddlewareExitCode.SUCCESS
+
+ init_execution_test(monkeypatch, tmpdir, application_config, system_config)
+
+ result = cli_runner.invoke(
+ run_cmd,
+ args=[
+ "-n",
+ application_config["name"],
+ "-s",
+ system_config["name"],
+ "--report",
+ str(report_file),
+ "--param",
+ param_val,
+ ],
+ )
+ assert result.exit_code == exit_code
+ assert report_file.is_file()
+ with open(report_file, "r", encoding="utf-8") as file:
+ report = json.load(file)
+
+ assert report == {
+ "application": {
+ "metrics": {"0": {"app_metric": 3.14}},
+ "name": "test_application",
+ "params": {"param": "val1", "p2": "the-right-choice"},
+ },
+ "system": {
+ "metrics": {"app_name": "test_application", "sys_name": "test_system"},
+ "name": "test_system",
+ "params": {},
+ },
+ }
+
+
+def init_execution_test(
+ monkeypatch: Any,
+ tmpdir: Any,
+ application_config: ApplicationConfig,
+ system_config: SystemConfig,
+ can_establish_connection: bool = True,
+ establish_conection_delay: float = 0,
+ remote_app_exit_code: int = 0,
+) -> None:
+ """Init execution test."""
+ application_name = application_config["name"]
+ system_name = system_config["name"]
+
+ execute_cmd.params[0].type = click.Choice([application_name])
+ execute_cmd.params[1].type = click.Choice([system_name])
+ execute_cmd.params[2].type = click.Choice(["build", "run", "some_command"])
+
+ run_cmd.params[0].type = click.Choice([application_name])
+ run_cmd.params[1].type = click.Choice([system_name])
+
+ if "config_location" not in application_config:
+ application_path = Path(tmpdir) / "application"
+ application_path.mkdir()
+ application_config["config_location"] = application_path
+
+ # this file could be used as deploy parameter value or
+ # as deploy parameter in application configuration
+ sample_file = application_path / "sample_file"
+ sample_file.touch()
+ monkeypatch.setattr(
+ "aiet.backend.application.get_available_applications",
+ MagicMock(return_value=[Application(application_config)]),
+ )
+
+ ssh_protocol_mock = MagicMock(spec=SSHProtocol)
+
+ def mock_establish_connection() -> bool:
+ """Mock establish connection function."""
+ # give some time for the system to start
+ time.sleep(establish_conection_delay)
+ return can_establish_connection
+
+ ssh_protocol_mock.establish_connection.side_effect = mock_establish_connection
+ ssh_protocol_mock.connection_details.return_value = ("localhost", 8022)
+ ssh_protocol_mock.run.return_value = (
+ remote_app_exit_code,
+ bytearray(),
+ bytearray(),
+ )
+ monkeypatch.setattr(
+ "aiet.backend.protocol.SSHProtocol", MagicMock(return_value=ssh_protocol_mock)
+ )
+
+ if "config_location" not in system_config:
+ system_path = Path(tmpdir) / "system"
+ system_path.mkdir()
+ system_config["config_location"] = system_path
+ monkeypatch.setattr(
+ "aiet.backend.system.get_available_systems",
+ MagicMock(return_value=[load_system(system_config)]),
+ )
+
+ monkeypatch.setattr("aiet.backend.execution.wait", MagicMock())
diff --git a/tests/aiet/test_cli_common.py b/tests/aiet/test_cli_common.py
new file mode 100644
index 0000000..d018e44
--- /dev/null
+++ b/tests/aiet/test_cli_common.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for cli common module."""
+from typing import Any
+
+import pytest
+
+from aiet.cli.common import print_command_details
+from aiet.cli.common import raise_exception_at_signal
+
+
+def test_print_command_details(capsys: Any) -> None:
+ """Test print_command_details function."""
+ command = {
+ "command_strings": ["echo test"],
+ "user_params": [
+ {"name": "param_name", "description": "param_description"},
+ {
+ "name": "param_name2",
+ "description": "param_description2",
+ "alias": "alias2",
+ },
+ ],
+ }
+ print_command_details(command)
+ captured = capsys.readouterr()
+ assert "echo test" in captured.out
+ assert "param_name" in captured.out
+ assert "alias2" in captured.out
+
+
+def test_raise_exception_at_signal() -> None:
+ """Test raise_exception_at_signal graceful shutdown."""
+ with pytest.raises(Exception) as err:
+ raise_exception_at_signal(1, "")
+
+ assert str(err.value) == "Middleware shutdown requested"
diff --git a/tests/aiet/test_cli_system.py b/tests/aiet/test_cli_system.py
new file mode 100644
index 0000000..fd39f31
--- /dev/null
+++ b/tests/aiet/test_cli_system.py
@@ -0,0 +1,240 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for testing CLI system subcommand."""
+import json
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+from unittest.mock import MagicMock
+
+import click
+import pytest
+from click.testing import CliRunner
+
+from aiet.backend.config import SystemConfig
+from aiet.backend.system import load_system
+from aiet.backend.system import System
+from aiet.cli.system import details_cmd
+from aiet.cli.system import install_cmd
+from aiet.cli.system import list_cmd
+from aiet.cli.system import remove_cmd
+from aiet.cli.system import system_cmd
+
+
+def test_system_cmd() -> None:
+ """Test system commands."""
+ commands = ["list", "details", "install", "remove"]
+ assert all(command in system_cmd.commands for command in commands)
+
+
+@pytest.mark.parametrize("format_", ["json", "cli"])
+def test_system_cmd_context(cli_runner: CliRunner, format_: str) -> None:
+ """Test setting command context parameters."""
+ result = cli_runner.invoke(system_cmd, ["--format", format_])
+ # command should fail if no subcommand provided
+ assert result.exit_code == 2
+
+ result = cli_runner.invoke(system_cmd, ["--format", format_, "list"])
+ assert result.exit_code == 0
+
+
+@pytest.mark.parametrize(
+ "format_,expected_output",
+ [
+ ("json", '{"type": "system", "available": ["system1", "system2"]}\n'),
+ ("cli", "Available systems:\n\nsystem1\nsystem2\n"),
+ ],
+)
+def test_list_cmd_with_format(
+ cli_runner: CliRunner, monkeypatch: Any, format_: str, expected_output: str
+) -> None:
+ """Test available systems command with different formats output."""
+ # Mock some systems
+ mock_system1 = MagicMock()
+ mock_system1.name = "system1"
+ mock_system2 = MagicMock()
+ mock_system2.name = "system2"
+
+ # Monkey patch the call get_available_systems
+ mock_available_systems = MagicMock()
+ mock_available_systems.return_value = [mock_system1, mock_system2]
+ monkeypatch.setattr("aiet.cli.system.get_available_systems", mock_available_systems)
+
+ obj = {"format": format_}
+ result = cli_runner.invoke(list_cmd, obj=obj)
+ assert result.output == expected_output
+
+
+def get_test_system(
+ annotations: Optional[Dict[str, Union[str, List[str]]]] = None
+) -> System:
+ """Return test system details."""
+ config = SystemConfig(
+ name="system",
+ description="test",
+ data_transfer={
+ "protocol": "ssh",
+ "username": "root",
+ "password": "root",
+ "hostname": "localhost",
+ "port": "8022",
+ },
+ commands={
+ "clean": ["clean"],
+ "build": ["build"],
+ "run": ["run"],
+ "post_run": ["post_run"],
+ },
+ annotations=annotations or {},
+ )
+
+ return load_system(config)
+
+
+def get_details_cmd_json_output(
+ annotations: Optional[Dict[str, Union[str, List[str]]]] = None
+) -> str:
+ """Test JSON output for details command."""
+ ann_str = ""
+ if annotations is not None:
+ ann_str = '"annotations":{},'.format(json.dumps(annotations))
+
+ json_output = (
+ """
+{
+ "type": "system",
+ "name": "system",
+ "description": "test",
+ "data_transfer_protocol": "ssh",
+ "commands": {
+ "clean":
+ {
+ "command_strings": ["clean"],
+ "user_params": []
+ },
+ "build":
+ {
+ "command_strings": ["build"],
+ "user_params": []
+ },
+ "run":
+ {
+ "command_strings": ["run"],
+ "user_params": []
+ },
+ "post_run":
+ {
+ "command_strings": ["post_run"],
+ "user_params": []
+ }
+ },
+"""
+ + ann_str
+ + """
+ "available_application" : []
+ }
+"""
+ )
+ return json.dumps(json.loads(json_output)) + "\n"
+
+
+def get_details_cmd_console_output(
+ annotations: Optional[Dict[str, Union[str, List[str]]]] = None
+) -> str:
+ """Test console output for details command."""
+ ann_str = ""
+ if annotations:
+ val_str = "".join(
+ "\n\t{}: {}".format(ann_name, ann_value)
+ for ann_name, ann_value in annotations.items()
+ )
+ ann_str = "\nAnnotations:{}".format(val_str)
+ return (
+ 'System "system" details'
+ + "\nDescription: test"
+ + "\nData Transfer Protocol: ssh"
+ + "\nAvailable Applications: "
+ + ann_str
+ + "\n\nclean commands:"
+ + "\nCommands: ['clean']"
+ + "\n\nbuild commands:"
+ + "\nCommands: ['build']"
+ + "\n\nrun commands:"
+ + "\nCommands: ['run']"
+ + "\n\npost_run commands:"
+ + "\nCommands: ['post_run']"
+ + "\n"
+ )
+
+
+@pytest.mark.parametrize(
+ "format_,system,expected_output",
+ [
+ (
+ "json",
+ get_test_system(annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}),
+ get_details_cmd_json_output(
+ annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}
+ ),
+ ),
+ (
+ "cli",
+ get_test_system(annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}),
+ get_details_cmd_console_output(
+ annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}
+ ),
+ ),
+ (
+ "json",
+ get_test_system(annotations={}),
+ get_details_cmd_json_output(annotations={}),
+ ),
+ (
+ "cli",
+ get_test_system(annotations={}),
+ get_details_cmd_console_output(annotations={}),
+ ),
+ ],
+)
+def test_details_cmd(
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ format_: str,
+ system: System,
+ expected_output: str,
+) -> None:
+ """Test details command with different formats output."""
+ mock_get_system = MagicMock()
+ mock_get_system.return_value = system
+ monkeypatch.setattr("aiet.cli.system.get_system", mock_get_system)
+
+ args = ["--name", "system"]
+ obj = {"format": format_}
+ details_cmd.params[0].type = click.Choice(["system"])
+
+ result = cli_runner.invoke(details_cmd, args=args, obj=obj)
+ assert result.output == expected_output
+
+
+def test_install_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test install system command."""
+ mock_install_system = MagicMock()
+ monkeypatch.setattr("aiet.cli.system.install_system", mock_install_system)
+
+ args = ["--source", "test"]
+ cli_runner.invoke(install_cmd, args=args)
+ mock_install_system.assert_called_once_with(Path("test"))
+
+
+def test_remove_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test remove system command."""
+ mock_remove_system = MagicMock()
+ monkeypatch.setattr("aiet.cli.system.remove_system", mock_remove_system)
+ remove_cmd.params[0].type = click.Choice(["test"])
+
+ args = ["--directory_name", "test"]
+ cli_runner.invoke(remove_cmd, args=args)
+ mock_remove_system.assert_called_once_with("test")
diff --git a/tests/aiet/test_cli_tool.py b/tests/aiet/test_cli_tool.py
new file mode 100644
index 0000000..45d45c8
--- /dev/null
+++ b/tests/aiet/test_cli_tool.py
@@ -0,0 +1,333 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=attribute-defined-outside-init,no-member,line-too-long,too-many-arguments,too-many-locals
+"""Module for testing CLI tool subcommand."""
+import json
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from unittest.mock import MagicMock
+
+import click
+import pytest
+from click.testing import CliRunner
+from click.testing import Result
+
+from aiet.backend.tool import get_unique_tool_names
+from aiet.backend.tool import Tool
+from aiet.cli.tool import details_cmd
+from aiet.cli.tool import execute_cmd
+from aiet.cli.tool import list_cmd
+from aiet.cli.tool import tool_cmd
+
+
+def test_tool_cmd() -> None:
+ """Test tool commands."""
+ commands = ["list", "details", "execute"]
+ assert all(command in tool_cmd.commands for command in commands)
+
+
+@pytest.mark.parametrize("format_", ["json", "cli"])
+def test_tool_cmd_context(cli_runner: CliRunner, format_: str) -> None:
+ """Test setting command context parameters."""
+ result = cli_runner.invoke(tool_cmd, ["--format", format_])
+ # command should fail if no subcommand provided
+ assert result.exit_code == 2
+
+ result = cli_runner.invoke(tool_cmd, ["--format", format_, "list"])
+ assert result.exit_code == 0
+
+
+@pytest.mark.parametrize(
+ "format_, expected_output",
+ [
+ (
+ "json",
+ '{"type": "tool", "available": ["tool_1", "tool_2"]}\n',
+ ),
+ ("cli", "Available tools:\n\ntool_1\ntool_2\n"),
+ ],
+)
+def test_list_cmd(
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ format_: str,
+ expected_output: str,
+) -> None:
+ """Test available tool commands."""
+ # Mock some tools
+ mock_tool_1 = MagicMock(spec=Tool)
+ mock_tool_1.name = "tool_1"
+ mock_tool_2 = MagicMock(spec=Tool)
+ mock_tool_2.name = "tool_2"
+
+ # Monkey patch the call get_available_tools
+ mock_available_tools = MagicMock()
+ mock_available_tools.return_value = [mock_tool_1, mock_tool_2]
+
+ monkeypatch.setattr("aiet.backend.tool.get_available_tools", mock_available_tools)
+
+ obj = {"format": format_}
+ args: Sequence[str] = []
+ result = cli_runner.invoke(list_cmd, obj=obj, args=args)
+ assert result.output == expected_output
+
+
+def get_details_cmd_json_output() -> List[dict]:
+ """Get JSON output for details command."""
+ json_output = [
+ {
+ "type": "tool",
+ "name": "tool_1",
+ "description": "This is tool 1",
+ "supported_systems": ["System 1"],
+ "commands": {
+ "clean": {"command_strings": ["echo 'clean'"], "user_params": []},
+ "build": {"command_strings": ["echo 'build'"], "user_params": []},
+ "run": {"command_strings": ["echo 'run'"], "user_params": []},
+ "post_run": {"command_strings": ["echo 'post_run'"], "user_params": []},
+ },
+ }
+ ]
+
+ return json_output
+
+
+def get_details_cmd_console_output() -> str:
+ """Get console output for details command."""
+ return (
+ 'Tool "tool_1" details'
+ "\nDescription: This is tool 1"
+ "\n\nSupported systems: System 1"
+ "\n\nclean commands:"
+ "\nCommands: [\"echo 'clean'\"]"
+ "\n\nbuild commands:"
+ "\nCommands: [\"echo 'build'\"]"
+ "\n\nrun commands:\nCommands: [\"echo 'run'\"]"
+ "\n\npost_run commands:"
+ "\nCommands: [\"echo 'post_run'\"]"
+ "\n"
+ )
+
+
+@pytest.mark.parametrize(
+ [
+ "tool_name",
+ "format_",
+ "expected_success",
+ "expected_output",
+ ],
+ [
+ ("tool_1", "json", True, get_details_cmd_json_output()),
+ ("tool_1", "cli", True, get_details_cmd_console_output()),
+ ("non-existent tool", "json", False, None),
+ ("non-existent tool", "cli", False, None),
+ ],
+)
+def test_details_cmd(
+ cli_runner: CliRunner,
+ tool_name: str,
+ format_: str,
+ expected_success: bool,
+ expected_output: str,
+) -> None:
+ """Test tool details command."""
+ details_cmd.params[0].type = click.Choice(["tool_1", "tool_2", "vela"])
+ result = cli_runner.invoke(
+ details_cmd, obj={"format": format_}, args=["--name", tool_name]
+ )
+ success = result.exit_code == 0
+ assert success == expected_success, result.output
+ if expected_success:
+ assert result.exception is None
+ output = json.loads(result.output) if format_ == "json" else result.output
+ assert output == expected_output
+
+
+@pytest.mark.parametrize(
+ "system_name",
+ [
+ "",
+ "Corstone-300: Cortex-M55+Ethos-U55",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ "Corstone-310: Cortex-M85+Ethos-U55",
+ ],
+)
+def test_details_cmd_vela(cli_runner: CliRunner, system_name: str) -> None:
+ """Test tool details command for Vela."""
+ details_cmd.params[0].type = click.Choice(get_unique_tool_names())
+ details_cmd.params[1].type = click.Choice([system_name])
+ args = ["--name", "vela"]
+ if system_name:
+ args += ["--system", system_name]
+ result = cli_runner.invoke(details_cmd, obj={"format": "json"}, args=args)
+ success = result.exit_code == 0
+ assert success, result.output
+ result_json = json.loads(result.output)
+ assert result_json
+ if system_name:
+ assert len(result_json) == 1
+ tool = result_json[0]
+ assert len(tool["supported_systems"]) == 1
+ assert system_name == tool["supported_systems"][0]
+ else: # no system specified => list details for all systems
+ assert len(result_json) == 3
+ assert all(len(tool["supported_systems"]) == 1 for tool in result_json)
+
+
+@pytest.fixture(scope="session")
+def input_model_file(non_optimised_input_model_file: Path) -> Path:
+ """Provide the path to a quantized dummy model file in the test_resources_path."""
+ return non_optimised_input_model_file
+
+
+def execute_vela(
+ cli_runner: CliRunner,
+ tool_name: str = "vela",
+ system_name: Optional[str] = None,
+ input_model: Optional[Path] = None,
+ output_model: Optional[Path] = None,
+ mac: Optional[int] = None,
+ format_: str = "cli",
+) -> Result:
+ """Run Vela with different parameters."""
+ execute_cmd.params[0].type = click.Choice(get_unique_tool_names())
+ execute_cmd.params[2].type = click.Choice([system_name or "dummy_system"])
+ args = ["--name", tool_name]
+ if system_name is not None:
+ args += ["--system", system_name]
+ if input_model is not None:
+ args += ["--param", "input={}".format(input_model)]
+ if output_model is not None:
+ args += ["--param", "output={}".format(output_model)]
+ if mac is not None:
+ args += ["--param", "mac={}".format(mac)]
+ result = cli_runner.invoke(
+ execute_cmd,
+ args=args,
+ obj={"format": format_},
+ )
+ return result
+
+
+@pytest.mark.parametrize("format_", ["cli, json"])
+@pytest.mark.parametrize(
+ ["tool_name", "system_name", "mac", "expected_success", "expected_output"],
+ [
+ ("vela", "System 1", 32, False, None), # system not supported
+ ("vela", "NON-EXISTENT SYSTEM", 128, False, None), # system does not exist
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 32, True, None),
+ ("NON-EXISTENT TOOL", "Corstone-300: Cortex-M55+Ethos-U55", 32, False, None),
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 64, True, None),
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 128, True, None),
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 256, True, None),
+ (
+ "vela",
+ "Corstone-300: Cortex-M55+Ethos-U55",
+ 512,
+ False,
+ None,
+ ), # mac not supported
+ (
+ "vela",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ 32,
+ False,
+ None,
+ ), # mac not supported
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U65", 256, True, None),
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U65", 512, True, None),
+ (
+ "vela",
+ None,
+ 512,
+ False,
+ "Error: Please specify the system for tool vela.",
+ ), # no system specified
+ (
+ "NON-EXISTENT TOOL",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ 512,
+ False,
+ None,
+ ), # tool does not exist
+ ("vela", "Corstone-310: Cortex-M85+Ethos-U55", 128, True, None),
+ ],
+)
+def test_vela_run(
+ cli_runner: CliRunner,
+ format_: str,
+ input_model_file: Path, # pylint: disable=redefined-outer-name
+ tool_name: str,
+ system_name: Optional[str],
+ mac: int,
+ expected_success: bool,
+ expected_output: Optional[str],
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ """Test the execution of the Vela command."""
+ monkeypatch.chdir(tmp_path)
+
+ output_file = Path("vela_output.tflite")
+
+ result = execute_vela(
+ cli_runner,
+ tool_name=tool_name,
+ system_name=system_name,
+ input_model=input_model_file,
+ output_model=output_file,
+ mac=mac,
+ format_=format_,
+ )
+
+ success = result.exit_code == 0
+ assert success == expected_success
+ if success:
+ # Check output file
+ output_file = output_file.resolve()
+ assert output_file.is_file()
+ if expected_output:
+ assert result.output.strip() == expected_output
+
+
+@pytest.mark.parametrize("include_input_model", [True, False])
+@pytest.mark.parametrize("include_output_model", [True, False])
+@pytest.mark.parametrize("include_mac", [True, False])
+def test_vela_run_missing_params(
+ cli_runner: CliRunner,
+ input_model_file: Path, # pylint: disable=redefined-outer-name
+ include_input_model: bool,
+ include_output_model: bool,
+ include_mac: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ """Test the execution of the Vela command with missing user parameters."""
+ monkeypatch.chdir(tmp_path)
+
+ output_model_file = Path("output_model.tflite")
+ system_name = "Corstone-300: Cortex-M55+Ethos-U65"
+ mac = 256
+ # input_model is a required parameters, but mac and output_model have default values.
+ expected_success = include_input_model
+
+ result = execute_vela(
+ cli_runner,
+ tool_name="vela",
+ system_name=system_name,
+ input_model=input_model_file if include_input_model else None,
+ output_model=output_model_file if include_output_model else None,
+ mac=mac if include_mac else None,
+ )
+
+ success = result.exit_code == 0
+ assert success == expected_success, (
+ f"Success is {success}, but expected {expected_success}. "
+ f"Included params: ["
+ f"input_model={include_input_model}, "
+ f"output_model={include_output_model}, "
+ f"mac={include_mac}]"
+ )
diff --git a/tests/aiet/test_main.py b/tests/aiet/test_main.py
new file mode 100644
index 0000000..f2ebae2
--- /dev/null
+++ b/tests/aiet/test_main.py
@@ -0,0 +1,16 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for testing AIET main.py."""
+from typing import Any
+from unittest.mock import MagicMock
+
+from aiet import main
+
+
+def test_main(monkeypatch: Any) -> None:
+ """Test main entry point function."""
+ with monkeypatch.context() as mock_context:
+ mock = MagicMock()
+ mock_context.setattr(main, "cli", mock)
+ main.main()
+ mock.assert_called_once()
diff --git a/tests/aiet/test_resources/application_config.json b/tests/aiet/test_resources/application_config.json
new file mode 100644
index 0000000..2dfcfec
--- /dev/null
+++ b/tests/aiet/test_resources/application_config.json
@@ -0,0 +1,96 @@
+[
+ {
+ "name": "application_1",
+ "description": "application number one",
+ "supported_systems": [
+ "system_1",
+ "system_2"
+ ],
+ "build_dir": "build_dir_11",
+ "commands": {
+ "clean": [
+ "clean_cmd_11"
+ ],
+ "build": [
+ "build_cmd_11"
+ ],
+ "run": [
+ "run_cmd_11"
+ ],
+ "post_run": [
+ "post_run_cmd_11"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "run_param_11",
+ "values": [],
+ "description": "run param number one"
+ }
+ ],
+ "build": [
+ {
+ "name": "build_param_11",
+ "values": [],
+ "description": "build param number one"
+ },
+ {
+ "name": "build_param_12",
+ "values": [],
+ "description": "build param number two"
+ },
+ {
+ "name": "build_param_13",
+ "values": [
+ "value_1"
+ ],
+ "description": "build param number three with some value"
+ }
+ ]
+ }
+ },
+ {
+ "name": "application_2",
+ "description": "application number two",
+ "supported_systems": [
+ "system_2"
+ ],
+ "build_dir": "build_dir_21",
+ "commands": {
+ "clean": [
+ "clean_cmd_21"
+ ],
+ "build": [
+ "build_cmd_21",
+ "build_cmd_22"
+ ],
+ "run": [
+ "run_cmd_21"
+ ],
+ "post_run": [
+ "post_run_cmd_21"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "build_param_21",
+ "values": [],
+ "description": "build param number one"
+ },
+ {
+ "name": "build_param_22",
+ "values": [],
+ "description": "build param number two"
+ },
+ {
+ "name": "build_param_23",
+ "values": [],
+ "description": "build param number three"
+ }
+ ],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/application_config.json.license b/tests/aiet/test_resources/application_config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/application_config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/applications/application1/aiet-config.json b/tests/aiet/test_resources/applications/application1/aiet-config.json
new file mode 100644
index 0000000..97f0401
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application1/aiet-config.json
@@ -0,0 +1,30 @@
+[
+ {
+ "name": "application_1",
+ "description": "This is application 1",
+ "supported_systems": [
+ {
+ "name": "System 1"
+ }
+ ],
+ "build_dir": "build",
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/applications/application1/aiet-config.json.license b/tests/aiet/test_resources/applications/application1/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application1/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/applications/application2/aiet-config.json b/tests/aiet/test_resources/applications/application2/aiet-config.json
new file mode 100644
index 0000000..e9122d3
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application2/aiet-config.json
@@ -0,0 +1,30 @@
+[
+ {
+ "name": "application_2",
+ "description": "This is application 2",
+ "supported_systems": [
+ {
+ "name": "System 2"
+ }
+ ],
+ "build_dir": "build",
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/applications/application2/aiet-config.json.license b/tests/aiet/test_resources/applications/application2/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application2/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/applications/application3/readme.txt b/tests/aiet/test_resources/applications/application3/readme.txt
new file mode 100644
index 0000000..8c72c05
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application3/readme.txt
@@ -0,0 +1,4 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+This application does not have json configuration file
diff --git a/tests/aiet/test_resources/applications/application4/aiet-config.json b/tests/aiet/test_resources/applications/application4/aiet-config.json
new file mode 100644
index 0000000..34dc780
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application4/aiet-config.json
@@ -0,0 +1,35 @@
+[
+ {
+ "name": "application_4",
+ "description": "This is application 4",
+ "build_dir": "build",
+ "supported_systems": [
+ {
+ "name": "System 4"
+ }
+ ],
+ "commands": {
+ "build": [
+ "cp ../hello_app.txt . # {user_params:0}"
+ ],
+ "run": [
+ "{application.build_dir}/hello_app.txt"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--app",
+ "description": "Sample command param",
+ "values": [
+ "application1",
+ "application2",
+ "application3"
+ ],
+ "default_value": "application1"
+ }
+ ],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/applications/application4/aiet-config.json.license b/tests/aiet/test_resources/applications/application4/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application4/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/applications/application4/hello_app.txt b/tests/aiet/test_resources/applications/application4/hello_app.txt
new file mode 100644
index 0000000..2ec0d1d
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application4/hello_app.txt
@@ -0,0 +1,4 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+Hello from APP!
diff --git a/tests/aiet/test_resources/applications/application5/aiet-config.json b/tests/aiet/test_resources/applications/application5/aiet-config.json
new file mode 100644
index 0000000..5269409
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application5/aiet-config.json
@@ -0,0 +1,160 @@
+[
+ {
+ "name": "application_5",
+ "description": "This is application 5",
+ "build_dir": "default_build_dir",
+ "supported_systems": [
+ {
+ "name": "System 1",
+ "lock": false
+ },
+ {
+ "name": "System 2"
+ }
+ ],
+ "variables": {
+ "var1": "value1",
+ "var2": "value2"
+ },
+ "lock": true,
+ "commands": {
+ "build": [
+ "default build command"
+ ],
+ "run": [
+ "default run command"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ },
+ {
+ "name": "application_5A",
+ "description": "This is application 5A",
+ "supported_systems": [
+ {
+ "name": "System 1",
+ "build_dir": "build_5A",
+ "variables": {
+ "var1": "new value1"
+ }
+ },
+ {
+ "name": "System 2",
+ "variables": {
+ "var2": "new value2"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "run command on system 2"
+ ]
+ }
+ }
+ ],
+ "variables": {
+ "var1": "value1",
+ "var2": "value2"
+ },
+ "build_dir": "build",
+ "commands": {
+ "build": [
+ "default build command"
+ ],
+ "run": [
+ "default run command"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ },
+ {
+ "name": "application_5B",
+ "description": "This is application 5B",
+ "supported_systems": [
+ {
+ "name": "System 1",
+ "build_dir": "build_5B",
+ "variables": {
+ "var1": "value for var1 System1",
+ "var2": "value for var2 System1"
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--param_5B",
+ "description": "Sample command param",
+ "values": [
+ "value1",
+ "value2",
+ "value3"
+ ],
+ "default_value": "value1",
+ "alias": "param1"
+ }
+ ]
+ }
+ },
+ {
+ "name": "System 2",
+ "variables": {
+ "var1": "value for var1 System2",
+ "var2": "value for var2 System2"
+ },
+ "commands": {
+ "build": [
+ "build command on system 2 with {variables:var1} {user_params:param1}"
+ ],
+ "run": [
+ "run command on system 2"
+ ]
+ },
+ "user_params": {
+ "run": []
+ }
+ }
+ ],
+ "build_dir": "build",
+ "commands": {
+ "build": [
+ "default build command with {variables:var1}"
+ ],
+ "run": [
+ "default run command with {variables:var2}"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--param",
+ "description": "Sample command param",
+ "values": [
+ "value1",
+ "value2",
+ "value3"
+ ],
+ "default_value": "value1",
+ "alias": "param1"
+ }
+ ],
+ "run": [],
+ "non_used_command": [
+ {
+ "name": "--not-used",
+ "description": "Not used param anywhere",
+ "values": [
+ "value1",
+ "value2",
+ "value3"
+ ],
+ "default_value": "value1",
+ "alias": "param1"
+ }
+ ]
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/applications/application5/aiet-config.json.license b/tests/aiet/test_resources/applications/application5/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application5/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/applications/readme.txt b/tests/aiet/test_resources/applications/readme.txt
new file mode 100644
index 0000000..a1f8209
--- /dev/null
+++ b/tests/aiet/test_resources/applications/readme.txt
@@ -0,0 +1,4 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+Dummy file for test purposes
diff --git a/tests/aiet/test_resources/hello_world.json b/tests/aiet/test_resources/hello_world.json
new file mode 100644
index 0000000..8a9a448
--- /dev/null
+++ b/tests/aiet/test_resources/hello_world.json
@@ -0,0 +1,54 @@
+[
+ {
+ "name": "Hello world",
+ "description": "Dummy application that displays 'Hello world!'",
+ "supported_systems": [
+ "Dummy System"
+ ],
+ "build_dir": "build",
+ "deploy_data": [
+ [
+ "src",
+ "/tmp/"
+ ],
+ [
+ "README",
+ "/tmp/README.md"
+ ]
+ ],
+ "commands": {
+ "clean": [],
+ "build": [],
+ "run": [
+ "echo 'Hello world!'",
+ "ls -l /tmp"
+ ],
+ "post_run": []
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--choice-param",
+ "values": [
+ "dummy_value_1",
+ "dummy_value_2"
+ ],
+ "default_value": "dummy_value_1",
+ "description": "Choice param"
+ },
+ {
+ "name": "--open-param",
+ "values": [],
+ "default_value": "dummy_value_4",
+ "description": "Open param"
+ },
+ {
+ "name": "--enable-flag",
+ "default_value": "dummy_value_4",
+ "description": "Flag param"
+ }
+ ],
+ "build": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/hello_world.json.license b/tests/aiet/test_resources/hello_world.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/hello_world.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/scripts/test_backend_run b/tests/aiet/test_resources/scripts/test_backend_run
new file mode 100755
index 0000000..548f577
--- /dev/null
+++ b/tests/aiet/test_resources/scripts/test_backend_run
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+echo "Hello from script"
+>&2 echo "Oops!"
+sleep 100
diff --git a/tests/aiet/test_resources/scripts/test_backend_run_script.sh b/tests/aiet/test_resources/scripts/test_backend_run_script.sh
new file mode 100644
index 0000000..548f577
--- /dev/null
+++ b/tests/aiet/test_resources/scripts/test_backend_run_script.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+echo "Hello from script"
+>&2 echo "Oops!"
+sleep 100
diff --git a/tests/aiet/test_resources/systems/system1/aiet-config.json b/tests/aiet/test_resources/systems/system1/aiet-config.json
new file mode 100644
index 0000000..4b5dd19
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system1/aiet-config.json
@@ -0,0 +1,35 @@
+[
+ {
+ "name": "System 1",
+ "description": "This is system 1",
+ "build_dir": "build",
+ "data_transfer": {
+ "protocol": "ssh",
+ "username": "root",
+ "password": "root",
+ "hostname": "localhost",
+ "port": "8021"
+ },
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ],
+ "deploy": [
+ "echo 'deploy'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/systems/system1/aiet-config.json.license b/tests/aiet/test_resources/systems/system1/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system1/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt b/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt
new file mode 100644
index 0000000..487e9d8
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt
@@ -0,0 +1,2 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/systems/system2/aiet-config.json b/tests/aiet/test_resources/systems/system2/aiet-config.json
new file mode 100644
index 0000000..a9e0eb3
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system2/aiet-config.json
@@ -0,0 +1,32 @@
+[
+ {
+ "name": "System 2",
+ "description": "This is system 2",
+ "build_dir": "build",
+ "data_transfer": {
+ "protocol": "ssh",
+ "username": "root",
+ "password": "root",
+ "hostname": "localhost",
+ "port": "8021"
+ },
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/systems/system2/aiet-config.json.license b/tests/aiet/test_resources/systems/system2/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system2/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/systems/system3/readme.txt b/tests/aiet/test_resources/systems/system3/readme.txt
new file mode 100644
index 0000000..aba5a9c
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system3/readme.txt
@@ -0,0 +1,4 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+This system does not have the json configuration file
diff --git a/tests/aiet/test_resources/systems/system4/aiet-config.json b/tests/aiet/test_resources/systems/system4/aiet-config.json
new file mode 100644
index 0000000..295e00f
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system4/aiet-config.json
@@ -0,0 +1,19 @@
+[
+ {
+ "name": "System 4",
+ "description": "This is system 4",
+ "build_dir": "build",
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "commands": {
+ "run": [
+ "echo {application.name}",
+ "cat {application.commands.run:0}"
+ ]
+ },
+ "user_params": {
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/systems/system4/aiet-config.json.license b/tests/aiet/test_resources/systems/system4/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system4/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/tools/tool1/aiet-config.json b/tests/aiet/test_resources/tools/tool1/aiet-config.json
new file mode 100644
index 0000000..067ef7e
--- /dev/null
+++ b/tests/aiet/test_resources/tools/tool1/aiet-config.json
@@ -0,0 +1,30 @@
+[
+ {
+ "name": "tool_1",
+ "description": "This is tool 1",
+ "build_dir": "build",
+ "supported_systems": [
+ {
+ "name": "System 1"
+ }
+ ],
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/tools/tool1/aiet-config.json.license b/tests/aiet/test_resources/tools/tool1/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/tools/tool1/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/tools/tool2/aiet-config.json b/tests/aiet/test_resources/tools/tool2/aiet-config.json
new file mode 100644
index 0000000..6eee9a6
--- /dev/null
+++ b/tests/aiet/test_resources/tools/tool2/aiet-config.json
@@ -0,0 +1,26 @@
+[
+ {
+ "name": "tool_2",
+ "description": "This is tool 2 with no supported systems",
+ "build_dir": "build",
+ "supported_systems": [],
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/tools/tool2/aiet-config.json.license b/tests/aiet/test_resources/tools/tool2/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/tools/tool2/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json
new file mode 100644
index 0000000..fe51488
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json
@@ -0,0 +1 @@
+[]
diff --git a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json
new file mode 100644
index 0000000..ff1cf1a
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json
@@ -0,0 +1,35 @@
+[
+ {
+ "name": "test_application",
+ "description": "This is test_application",
+ "build_dir": "build",
+ "supported_systems": [
+ {
+ "name": "System 4"
+ }
+ ],
+ "commands": {
+ "build": [
+ "cp ../hello_app.txt ."
+ ],
+ "run": [
+ "{application.build_dir}/hello_app.txt"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--app",
+ "description": "Sample command param",
+ "values": [
+ "application1",
+ "application2",
+ "application3"
+ ],
+ "default_value": "application1"
+ }
+ ],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json
new file mode 100644
index 0000000..724b31b
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json
@@ -0,0 +1,2 @@
+This is not valid json file
+{
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json
new file mode 100644
index 0000000..1ebb29c
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json
@@ -0,0 +1,30 @@
+[
+ {
+ "name": "test_application",
+ "description": "This is test_application",
+ "build_dir": "build",
+ "commands": {
+ "build": [
+ "cp ../hello_app.txt ."
+ ],
+ "run": [
+ "{application.build_dir}/hello_app.txt"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--app",
+ "description": "Sample command param",
+ "values": [
+ "application1",
+ "application2",
+ "application3"
+ ],
+ "default_value": "application1"
+ }
+ ],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json
new file mode 100644
index 0000000..410d12d
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json
@@ -0,0 +1,35 @@
+[
+ {
+ "name": "test_application",
+ "description": "This is test_application",
+ "build_dir": "build",
+ "supported_systems": [
+ {
+ "anme": "System 4"
+ }
+ ],
+ "commands": {
+ "build": [
+ "cp ../hello_app.txt ."
+ ],
+ "run": [
+ "{application.build_dir}/hello_app.txt"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--app",
+ "description": "Sample command param",
+ "values": [
+ "application1",
+ "application2",
+ "application3"
+ ],
+ "default_value": "application1"
+ }
+ ],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json
new file mode 100644
index 0000000..fe51488
--- /dev/null
+++ b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json
@@ -0,0 +1 @@
+[]
diff --git a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json
new file mode 100644
index 0000000..20142e9
--- /dev/null
+++ b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json
@@ -0,0 +1,16 @@
+[
+ {
+ "name": "Test system",
+ "description": "This is a test system",
+ "build_dir": "build",
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "commands": {
+ "run": []
+ },
+ "user_params": {
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_run_vela_script.py b/tests/aiet/test_run_vela_script.py
new file mode 100644
index 0000000..971856e
--- /dev/null
+++ b/tests/aiet/test_run_vela_script.py
@@ -0,0 +1,152 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=redefined-outer-name,no-self-use
+"""Module for testing run_vela.py script."""
+from pathlib import Path
+from typing import Any
+from typing import List
+
+import pytest
+from click.testing import CliRunner
+
+from aiet.cli.common import MiddlewareExitCode
+from aiet.resources.tools.vela.check_model import get_model_from_file
+from aiet.resources.tools.vela.check_model import is_vela_optimised
+from aiet.resources.tools.vela.run_vela import run_vela
+
+
+@pytest.fixture(scope="session")
+def vela_config_path(test_tools_path: Path) -> Path:
+ """Return test systems path in a pytest fixture."""
+ return test_tools_path / "vela" / "vela.ini"
+
+
+@pytest.fixture(
+ params=[
+ ["ethos-u65-256", "Ethos_U65_High_End", "U65_Shared_Sram"],
+ ["ethos-u55-32", "Ethos_U55_High_End_Embedded", "U55_Shared_Sram"],
+ ]
+)
+def ethos_config(request: Any) -> Any:
+ """Fixture to provide different configuration for Ethos-U optimization with Vela."""
+ return request.param
+
+
+# pylint: disable=too-many-arguments
+def generate_args(
+ input_: Path,
+ output: Path,
+ cfg: Path,
+ acc_config: str,
+ system_config: str,
+ memory_mode: str,
+) -> List[str]:
+ """Generate arguments that can be passed to script 'run_vela'."""
+ return [
+ "-i",
+ str(input_),
+ "-o",
+ str(output),
+ "--config",
+ str(cfg),
+ "--accelerator-config",
+ acc_config,
+ "--system-config",
+ system_config,
+ "--memory-mode",
+ memory_mode,
+ "--optimise",
+ "Performance",
+ ]
+
+
+def check_run_vela(
+ cli_runner: CliRunner, args: List, expected_success: bool, output_file: Path
+) -> None:
+ """Run Vela with the given arguments and check the result."""
+ result = cli_runner.invoke(run_vela, args)
+ success = result.exit_code == MiddlewareExitCode.SUCCESS
+ assert success == expected_success
+ if success:
+ model = get_model_from_file(output_file)
+ assert is_vela_optimised(model)
+
+
+def run_vela_script(
+ cli_runner: CliRunner,
+ input_model_file: Path,
+ output_model_file: Path,
+ vela_config: Path,
+ expected_success: bool,
+ acc_config: str,
+ system_config: str,
+ memory_mode: str,
+) -> None:
+ """Run the command 'run_vela' on the command line."""
+ args = generate_args(
+ input_model_file,
+ output_model_file,
+ vela_config,
+ acc_config,
+ system_config,
+ memory_mode,
+ )
+ check_run_vela(cli_runner, args, expected_success, output_model_file)
+
+
+class TestRunVelaCli:
+ """Test the command-line execution of the run_vela command."""
+
+ def test_non_optimised_model(
+ self,
+ cli_runner: CliRunner,
+ non_optimised_input_model_file: Path,
+ tmp_path: Path,
+ vela_config_path: Path,
+ ethos_config: List,
+ ) -> None:
+ """Verify Vela is run correctly on an unoptimised model."""
+ run_vela_script(
+ cli_runner,
+ non_optimised_input_model_file,
+ tmp_path / "test.tflite",
+ vela_config_path,
+ True,
+ *ethos_config,
+ )
+
+ def test_optimised_model(
+ self,
+ cli_runner: CliRunner,
+ optimised_input_model_file: Path,
+ tmp_path: Path,
+ vela_config_path: Path,
+ ethos_config: List,
+ ) -> None:
+ """Verify Vela is run correctly on an already optimised model."""
+ run_vela_script(
+ cli_runner,
+ optimised_input_model_file,
+ tmp_path / "test.tflite",
+ vela_config_path,
+ True,
+ *ethos_config,
+ )
+
+ def test_invalid_model(
+ self,
+ cli_runner: CliRunner,
+ invalid_input_model_file: Path,
+ tmp_path: Path,
+ vela_config_path: Path,
+ ethos_config: List,
+ ) -> None:
+ """Verify an error is raised when the input model is not valid."""
+ run_vela_script(
+ cli_runner,
+ invalid_input_model_file,
+ tmp_path / "test.tflite",
+ vela_config_path,
+ False,
+ *ethos_config,
+ )
diff --git a/tests/aiet/test_utils_fs.py b/tests/aiet/test_utils_fs.py
new file mode 100644
index 0000000..46d276e
--- /dev/null
+++ b/tests/aiet/test_utils_fs.py
@@ -0,0 +1,168 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use
+"""Module for testing fs.py."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Union
+from unittest.mock import MagicMock
+
+import pytest
+
+from aiet.utils.fs import get_resources
+from aiet.utils.fs import read_file_as_bytearray
+from aiet.utils.fs import read_file_as_string
+from aiet.utils.fs import recreate_directory
+from aiet.utils.fs import remove_directory
+from aiet.utils.fs import remove_resource
+from aiet.utils.fs import ResourceType
+from aiet.utils.fs import valid_for_filename
+
+
+@pytest.mark.parametrize(
+ "resource_name,expected_path",
+ [
+ ("systems", does_not_raise()),
+ ("applications", does_not_raise()),
+ ("whaaat", pytest.raises(ResourceWarning)),
+ (None, pytest.raises(ResourceWarning)),
+ ],
+)
+def test_get_resources(resource_name: ResourceType, expected_path: Any) -> None:
+ """Test get_resources() with multiple parameters."""
+ with expected_path:
+ resource_path = get_resources(resource_name)
+ assert resource_path.exists()
+
+
+def test_remove_resource_wrong_directory(
+ monkeypatch: Any, test_applications_path: Path
+) -> None:
+ """Test removing resource with wrong directory."""
+ mock_get_resources = MagicMock(return_value=test_applications_path)
+ monkeypatch.setattr("aiet.utils.fs.get_resources", mock_get_resources)
+
+ mock_shutil_rmtree = MagicMock()
+ monkeypatch.setattr("aiet.utils.fs.shutil.rmtree", mock_shutil_rmtree)
+
+ with pytest.raises(Exception, match="Resource .* does not exist"):
+ remove_resource("unknown", "applications")
+ mock_shutil_rmtree.assert_not_called()
+
+ with pytest.raises(Exception, match="Wrong resource .*"):
+ remove_resource("readme.txt", "applications")
+ mock_shutil_rmtree.assert_not_called()
+
+
+def test_remove_resource(monkeypatch: Any, test_applications_path: Path) -> None:
+ """Test removing resource data."""
+ mock_get_resources = MagicMock(return_value=test_applications_path)
+ monkeypatch.setattr("aiet.utils.fs.get_resources", mock_get_resources)
+
+ mock_shutil_rmtree = MagicMock()
+ monkeypatch.setattr("aiet.utils.fs.shutil.rmtree", mock_shutil_rmtree)
+
+ remove_resource("application1", "applications")
+ mock_shutil_rmtree.assert_called_once()
+
+
+def test_remove_directory(tmpdir: Any) -> None:
+ """Test directory removal."""
+ tmpdir_path = Path(tmpdir)
+ tmpfile = tmpdir_path / "temp.txt"
+
+ for item in [None, tmpfile]:
+ with pytest.raises(Exception, match="No directory path provided"):
+ remove_directory(item)
+
+ newdir = tmpdir_path / "newdir"
+ newdir.mkdir()
+
+ assert newdir.is_dir()
+ remove_directory(newdir)
+ assert not newdir.exists()
+
+
+def test_recreate_directory(tmpdir: Any) -> None:
+ """Test directory recreation."""
+ with pytest.raises(Exception, match="No directory path provided"):
+ recreate_directory(None)
+
+ tmpdir_path = Path(tmpdir)
+ tmpfile = tmpdir_path / "temp.txt"
+ tmpfile.touch()
+ with pytest.raises(Exception, match="Path .* does exist and it is not a directory"):
+ recreate_directory(tmpfile)
+
+ newdir = tmpdir_path / "newdir"
+ newdir.mkdir()
+ newfile = newdir / "newfile"
+ newfile.touch()
+ assert list(newdir.iterdir()) == [newfile]
+ recreate_directory(newdir)
+ assert not list(newdir.iterdir())
+
+ newdir2 = tmpdir_path / "newdir2"
+ assert not newdir2.exists()
+ recreate_directory(newdir2)
+ assert newdir2.is_dir()
+
+
+def write_to_file(
+ write_directory: Any, write_mode: str, write_text: Union[str, bytes]
+) -> Path:
+ """Write some text to a temporary test file."""
+ tmpdir_path = Path(write_directory)
+ tmpfile = tmpdir_path / "file_name.txt"
+ with open(tmpfile, write_mode) as file: # pylint: disable=unspecified-encoding
+ file.write(write_text)
+ return tmpfile
+
+
+class TestReadFileAsString:
+ """Test read_file_as_string() function."""
+
+ def test_returns_text_from_valid_file(self, tmpdir: Any) -> None:
+ """Ensure the string written to a file read correctly."""
+ file_path = write_to_file(tmpdir, "w", "hello")
+ assert read_file_as_string(file_path) == "hello"
+
+ def test_output_is_empty_string_when_input_file_non_existent(
+ self, tmpdir: Any
+ ) -> None:
+ """Ensure empty string returned when reading from non-existent file."""
+ file_path = Path(tmpdir / "non-existent.txt")
+ assert read_file_as_string(file_path) == ""
+
+
+class TestReadFileAsByteArray:
+ """Test read_file_as_bytearray() function."""
+
+ def test_returns_bytes_from_valid_file(self, tmpdir: Any) -> None:
+ """Ensure the bytes written to a file read correctly."""
+ file_path = write_to_file(tmpdir, "wb", b"hello bytes")
+ assert read_file_as_bytearray(file_path) == b"hello bytes"
+
+ def test_output_is_empty_bytearray_when_input_file_non_existent(
+ self, tmpdir: Any
+ ) -> None:
+ """Ensure empty bytearray returned when reading from non-existent file."""
+ file_path = Path(tmpdir / "non-existent.txt")
+ assert read_file_as_bytearray(file_path) == bytearray()
+
+
+@pytest.mark.parametrize(
+ "value, replacement, expected_result",
+ [
+ ["", "", ""],
+ ["123", "", "123"],
+ ["123", "_", "123"],
+ ["/some_folder/some_script.sh", "", "some_foldersome_script.sh"],
+ ["/some_folder/some_script.sh", "_", "_some_folder_some_script.sh"],
+ ["!;'some_name$%^!", "_", "___some_name____"],
+ ],
+)
+def test_valid_for_filename(value: str, replacement: str, expected_result: str) -> None:
+ """Test function valid_for_filename."""
+ assert valid_for_filename(value, replacement) == expected_result
diff --git a/tests/aiet/test_utils_helpers.py b/tests/aiet/test_utils_helpers.py
new file mode 100644
index 0000000..bbe03fc
--- /dev/null
+++ b/tests/aiet/test_utils_helpers.py
@@ -0,0 +1,27 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for testing helpers.py."""
+import logging
+from typing import Any
+from typing import List
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+from aiet.utils.helpers import set_verbosity
+
+
+@pytest.mark.parametrize(
+ "verbosity,expected_calls",
+ [(0, []), (1, [call(logging.INFO)]), (2, [call(logging.DEBUG)])],
+)
+def test_set_verbosity(
+ verbosity: int, expected_calls: List[Any], monkeypatch: Any
+) -> None:
+ """Test set_verbosity() with different verbsosity levels."""
+ with monkeypatch.context() as mock_context:
+ logging_mock = MagicMock()
+ mock_context.setattr(logging.getLogger(), "setLevel", logging_mock)
+ set_verbosity(None, None, verbosity)
+ logging_mock.assert_has_calls(expected_calls)
diff --git a/tests/aiet/test_utils_proc.py b/tests/aiet/test_utils_proc.py
new file mode 100644
index 0000000..9fb48dd
--- /dev/null
+++ b/tests/aiet/test_utils_proc.py
@@ -0,0 +1,272 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=attribute-defined-outside-init,no-self-use,not-callable
+"""Pytests for testing aiet/utils/proc.py."""
+from pathlib import Path
+from typing import Any
+from unittest import mock
+
+import psutil
+import pytest
+from sh import ErrorReturnCode
+
+from aiet.utils.proc import Command
+from aiet.utils.proc import CommandFailedException
+from aiet.utils.proc import CommandNotFound
+from aiet.utils.proc import parse_command
+from aiet.utils.proc import print_command_stdout
+from aiet.utils.proc import run_and_wait
+from aiet.utils.proc import save_process_info
+from aiet.utils.proc import ShellCommand
+from aiet.utils.proc import terminate_command
+from aiet.utils.proc import terminate_external_process
+
+
+class TestShellCommand:
+ """Sample class for collecting tests."""
+
+ def test_shellcommand_default_value(self) -> None:
+ """Test the instantiation of the class ShellCommand with no parameter."""
+ shell_command = ShellCommand()
+ assert shell_command.base_log_path == "/tmp"
+
+ @pytest.mark.parametrize(
+ "base_log_path,expected", [("/test", "/test"), ("/asd", "/asd")]
+ )
+ def test_shellcommand_with_param(self, base_log_path: str, expected: str) -> None:
+ """Test init ShellCommand with different parameters."""
+ shell_command = ShellCommand(base_log_path)
+ assert shell_command.base_log_path == expected
+
+ def test_run_ls(self, monkeypatch: Any) -> None:
+ """Test a simple ls command."""
+ mock_command = mock.MagicMock()
+ monkeypatch.setattr(Command, "bake", mock_command)
+
+ mock_get_stdout_stderr_paths = mock.MagicMock()
+ mock_get_stdout_stderr_paths.return_value = ("/tmp/std.out", "/tmp/std.err")
+ monkeypatch.setattr(
+ ShellCommand, "get_stdout_stderr_paths", mock_get_stdout_stderr_paths
+ )
+
+ shell_command = ShellCommand()
+ shell_command.run("ls", "-l")
+ assert mock_command.mock_calls[0] == mock.call(("-l",))
+ assert mock_command.mock_calls[1] == mock.call()(
+ _bg=True, _err="/tmp/std.err", _out="/tmp/std.out", _tee=True, _bg_exc=False
+ )
+
+ def test_run_command_not_found(self) -> None:
+ """Test whe the command doesn't exist."""
+ shell_command = ShellCommand()
+ with pytest.raises(CommandNotFound):
+ shell_command.run("lsl", "-l")
+
+ def test_get_stdout_stderr_paths_valid_path(self) -> None:
+ """Test the method to get files to store stdout and stderr."""
+ valid_path = "/tmp"
+ shell_command = ShellCommand(valid_path)
+ out, err = shell_command.get_stdout_stderr_paths(valid_path, "cmd")
+ assert out.exists() and out.is_file()
+ assert err.exists() and err.is_file()
+ assert "cmd" in out.name
+ assert "cmd" in err.name
+
+ def test_get_stdout_stderr_paths_not_invalid_path(self) -> None:
+ """Test the method to get output files with an invalid path."""
+ invalid_path = "/invalid/foo/bar"
+ shell_command = ShellCommand(invalid_path)
+ with pytest.raises(FileNotFoundError):
+ shell_command.get_stdout_stderr_paths(invalid_path, "cmd")
+
+
+@mock.patch("builtins.print")
+def test_print_command_stdout_alive(mock_print: Any) -> None:
+ """Test the print command stdout with an alive (running) process."""
+ mock_command = mock.MagicMock()
+ mock_command.is_alive.return_value = True
+ mock_command.next.side_effect = ["test1", "test2", StopIteration]
+
+ print_command_stdout(mock_command)
+
+ mock_command.assert_has_calls(
+ [mock.call.is_alive(), mock.call.next(), mock.call.next()]
+ )
+ mock_print.assert_has_calls(
+ [mock.call("test1", end=""), mock.call("test2", end="")]
+ )
+
+
+@mock.patch("builtins.print")
+def test_print_command_stdout_not_alive(mock_print: Any) -> None:
+ """Test the print command stdout with a not alive (exited) process."""
+ mock_command = mock.MagicMock()
+ mock_command.is_alive.return_value = False
+ mock_command.stdout = "test"
+
+ print_command_stdout(mock_command)
+ mock_command.assert_has_calls([mock.call.is_alive()])
+ mock_print.assert_called_once_with("test")
+
+
+def test_terminate_external_process_no_process(capsys: Any) -> None:
+ """Test that non existed process could be terminated."""
+ mock_command = mock.MagicMock()
+ mock_command.terminate.side_effect = psutil.Error("Error!")
+
+ terminate_external_process(mock_command)
+ captured = capsys.readouterr()
+ assert captured.out == "Unable to terminate process\n"
+
+
+def test_terminate_external_process_case1() -> None:
+ """Test when process terminated immediately."""
+ mock_command = mock.MagicMock()
+ mock_command.is_running.return_value = False
+
+ terminate_external_process(mock_command)
+ mock_command.terminate.assert_called_once()
+ mock_command.is_running.assert_called_once()
+
+
+def test_terminate_external_process_case2() -> None:
+ """Test when process termination takes time."""
+ mock_command = mock.MagicMock()
+ mock_command.is_running.side_effect = [True, True, False]
+
+ terminate_external_process(mock_command)
+ mock_command.terminate.assert_called_once()
+ assert mock_command.is_running.call_count == 3
+
+
+def test_terminate_external_process_case3() -> None:
+ """Test when process termination takes more time."""
+ mock_command = mock.MagicMock()
+ mock_command.is_running.side_effect = [True, True, True]
+
+ terminate_external_process(
+ mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1
+ )
+ assert mock_command.is_running.call_count == 3
+ assert mock_command.terminate.call_count == 2
+
+
+def test_terminate_external_process_case4() -> None:
+ """Test when process termination takes more time."""
+ mock_command = mock.MagicMock()
+ mock_command.is_running.side_effect = [True, True, False]
+
+ terminate_external_process(
+ mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1
+ )
+ mock_command.terminate.assert_called_once()
+ assert mock_command.is_running.call_count == 3
+ assert mock_command.terminate.call_count == 1
+
+
+def test_terminate_command_no_process() -> None:
+ """Test command termination when process does not exist."""
+ mock_command = mock.MagicMock()
+ mock_command.process.signal_group.side_effect = ProcessLookupError()
+
+ terminate_command(mock_command)
+ mock_command.process.signal_group.assert_called_once()
+ mock_command.is_alive.assert_not_called()
+
+
+def test_terminate_command() -> None:
+ """Test command termination."""
+ mock_command = mock.MagicMock()
+ mock_command.is_alive.return_value = False
+
+ terminate_command(mock_command)
+ mock_command.process.signal_group.assert_called_once()
+
+
+def test_terminate_command_case1() -> None:
+ """Test command termination when it takes time.."""
+ mock_command = mock.MagicMock()
+ mock_command.is_alive.side_effect = [True, True, False]
+
+ terminate_command(mock_command, wait_period=0.1)
+ mock_command.process.signal_group.assert_called_once()
+ assert mock_command.is_alive.call_count == 3
+
+
+def test_terminate_command_case2() -> None:
+ """Test command termination when it takes much time.."""
+ mock_command = mock.MagicMock()
+ mock_command.is_alive.side_effect = [True, True, True]
+
+ terminate_command(mock_command, number_of_attempts=3, wait_period=0.1)
+ assert mock_command.is_alive.call_count == 3
+ assert mock_command.process.signal_group.call_count == 2
+
+
+class TestRunAndWait:
+ """Test run_and_wait function."""
+
+ @pytest.fixture(autouse=True)
+ def setup_method(self, monkeypatch: Any) -> None:
+ """Init test method."""
+ self.execute_command_mock = mock.MagicMock()
+ monkeypatch.setattr(
+ "aiet.utils.proc.execute_command", self.execute_command_mock
+ )
+
+ self.terminate_command_mock = mock.MagicMock()
+ monkeypatch.setattr(
+ "aiet.utils.proc.terminate_command", self.terminate_command_mock
+ )
+
+ def test_if_execute_command_raises_exception(self) -> None:
+ """Test if execute_command fails."""
+ self.execute_command_mock.side_effect = Exception("Error!")
+ with pytest.raises(Exception, match="Error!"):
+ run_and_wait("command", Path.cwd())
+
+ def test_if_command_finishes_with_error(self) -> None:
+ """Test if command finishes with error."""
+ cmd_mock = mock.MagicMock()
+ self.execute_command_mock.return_value = cmd_mock
+ exit_code_mock = mock.PropertyMock(
+ side_effect=ErrorReturnCode("cmd", bytearray(), bytearray())
+ )
+ type(cmd_mock).exit_code = exit_code_mock
+
+ with pytest.raises(CommandFailedException):
+ run_and_wait("command", Path.cwd())
+
+ @pytest.mark.parametrize("terminate_on_error, call_count", ((False, 0), (True, 1)))
+ def test_if_command_finishes_with_exception(
+ self, terminate_on_error: bool, call_count: int
+ ) -> None:
+ """Test if command finishes with error."""
+ cmd_mock = mock.MagicMock()
+ self.execute_command_mock.return_value = cmd_mock
+ exit_code_mock = mock.PropertyMock(side_effect=Exception("Error!"))
+ type(cmd_mock).exit_code = exit_code_mock
+
+ with pytest.raises(Exception, match="Error!"):
+ run_and_wait("command", Path.cwd(), terminate_on_error=terminate_on_error)
+
+ assert self.terminate_command_mock.call_count == call_count
+
+
+def test_save_process_info_no_process(monkeypatch: Any, tmpdir: Any) -> None:
+ """Test save_process_info function."""
+ mock_process = mock.MagicMock()
+ monkeypatch.setattr("psutil.Process", mock.MagicMock(return_value=mock_process))
+ mock_process.children.side_effect = psutil.NoSuchProcess(555)
+
+ pid_file_path = Path(tmpdir) / "test.pid"
+ save_process_info(555, pid_file_path)
+ assert not pid_file_path.exists()
+
+
+def test_parse_command() -> None:
+ """Test parse_command function."""
+ assert parse_command("1.sh") == ["bash", "1.sh"]
+ assert parse_command("1.sh", shell="sh") == ["sh", "1.sh"]
+ assert parse_command("command") == ["command"]
+ assert parse_command("command 123 --param=1") == ["command", "123", "--param=1"]
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..5c6156c
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,95 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Pytest conf module."""
+import shutil
+from pathlib import Path
+from typing import Generator
+
+import pytest
+import tensorflow as tf
+
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import save_keras_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+from mlia.tools.vela_wrapper import optimize_model
+
+
+def get_test_keras_model() -> tf.keras.Model:
+ """Return test Keras model."""
+ model = tf.keras.Sequential(
+ [
+ tf.keras.Input(shape=(28, 28, 1), batch_size=1, name="input"),
+ tf.keras.layers.Reshape((28, 28, 1)),
+ tf.keras.layers.Conv2D(
+ filters=12, kernel_size=(3, 3), activation="relu", name="conv1"
+ ),
+ tf.keras.layers.Conv2D(
+ filters=12, kernel_size=(3, 3), activation="relu", name="conv2"
+ ),
+ tf.keras.layers.MaxPool2D(2, 2),
+ tf.keras.layers.Flatten(),
+ tf.keras.layers.Dense(10, name="output"),
+ ]
+ )
+
+ model.compile(optimizer="sgd", loss="mean_squared_error")
+ return model
+
+
+@pytest.fixture(scope="session", name="test_models_path")
+def fixture_test_models_path(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Provide path to the test models."""
+ tmp_path = tmp_path_factory.mktemp("models")
+
+ keras_model = get_test_keras_model()
+ save_keras_model(keras_model, tmp_path / "test_model.h5")
+
+ tflite_model = convert_to_tflite(keras_model, quantized=True)
+ tflite_model_path = tmp_path / "test_model.tflite"
+ save_tflite_model(tflite_model, tflite_model_path)
+
+ tflite_vela_model = tmp_path / "test_model_vela.tflite"
+ device = EthosUConfiguration("ethos-u55-256")
+ optimize_model(tflite_model_path, device.compiler_options, tflite_vela_model)
+
+ tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model"))
+
+ invalid_tflite_model = tmp_path / "invalid.tflite"
+ invalid_tflite_model.touch()
+
+ yield tmp_path
+
+ shutil.rmtree(tmp_path)
+
+
+@pytest.fixture(scope="session", name="test_keras_model")
+def fixture_test_keras_model(test_models_path: Path) -> Path:
+ """Return test Keras model."""
+ return test_models_path / "test_model.h5"
+
+
+@pytest.fixture(scope="session", name="test_tflite_model")
+def fixture_test_tflite_model(test_models_path: Path) -> Path:
+ """Return test TFLite model."""
+ return test_models_path / "test_model.tflite"
+
+
+@pytest.fixture(scope="session", name="test_tflite_vela_model")
+def fixture_test_tflite_vela_model(test_models_path: Path) -> Path:
+ """Return test Vela-optimized TFLite model."""
+ return test_models_path / "test_model_vela.tflite"
+
+
+@pytest.fixture(scope="session", name="test_tf_model")
+def fixture_test_tf_model(test_models_path: Path) -> Path:
+ """Return test TFLite model."""
+ return test_models_path / "tf_model_test_model"
+
+
+@pytest.fixture(scope="session", name="test_tflite_invalid_model")
+def fixture_test_tflite_invalid_model(test_models_path: Path) -> Path:
+ """Return test invalid TFLite model."""
+ return test_models_path / "invalid.tflite"
diff --git a/tests/mlia/__init__.py b/tests/mlia/__init__.py
new file mode 100644
index 0000000..0687f14
--- /dev/null
+++ b/tests/mlia/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""MLIA tests module."""
diff --git a/tests/mlia/conftest.py b/tests/mlia/conftest.py
new file mode 100644
index 0000000..f683fca
--- /dev/null
+++ b/tests/mlia/conftest.py
@@ -0,0 +1,20 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Pytest conf module."""
+from pathlib import Path
+
+import pytest
+
+from mlia.core.context import ExecutionContext
+
+
+@pytest.fixture(scope="session", name="test_resources_path")
+def fixture_test_resources_path() -> Path:
+ """Return test resources path."""
+ return Path(__file__).parent / "test_resources"
+
+
+@pytest.fixture(name="dummy_context")
+def fixture_dummy_context(tmpdir: str) -> ExecutionContext:
+ """Return dummy context fixture."""
+ return ExecutionContext(working_dir=tmpdir)
diff --git a/tests/mlia/test_api.py b/tests/mlia/test_api.py
new file mode 100644
index 0000000..54d4796
--- /dev/null
+++ b/tests/mlia/test_api.py
@@ -0,0 +1,96 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the API functions."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.api import get_advice
+from mlia.core.common import AdviceCategory
+from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
+
+
+def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
+ """Test getting advice when no target provided."""
+ with pytest.raises(Exception, match="Target is not provided"):
+ get_advice(None, test_keras_model, "all") # type: ignore
+
+
+def test_get_advice_wrong_category(test_keras_model: Path) -> None:
+ """Test getting advice when wrong advice category provided."""
+ with pytest.raises(Exception, match="Invalid advice category unknown"):
+ get_advice("ethos-u55-256", test_keras_model, "unknown") # type: ignore
+
+
+@pytest.mark.parametrize(
+ "category, context, expected_category",
+ [
+ [
+ "all",
+ None,
+ AdviceCategory.ALL,
+ ],
+ [
+ "optimization",
+ None,
+ AdviceCategory.OPTIMIZATION,
+ ],
+ [
+ "operators",
+ None,
+ AdviceCategory.OPERATORS,
+ ],
+ [
+ "performance",
+ None,
+ AdviceCategory.PERFORMANCE,
+ ],
+ [
+ "all",
+ ExecutionContext(),
+ AdviceCategory.ALL,
+ ],
+ [
+ "all",
+ ExecutionContext(advice_category=AdviceCategory.PERFORMANCE),
+ AdviceCategory.PERFORMANCE,
+ ],
+ [
+ "all",
+ ExecutionContext(config_parameters={"param": "value"}),
+ AdviceCategory.ALL,
+ ],
+ [
+ "all",
+ ExecutionContext(event_handlers=[MagicMock()]),
+ AdviceCategory.ALL,
+ ],
+ ],
+)
+def test_get_advice(
+ monkeypatch: pytest.MonkeyPatch,
+ category: str,
+ context: ExecutionContext,
+ expected_category: AdviceCategory,
+ test_keras_model: Path,
+) -> None:
+ """Test getting advice with valid parameters."""
+ advisor_mock = MagicMock()
+ monkeypatch.setattr("mlia.api._get_advisor", MagicMock(return_value=advisor_mock))
+
+ get_advice(
+ "ethos-u55-256",
+ test_keras_model,
+ category, # type: ignore
+ context=context,
+ )
+
+ advisor_mock.run.assert_called_once()
+ context = advisor_mock.run.mock_calls[0].args[0]
+ assert isinstance(context, Context)
+ assert context.advice_category == expected_category
+
+ assert context.event_handlers is not None
+ assert context.config_parameters is not None
diff --git a/tests/mlia/test_cli_commands.py b/tests/mlia/test_cli_commands.py
new file mode 100644
index 0000000..bf17339
--- /dev/null
+++ b/tests/mlia/test_cli_commands.py
@@ -0,0 +1,204 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for cli.commands module."""
+from pathlib import Path
+from typing import Any
+from typing import Optional
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.cli.commands import backend
+from mlia.cli.commands import operators
+from mlia.cli.commands import optimization
+from mlia.cli.commands import performance
+from mlia.core.context import ExecutionContext
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.tools.metadata.common import InstallationManager
+
+
+def test_operators_expected_parameters(dummy_context: ExecutionContext) -> None:
+ """Test operators command wrong parameters."""
+ with pytest.raises(Exception, match="Model is not provided"):
+ operators(dummy_context, "ethos-u55-256")
+
+
+def test_performance_unknown_target(
+ dummy_context: ExecutionContext, test_tflite_model: Path
+) -> None:
+ """Test that command should fail if unknown target passed."""
+ with pytest.raises(Exception, match="Unable to find target profile unknown"):
+ performance(
+ dummy_context, model=str(test_tflite_model), target_profile="unknown"
+ )
+
+
+@pytest.mark.parametrize(
+ "target_profile, optimization_type, optimization_target, expected_error",
+ [
+ [
+ "ethos-u55-256",
+ None,
+ "0.5",
+ pytest.raises(Exception, match="Optimization type is not provided"),
+ ],
+ [
+ "ethos-u65-512",
+ "unknown",
+ "16",
+ pytest.raises(Exception, match="Unsupported optimization type: unknown"),
+ ],
+ [
+ "ethos-u55-256",
+ "pruning",
+ None,
+ pytest.raises(Exception, match="Optimization target is not provided"),
+ ],
+ [
+ "ethos-u65-512",
+ "clustering",
+ None,
+ pytest.raises(Exception, match="Optimization target is not provided"),
+ ],
+ [
+ "unknown",
+ "clustering",
+ "16",
+ pytest.raises(Exception, match="Unable to find target profile unknown"),
+ ],
+ ],
+)
+def test_opt_expected_parameters(
+ dummy_context: ExecutionContext,
+ target_profile: str,
+ monkeypatch: pytest.MonkeyPatch,
+ optimization_type: str,
+ optimization_target: str,
+ expected_error: Any,
+ test_keras_model: Path,
+) -> None:
+ """Test that command should fail if no or unknown optimization type provided."""
+ mock_performance_estimation(monkeypatch)
+
+ with expected_error:
+ optimization(
+ ctx=dummy_context,
+ target_profile=target_profile,
+ model=str(test_keras_model),
+ optimization_type=optimization_type,
+ optimization_target=optimization_target,
+ )
+
+
+@pytest.mark.parametrize(
+ "target_profile, optimization_type, optimization_target",
+ [
+ ["ethos-u55-256", "pruning", "0.5"],
+ ["ethos-u65-512", "clustering", "32"],
+ ["ethos-u55-256", "pruning,clustering", "0.5,32"],
+ ],
+)
+def test_opt_valid_optimization_target(
+ target_profile: str,
+ dummy_context: ExecutionContext,
+ optimization_type: str,
+ optimization_target: str,
+ monkeypatch: pytest.MonkeyPatch,
+ test_keras_model: Path,
+) -> None:
+ """Test that command should not fail with valid optimization targets."""
+ mock_performance_estimation(monkeypatch)
+
+ optimization(
+ ctx=dummy_context,
+ target_profile=target_profile,
+ model=str(test_keras_model),
+ optimization_type=optimization_type,
+ optimization_target=optimization_target,
+ )
+
+
+def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Mock performance estimation."""
+ metrics = PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ MemoryUsage(1, 2, 3, 4, 5),
+ )
+ monkeypatch.setattr(
+ "mlia.devices.ethosu.data_collection.EthosUPerformanceEstimator.estimate",
+ MagicMock(return_value=metrics),
+ )
+
+
+@pytest.fixture(name="installation_manager_mock")
+def fixture_mock_installation_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
+ """Mock installation manager."""
+ install_manager_mock = MagicMock(spec=InstallationManager)
+ monkeypatch.setattr(
+ "mlia.cli.commands.get_installation_manager",
+ MagicMock(return_value=install_manager_mock),
+ )
+ return install_manager_mock
+
+
+def test_backend_command_action_status(installation_manager_mock: MagicMock) -> None:
+ """Test backend command "status"."""
+ backend(backend_action="status")
+
+ installation_manager_mock.show_env_details.assert_called_once()
+
+
+@pytest.mark.parametrize(
+ "i_agree_to_the_contained_eula, backend_name, expected_calls",
+ [
+ [False, None, [call(None, True)]],
+ [True, None, [call(None, False)]],
+ [False, "backend_name", [call("backend_name", True)]],
+ [True, "backend_name", [call("backend_name", False)]],
+ ],
+)
+def test_backend_command_action_add_downoad(
+ installation_manager_mock: MagicMock,
+ i_agree_to_the_contained_eula: bool,
+ backend_name: Optional[str],
+ expected_calls: Any,
+) -> None:
+ """Test backend command "install" with download option."""
+ backend(
+ backend_action="install",
+ download=True,
+ name=backend_name,
+ i_agree_to_the_contained_eula=i_agree_to_the_contained_eula,
+ )
+
+ assert installation_manager_mock.download_and_install.mock_calls == expected_calls
+
+
+@pytest.mark.parametrize("backend_name", [None, "backend_name"])
+def test_backend_command_action_install_from_path(
+ installation_manager_mock: MagicMock,
+ tmp_path: Path,
+ backend_name: Optional[str],
+) -> None:
+ """Test backend command "install" with backend path."""
+ backend(backend_action="install", path=tmp_path, name=backend_name)
+
+ installation_manager_mock.install_from(tmp_path, backend_name)
+
+
+def test_backend_command_action_install_only_one_action(
+ installation_manager_mock: MagicMock, # pylint: disable=unused-argument
+ tmp_path: Path,
+) -> None:
+ """Test that only one of action type allowed."""
+ with pytest.raises(
+ Exception,
+ match="Please select only one action: download or "
+ "provide path to the backend installation",
+ ):
+ backend(backend_action="install", download=True, path=tmp_path)
diff --git a/tests/mlia/test_cli_config.py b/tests/mlia/test_cli_config.py
new file mode 100644
index 0000000..6d19eec
--- /dev/null
+++ b/tests/mlia/test_cli_config.py
@@ -0,0 +1,49 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for cli.config module."""
+from typing import List
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.cli.config import get_default_backends
+from mlia.cli.config import is_corstone_backend
+
+
+@pytest.mark.parametrize(
+ "available_backends, expected_default_backends",
+ [
+ [["Vela"], ["Vela"]],
+ [["Corstone-300"], ["Corstone-300"]],
+ [["Corstone-310"], ["Corstone-310"]],
+ [["Corstone-300", "Corstone-310"], ["Corstone-310"]],
+ [["Vela", "Corstone-300", "Corstone-310"], ["Vela", "Corstone-310"]],
+ [
+ ["Vela", "Corstone-300", "Corstone-310", "New backend"],
+ ["Vela", "Corstone-310", "New backend"],
+ ],
+ [
+ ["Vela", "Corstone-300", "New backend"],
+ ["Vela", "Corstone-300", "New backend"],
+ ],
+ ],
+)
+def test_get_default_backends(
+ monkeypatch: pytest.MonkeyPatch,
+ available_backends: List[str],
+ expected_default_backends: List[str],
+) -> None:
+ """Test function get_default backends."""
+ monkeypatch.setattr(
+ "mlia.cli.config.get_available_backends",
+ MagicMock(return_value=available_backends),
+ )
+
+ assert get_default_backends() == expected_default_backends
+
+
+def test_is_corstone_backend() -> None:
+ """Test function is_corstone_backend."""
+ assert is_corstone_backend("Corstone-300") is True
+ assert is_corstone_backend("Corstone-310") is True
+ assert is_corstone_backend("New backend") is False
diff --git a/tests/mlia/test_cli_helpers.py b/tests/mlia/test_cli_helpers.py
new file mode 100644
index 0000000..2c52885
--- /dev/null
+++ b/tests/mlia/test_cli_helpers.py
@@ -0,0 +1,165 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the helper classes."""
+from typing import Any
+from typing import Dict
+from typing import List
+
+import pytest
+
+from mlia.cli.helpers import CLIActionResolver
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+class TestCliActionResolver:
+ """Test cli action resolver."""
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "args, params, expected_result",
+ [
+ [
+ {},
+ {"opt_settings": "some_setting"},
+ [],
+ ],
+ [
+ {},
+ {},
+ [
+ "Note: you will need a Keras model for that.",
+ "For example: mlia optimization --optimization-type "
+ "pruning,clustering --optimization-target 0.5,32 "
+ "/path/to/keras_model",
+ "For more info: mlia optimization --help",
+ ],
+ ],
+ [
+ {"model": "model.h5"},
+ {},
+ [
+ "For example: mlia optimization --optimization-type "
+ "pruning,clustering --optimization-target 0.5,32 model.h5",
+ "For more info: mlia optimization --help",
+ ],
+ ],
+ [
+ {"model": "model.h5"},
+ {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]},
+ [
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ "mlia optimization --optimization-type pruning "
+ "--optimization-target 0.5 model.h5",
+ ],
+ ],
+ [
+ {"model": "model.h5", "target_profile": "target_profile"},
+ {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]},
+ [
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ "mlia optimization --optimization-type pruning "
+ "--optimization-target 0.5 "
+ "--target-profile target_profile model.h5",
+ ],
+ ],
+ ],
+ )
+ def test_apply_optimizations(
+ args: Dict[str, Any],
+ params: Dict[str, Any],
+ expected_result: List[str],
+ ) -> None:
+ """Test action resolving for applying optimizations."""
+ resolver = CLIActionResolver(args)
+ assert resolver.apply_optimizations(**params) == expected_result
+
+ @staticmethod
+ def test_supported_operators_info() -> None:
+ """Test supported operators info."""
+ resolver = CLIActionResolver({})
+ assert resolver.supported_operators_info() == [
+ "For guidance on supported operators, run: mlia operators "
+ "--supported-ops-report",
+ ]
+
+ @staticmethod
+ def test_operator_compatibility_details() -> None:
+ """Test operator compatibility details info."""
+ resolver = CLIActionResolver({})
+ assert resolver.operator_compatibility_details() == [
+ "For more details, run: mlia operators --help"
+ ]
+
+ @staticmethod
+ def test_optimization_details() -> None:
+ """Test optimization details info."""
+ resolver = CLIActionResolver({})
+ assert resolver.optimization_details() == [
+ "For more info, see: mlia optimization --help"
+ ]
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "args, expected_result",
+ [
+ [
+ {},
+ [],
+ ],
+ [
+ {"model": "model.tflite"},
+ [
+ "Check the estimated performance by running the "
+ "following command: ",
+ "mlia performance model.tflite",
+ ],
+ ],
+ [
+ {"model": "model.tflite", "target_profile": "target_profile"},
+ [
+ "Check the estimated performance by running the "
+ "following command: ",
+ "mlia performance --target-profile target_profile model.tflite",
+ ],
+ ],
+ ],
+ )
+ def test_check_performance(
+ args: Dict[str, Any], expected_result: List[str]
+ ) -> None:
+ """Test check performance info."""
+ resolver = CLIActionResolver(args)
+ assert resolver.check_performance() == expected_result
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "args, expected_result",
+ [
+ [
+ {},
+ [],
+ ],
+ [
+ {"model": "model.tflite"},
+ [
+ "Try running the following command to verify that:",
+ "mlia operators model.tflite",
+ ],
+ ],
+ [
+ {"model": "model.tflite", "target_profile": "target_profile"},
+ [
+ "Try running the following command to verify that:",
+ "mlia operators --target-profile target_profile model.tflite",
+ ],
+ ],
+ ],
+ )
+ def test_check_operator_compatibility(
+ args: Dict[str, Any], expected_result: List[str]
+ ) -> None:
+ """Test checking operator compatibility info."""
+ resolver = CLIActionResolver(args)
+ assert resolver.check_operator_compatibility() == expected_result
diff --git a/tests/mlia/test_cli_logging.py b/tests/mlia/test_cli_logging.py
new file mode 100644
index 0000000..7c5f299
--- /dev/null
+++ b/tests/mlia/test_cli_logging.py
@@ -0,0 +1,104 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module cli.logging."""
+import logging
+from pathlib import Path
+from typing import Optional
+
+import pytest
+
+from mlia.cli.logging import setup_logging
+from tests.mlia.utils.logging import clear_loggers
+
+
+def teardown_function() -> None:
+ """Perform action after test completion.
+
+ This function is launched automatically by pytest after each test
+ in this module.
+ """
+ clear_loggers()
+
+
+@pytest.mark.parametrize(
+ "logs_dir, verbose, expected_output, expected_log_file_content",
+ [
+ (
+ None,
+ None,
+ "cli info\n",
+ None,
+ ),
+ (
+ None,
+ True,
+ """mlia.tools.aiet_wrapper - aiet debug
+cli info
+mlia.cli - cli debug
+""",
+ None,
+ ),
+ (
+ "logs",
+ True,
+ """mlia.tools.aiet_wrapper - aiet debug
+cli info
+mlia.cli - cli debug
+""",
+ """mlia.tools.aiet_wrapper - DEBUG - aiet debug
+mlia.cli - DEBUG - cli debug
+""",
+ ),
+ ],
+)
+def test_setup_logging(
+ tmp_path: Path,
+ capfd: pytest.CaptureFixture,
+ logs_dir: str,
+ verbose: bool,
+ expected_output: str,
+ expected_log_file_content: str,
+) -> None:
+ """Test function setup_logging."""
+ logs_dir_path = tmp_path / logs_dir if logs_dir else None
+
+ setup_logging(logs_dir_path, verbose)
+
+ aiet_logger = logging.getLogger("mlia.tools.aiet_wrapper")
+ aiet_logger.debug("aiet debug")
+
+ cli_logger = logging.getLogger("mlia.cli")
+ cli_logger.info("cli info")
+ cli_logger.debug("cli debug")
+
+ stdout, _ = capfd.readouterr()
+ assert stdout == expected_output
+
+ check_log_assertions(logs_dir_path, expected_log_file_content)
+
+
+def check_log_assertions(
+ logs_dir_path: Optional[Path], expected_log_file_content: str
+) -> None:
+ """Test assertions for log file."""
+ if logs_dir_path is not None:
+ assert logs_dir_path.is_dir()
+
+ items = list(logs_dir_path.iterdir())
+ assert len(items) == 1
+
+ log_file_path = items[0]
+ assert log_file_path.is_file()
+
+ log_file_name = log_file_path.name
+ assert log_file_name == "mlia.log"
+
+ with open(log_file_path, encoding="utf-8") as log_file:
+ log_content = log_file.read()
+
+ expected_lines = expected_log_file_content.split("\n")
+ produced_lines = log_content.split("\n")
+
+ assert len(expected_lines) == len(produced_lines)
+ for expected, produced in zip(expected_lines, produced_lines):
+ assert expected in produced
diff --git a/tests/mlia/test_cli_main.py b/tests/mlia/test_cli_main.py
new file mode 100644
index 0000000..a0937d5
--- /dev/null
+++ b/tests/mlia/test_cli_main.py
@@ -0,0 +1,357 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for main module."""
+import argparse
+from functools import wraps
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import List
+from unittest.mock import ANY
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+import mlia
+from mlia.cli.main import CommandInfo
+from mlia.cli.main import main
+from mlia.core.context import ExecutionContext
+from tests.mlia.utils.logging import clear_loggers
+
+
+def teardown_function() -> None:
+ """Perform action after test completion.
+
+ This function is launched automatically by pytest after each test
+ in this module.
+ """
+ clear_loggers()
+
+
+def test_option_version(capfd: pytest.CaptureFixture) -> None:
+ """Test --version."""
+ with pytest.raises(SystemExit) as ex:
+ main(["--version"])
+
+ assert ex.type == SystemExit
+ assert ex.value.code == 0
+
+ stdout, stderr = capfd.readouterr()
+ assert len(stdout.splitlines()) == 1
+ assert stderr == ""
+
+
+@pytest.mark.parametrize(
+ "is_default, expected_command_help",
+ [(True, "Test command [default]"), (False, "Test command")],
+)
+def test_command_info(is_default: bool, expected_command_help: str) -> None:
+ """Test properties of CommandInfo object."""
+
+ def test_command() -> None:
+ """Test command."""
+
+ command_info = CommandInfo(test_command, ["test"], [], is_default)
+ assert command_info.command_name == "test_command"
+ assert command_info.command_name_and_aliases == ["test_command", "test"]
+ assert command_info.command_help == expected_command_help
+
+
+def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
+ """Test adding default command."""
+
+ def mock_command(
+ func_mock: MagicMock, name: str, with_working_dir: bool
+ ) -> Callable[..., None]:
+ """Mock cli command."""
+
+ def sample_cmd_1(*args: Any, **kwargs: Any) -> None:
+ """Sample command."""
+ func_mock(*args, **kwargs)
+
+ def sample_cmd_2(ctx: ExecutionContext, **kwargs: Any) -> None:
+ """Another sample command."""
+ func_mock(ctx=ctx, **kwargs)
+
+ ret_func = sample_cmd_2 if with_working_dir else sample_cmd_1
+ ret_func.__name__ = name
+
+ return ret_func # type: ignore
+
+ default_command = MagicMock()
+ non_default_command = MagicMock()
+
+ def default_command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for default command."""
+ parser.add_argument("--sample")
+ parser.add_argument("--default_arg", default="123")
+
+ def non_default_command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for non default command."""
+ parser.add_argument("--param")
+
+ monkeypatch.setattr(
+ "mlia.cli.main.get_commands",
+ MagicMock(
+ return_value=[
+ CommandInfo(
+ func=mock_command(default_command, "default_command", True),
+ aliases=["command1"],
+ opt_groups=[default_command_params],
+ is_default=True,
+ ),
+ CommandInfo(
+ func=mock_command(
+ non_default_command, "non_default_command", False
+ ),
+ aliases=["command2"],
+ opt_groups=[non_default_command_params],
+ is_default=False,
+ ),
+ ]
+ ),
+ )
+
+ tmp_working_dir = str(tmp_path)
+ main(["--working-dir", tmp_working_dir, "--sample", "1"])
+ main(["command2", "--param", "test"])
+
+ default_command.assert_called_once_with(ctx=ANY, sample="1", default_arg="123")
+ non_default_command.assert_called_once_with(param="test")
+
+
+@pytest.mark.parametrize(
+ "params, expected_call",
+ [
+ [
+ ["operators", "sample_model.tflite"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["ops", "sample_model.tflite"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators", "sample_model.tflite", "--target-profile", "ethos-u55-128"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-128",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model=None,
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators", "--supported-ops-report"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model=None,
+ output=None,
+ supported_ops_report=True,
+ ),
+ ],
+ [
+ [
+ "all_tests",
+ "sample_model.h5",
+ "--optimization-type",
+ "pruning",
+ "--optimization-target",
+ "0.5",
+ ],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning",
+ optimization_target="0.5",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["sample_model.h5"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["performance", "sample_model.h5", "--output", "result.json"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ output="result.json",
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["perf", "sample_model.h5", "--target-profile", "ethos-u55-128"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-128",
+ model="sample_model.h5",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["optimization", "sample_model.h5"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["optimization", "sample_model.h5", "--evaluate-on", "some_backend"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["some_backend"],
+ ),
+ ],
+ ],
+)
+def test_commands_execution(
+ monkeypatch: pytest.MonkeyPatch, params: List[str], expected_call: Any
+) -> None:
+ """Test calling commands from the main function."""
+ mock = MagicMock()
+
+ def wrap_mock_command(command: Callable) -> Callable:
+ """Wrap the command with the mock."""
+
+ @wraps(command)
+ def mock_command(*args: Any, **kwargs: Any) -> Any:
+ """Mock the command."""
+ mock(*args, **kwargs)
+
+ return mock_command
+
+ monkeypatch.setattr(
+ "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"])
+ )
+
+ monkeypatch.setattr(
+ "mlia.cli.options.get_available_backends",
+ MagicMock(return_value=["Vela", "some_backend"]),
+ )
+
+ for command in ["all_tests", "operators", "performance", "optimization"]:
+ monkeypatch.setattr(
+ f"mlia.cli.main.{command}",
+ wrap_mock_command(getattr(mlia.cli.main, command)),
+ )
+
+ main(params)
+
+ mock.assert_called_once_with(*expected_call.args, **expected_call.kwargs)
+
+
+@pytest.mark.parametrize(
+ "verbose, exc_mock, expected_output",
+ [
+ [
+ True,
+ MagicMock(side_effect=Exception("Error")),
+ [
+ "Execution finished with error: Error",
+ f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
+ "for more details",
+ ],
+ ],
+ [
+ False,
+ MagicMock(side_effect=Exception("Error")),
+ [
+ "Execution finished with error: Error",
+ f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
+ "for more details, or enable verbose mode",
+ ],
+ ],
+ [
+ False,
+ MagicMock(side_effect=KeyboardInterrupt()),
+ ["Execution has been interrupted"],
+ ],
+ ],
+)
+def test_verbose_output(
+ monkeypatch: pytest.MonkeyPatch,
+ capsys: pytest.CaptureFixture,
+ verbose: bool,
+ exc_mock: MagicMock,
+ expected_output: List[str],
+) -> None:
+ """Test flag --verbose."""
+
+ def command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for non default command."""
+ parser.add_argument("--verbose", action="store_true")
+
+ def command() -> None:
+ """Run test command."""
+ exc_mock()
+
+ monkeypatch.setattr(
+ "mlia.cli.main.get_commands",
+ MagicMock(
+ return_value=[
+ CommandInfo(
+ func=command,
+ aliases=["command"],
+ opt_groups=[command_params],
+ ),
+ ]
+ ),
+ )
+
+ params = ["command"]
+ if verbose:
+ params.append("--verbose")
+
+ exit_code = main(params)
+ assert exit_code == 1
+
+ stdout, _ = capsys.readouterr()
+ for expected_message in expected_output:
+ assert expected_message in stdout
diff --git a/tests/mlia/test_cli_options.py b/tests/mlia/test_cli_options.py
new file mode 100644
index 0000000..a441e58
--- /dev/null
+++ b/tests/mlia/test_cli_options.py
@@ -0,0 +1,186 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module options."""
+import argparse
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+import pytest
+
+from mlia.cli.options import add_output_options
+from mlia.cli.options import get_target_profile_opts
+from mlia.cli.options import parse_optimization_parameters
+
+
+@pytest.mark.parametrize(
+ "optimization_type, optimization_target, expected_error, expected_result",
+ [
+ (
+ "pruning",
+ "0.5",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ )
+ ],
+ ),
+ (
+ "clustering",
+ "32",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.0,
+ layers_to_optimize=None,
+ )
+ ],
+ ),
+ (
+ "pruning,clustering",
+ "0.5,32",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.0,
+ layers_to_optimize=None,
+ ),
+ ],
+ ),
+ (
+ "pruning, clustering",
+ "0.5, 32",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.0,
+ layers_to_optimize=None,
+ ),
+ ],
+ ),
+ (
+ "pruning,clustering",
+ "0.5",
+ pytest.raises(
+ Exception, match="Wrong number of optimization targets and types"
+ ),
+ None,
+ ),
+ (
+ "",
+ "0.5",
+ pytest.raises(Exception, match="Optimization type is not provided"),
+ None,
+ ),
+ (
+ "pruning,clustering",
+ "",
+ pytest.raises(Exception, match="Optimization target is not provided"),
+ None,
+ ),
+ (
+ "pruning,",
+ "0.5,abc",
+ pytest.raises(
+ Exception, match="Non numeric value for the optimization target"
+ ),
+ None,
+ ),
+ ],
+)
+def test_parse_optimization_parameters(
+ optimization_type: str,
+ optimization_target: str,
+ expected_error: Any,
+ expected_result: Any,
+) -> None:
+ """Test function parse_optimization_parameters."""
+ with expected_error:
+ result = parse_optimization_parameters(optimization_type, optimization_target)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "args, expected_opts",
+ [
+ [
+ {},
+ [],
+ ],
+ [
+ {"target_profile": "profile"},
+ ["--target-profile", "profile"],
+ ],
+ [
+ # for the default profile empty list should be returned
+ {"target": "ethos-u55-256"},
+ [],
+ ],
+ ],
+)
+def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None:
+ """Test getting target options."""
+ assert get_target_profile_opts(args) == expected_opts
+
+
+@pytest.mark.parametrize(
+ "output_parameters, expected_path",
+ [
+ [["--output", "report.json"], "report.json"],
+ [["--output", "REPORT.JSON"], "REPORT.JSON"],
+ [["--output", "some_folder/report.json"], "some_folder/report.json"],
+ [["--output", "report.csv"], "report.csv"],
+ [["--output", "REPORT.CSV"], "REPORT.CSV"],
+ [["--output", "some_folder/report.csv"], "some_folder/report.csv"],
+ ],
+)
+def test_output_options(output_parameters: List[str], expected_path: str) -> None:
+ """Test output options resolving."""
+ parser = argparse.ArgumentParser()
+ add_output_options(parser)
+
+ args = parser.parse_args(output_parameters)
+ assert args.output == expected_path
+
+
+@pytest.mark.parametrize(
+ "output_filename",
+ [
+ "report.txt",
+ "report.TXT",
+ "report",
+ "report.pdf",
+ ],
+)
+def test_output_options_bad_parameters(
+ output_filename: str, capsys: pytest.CaptureFixture
+) -> None:
+ """Test that args parsing should fail if format is not supported."""
+ parser = argparse.ArgumentParser()
+ add_output_options(parser)
+
+ with pytest.raises(SystemExit):
+ parser.parse_args(["--output", output_filename])
+
+ err_output = capsys.readouterr().err
+ suffix = Path(output_filename).suffix[1:]
+ assert f"Unsupported format '{suffix}'" in err_output
diff --git a/tests/mlia/test_core_advice_generation.py b/tests/mlia/test_core_advice_generation.py
new file mode 100644
index 0000000..05db698
--- /dev/null
+++ b/tests/mlia/test_core_advice_generation.py
@@ -0,0 +1,71 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module advice_generation."""
+from typing import List
+
+import pytest
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import advice_category
+from mlia.core.advice_generation import FactBasedAdviceProducer
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.context import Context
+
+
+def test_advice_generation() -> None:
+ """Test advice generation."""
+
+ class SampleProducer(FactBasedAdviceProducer):
+ """Sample producer."""
+
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Process data."""
+ self.add_advice([f"Advice for {data_item}"])
+
+ producer = SampleProducer()
+ producer.produce_advice(123)
+ producer.produce_advice("hello")
+
+ advice = producer.get_advice()
+ assert advice == [Advice(["Advice for 123"]), Advice(["Advice for hello"])]
+
+
+@pytest.mark.parametrize(
+ "category, expected_advice",
+ [
+ [
+ AdviceCategory.OPERATORS,
+ [Advice(["Good advice!"])],
+ ],
+ [
+ AdviceCategory.PERFORMANCE,
+ [],
+ ],
+ ],
+)
+def test_advice_category_decorator(
+ category: AdviceCategory,
+ expected_advice: List[Advice],
+ dummy_context: Context,
+) -> None:
+ """Test for advice_category decorator."""
+
+ class SampleAdviceProducer(FactBasedAdviceProducer):
+ """Sample advice producer."""
+
+ @advice_category(AdviceCategory.OPERATORS)
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Produce the advice."""
+ self.add_advice(["Good advice!"])
+
+ producer = SampleAdviceProducer()
+ dummy_context.update(
+ advice_category=category, event_handlers=[], config_parameters={}
+ )
+ producer.set_context(dummy_context)
+
+ producer.produce_advice("some_data")
+ advice = producer.get_advice()
+
+ assert advice == expected_advice
diff --git a/tests/mlia/test_core_advisor.py b/tests/mlia/test_core_advisor.py
new file mode 100644
index 0000000..375ff62
--- /dev/null
+++ b/tests/mlia/test_core_advisor.py
@@ -0,0 +1,40 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module advisor."""
+from unittest.mock import MagicMock
+
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.context import Context
+from mlia.core.workflow import WorkflowExecutor
+
+
+def test_inference_advisor_run() -> None:
+ """Test running sample inference advisor."""
+ executor_mock = MagicMock(spec=WorkflowExecutor)
+ context_mock = MagicMock(spec=Context)
+
+ class SampleAdvisor(InferenceAdvisor):
+ """Sample inference advisor."""
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the advisor."""
+ return "sample_advisor"
+
+ @classmethod
+ def description(cls) -> str:
+ """Return description of the advisor."""
+ return "Sample advisor"
+
+ @classmethod
+ def info(cls) -> None:
+ """Print advisor info."""
+
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor."""
+ return executor_mock
+
+ advisor = SampleAdvisor()
+ advisor.run(context_mock)
+
+ executor_mock.run.assert_called_once()
diff --git a/tests/mlia/test_core_context.py b/tests/mlia/test_core_context.py
new file mode 100644
index 0000000..10015aa
--- /dev/null
+++ b/tests/mlia/test_core_context.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module context."""
+from pathlib import Path
+
+from mlia.core.common import AdviceCategory
+from mlia.core.context import ExecutionContext
+from mlia.core.events import DefaultEventPublisher
+
+
+def test_execution_context(tmpdir: str) -> None:
+ """Test execution context."""
+ publisher = DefaultEventPublisher()
+ category = AdviceCategory.OPERATORS
+
+ context = ExecutionContext(
+ advice_category=category,
+ config_parameters={"param": "value"},
+ working_dir=tmpdir,
+ event_handlers=[],
+ event_publisher=publisher,
+ verbose=True,
+ logs_dir="logs_directory",
+ models_dir="models_directory",
+ )
+
+ assert context.advice_category == category
+ assert context.config_parameters == {"param": "value"}
+ assert context.event_handlers == []
+ assert context.event_publisher == publisher
+ assert context.logs_path == Path(tmpdir) / "logs_directory"
+ expected_model_path = Path(tmpdir) / "models_directory/sample.model"
+ assert context.get_model_path("sample.model") == expected_model_path
+ assert context.verbose is True
+ assert str(context) == (
+ f"ExecutionContext: "
+ f"working_dir={tmpdir}, "
+ "advice_category=OPERATORS, "
+ "config_parameters={'param': 'value'}, "
+ "verbose=True"
+ )
+
+ context_with_default_params = ExecutionContext(working_dir=tmpdir)
+ assert context_with_default_params.advice_category is None
+ assert context_with_default_params.config_parameters is None
+ assert context_with_default_params.event_handlers is None
+ assert isinstance(
+ context_with_default_params.event_publisher, DefaultEventPublisher
+ )
+ assert context_with_default_params.logs_path == Path(tmpdir) / "logs"
+
+ default_model_path = context_with_default_params.get_model_path("sample.model")
+ expected_default_model_path = Path(tmpdir) / "models/sample.model"
+ assert default_model_path == expected_default_model_path
+
+ expected_str = (
+ f"ExecutionContext: working_dir={tmpdir}, "
+ "advice_category=<not set>, "
+ "config_parameters=None, "
+ "verbose=False"
+ )
+ assert str(context_with_default_params) == expected_str
diff --git a/tests/mlia/test_core_data_analysis.py b/tests/mlia/test_core_data_analysis.py
new file mode 100644
index 0000000..a782159
--- /dev/null
+++ b/tests/mlia/test_core_data_analysis.py
@@ -0,0 +1,31 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module data_analysis."""
+from dataclasses import dataclass
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.core.data_analysis import FactExtractor
+
+
+def test_fact_extractor() -> None:
+ """Test fact extractor."""
+
+ @dataclass
+ class SampleFact(Fact):
+ """Sample fact."""
+
+ msg: str
+
+ class SampleExtractor(FactExtractor):
+ """Sample extractor."""
+
+ def analyze_data(self, data_item: DataItem) -> None:
+ self.add_fact(SampleFact(f"Fact for {data_item}"))
+
+ extractor = SampleExtractor()
+ extractor.analyze_data(42)
+ extractor.analyze_data("some data")
+
+ facts = extractor.get_analyzed_data()
+ assert facts == [SampleFact("Fact for 42"), SampleFact("Fact for some data")]
diff --git a/tests/mlia/test_core_events.py b/tests/mlia/test_core_events.py
new file mode 100644
index 0000000..faaab7c
--- /dev/null
+++ b/tests/mlia/test_core_events.py
@@ -0,0 +1,155 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module events."""
+from dataclasses import dataclass
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.events import action
+from mlia.core.events import ActionFinishedEvent
+from mlia.core.events import ActionStartedEvent
+from mlia.core.events import DebugEventHandler
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import Event
+from mlia.core.events import EventDispatcher
+from mlia.core.events import EventHandler
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.events import stage
+from mlia.core.events import SystemEventsHandler
+
+
+@dataclass
+class SampleEvent(Event):
+ """Sample event."""
+
+ msg: str
+
+
+def test_event_publisher() -> None:
+ """Test event publishing."""
+ publisher = DefaultEventPublisher()
+ handler_mock1 = MagicMock(spec=EventHandler)
+ handler_mock2 = MagicMock(spec=EventHandler)
+
+ publisher.register_event_handlers([handler_mock1, handler_mock2])
+
+ event = SampleEvent("hello, event!")
+ publisher.publish_event(event)
+
+ handler_mock1.handle_event.assert_called_once_with(event)
+ handler_mock2.handle_event.assert_called_once_with(event)
+
+
+def test_stage_context_manager() -> None:
+ """Test stage context manager."""
+ publisher = DefaultEventPublisher()
+
+ handler_mock = MagicMock(spec=EventHandler)
+ publisher.register_event_handler(handler_mock)
+
+ events = (SampleEvent("hello"), SampleEvent("goodbye"))
+ with stage(publisher, events):
+ print("perform actions")
+
+ assert handler_mock.handle_event.call_count == 2
+ calls = [call(event) for event in events]
+ handler_mock.handle_event.assert_has_calls(calls)
+
+
+def test_action_context_manager() -> None:
+ """Test action stage context manager."""
+ publisher = DefaultEventPublisher()
+
+ handler_mock = MagicMock(spec=EventHandler)
+ publisher.register_event_handler(handler_mock)
+
+ with action(publisher, "Sample action"):
+ print("perform actions")
+
+ assert handler_mock.handle_event.call_count == 2
+ calls = handler_mock.handle_event.mock_calls
+
+ action_started = calls[0].args[0]
+ action_finished = calls[1].args[0]
+
+ assert isinstance(action_started, ActionStartedEvent)
+ assert isinstance(action_finished, ActionFinishedEvent)
+
+ assert action_finished.parent_event_id == action_started.event_id
+
+
+def test_debug_event_handler(capsys: pytest.CaptureFixture) -> None:
+ """Test debugging event handler."""
+ publisher = DefaultEventPublisher()
+
+ publisher.register_event_handler(DebugEventHandler())
+ publisher.register_event_handler(DebugEventHandler(with_stacktrace=True))
+
+ messages = ["Sample event 1", "Sample event 2"]
+ for message in messages:
+ publisher.publish_event(SampleEvent(message))
+
+ captured = capsys.readouterr()
+ for message in messages:
+ assert message in captured.out
+
+ assert "traceback.print_stack" in captured.err
+
+
+def test_event_dispatcher(capsys: pytest.CaptureFixture) -> None:
+ """Test event dispatcher."""
+
+ class SampleEventHandler(EventDispatcher):
+ """Sample event handler."""
+
+ def on_sample_event( # pylint: disable=no-self-use
+ self, _event: SampleEvent
+ ) -> None:
+ """Event handler for SampleEvent."""
+ print("Got sample event")
+
+ publisher = DefaultEventPublisher()
+ publisher.register_event_handler(SampleEventHandler())
+ publisher.publish_event(SampleEvent("Sample event"))
+
+ captured = capsys.readouterr()
+ assert captured.out.strip() == "Got sample event"
+
+
+def test_system_events_handler(capsys: pytest.CaptureFixture) -> None:
+ """Test system events handler."""
+
+ class CustomSystemEventHandler(SystemEventsHandler):
+ """Custom system event handler."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+ print("Execution started")
+
+ def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
+ """Handle ExecutionFinished event."""
+ print("Execution finished")
+
+ publisher = DefaultEventPublisher()
+ publisher.register_event_handler(CustomSystemEventHandler())
+
+ publisher.publish_event(ExecutionStartedEvent())
+ publisher.publish_event(SampleEvent("Hello world!"))
+ publisher.publish_event(ExecutionFinishedEvent())
+
+ captured = capsys.readouterr()
+ assert captured.out.strip() == "Execution started\nExecution finished"
+
+
+def test_compare_without_id() -> None:
+ """Test event comparison without event_id."""
+ event1 = SampleEvent("message")
+ event2 = SampleEvent("message")
+
+ assert event1 != event2
+ assert event1.compare_without_id(event2)
+
+ assert not event1.compare_without_id("message") # type: ignore
diff --git a/tests/mlia/test_core_helpers.py b/tests/mlia/test_core_helpers.py
new file mode 100644
index 0000000..8577617
--- /dev/null
+++ b/tests/mlia/test_core_helpers.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the helper classes."""
+from mlia.core.helpers import APIActionResolver
+
+
+def test_api_action_resolver() -> None:
+ """Test APIActionResolver class."""
+ helper = APIActionResolver()
+
+ # pylint: disable=use-implicit-booleaness-not-comparison
+ assert helper.apply_optimizations() == []
+ assert helper.supported_operators_info() == []
+ assert helper.check_performance() == []
+ assert helper.check_operator_compatibility() == []
+ assert helper.operator_compatibility_details() == []
+ assert helper.optimization_details() == []
diff --git a/tests/mlia/test_core_mixins.py b/tests/mlia/test_core_mixins.py
new file mode 100644
index 0000000..d66213d
--- /dev/null
+++ b/tests/mlia/test_core_mixins.py
@@ -0,0 +1,99 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module mixins."""
+import pytest
+
+from mlia.core.common import AdviceCategory
+from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
+from mlia.core.mixins import ContextMixin
+from mlia.core.mixins import ParameterResolverMixin
+
+
+def test_context_mixin(dummy_context: Context) -> None:
+ """Test ContextMixin."""
+
+ class SampleClass(ContextMixin):
+ """Sample class."""
+
+ sample_object = SampleClass()
+ sample_object.set_context(dummy_context)
+ assert sample_object.context == dummy_context
+
+
+class TestParameterResolverMixin:
+ """Tests for parameter resolver mixin."""
+
+ @staticmethod
+ def test_parameter_resolver_mixin(dummy_context: ExecutionContext) -> None:
+ """Test ParameterResolverMixin."""
+
+ class SampleClass(ParameterResolverMixin):
+ """Sample class."""
+
+ def __init__(self) -> None:
+ """Init sample object."""
+ self.context = dummy_context
+
+ self.context.update(
+ advice_category=AdviceCategory.OPERATORS,
+ event_handlers=[],
+ config_parameters={"section": {"param": 123}},
+ )
+
+ sample_object = SampleClass()
+ value = sample_object.get_parameter("section", "param")
+ assert value == 123
+
+ with pytest.raises(
+ Exception, match="Parameter param expected to have type <class 'str'>"
+ ):
+ value = sample_object.get_parameter("section", "param", expected_type=str)
+
+ with pytest.raises(Exception, match="Parameter no_param is not set"):
+ value = sample_object.get_parameter("section", "no_param")
+
+ @staticmethod
+ def test_parameter_resolver_mixin_no_config(
+ dummy_context: ExecutionContext,
+ ) -> None:
+ """Test ParameterResolverMixin without config params."""
+
+ class SampleClassNoConfig(ParameterResolverMixin):
+ """Sample context without config params."""
+
+ def __init__(self) -> None:
+ """Init sample object."""
+ self.context = dummy_context
+
+ with pytest.raises(Exception, match="Configuration parameters are not set"):
+ sample_object_no_config = SampleClassNoConfig()
+ sample_object_no_config.get_parameter("section", "param", expected_type=str)
+
+ @staticmethod
+ def test_parameter_resolver_mixin_bad_section(
+ dummy_context: ExecutionContext,
+ ) -> None:
+ """Test ParameterResolverMixin without config params."""
+
+ class SampleClassBadSection(ParameterResolverMixin):
+ """Sample context with bad section in config."""
+
+ def __init__(self) -> None:
+ """Init sample object."""
+ self.context = dummy_context
+ self.context.update(
+ advice_category=AdviceCategory.OPERATORS,
+ event_handlers=[],
+ config_parameters={"section": ["param"]},
+ )
+
+ with pytest.raises(
+ Exception,
+ match="Parameter section section has wrong format, "
+ "expected to be a dictionary",
+ ):
+ sample_object_bad_section = SampleClassBadSection()
+ sample_object_bad_section.get_parameter(
+ "section", "param", expected_type=str
+ )
diff --git a/tests/mlia/test_core_performance.py b/tests/mlia/test_core_performance.py
new file mode 100644
index 0000000..0d28fe8
--- /dev/null
+++ b/tests/mlia/test_core_performance.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module performance."""
+from pathlib import Path
+
+from mlia.core.performance import estimate_performance
+from mlia.core.performance import PerformanceEstimator
+
+
+def test_estimate_performance(tmp_path: Path) -> None:
+ """Test function estimate_performance."""
+ model_path = tmp_path / "original.tflite"
+
+ class SampleEstimator(PerformanceEstimator[Path, int]):
+ """Sample estimator."""
+
+ def estimate(self, model: Path) -> int:
+ """Estimate performance."""
+ if model.name == "original.tflite":
+ return 1
+
+ return 2
+
+ def optimized_model(_original: Path) -> Path:
+ """Return path to the 'optimized' model."""
+ return tmp_path / "optimized.tflite"
+
+ results = estimate_performance(model_path, SampleEstimator(), [optimized_model])
+ assert results == [1, 2]
diff --git a/tests/mlia/test_core_reporting.py b/tests/mlia/test_core_reporting.py
new file mode 100644
index 0000000..2f7ec22
--- /dev/null
+++ b/tests/mlia/test_core_reporting.py
@@ -0,0 +1,413 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for reporting module."""
+from typing import List
+
+import pytest
+
+from mlia.core.reporting import BytesCell
+from mlia.core.reporting import Cell
+from mlia.core.reporting import ClockCell
+from mlia.core.reporting import Column
+from mlia.core.reporting import CyclesCell
+from mlia.core.reporting import Format
+from mlia.core.reporting import NestedReport
+from mlia.core.reporting import ReportItem
+from mlia.core.reporting import SingleRow
+from mlia.core.reporting import Table
+from mlia.utils.console import remove_ascii_codes
+
+
+@pytest.mark.parametrize(
+ "cell, expected_repr",
+ [
+ (BytesCell(None), ""),
+ (BytesCell(0), "0 bytes"),
+ (BytesCell(1), "1 byte"),
+ (BytesCell(100000), "100,000 bytes"),
+ (ClockCell(None), ""),
+ (ClockCell(0), "0 Hz"),
+ (ClockCell(1), "1 Hz"),
+ (ClockCell(100000), "100,000 Hz"),
+ (CyclesCell(None), ""),
+ (CyclesCell(0), "0 cycles"),
+ (CyclesCell(1), "1 cycle"),
+ (CyclesCell(100000), "100,000 cycles"),
+ ],
+)
+def test_predefined_cell_types(cell: Cell, expected_repr: str) -> None:
+ """Test predefined cell types."""
+ assert str(cell) == expected_repr
+
+
+@pytest.mark.parametrize(
+ "with_notes, expected_text_report",
+ [
+ [
+ True,
+ """
+Sample table:
+┌──────────┬──────────┬──────────┐
+│ Header 1 │ Header 2 │ Header 3 │
+╞══════════╪══════════╪══════════╡
+│ 1 │ 2 │ 3 │
+├──────────┼──────────┼──────────┤
+│ 4 │ 5 │ 123,123 │
+└──────────┴──────────┴──────────┘
+Sample notes
+ """.strip(),
+ ],
+ [
+ False,
+ """
+Sample table:
+┌──────────┬──────────┬──────────┐
+│ Header 1 │ Header 2 │ Header 3 │
+╞══════════╪══════════╪══════════╡
+│ 1 │ 2 │ 3 │
+├──────────┼──────────┼──────────┤
+│ 4 │ 5 │ 123,123 │
+└──────────┴──────────┴──────────┘
+ """.strip(),
+ ],
+ ],
+)
+def test_table_representation(with_notes: bool, expected_text_report: str) -> None:
+ """Test table report representation."""
+
+ def sample_table(with_notes: bool) -> Table:
+ columns = [
+ Column("Header 1", alias="header1", only_for=["plain_text"]),
+ Column("Header 2", alias="header2", fmt=Format(wrap_width=5)),
+ Column("Header 3", alias="header3"),
+ ]
+ rows = [(1, 2, 3), (4, 5, Cell(123123, fmt=Format(str_fmt="0,d")))]
+
+ return Table(
+ columns,
+ rows,
+ name="Sample table",
+ alias="sample_table",
+ notes="Sample notes" if with_notes else None,
+ )
+
+ table = sample_table(with_notes)
+ csv_repr = table.to_csv()
+ assert csv_repr == [["Header 2", "Header 3"], [2, 3], [5, 123123]]
+
+ json_repr = table.to_json()
+ assert json_repr == {
+ "sample_table": [
+ {"header2": 2, "header3": 3},
+ {"header2": 5, "header3": 123123},
+ ]
+ }
+
+ text_report = remove_ascii_codes(table.to_plain_text())
+ assert text_report == expected_text_report
+
+
+def test_csv_nested_table_representation() -> None:
+ """Test representation of the nested tables in csv format."""
+
+ def sample_table(num_of_cols: int) -> Table:
+ columns = [
+ Column("Header 1", alias="header1"),
+ Column("Header 2", alias="header2"),
+ ]
+
+ rows = [
+ (
+ 1,
+ Table(
+ columns=[
+ Column(f"Nested column {i+1}") for i in range(num_of_cols)
+ ],
+ rows=[[f"value{i+1}" for i in range(num_of_cols)]],
+ name="Nested table",
+ ),
+ )
+ ]
+
+ return Table(columns, rows, name="Sample table", alias="sample_table")
+
+ assert sample_table(num_of_cols=2).to_csv() == [
+ ["Header 1", "Header 2"],
+ [1, "value1;value2"],
+ ]
+
+ assert sample_table(num_of_cols=1).to_csv() == [
+ ["Header 1", "Header 2"],
+ [1, "value1"],
+ ]
+
+ assert sample_table(num_of_cols=0).to_csv() == [
+ ["Header 1", "Header 2"],
+ [1, ""],
+ ]
+
+
+@pytest.mark.parametrize(
+ "report, expected_plain_text, expected_json_data, expected_csv_data",
+ [
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem("Item", "item", "item_value"),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+""".strip(),
+ {
+ "sample_report": {"item": "item_value"},
+ },
+ [
+ ("item",),
+ ("item_value",),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [ReportItem("Nested item", "nested_item", "nested_item_value")],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item nested_item_value
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": "nested_item_value"},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", "nested_item_value"),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [ReportItem("Nested item", "nested_item", BytesCell(10))],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10 bytes
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": {"unit": "bytes", "value": 10}},
+ },
+ },
+ [
+ ("item", "nested_item_value", "nested_item_unit"),
+ ("item_value", 10, "bytes"),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem(
+ "Nested item",
+ "nested_item",
+ Cell(
+ 10, fmt=Format(str_fmt=lambda x: f"{x} cell value")
+ ),
+ )
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10 cell value
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem(
+ "Nested item",
+ "nested_item",
+ Cell(
+ 10, fmt=Format(str_fmt=lambda x: f"{x} cell value")
+ ),
+ )
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10 cell value
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem("Nested item", "nested_item", Cell(10)),
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem(
+ "Nested item", "nested_item", Cell(10, fmt=Format())
+ ),
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ ],
+)
+def test_nested_report_representation(
+ report: NestedReport,
+ expected_plain_text: str,
+ expected_json_data: dict,
+ expected_csv_data: List,
+) -> None:
+ """Test representation of the NestedReport."""
+ plain_text = report.to_plain_text()
+ assert plain_text == expected_plain_text
+
+ json_data = report.to_json()
+ assert json_data == expected_json_data
+
+ csv_data = report.to_csv()
+ assert csv_data == expected_csv_data
+
+
+def test_single_row_representation() -> None:
+ """Test representation of the SingleRow."""
+ single_row = SingleRow(
+ columns=[
+ Column("column1", "column1"),
+ ],
+ rows=[("value1", "value2")],
+ name="Single row example",
+ alias="simple_row_example",
+ )
+
+ expected_text = """
+Single row example:
+ column1 value1
+""".strip()
+ assert single_row.to_plain_text() == expected_text
+ assert single_row.to_csv() == [["column1"], ["value1"]]
+ assert single_row.to_json() == {"simple_row_example": [{"column1": "value1"}]}
+
+ with pytest.raises(Exception, match="Table should have only one row"):
+ wrong_single_row = SingleRow(
+ columns=[
+ Column("column1", "column1"),
+ ],
+ rows=[
+ ("value1", "value2"),
+ ("value1", "value2"),
+ ],
+ name="Single row example",
+ alias="simple_row_example",
+ )
+ wrong_single_row.to_plain_text()
diff --git a/tests/mlia/test_core_workflow.py b/tests/mlia/test_core_workflow.py
new file mode 100644
index 0000000..470e572
--- /dev/null
+++ b/tests/mlia/test_core_workflow.py
@@ -0,0 +1,164 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module workflow."""
+from dataclasses import dataclass
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.advice_generation import ContextAwareAdviceProducer
+from mlia.core.context import ExecutionContext
+from mlia.core.data_analysis import ContextAwareDataAnalyzer
+from mlia.core.data_collection import ContextAwareDataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import AnalyzedDataEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataAnalysisStageStartedEvent
+from mlia.core.events import DataCollectionStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import Event
+from mlia.core.events import EventHandler
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.workflow import DefaultWorkflowExecutor
+
+
+@dataclass
+class SampleEvent(Event):
+ """Sample event."""
+
+ msg: str
+
+
+def test_workflow_executor(tmpdir: str) -> None:
+ """Test workflow executor."""
+ handler_mock = MagicMock(spec=EventHandler)
+ data_collector_mock = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock.collect_data.return_value = 42
+
+ data_collector_mock_no_value = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock_no_value.collect_data.return_value = None
+
+ data_collector_mock_skipped = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock_skipped.name.return_value = "skipped_collector"
+ data_collector_mock_skipped.collect_data.side_effect = (
+ FunctionalityNotSupportedError("Error!", "Error!")
+ )
+
+ data_analyzer_mock = MagicMock(spec=ContextAwareDataAnalyzer)
+ data_analyzer_mock.get_analyzed_data.return_value = ["Really good number!"]
+
+ advice_producer_mock1 = MagicMock(spec=ContextAwareAdviceProducer)
+ advice_producer_mock1.get_advice.return_value = Advice(["All good!"])
+
+ advice_producer_mock2 = MagicMock(spec=ContextAwareAdviceProducer)
+ advice_producer_mock2.get_advice.return_value = [Advice(["Good advice!"])]
+
+ context = ExecutionContext(
+ working_dir=tmpdir,
+ event_handlers=[handler_mock],
+ event_publisher=DefaultEventPublisher(),
+ )
+
+ executor = DefaultWorkflowExecutor(
+ context,
+ [
+ data_collector_mock,
+ data_collector_mock_no_value,
+ data_collector_mock_skipped,
+ ],
+ [data_analyzer_mock],
+ [
+ advice_producer_mock1,
+ advice_producer_mock2,
+ ],
+ [SampleEvent("Hello from advisor!")],
+ )
+
+ executor.run()
+
+ data_collector_mock.collect_data.assert_called_once()
+ data_collector_mock_no_value.collect_data.assert_called_once()
+ data_collector_mock_skipped.collect_data.assert_called_once()
+
+ data_analyzer_mock.analyze_data.assert_called_once_with(42)
+
+ advice_producer_mock1.produce_advice.assert_called_once_with("Really good number!")
+ advice_producer_mock1.get_advice.assert_called_once()
+
+ advice_producer_mock2.produce_advice.called_once_with("Really good number!")
+ advice_producer_mock2.get_advice.assert_called_once()
+
+ expected_mock_calls = [
+ call(ExecutionStartedEvent()),
+ call(SampleEvent("Hello from advisor!")),
+ call(DataCollectionStageStartedEvent()),
+ call(CollectedDataEvent(data_item=42)),
+ call(DataCollectorSkippedEvent("skipped_collector", "Error!: Error!")),
+ call(DataCollectionStageFinishedEvent()),
+ call(DataAnalysisStageStartedEvent()),
+ call(AnalyzedDataEvent(data_item="Really good number!")),
+ call(DataAnalysisStageFinishedEvent()),
+ call(AdviceStageStartedEvent()),
+ call(AdviceEvent(advice=Advice(messages=["All good!"]))),
+ call(AdviceEvent(advice=Advice(messages=["Good advice!"]))),
+ call(AdviceStageFinishedEvent()),
+ call(ExecutionFinishedEvent()),
+ ]
+
+ for expected_call, actual_call in zip(
+ expected_mock_calls, handler_mock.handle_event.mock_calls
+ ):
+ expected_event = expected_call.args[0]
+ actual_event = actual_call.args[0]
+
+ assert actual_event.compare_without_id(expected_event)
+
+
+def test_workflow_executor_failed(tmpdir: str) -> None:
+ """Test scenario when one of the components raises exception."""
+ handler_mock = MagicMock(spec=EventHandler)
+
+ context = ExecutionContext(
+ working_dir=tmpdir,
+ event_handlers=[handler_mock],
+ event_publisher=DefaultEventPublisher(),
+ )
+
+ collection_exception = Exception("Collection failed")
+
+ data_collector_mock = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock.collect_data.side_effect = collection_exception
+
+ executor = DefaultWorkflowExecutor(context, [data_collector_mock], [], [])
+ executor.run()
+
+ expected_mock_calls = [
+ call(ExecutionStartedEvent()),
+ call(DataCollectionStageStartedEvent()),
+ call(ExecutionFailedEvent(collection_exception)),
+ ]
+
+ for expected_call, actual_call in zip(
+ expected_mock_calls, handler_mock.handle_event.mock_calls
+ ):
+ expected_event = expected_call.args[0]
+ actual_event = actual_call.args[0]
+
+ if isinstance(actual_event, ExecutionFailedEvent):
+ # seems that dataclass comparison doesn't work well
+ # for the exceptions
+ actual_exception = actual_event.err
+ expected_exception = expected_event.err
+
+ assert actual_exception == expected_exception
+ continue
+
+ assert actual_event.compare_without_id(expected_event)
diff --git a/tests/mlia/test_devices_ethosu_advice_generation.py b/tests/mlia/test_devices_ethosu_advice_generation.py
new file mode 100644
index 0000000..98c8a57
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_advice_generation.py
@@ -0,0 +1,483 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Ethos-U advice generation."""
+from typing import List
+from typing import Optional
+
+import pytest
+
+from mlia.cli.helpers import CLIActionResolver
+from mlia.core.advice_generation import Advice
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.context import ExecutionContext
+from mlia.core.helpers import ActionResolver
+from mlia.core.helpers import APIActionResolver
+from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer
+from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer
+from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU
+from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators
+from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators
+from mlia.devices.ethosu.data_analysis import OptimizationDiff
+from mlia.devices.ethosu.data_analysis import OptimizationResults
+from mlia.devices.ethosu.data_analysis import PerfMetricDiff
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+@pytest.mark.parametrize(
+ "input_data, advice_category, action_resolver, expected_advice",
+ [
+ [
+ AllOperatorsSupportedOnNPU(),
+ AdviceCategory.OPERATORS,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You don't have any unsupported operators, your model will "
+ "run completely on NPU."
+ ]
+ )
+ ],
+ ],
+ [
+ AllOperatorsSupportedOnNPU(),
+ AdviceCategory.OPERATORS,
+ CLIActionResolver(
+ {
+ "target_profile": "sample_target",
+ "model": "sample_model.tflite",
+ }
+ ),
+ [
+ Advice(
+ [
+ "You don't have any unsupported operators, your model will "
+ "run completely on NPU.",
+ "Check the estimated performance by running the "
+ "following command: ",
+ "mlia performance --target-profile sample_target "
+ "sample_model.tflite",
+ ]
+ )
+ ],
+ ],
+ [
+ HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]),
+ AdviceCategory.OPERATORS,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You have at least 3 operators that is CPU only: "
+ "OP1,OP2,OP3.",
+ "Using operators that are supported by the NPU will "
+ "improve performance.",
+ ]
+ )
+ ],
+ ],
+ [
+ HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]),
+ AdviceCategory.OPERATORS,
+ CLIActionResolver({}),
+ [
+ Advice(
+ [
+ "You have at least 3 operators that is CPU only: "
+ "OP1,OP2,OP3.",
+ "Using operators that are supported by the NPU will "
+ "improve performance.",
+ "For guidance on supported operators, run: mlia operators "
+ "--supported-ops-report",
+ ]
+ )
+ ],
+ ],
+ [
+ HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4),
+ AdviceCategory.OPERATORS,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You have 40% of operators that cannot be placed on the NPU.",
+ "For better performance, please review the reasons reported "
+ "in the table, and adjust the model accordingly "
+ "where possible.",
+ ]
+ )
+ ],
+ ],
+ [
+ HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4),
+ AdviceCategory.OPERATORS,
+ CLIActionResolver({}),
+ [
+ Advice(
+ [
+ "You have 40% of operators that cannot be placed on the NPU.",
+ "For better performance, please review the reasons reported "
+ "in the table, and adjust the model accordingly "
+ "where possible.",
+ ]
+ )
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "You can try to push the optimization target higher "
+ "(e.g. pruning: 0.6) "
+ "to check if those results can be further improved.",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ CLIActionResolver({"model": "sample_model.h5"}),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "You can try to push the optimization target higher "
+ "(e.g. pruning: 0.6) "
+ "to check if those results can be further improved.",
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ "mlia optimization --optimization-type pruning "
+ "--optimization-target 0.6 sample_model.h5",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[
+ OptimizationSettings("pruning", 0.5, None),
+ OptimizationSettings("clustering", 32, None),
+ ],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5, clustering: 32)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "You can try to push the optimization target higher "
+ "(e.g. pruning: 0.6 and/or clustering: 16) "
+ "to check if those results can be further improved.",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[
+ OptimizationSettings("clustering", 2, None),
+ ],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (clustering: 2)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 150),
+ "on_chip_flash": PerfMetricDiff(100, 150),
+ "off_chip_flash": PerfMetricDiff(100, 150),
+ "npu_total_cycles": PerfMetricDiff(10, 100),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5)",
+ "- DRAM used (KB) have degraded by 50.00%",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "- On chip flash used (KB) have degraded by 50.00%",
+ "- Off chip flash used (KB) have degraded by 50.00%",
+ "- NPU total cycles have degraded by 900.00%",
+ "The performance seems to have degraded after "
+ "applying the selected optimizations, "
+ "try exploring different optimization types/targets.",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 150),
+ "on_chip_flash": PerfMetricDiff(100, 150),
+ "off_chip_flash": PerfMetricDiff(100, 150),
+ "npu_total_cycles": PerfMetricDiff(10, 100),
+ },
+ ),
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.6, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 150),
+ "on_chip_flash": PerfMetricDiff(100, 150),
+ "off_chip_flash": PerfMetricDiff(100, 150),
+ "npu_total_cycles": PerfMetricDiff(10, 100),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [], # no advice for more than one optimization result
+ ],
+ ],
+)
+def test_ethosu_advice_producer(
+ tmpdir: str,
+ input_data: DataItem,
+ expected_advice: List[Advice],
+ advice_category: AdviceCategory,
+ action_resolver: ActionResolver,
+) -> None:
+ """Test Ethos-U Advice producer."""
+ producer = EthosUAdviceProducer()
+
+ context = ExecutionContext(
+ advice_category=advice_category,
+ working_dir=tmpdir,
+ action_resolver=action_resolver,
+ )
+
+ producer.set_context(context)
+ producer.produce_advice(input_data)
+
+ assert producer.get_advice() == expected_advice
+
+
+@pytest.mark.parametrize(
+ "advice_category, action_resolver, expected_advice",
+ [
+ [
+ None,
+ None,
+ [],
+ ],
+ [
+ AdviceCategory.OPERATORS,
+ None,
+ [],
+ ],
+ [
+ AdviceCategory.PERFORMANCE,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You can improve the inference time by using only operators "
+ "that are supported by the NPU.",
+ ]
+ ),
+ Advice(
+ [
+ "Check if you can improve the performance by applying "
+ "tooling techniques to your model."
+ ]
+ ),
+ ],
+ ],
+ [
+ AdviceCategory.PERFORMANCE,
+ CLIActionResolver({"model": "test_model.h5"}),
+ [
+ Advice(
+ [
+ "You can improve the inference time by using only operators "
+ "that are supported by the NPU.",
+ "Try running the following command to verify that:",
+ "mlia operators test_model.h5",
+ ]
+ ),
+ Advice(
+ [
+ "Check if you can improve the performance by applying "
+ "tooling techniques to your model.",
+ "For example: mlia optimization --optimization-type "
+ "pruning,clustering --optimization-target 0.5,32 "
+ "test_model.h5",
+ "For more info: mlia optimization --help",
+ ]
+ ),
+ ],
+ ],
+ [
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "For better performance, make sure that all the operators "
+ "of your final TFLite model are supported by the NPU.",
+ ]
+ )
+ ],
+ ],
+ [
+ AdviceCategory.OPTIMIZATION,
+ CLIActionResolver({"model": "test_model.h5"}),
+ [
+ Advice(
+ [
+ "For better performance, make sure that all the operators "
+ "of your final TFLite model are supported by the NPU.",
+ "For more details, run: mlia operators --help",
+ ]
+ )
+ ],
+ ],
+ ],
+)
+def test_ethosu_static_advice_producer(
+ tmpdir: str,
+ advice_category: Optional[AdviceCategory],
+ action_resolver: ActionResolver,
+ expected_advice: List[Advice],
+) -> None:
+ """Test static advice generation."""
+ producer = EthosUStaticAdviceProducer()
+
+ context = ExecutionContext(
+ advice_category=advice_category,
+ working_dir=tmpdir,
+ action_resolver=action_resolver,
+ )
+ producer.set_context(context)
+ assert producer.get_advice() == expected_advice
diff --git a/tests/mlia/test_devices_ethosu_advisor.py b/tests/mlia/test_devices_ethosu_advisor.py
new file mode 100644
index 0000000..74d2408
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_advisor.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Ethos-U MLIA module."""
+from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
+
+
+def test_advisor_metadata() -> None:
+ """Test advisor metadata."""
+ assert EthosUInferenceAdvisor.name() == "ethos_u_inference_advisor"
diff --git a/tests/mlia/test_devices_ethosu_config.py b/tests/mlia/test_devices_ethosu_config.py
new file mode 100644
index 0000000..49c999a
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_config.py
@@ -0,0 +1,124 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for config module."""
+from contextlib import ExitStack as does_not_raise
+from typing import Any
+from typing import Dict
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.config import get_target
+from mlia.tools.vela_wrapper import VelaCompilerOptions
+from mlia.utils.filesystem import get_vela_config
+
+
+def test_compiler_options_default_init() -> None:
+ """Test compiler options default init."""
+ opts = VelaCompilerOptions()
+
+ assert opts.config_files is None
+ assert opts.system_config == "internal-default"
+ assert opts.memory_mode == "internal-default"
+ assert opts.accelerator_config is None
+ assert opts.max_block_dependency == 3
+ assert opts.arena_cache_size is None
+ assert opts.tensor_allocator == "HillClimb"
+ assert opts.cpu_tensor_alignment == 16
+ assert opts.optimization_strategy == "Performance"
+ assert opts.output_dir is None
+
+
+def test_ethosu_target() -> None:
+ """Test Ethos-U target configuration init."""
+ default_config = EthosUConfiguration("ethos-u55-256")
+
+ assert default_config.target == "ethos-u55"
+ assert default_config.mac == 256
+ assert default_config.compiler_options is not None
+
+
+def test_get_target() -> None:
+ """Test function get_target."""
+ with pytest.raises(Exception, match="No target profile given"):
+ get_target(None) # type: ignore
+
+ with pytest.raises(Exception, match="Unable to find target profile unknown"):
+ get_target("unknown")
+
+ u65_device = get_target("ethos-u65-512")
+
+ assert isinstance(u65_device, EthosUConfiguration)
+ assert u65_device.target == "ethos-u65"
+ assert u65_device.mac == 512
+ assert u65_device.compiler_options.accelerator_config == "ethos-u65-512"
+ assert u65_device.compiler_options.memory_mode == "Dedicated_Sram"
+ assert u65_device.compiler_options.config_files == str(get_vela_config())
+
+
+@pytest.mark.parametrize(
+ "profile_data, expected_error",
+ [
+ [
+ {},
+ pytest.raises(
+ Exception,
+ match="Mandatory fields missing from target profile: "
+ r"\['mac', 'memory_mode', 'system_config', 'target'\]",
+ ),
+ ],
+ [
+ {"target": "ethos-u65", "mac": 512},
+ pytest.raises(
+ Exception,
+ match="Mandatory fields missing from target profile: "
+ r"\['memory_mode', 'system_config'\]",
+ ),
+ ],
+ [
+ {
+ "target": "ethos-u65",
+ "mac": 2,
+ "system_config": "Ethos_U65_Embedded",
+ "memory_mode": "Shared_Sram",
+ },
+ pytest.raises(
+ Exception,
+ match=r"Mac value for selected device should be in \[256, 512\]",
+ ),
+ ],
+ [
+ {
+ "target": "ethos-u55",
+ "mac": 1,
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "memory_mode": "Shared_Sram",
+ },
+ pytest.raises(
+ Exception,
+ match="Mac value for selected device should be "
+ r"in \[32, 64, 128, 256\]",
+ ),
+ ],
+ [
+ {
+ "target": "ethos-u65",
+ "mac": 512,
+ "system_config": "Ethos_U65_Embedded",
+ "memory_mode": "Shared_Sram",
+ },
+ does_not_raise(),
+ ],
+ ],
+)
+def test_ethosu_configuration(
+ monkeypatch: pytest.MonkeyPatch, profile_data: Dict[str, Any], expected_error: Any
+) -> None:
+ """Test creating Ethos-U configuration."""
+ monkeypatch.setattr(
+ "mlia.devices.ethosu.config.get_profile", MagicMock(return_value=profile_data)
+ )
+
+ with expected_error:
+ EthosUConfiguration("target")
diff --git a/tests/mlia/test_devices_ethosu_data_analysis.py b/tests/mlia/test_devices_ethosu_data_analysis.py
new file mode 100644
index 0000000..4b1d38b
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_data_analysis.py
@@ -0,0 +1,147 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Ethos-U data analysis module."""
+from typing import List
+
+import pytest
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU
+from mlia.devices.ethosu.data_analysis import EthosUDataAnalyzer
+from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators
+from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators
+from mlia.devices.ethosu.data_analysis import OptimizationDiff
+from mlia.devices.ethosu.data_analysis import OptimizationResults
+from mlia.devices.ethosu.data_analysis import PerfMetricDiff
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.tools.vela_wrapper import NpuSupported
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+
+
+def test_perf_metrics_diff() -> None:
+ """Test PerfMetricsDiff class."""
+ diff_same = PerfMetricDiff(1, 1)
+ assert diff_same.same is True
+ assert diff_same.improved is False
+ assert diff_same.degraded is False
+ assert diff_same.diff == 0
+
+ diff_improved = PerfMetricDiff(10, 5)
+ assert diff_improved.same is False
+ assert diff_improved.improved is True
+ assert diff_improved.degraded is False
+ assert diff_improved.diff == 50.0
+
+ diff_degraded = PerfMetricDiff(5, 10)
+ assert diff_degraded.same is False
+ assert diff_degraded.improved is False
+ assert diff_degraded.degraded is True
+ assert diff_degraded.diff == -100.0
+
+ diff_original_zero = PerfMetricDiff(0, 1)
+ assert diff_original_zero.diff == 0
+
+
+@pytest.mark.parametrize(
+ "input_data, expected_facts",
+ [
+ [
+ Operators(
+ [
+ Operator(
+ "CPU operator",
+ "CPU operator type",
+ NpuSupported(False, [("CPU only operator", "")]),
+ )
+ ]
+ ),
+ [
+ HasCPUOnlyOperators(["CPU operator type"]),
+ HasUnsupportedOnNPUOperators(1.0),
+ ],
+ ],
+ [
+ Operators(
+ [
+ Operator(
+ "NPU operator",
+ "NPU operator type",
+ NpuSupported(True, []),
+ )
+ ]
+ ),
+ [
+ AllOperatorsSupportedOnNPU(),
+ ],
+ ],
+ [
+ OptimizationPerformanceMetrics(
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ # memory metrics are in kilobytes
+ MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore
+ ),
+ [
+ [
+ [
+ OptimizationSettings("pruning", 0.5, None),
+ ],
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ # memory metrics are in kilobytes
+ MemoryUsage(
+ *[i * 1024 for i in range(1, 6)] # type: ignore
+ ),
+ ),
+ ],
+ ],
+ ),
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[
+ OptimizationSettings("pruning", 0.5, None),
+ ],
+ opt_diffs={
+ "sram": PerfMetricDiff(1.0, 1.0),
+ "dram": PerfMetricDiff(2.0, 2.0),
+ "on_chip_flash": PerfMetricDiff(4.0, 4.0),
+ "off_chip_flash": PerfMetricDiff(5.0, 5.0),
+ "npu_total_cycles": PerfMetricDiff(3, 3),
+ },
+ )
+ ]
+ )
+ ],
+ ],
+ [
+ OptimizationPerformanceMetrics(
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ # memory metrics are in kilobytes
+ MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore
+ ),
+ [],
+ ),
+ [],
+ ],
+ ],
+)
+def test_ethos_u_data_analyzer(
+ input_data: DataItem, expected_facts: List[Fact]
+) -> None:
+ """Test Ethos-U data analyzer."""
+ analyzer = EthosUDataAnalyzer()
+ analyzer.analyze_data(input_data)
+ assert analyzer.get_analyzed_data() == expected_facts
diff --git a/tests/mlia/test_devices_ethosu_data_collection.py b/tests/mlia/test_devices_ethosu_data_collection.py
new file mode 100644
index 0000000..897cf41
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_data_collection.py
@@ -0,0 +1,151 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the data collection module for Ethos-U."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.context import Context
+from mlia.core.data_collection import DataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility
+from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance
+from mlia.devices.ethosu.data_collection import EthosUPerformance
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.tools.vela_wrapper import Operators
+
+
+@pytest.mark.parametrize(
+ "collector, expected_name",
+ [
+ (
+ EthosUOperatorCompatibility,
+ "ethos_u_operator_compatibility",
+ ),
+ (
+ EthosUPerformance,
+ "ethos_u_performance",
+ ),
+ (
+ EthosUOptimizationPerformance,
+ "ethos_u_model_optimizations",
+ ),
+ ],
+)
+def test_collectors_metadata(
+ collector: DataCollector,
+ expected_name: str,
+) -> None:
+ """Test collectors metadata."""
+ assert collector.name() == expected_name
+
+
+def test_operator_compatibility_collector(
+ dummy_context: Context, test_tflite_model: Path
+) -> None:
+ """Test operator compatibility data collector."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ collector = EthosUOperatorCompatibility(test_tflite_model, device)
+ collector.set_context(dummy_context)
+
+ result = collector.collect_data()
+ assert isinstance(result, Operators)
+
+
+def test_performance_collector(
+ monkeypatch: pytest.MonkeyPatch, dummy_context: Context, test_tflite_model: Path
+) -> None:
+ """Test performance data collector."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ mock_performance_estimation(monkeypatch, device)
+
+ collector = EthosUPerformance(test_tflite_model, device)
+ collector.set_context(dummy_context)
+
+ result = collector.collect_data()
+ assert isinstance(result, PerformanceMetrics)
+
+
+def test_optimization_performance_collector(
+ monkeypatch: pytest.MonkeyPatch,
+ dummy_context: Context,
+ test_keras_model: Path,
+ test_tflite_model: Path,
+) -> None:
+ """Test optimization performance data collector."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ mock_performance_estimation(monkeypatch, device)
+ collector = EthosUOptimizationPerformance(
+ test_keras_model,
+ device,
+ [
+ [
+ {"optimization_type": "pruning", "optimization_target": 0.5},
+ ]
+ ],
+ )
+ collector.set_context(dummy_context)
+ result = collector.collect_data()
+
+ assert isinstance(result, OptimizationPerformanceMetrics)
+ assert isinstance(result.original_perf_metrics, PerformanceMetrics)
+ assert isinstance(result.optimizations_perf_metrics, list)
+ assert len(result.optimizations_perf_metrics) == 1
+
+ opt, metrics = result.optimizations_perf_metrics[0]
+ assert opt == [OptimizationSettings("pruning", 0.5, None)]
+ assert isinstance(metrics, PerformanceMetrics)
+
+ collector_no_optimizations = EthosUOptimizationPerformance(
+ test_keras_model,
+ device,
+ [],
+ )
+ with pytest.raises(FunctionalityNotSupportedError):
+ collector_no_optimizations.collect_data()
+
+ collector_tflite = EthosUOptimizationPerformance(
+ test_tflite_model,
+ device,
+ [
+ [
+ {"optimization_type": "pruning", "optimization_target": 0.5},
+ ]
+ ],
+ )
+ collector_tflite.set_context(dummy_context)
+ with pytest.raises(FunctionalityNotSupportedError):
+ collector_tflite.collect_data()
+
+ with pytest.raises(
+ Exception, match="Optimization parameters expected to be a list"
+ ):
+ collector_bad_config = EthosUOptimizationPerformance(
+ test_keras_model, device, {"optimization_type": "pruning"} # type: ignore
+ )
+ collector.set_context(dummy_context)
+ collector_bad_config.collect_data()
+
+
+def mock_performance_estimation(
+ monkeypatch: pytest.MonkeyPatch, device: EthosUConfiguration
+) -> None:
+ """Mock performance estimation."""
+ metrics = PerformanceMetrics(
+ device,
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ MemoryUsage(1, 2, 3, 4, 5),
+ )
+ monkeypatch.setattr(
+ "mlia.devices.ethosu.data_collection.EthosUPerformanceEstimator.estimate",
+ MagicMock(return_value=metrics),
+ )
diff --git a/tests/mlia/test_devices_ethosu_performance.py b/tests/mlia/test_devices_ethosu_performance.py
new file mode 100644
index 0000000..e27efa0
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_performance.py
@@ -0,0 +1,28 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Performance estimation tests."""
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.devices.ethosu.performance import MemorySizeType
+from mlia.devices.ethosu.performance import MemoryUsage
+
+
+def test_memory_usage_conversion() -> None:
+ """Test MemoryUsage objects conversion."""
+ memory_usage_in_kb = MemoryUsage(1, 2, 3, 4, 5, MemorySizeType.KILOBYTES)
+ assert memory_usage_in_kb.in_kilobytes() == memory_usage_in_kb
+
+ memory_usage_in_bytes = MemoryUsage(
+ 1 * 1024, 2 * 1024, 3 * 1024, 4 * 1024, 5 * 1024
+ )
+ assert memory_usage_in_bytes.in_kilobytes() == memory_usage_in_kb
+
+
+def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Mock performance estimation."""
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.estimate_performance",
+ MagicMock(return_value=MagicMock()),
+ )
diff --git a/tests/mlia/test_devices_ethosu_reporters.py b/tests/mlia/test_devices_ethosu_reporters.py
new file mode 100644
index 0000000..2d5905c
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_reporters.py
@@ -0,0 +1,434 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for reports module."""
+import json
+import sys
+from contextlib import ExitStack as doesnt_raise
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Literal
+
+import pytest
+
+from mlia.core.reporting import get_reporter
+from mlia.core.reporting import produce_report
+from mlia.core.reporting import Report
+from mlia.core.reporting import Reporter
+from mlia.core.reporting import Table
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.devices.ethosu.reporters import find_appropriate_formatter
+from mlia.devices.ethosu.reporters import report_device_details
+from mlia.devices.ethosu.reporters import report_operators
+from mlia.devices.ethosu.reporters import report_perf_metrics
+from mlia.tools.vela_wrapper import NpuSupported
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+from mlia.utils.console import remove_ascii_codes
+
+
+@pytest.mark.parametrize(
+ "data, formatters",
+ [
+ (
+ [Operator("test_operator", "test_type", NpuSupported(False, []))],
+ [report_operators],
+ ),
+ (
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(0, 0, 0, 0, 0, 0),
+ MemoryUsage(0, 0, 0, 0, 0),
+ ),
+ [report_perf_metrics],
+ ),
+ ],
+)
+@pytest.mark.parametrize(
+ "fmt, output, expected_error",
+ [
+ [
+ "unknown_format",
+ sys.stdout,
+ pytest.raises(Exception, match="Unknown format unknown_format"),
+ ],
+ [
+ "plain_text",
+ sys.stdout,
+ doesnt_raise(),
+ ],
+ [
+ "json",
+ sys.stdout,
+ doesnt_raise(),
+ ],
+ [
+ "csv",
+ sys.stdout,
+ doesnt_raise(),
+ ],
+ [
+ "plain_text",
+ "report.txt",
+ doesnt_raise(),
+ ],
+ [
+ "json",
+ "report.json",
+ doesnt_raise(),
+ ],
+ [
+ "csv",
+ "report.csv",
+ doesnt_raise(),
+ ],
+ ],
+)
+def test_report(
+ data: Any,
+ formatters: List[Callable],
+ fmt: Literal["plain_text", "json", "csv"],
+ output: Any,
+ expected_error: Any,
+ tmp_path: Path,
+) -> None:
+ """Test report function."""
+ if is_file := isinstance(output, str):
+ output = tmp_path / output
+
+ for formatter in formatters:
+ with expected_error:
+ produce_report(data, formatter, fmt, output)
+
+ if is_file:
+ assert output.is_file()
+ assert output.stat().st_size > 0
+
+
+@pytest.mark.parametrize(
+ "ops, expected_plain_text, expected_json_dict, expected_csv_list",
+ [
+ (
+ [
+ Operator(
+ "npu_supported",
+ "test_type",
+ NpuSupported(True, []),
+ ),
+ Operator(
+ "cpu_only",
+ "test_type",
+ NpuSupported(
+ False,
+ [
+ (
+ "CPU only operator",
+ "",
+ ),
+ ],
+ ),
+ ),
+ Operator(
+ "npu_unsupported",
+ "test_type",
+ NpuSupported(
+ False,
+ [
+ (
+ "Not supported operator",
+ "Reason why operator is not supported",
+ )
+ ],
+ ),
+ ),
+ ],
+ """
+Operators:
+┌───┬─────────────────┬───────────────┬───────────┬───────────────────────────────┐
+│ # │ Operator name │ Operator type │ Placement │ Notes │
+╞═══╪═════════════════╪═══════════════╪═══════════╪═══════════════════════════════╡
+│ 1 │ npu_supported │ test_type │ NPU │ │
+├───┼─────────────────┼───────────────┼───────────┼───────────────────────────────┤
+│ 2 │ cpu_only │ test_type │ CPU │ * CPU only operator │
+├───┼─────────────────┼───────────────┼───────────┼───────────────────────────────┤
+│ 3 │ npu_unsupported │ test_type │ CPU │ * Not supported operator │
+│ │ │ │ │ │
+│ │ │ │ │ * Reason why operator is not │
+│ │ │ │ │ supported │
+└───┴─────────────────┴───────────────┴───────────┴───────────────────────────────┘
+""".strip(),
+ {
+ "operators": [
+ {
+ "operator_name": "npu_supported",
+ "operator_type": "test_type",
+ "placement": "NPU",
+ "notes": [],
+ },
+ {
+ "operator_name": "cpu_only",
+ "operator_type": "test_type",
+ "placement": "CPU",
+ "notes": [{"note": "CPU only operator"}],
+ },
+ {
+ "operator_name": "npu_unsupported",
+ "operator_type": "test_type",
+ "placement": "CPU",
+ "notes": [
+ {"note": "Not supported operator"},
+ {"note": "Reason why operator is not supported"},
+ ],
+ },
+ ]
+ },
+ [
+ ["Operator name", "Operator type", "Placement", "Notes"],
+ ["npu_supported", "test_type", "NPU", ""],
+ ["cpu_only", "test_type", "CPU", "CPU only operator"],
+ [
+ "npu_unsupported",
+ "test_type",
+ "CPU",
+ "Not supported operator;Reason why operator is not supported",
+ ],
+ ],
+ ),
+ ],
+)
+def test_report_operators(
+ ops: List[Operator],
+ expected_plain_text: str,
+ expected_json_dict: Dict,
+ expected_csv_list: List,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test report_operatos formatter."""
+ # make terminal wide enough to print whole table
+ monkeypatch.setenv("COLUMNS", "100")
+
+ report = report_operators(ops)
+ assert isinstance(report, Table)
+
+ plain_text = remove_ascii_codes(report.to_plain_text())
+ assert plain_text == expected_plain_text
+
+ json_dict = report.to_json()
+ assert json_dict == expected_json_dict
+
+ csv_list = report.to_csv()
+ assert csv_list == expected_csv_list
+
+
+@pytest.mark.parametrize(
+ "device, expected_plain_text, expected_json_dict, expected_csv_list",
+ [
+ [
+ EthosUConfiguration("ethos-u55-256"),
+ """Device information:
+ Target ethos-u55
+ MAC 256
+
+ Memory mode Shared_Sram
+ Const mem area Axi1
+ Arena mem area Axi0
+ Cache mem area Axi0
+ Arena cache size 4,294,967,296 bytes
+
+ System config Ethos_U55_High_End_Embedded
+ Accelerator clock 500,000,000 Hz
+ AXI0 port Sram
+ AXI1 port OffChipFlash
+
+ Memory area settings:
+ Sram:
+ Clock scales 1.0
+ Burst length 32 bytes
+ Read latency 32 cycles
+ Write latency 32 cycles
+
+ Dram:
+ Clock scales 1.0
+ Burst length 1 byte
+ Read latency 0 cycles
+ Write latency 0 cycles
+
+ OnChipFlash:
+ Clock scales 1.0
+ Burst length 1 byte
+ Read latency 0 cycles
+ Write latency 0 cycles
+
+ OffChipFlash:
+ Clock scales 0.125
+ Burst length 128 bytes
+ Read latency 64 cycles
+ Write latency 64 cycles
+
+ Architecture settings:
+ Permanent storage mem area OffChipFlash
+ Feature map storage mem area Sram
+ Fast storage mem area Sram""",
+ {
+ "device": {
+ "target": "ethos-u55",
+ "mac": 256,
+ "memory_mode": {
+ "const_mem_area": "Axi1",
+ "arena_mem_area": "Axi0",
+ "cache_mem_area": "Axi0",
+ "arena_cache_size": {"value": 4294967296, "unit": "bytes"},
+ },
+ "system_config": {
+ "accelerator_clock": {"value": 500000000.0, "unit": "Hz"},
+ "axi0_port": "Sram",
+ "axi1_port": "OffChipFlash",
+ "memory_area": {
+ "Sram": {
+ "clock_scales": 1.0,
+ "burst_length": {"value": 32, "unit": "bytes"},
+ "read_latency": {"value": 32, "unit": "cycles"},
+ "write_latency": {"value": 32, "unit": "cycles"},
+ },
+ "Dram": {
+ "clock_scales": 1.0,
+ "burst_length": {"value": 1, "unit": "byte"},
+ "read_latency": {"value": 0, "unit": "cycles"},
+ "write_latency": {"value": 0, "unit": "cycles"},
+ },
+ "OnChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": {"value": 1, "unit": "byte"},
+ "read_latency": {"value": 0, "unit": "cycles"},
+ "write_latency": {"value": 0, "unit": "cycles"},
+ },
+ "OffChipFlash": {
+ "clock_scales": 0.125,
+ "burst_length": {"value": 128, "unit": "bytes"},
+ "read_latency": {"value": 64, "unit": "cycles"},
+ "write_latency": {"value": 64, "unit": "cycles"},
+ },
+ },
+ },
+ "arch_settings": {
+ "permanent_storage_mem_area": "OffChipFlash",
+ "feature_map_storage_mem_area": "Sram",
+ "fast_storage_mem_area": "Sram",
+ },
+ }
+ },
+ [
+ (
+ "target",
+ "mac",
+ "memory_mode",
+ "const_mem_area",
+ "arena_mem_area",
+ "cache_mem_area",
+ "arena_cache_size_value",
+ "arena_cache_size_unit",
+ "system_config",
+ "accelerator_clock_value",
+ "accelerator_clock_unit",
+ "axi0_port",
+ "axi1_port",
+ "clock_scales",
+ "burst_length_value",
+ "burst_length_unit",
+ "read_latency_value",
+ "read_latency_unit",
+ "write_latency_value",
+ "write_latency_unit",
+ "permanent_storage_mem_area",
+ "feature_map_storage_mem_area",
+ "fast_storage_mem_area",
+ ),
+ (
+ "ethos-u55",
+ 256,
+ "Shared_Sram",
+ "Axi1",
+ "Axi0",
+ "Axi0",
+ 4294967296,
+ "bytes",
+ "Ethos_U55_High_End_Embedded",
+ 500000000.0,
+ "Hz",
+ "Sram",
+ "OffChipFlash",
+ 0.125,
+ 128,
+ "bytes",
+ 64,
+ "cycles",
+ 64,
+ "cycles",
+ "OffChipFlash",
+ "Sram",
+ "Sram",
+ ),
+ ],
+ ],
+ ],
+)
+def test_report_device_details(
+ device: EthosUConfiguration,
+ expected_plain_text: str,
+ expected_json_dict: Dict,
+ expected_csv_list: List,
+) -> None:
+ """Test report_operatos formatter."""
+ report = report_device_details(device)
+ assert isinstance(report, Report)
+
+ plain_text = report.to_plain_text()
+ assert plain_text == expected_plain_text
+
+ json_dict = report.to_json()
+ assert json_dict == expected_json_dict
+
+ csv_list = report.to_csv()
+ assert csv_list == expected_csv_list
+
+
+def test_get_reporter(tmp_path: Path) -> None:
+ """Test reporter functionality."""
+ ops = Operators(
+ [
+ Operator(
+ "npu_supported",
+ "op_type",
+ NpuSupported(True, []),
+ ),
+ ]
+ )
+
+ output = tmp_path / "output.json"
+ with get_reporter("json", output, find_appropriate_formatter) as reporter:
+ assert isinstance(reporter, Reporter)
+
+ with pytest.raises(
+ Exception, match="Unable to find appropriate formatter for some_data"
+ ):
+ reporter.submit("some_data")
+
+ reporter.submit(ops)
+
+ with open(output, encoding="utf-8") as file:
+ json_data = json.load(file)
+
+ assert json_data == {
+ "operators_stats": [
+ {
+ "npu_unsupported_ratio": 0.0,
+ "num_of_npu_supported_operators": 1,
+ "num_of_operators": 1,
+ }
+ ]
+ }
diff --git a/tests/mlia/test_nn_tensorflow_config.py b/tests/mlia/test_nn_tensorflow_config.py
new file mode 100644
index 0000000..1ac9f97
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_config.py
@@ -0,0 +1,72 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for config module."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+
+import pytest
+
+from mlia.nn.tensorflow.config import get_model
+from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.config import TfModel
+
+
+def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None:
+ """Test Keras to TFLite conversion."""
+ keras_model = KerasModel(test_keras_model)
+
+ tflite_model_path = tmp_path / "test.tflite"
+ keras_model.convert_to_tflite(tflite_model_path)
+
+ assert tflite_model_path.is_file()
+ assert tflite_model_path.stat().st_size > 0
+
+
+def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None:
+ """Test TensorFlow saved model to TFLite conversion."""
+ tf_model = TfModel(test_tf_model)
+
+ tflite_model_path = tmp_path / "test.tflite"
+ tf_model.convert_to_tflite(tflite_model_path)
+
+ assert tflite_model_path.is_file()
+ assert tflite_model_path.stat().st_size > 0
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_type, expected_error",
+ [
+ ("test.tflite", TFLiteModel, does_not_raise()),
+ ("test.h5", KerasModel, does_not_raise()),
+ ("test.hdf5", KerasModel, does_not_raise()),
+ (
+ "test.model",
+ None,
+ pytest.raises(
+ Exception,
+ match="The input model format is not supported"
+ r"\(supported formats: TFLite, Keras, TensorFlow saved model\)!",
+ ),
+ ),
+ ],
+)
+def test_get_model_file(
+ model_path: str, expected_type: type, expected_error: Any
+) -> None:
+ """Test TFLite model type."""
+ with expected_error:
+ model = get_model(model_path)
+ assert isinstance(model, expected_type)
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_type", [("tf_model_test_model", TfModel)]
+)
+def test_get_model_dir(
+ test_models_path: Path, model_path: str, expected_type: type
+) -> None:
+ """Test TFLite model type."""
+ model = get_model(str(test_models_path / model_path))
+ assert isinstance(model, expected_type)
diff --git a/tests/mlia/test_nn_tensorflow_optimizations_clustering.py b/tests/mlia/test_nn_tensorflow_optimizations_clustering.py
new file mode 100644
index 0000000..9bcf918
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_optimizations_clustering.py
@@ -0,0 +1,131 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module optimizations/clustering."""
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.optimizations.clustering import Clusterer
+from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode
+from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import save_tflite_model
+from tests.mlia.utils.common import get_dataset
+from tests.mlia.utils.common import train_model
+
+
+def _prune_model(
+ model: tf.keras.Model, target_sparsity: float, layers_to_prune: Optional[List[str]]
+) -> tf.keras.Model:
+ x_train, y_train = get_dataset()
+ batch_size = 1
+ num_epochs = 1
+
+ pruner = Pruner(
+ model,
+ PruningConfiguration(
+ target_sparsity,
+ layers_to_prune,
+ x_train,
+ y_train,
+ batch_size,
+ num_epochs,
+ ),
+ )
+ pruner.apply_optimization()
+ pruned_model = pruner.get_model()
+
+ return pruned_model
+
+
+def _test_num_unique_weights(
+ metrics: TFLiteMetrics,
+ target_num_clusters: int,
+ layers_to_cluster: Optional[List[str]],
+) -> None:
+ clustered_uniqueness_dict = metrics.num_unique_weights(
+ ReportClusterMode.NUM_CLUSTERS_PER_AXIS
+ )
+ num_clustered_layers = 0
+ num_optimizable_layers = len(clustered_uniqueness_dict)
+ if layers_to_cluster:
+ expected_num_clustered_layers = len(layers_to_cluster)
+ else:
+ expected_num_clustered_layers = num_optimizable_layers
+ for layer_name in clustered_uniqueness_dict:
+ # the +1 is there temporarily because of a bug that's been fixed
+ # but the fix hasn't been merged yet.
+ # Will need to be removed in the future.
+ if clustered_uniqueness_dict[layer_name][0] <= (target_num_clusters + 1):
+ num_clustered_layers = num_clustered_layers + 1
+ # make sure we are having exactly as many clustered layers as we wanted
+ assert num_clustered_layers == expected_num_clustered_layers
+
+
+def _test_sparsity(
+ metrics: TFLiteMetrics,
+ target_sparsity: float,
+ layers_to_cluster: Optional[List[str]],
+) -> None:
+ pruned_sparsity_dict = metrics.sparsity_per_layer()
+ num_sparse_layers = 0
+ num_optimizable_layers = len(pruned_sparsity_dict)
+ error_margin = 0.03
+ if layers_to_cluster:
+ expected_num_sparse_layers = len(layers_to_cluster)
+ else:
+ expected_num_sparse_layers = num_optimizable_layers
+ for layer_name in pruned_sparsity_dict:
+ if abs(pruned_sparsity_dict[layer_name] - target_sparsity) < error_margin:
+ num_sparse_layers = num_sparse_layers + 1
+ # make sure we are having exactly as many sparse layers as we wanted
+ assert num_sparse_layers == expected_num_sparse_layers
+
+
+@pytest.mark.skip(reason="Test fails randomly, further investigation is needed")
+@pytest.mark.parametrize("target_num_clusters", (32, 4))
+@pytest.mark.parametrize("sparsity_aware", (False, True))
+@pytest.mark.parametrize("layers_to_cluster", (["conv1"], ["conv1", "conv2"], None))
+def test_cluster_simple_model_fully(
+ target_num_clusters: int,
+ sparsity_aware: bool,
+ layers_to_cluster: Optional[List[str]],
+ tmp_path: Path,
+ test_keras_model: Path,
+) -> None:
+ """Simple MNIST test to see if clustering works correctly."""
+ target_sparsity = 0.5
+
+ base_model = tf.keras.models.load_model(str(test_keras_model))
+ train_model(base_model)
+
+ if sparsity_aware:
+ base_model = _prune_model(base_model, target_sparsity, layers_to_cluster)
+
+ clusterer = Clusterer(
+ base_model,
+ ClusteringConfiguration(
+ target_num_clusters,
+ layers_to_cluster,
+ ),
+ )
+ clusterer.apply_optimization()
+ clustered_model = clusterer.get_model()
+
+ temp_file = tmp_path / "test_cluster_simple_model_fully_after.tflite"
+ tflite_clustered_model = convert_to_tflite(clustered_model)
+ save_tflite_model(tflite_clustered_model, temp_file)
+ clustered_tflite_metrics = TFLiteMetrics(str(temp_file))
+
+ _test_num_unique_weights(
+ clustered_tflite_metrics, target_num_clusters, layers_to_cluster
+ )
+
+ if sparsity_aware:
+ _test_sparsity(clustered_tflite_metrics, target_sparsity, layers_to_cluster)
diff --git a/tests/mlia/test_nn_tensorflow_optimizations_pruning.py b/tests/mlia/test_nn_tensorflow_optimizations_pruning.py
new file mode 100644
index 0000000..64030a6
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_optimizations_pruning.py
@@ -0,0 +1,117 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module optimizations/pruning."""
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+import pytest
+import tensorflow as tf
+from numpy.core.numeric import isclose
+
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import save_tflite_model
+from tests.mlia.utils.common import get_dataset
+from tests.mlia.utils.common import train_model
+
+
+def _test_sparsity(
+ metrics: TFLiteMetrics,
+ target_sparsity: float,
+ layers_to_prune: Optional[List[str]],
+) -> None:
+ pruned_sparsity_dict = metrics.sparsity_per_layer()
+ num_sparse_layers = 0
+ num_optimizable_layers = len(pruned_sparsity_dict)
+ error_margin = 0.03
+ if layers_to_prune:
+ expected_num_sparse_layers = len(layers_to_prune)
+ else:
+ expected_num_sparse_layers = num_optimizable_layers
+ for layer_name in pruned_sparsity_dict:
+ if abs(pruned_sparsity_dict[layer_name] - target_sparsity) < error_margin:
+ num_sparse_layers = num_sparse_layers + 1
+ # make sure we are having exactly as many sparse layers as we wanted
+ assert num_sparse_layers == expected_num_sparse_layers
+
+
+def _test_check_sparsity(base_tflite_metrics: TFLiteMetrics) -> None:
+ """Assert the sparsity of a model is zero."""
+ base_sparsity_dict = base_tflite_metrics.sparsity_per_layer()
+ for layer_name, sparsity in base_sparsity_dict.items():
+ assert isclose(
+ sparsity, 0, atol=1e-2
+ ), f"Sparsity for layer '{layer_name}' is {sparsity}, but should be zero."
+
+
+def _get_tflite_metrics(
+ path: Path, tflite_fn: str, model: tf.keras.Model
+) -> TFLiteMetrics:
+ """Save model as TFLiteModel and return metrics."""
+ temp_file = path / tflite_fn
+ save_tflite_model(convert_to_tflite(model), temp_file)
+ return TFLiteMetrics(str(temp_file))
+
+
+@pytest.mark.parametrize("target_sparsity", (0.5, 0.9))
+@pytest.mark.parametrize("mock_data", (False, True))
+@pytest.mark.parametrize("layers_to_prune", (["conv1"], ["conv1", "conv2"], None))
+def test_prune_simple_model_fully(
+ target_sparsity: float,
+ mock_data: bool,
+ layers_to_prune: Optional[List[str]],
+ tmp_path: Path,
+ test_keras_model: Path,
+) -> None:
+ """Simple MNIST test to see if pruning works correctly."""
+ x_train, y_train = get_dataset()
+ batch_size = 1
+ num_epochs = 1
+
+ base_model = tf.keras.models.load_model(str(test_keras_model))
+ train_model(base_model)
+
+ base_tflite_metrics = _get_tflite_metrics(
+ path=tmp_path,
+ tflite_fn="test_prune_simple_model_fully_before.tflite",
+ model=base_model,
+ )
+
+ # Make sure sparsity is zero before pruning
+ _test_check_sparsity(base_tflite_metrics)
+
+ if mock_data:
+ pruner = Pruner(
+ base_model,
+ PruningConfiguration(
+ target_sparsity,
+ layers_to_prune,
+ ),
+ )
+
+ else:
+ pruner = Pruner(
+ base_model,
+ PruningConfiguration(
+ target_sparsity,
+ layers_to_prune,
+ x_train,
+ y_train,
+ batch_size,
+ num_epochs,
+ ),
+ )
+
+ pruner.apply_optimization()
+ pruned_model = pruner.get_model()
+
+ pruned_tflite_metrics = _get_tflite_metrics(
+ path=tmp_path,
+ tflite_fn="test_prune_simple_model_fully_after.tflite",
+ model=pruned_model,
+ )
+
+ _test_sparsity(pruned_tflite_metrics, target_sparsity, layers_to_prune)
diff --git a/tests/mlia/test_nn_tensorflow_optimizations_select.py b/tests/mlia/test_nn_tensorflow_optimizations_select.py
new file mode 100644
index 0000000..5cac8ba
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_optimizations_select.py
@@ -0,0 +1,240 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module select."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.optimizations.clustering import Clusterer
+from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.optimizations.select import get_optimizer
+from mlia.nn.tensorflow.optimizations.select import MultiStageOptimizer
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+@pytest.mark.parametrize(
+ "config, expected_error, expected_type, expected_config",
+ [
+ (
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ does_not_raise(),
+ Pruner,
+ "pruning: 0.5",
+ ),
+ (
+ PruningConfiguration(0.5),
+ does_not_raise(),
+ Pruner,
+ "pruning: 0.5",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ does_not_raise(),
+ Clusterer,
+ "clustering: 32",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization target should be a "
+ "positive integer. "
+ "Optimization target provided: 0.5",
+ ),
+ None,
+ None,
+ ),
+ (
+ ClusteringConfiguration(32),
+ does_not_raise(),
+ Clusterer,
+ "clustering: 32",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="superoptimization",
+ optimization_target="supertarget", # type: ignore
+ layers_to_optimize="all", # type: ignore
+ ),
+ pytest.raises(
+ Exception,
+ match="Unsupported optimization type: superoptimization",
+ ),
+ None,
+ None,
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization type is not provided",
+ ),
+ None,
+ None,
+ ),
+ (
+ "wrong_config",
+ pytest.raises(
+ Exception,
+ match="Unknown optimization configuration wrong_config",
+ ),
+ None,
+ None,
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=None, # type: ignore
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization target is not provided",
+ ),
+ None,
+ None,
+ ),
+ (
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ ],
+ does_not_raise(),
+ MultiStageOptimizer,
+ "pruning: 0.5 - clustering: 32",
+ ),
+ ],
+)
+def test_get_optimizer(
+ config: Any,
+ expected_error: Any,
+ expected_type: type,
+ expected_config: str,
+ test_keras_model: Path,
+) -> None:
+ """Test function get_optimzer."""
+ model = tf.keras.models.load_model(str(test_keras_model))
+
+ with expected_error:
+ optimizer = get_optimizer(model, config)
+ assert isinstance(optimizer, expected_type)
+ assert optimizer.optimization_config() == expected_config
+
+
+@pytest.mark.parametrize(
+ "params, expected_result",
+ [
+ (
+ [],
+ [],
+ ),
+ (
+ [("pruning", 0.5)],
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ )
+ ],
+ ),
+ (
+ [("pruning", 0.5), ("clustering", 32)],
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ ],
+ ),
+ ],
+)
+def test_optimization_settings_create_from(
+ params: List[Tuple[str, float]], expected_result: List[OptimizationSettings]
+) -> None:
+ """Test creating settings from parsed params."""
+ assert OptimizationSettings.create_from(params) == expected_result
+
+
+@pytest.mark.parametrize(
+ "settings, expected_next_target, expected_error",
+ [
+ [
+ OptimizationSettings("clustering", 32, None),
+ OptimizationSettings("clustering", 16, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("clustering", 4, None),
+ OptimizationSettings("clustering", 4, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("clustering", 10, None),
+ OptimizationSettings("clustering", 8, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("pruning", 0.5, None),
+ OptimizationSettings("pruning", 0.6, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("pruning", 0.9, None),
+ OptimizationSettings("pruning", 0.9, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("super_optimization", 42, None),
+ None,
+ pytest.raises(
+ Exception, match="Unknown optimization type super_optimization"
+ ),
+ ],
+ ],
+)
+def test_optimization_settings_next_target(
+ settings: OptimizationSettings,
+ expected_next_target: OptimizationSettings,
+ expected_error: Any,
+) -> None:
+ """Test getting next optimization target."""
+ with expected_error:
+ assert settings.next_target() == expected_next_target
diff --git a/tests/mlia/test_nn_tensorflow_tflite_metrics.py b/tests/mlia/test_nn_tensorflow_tflite_metrics.py
new file mode 100644
index 0000000..805f7d1
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_tflite_metrics.py
@@ -0,0 +1,137 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module utils/tflite_metrics."""
+import os
+import tempfile
+from math import isclose
+from pathlib import Path
+from typing import Generator
+from typing import List
+
+import numpy as np
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode
+from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
+
+
+def _dummy_keras_model() -> tf.keras.Model:
+ # Create a dummy model
+ keras_model = tf.keras.Sequential(
+ [
+ tf.keras.Input(shape=(8, 8, 3)),
+ tf.keras.layers.Conv2D(4, 3),
+ tf.keras.layers.DepthwiseConv2D(3),
+ tf.keras.layers.Flatten(),
+ tf.keras.layers.Dense(8),
+ ]
+ )
+ return keras_model
+
+
+def _sparse_binary_keras_model() -> tf.keras.Model:
+ def get_sparse_weights(shape: List[int]) -> np.array:
+ weights = np.zeros(shape)
+ with np.nditer(weights, op_flags=["writeonly"]) as weight_iterator:
+ for idx, value in enumerate(weight_iterator):
+ if idx % 2 == 0:
+ value[...] = 1.0
+ return weights
+
+ keras_model = _dummy_keras_model()
+ # Assign weights to have 0.5 sparsity
+ for layer in keras_model.layers:
+ if not isinstance(layer, tf.keras.layers.Flatten):
+ weight = layer.weights[0]
+ weight.assign(get_sparse_weights(weight.shape))
+ print(layer)
+ print(weight.numpy())
+ return keras_model
+
+
+@pytest.fixture(scope="class", name="tflite_file")
+def fixture_tflite_file() -> Generator:
+ """Generate temporary TFLite file for tests."""
+ converter = tf.lite.TFLiteConverter.from_keras_model(_sparse_binary_keras_model())
+ tflite_model = converter.convert()
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file = os.path.join(tmp_dir, "test.tflite")
+ Path(file).write_bytes(tflite_model)
+ yield file
+
+
+@pytest.fixture(scope="function", name="metrics")
+def fixture_metrics(tflite_file: str) -> TFLiteMetrics:
+ """Generate metrics file for a given TFLite model."""
+ return TFLiteMetrics(tflite_file)
+
+
+class TestTFLiteMetrics:
+ """Tests for module TFLite_metrics."""
+
+ @staticmethod
+ def test_sparsity(metrics: TFLiteMetrics) -> None:
+ """Test sparsity."""
+ # Create new instance with a dummy TFLite file
+ # Check sparsity calculation
+ sparsity_per_layer = metrics.sparsity_per_layer()
+ for name, sparsity in sparsity_per_layer.items():
+ assert isclose(sparsity, 0.5), "Layer '{}' has incorrect sparsity.".format(
+ name
+ )
+ assert isclose(metrics.sparsity_overall(), 0.5)
+
+ @staticmethod
+ def test_clusters(metrics: TFLiteMetrics) -> None:
+ """Test clusters."""
+ # NUM_CLUSTERS_PER_AXIS and NUM_CLUSTERS_MIN_MAX can be handled together
+ for mode in [
+ ReportClusterMode.NUM_CLUSTERS_PER_AXIS,
+ ReportClusterMode.NUM_CLUSTERS_MIN_MAX,
+ ]:
+ num_unique_weights = metrics.num_unique_weights(mode)
+ for name, num_unique_per_axis in num_unique_weights.items():
+ for num_unique in num_unique_per_axis:
+ assert (
+ num_unique == 2
+ ), "Layer '{}' has incorrect number of clusters.".format(name)
+ # NUM_CLUSTERS_HISTOGRAM
+ hists = metrics.num_unique_weights(ReportClusterMode.NUM_CLUSTERS_HISTOGRAM)
+ assert hists
+ for name, hist in hists.items():
+ assert hist
+ for idx, num_axes in enumerate(hist):
+ # The histogram starts with the bin for for num_clusters == 1
+ num_clusters = idx + 1
+ msg = (
+ "Histogram of layer '{}': There are {} axes with {} "
+ "clusters".format(name, num_axes, num_clusters)
+ )
+ if num_clusters == 2:
+ assert num_axes > 0, "{}, but there should be at least one.".format(
+ msg
+ )
+ else:
+ assert num_axes == 0, "{}, but there should be none.".format(msg)
+
+ @staticmethod
+ @pytest.mark.parametrize("report_sparsity", (False, True))
+ @pytest.mark.parametrize("report_cluster_mode", ReportClusterMode)
+ @pytest.mark.parametrize("max_num_clusters", (-1, 8))
+ @pytest.mark.parametrize("verbose", (False, True))
+ def test_summary(
+ tflite_file: str,
+ report_sparsity: bool,
+ report_cluster_mode: ReportClusterMode,
+ max_num_clusters: int,
+ verbose: bool,
+ ) -> None:
+ """Test the summary function."""
+ for metrics in [TFLiteMetrics(tflite_file), TFLiteMetrics(tflite_file, [])]:
+ metrics.summary(
+ report_sparsity=report_sparsity,
+ report_cluster_mode=report_cluster_mode,
+ max_num_clusters=max_num_clusters,
+ verbose=verbose,
+ )
diff --git a/tests/mlia/test_nn_tensorflow_utils.py b/tests/mlia/test_nn_tensorflow_utils.py
new file mode 100644
index 0000000..6d27299
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_utils.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module utils/test_utils."""
+from pathlib import Path
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import get_tf_tensor_shape
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.nn.tensorflow.utils import save_keras_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+
+
+def test_convert_to_tflite(test_keras_model: Path) -> None:
+ """Test converting Keras model to TFLite."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+ tflite_model = convert_to_tflite(keras_model)
+
+ assert tflite_model
+
+
+def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None:
+ """Test saving Keras model."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+
+ temp_file = tmp_path / "test_model_saving.h5"
+ save_keras_model(keras_model, temp_file)
+ loaded_model = tf.keras.models.load_model(temp_file)
+
+ assert loaded_model.summary() == keras_model.summary()
+
+
+def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None:
+ """Test saving TFLite model."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+
+ tflite_model = convert_to_tflite(keras_model)
+
+ temp_file = tmp_path / "test_model_saving.tflite"
+ save_tflite_model(tflite_model, temp_file)
+
+ interpreter = tf.lite.Interpreter(model_path=str(temp_file))
+ assert interpreter
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_result",
+ [
+ [Path("sample_model.tflite"), True],
+ [Path("strange_model.tflite.tfl"), False],
+ [Path("sample_model.h5"), False],
+ [Path("sample_model"), False],
+ ],
+)
+def test_is_tflite_model(model_path: Path, expected_result: bool) -> None:
+ """Test function is_tflite_model."""
+ result = is_tflite_model(model_path)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_result",
+ [
+ [Path("sample_model.h5"), True],
+ [Path("strange_model.h5.keras"), False],
+ [Path("sample_model.tflite"), False],
+ [Path("sample_model"), False],
+ ],
+)
+def test_is_keras_model(model_path: Path, expected_result: bool) -> None:
+ """Test function is_keras_model."""
+ result = is_keras_model(model_path)
+ assert result == expected_result
+
+
+def test_get_tf_tensor_shape(test_tf_model: Path) -> None:
+ """Test get_tf_tensor_shape with test model."""
+ assert get_tf_tensor_shape(str(test_tf_model)) == [1, 28, 28, 1]
diff --git a/tests/mlia/test_resources/vela/sample_vela.ini b/tests/mlia/test_resources/vela/sample_vela.ini
new file mode 100644
index 0000000..c992458
--- /dev/null
+++ b/tests/mlia/test_resources/vela/sample_vela.ini
@@ -0,0 +1,47 @@
+; SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+; SPDX-License-Identifier: Apache-2.0
+; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s)
+[System_Config.Ethos_U55_High_End_Embedded]
+core_clock=500e6
+axi0_port=Sram
+axi1_port=OffChipFlash
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+OffChipFlash_clock_scale=0.125
+OffChipFlash_burst_length=128
+OffChipFlash_read_latency=64
+OffChipFlash_write_latency=64
+
+; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s)
+[System_Config.Ethos_U65_High_End]
+core_clock=1e9
+axi0_port=Sram
+axi1_port=Dram
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+Dram_clock_scale=0.234375
+Dram_burst_length=128
+Dram_read_latency=500
+Dram_write_latency=250
+
+; -----------------------------------------------------------------------------
+; Memory Mode
+
+; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software
+; The non-SRAM memory is assumed to be read-only
+[Memory_Mode.Shared_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+
+; The SRAM (384KB) is only for use by the Ethos-U
+; The non-SRAM memory is assumed to be read-writeable
+[Memory_Mode.Dedicated_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi1
+cache_mem_area=Axi0
+arena_cache_size=393216
diff --git a/tests/mlia/test_tools_aiet_wrapper.py b/tests/mlia/test_tools_aiet_wrapper.py
new file mode 100644
index 0000000..ab55b71
--- /dev/null
+++ b/tests/mlia/test_tools_aiet_wrapper.py
@@ -0,0 +1,760 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module tools/aiet_wrapper."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from unittest.mock import MagicMock
+from unittest.mock import PropertyMock
+
+import pytest
+
+from mlia.tools.aiet_wrapper import AIETRunner
+from mlia.tools.aiet_wrapper import DeviceInfo
+from mlia.tools.aiet_wrapper import estimate_performance
+from mlia.tools.aiet_wrapper import ExecutionParams
+from mlia.tools.aiet_wrapper import GenericInferenceOutputParser
+from mlia.tools.aiet_wrapper import GenericInferenceRunnerEthosU
+from mlia.tools.aiet_wrapper import get_aiet_runner
+from mlia.tools.aiet_wrapper import get_generic_runner
+from mlia.tools.aiet_wrapper import get_system_name
+from mlia.tools.aiet_wrapper import is_supported
+from mlia.tools.aiet_wrapper import ModelInfo
+from mlia.tools.aiet_wrapper import PerformanceMetrics
+from mlia.tools.aiet_wrapper import supported_backends
+from mlia.utils.proc import RunningCommand
+
+
+@pytest.mark.parametrize(
+ "data, is_ready, result, missed_keys",
+ [
+ (
+ [],
+ False,
+ {},
+ [
+ "npu_active_cycles",
+ "npu_axi0_rd_data_beat_received",
+ "npu_axi0_wr_data_beat_written",
+ "npu_axi1_rd_data_beat_received",
+ "npu_idle_cycles",
+ "npu_total_cycles",
+ ],
+ ),
+ (
+ ["sample text"],
+ False,
+ {},
+ [
+ "npu_active_cycles",
+ "npu_axi0_rd_data_beat_received",
+ "npu_axi0_wr_data_beat_written",
+ "npu_axi1_rd_data_beat_received",
+ "npu_idle_cycles",
+ "npu_total_cycles",
+ ],
+ ),
+ (
+ [
+ ["NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 123"],
+ False,
+ {"npu_axi0_rd_data_beat_received": 123},
+ [
+ "npu_active_cycles",
+ "npu_axi0_wr_data_beat_written",
+ "npu_axi1_rd_data_beat_received",
+ "npu_idle_cycles",
+ "npu_total_cycles",
+ ],
+ ]
+ ),
+ (
+ [
+ "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1",
+ "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2",
+ "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3",
+ "NPU ACTIVE cycles: 4",
+ "NPU IDLE cycles: 5",
+ "NPU TOTAL cycles: 6",
+ ],
+ True,
+ {
+ "npu_axi0_rd_data_beat_received": 1,
+ "npu_axi0_wr_data_beat_written": 2,
+ "npu_axi1_rd_data_beat_received": 3,
+ "npu_active_cycles": 4,
+ "npu_idle_cycles": 5,
+ "npu_total_cycles": 6,
+ },
+ [],
+ ),
+ ],
+)
+def test_generic_inference_output_parser(
+ data: List[str], is_ready: bool, result: Dict, missed_keys: List[str]
+) -> None:
+ """Test generic runner output parser."""
+ parser = GenericInferenceOutputParser()
+
+ for line in data:
+ parser.feed(line)
+
+ assert parser.is_ready() == is_ready
+ assert parser.result == result
+ assert parser.missed_keys() == missed_keys
+
+
+class TestAIETRunner:
+ """Tests for AIETRunner class."""
+
+ @staticmethod
+ def _setup_aiet(
+ monkeypatch: pytest.MonkeyPatch,
+ available_systems: Optional[List[str]] = None,
+ available_apps: Optional[List[str]] = None,
+ ) -> None:
+ """Set up AIET metadata."""
+
+ def mock_system(system: str) -> MagicMock:
+ """Mock the System instance."""
+ mock = MagicMock()
+ type(mock).name = PropertyMock(return_value=system)
+ return mock
+
+ def mock_app(app: str) -> MagicMock:
+ """Mock the Application instance."""
+ mock = MagicMock()
+ type(mock).name = PropertyMock(return_value=app)
+ mock.can_run_on.return_value = True
+ return mock
+
+ system_mocks = [mock_system(name) for name in (available_systems or [])]
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.get_available_systems",
+ MagicMock(return_value=system_mocks),
+ )
+
+ apps_mock = [mock_app(name) for name in (available_apps or [])]
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.get_available_applications",
+ MagicMock(return_value=apps_mock),
+ )
+
+ @pytest.mark.parametrize(
+ "available_systems, system, installed",
+ [
+ ([], "system1", False),
+ (["system1", "system2"], "system1", True),
+ ],
+ )
+ def test_is_system_installed(
+ self,
+ available_systems: List,
+ system: str,
+ installed: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method is_system_installed."""
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ self._setup_aiet(monkeypatch, available_systems)
+
+ assert aiet_runner.is_system_installed(system) == installed
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_systems, systems",
+ [
+ ([], []),
+ (["system1"], ["system1"]),
+ ],
+ )
+ def test_installed_systems(
+ self,
+ available_systems: List[str],
+ systems: List[str],
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method installed_systems."""
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ self._setup_aiet(monkeypatch, available_systems)
+ assert aiet_runner.get_installed_systems() == systems
+
+ mock_executor.assert_not_called()
+
+ @staticmethod
+ def test_install_system(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test system installation."""
+ install_system_mock = MagicMock()
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.install_system", install_system_mock
+ )
+
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+ aiet_runner.install_system(Path("test_system_path"))
+
+ install_system_mock.assert_called_once_with(Path("test_system_path"))
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_systems, systems, expected_result",
+ [
+ ([], [], False),
+ (["system1"], [], False),
+ (["system1"], ["system1"], True),
+ (["system1", "system2"], ["system1", "system3"], False),
+ (["system1", "system2"], ["system1", "system2"], True),
+ ],
+ )
+ def test_systems_installed(
+ self,
+ available_systems: List[str],
+ systems: List[str],
+ expected_result: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method systems_installed."""
+ self._setup_aiet(monkeypatch, available_systems)
+
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ assert aiet_runner.systems_installed(systems) is expected_result
+
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_apps, applications, expected_result",
+ [
+ ([], [], False),
+ (["app1"], [], False),
+ (["app1"], ["app1"], True),
+ (["app1", "app2"], ["app1", "app3"], False),
+ (["app1", "app2"], ["app1", "app2"], True),
+ ],
+ )
+ def test_applications_installed(
+ self,
+ available_apps: List[str],
+ applications: List[str],
+ expected_result: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method applications_installed."""
+ self._setup_aiet(monkeypatch, [], available_apps)
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ assert aiet_runner.applications_installed(applications) is expected_result
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_apps, applications",
+ [
+ ([], []),
+ (
+ ["application1", "application2"],
+ ["application1", "application2"],
+ ),
+ ],
+ )
+ def test_get_installed_applications(
+ self,
+ available_apps: List[str],
+ applications: List[str],
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method get_installed_applications."""
+ mock_executor = MagicMock()
+ self._setup_aiet(monkeypatch, [], available_apps)
+
+ aiet_runner = AIETRunner(mock_executor)
+ assert applications == aiet_runner.get_installed_applications()
+
+ mock_executor.assert_not_called()
+
+ @staticmethod
+ def test_install_application(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test application installation."""
+ mock_install_application = MagicMock()
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.install_application", mock_install_application
+ )
+
+ mock_executor = MagicMock()
+
+ aiet_runner = AIETRunner(mock_executor)
+ aiet_runner.install_application(Path("test_application_path"))
+ mock_install_application.assert_called_once_with(Path("test_application_path"))
+
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_apps, application, installed",
+ [
+ ([], "system1", False),
+ (
+ ["application1", "application2"],
+ "application1",
+ True,
+ ),
+ (
+ [],
+ "application1",
+ False,
+ ),
+ ],
+ )
+ def test_is_application_installed(
+ self,
+ available_apps: List[str],
+ application: str,
+ installed: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method is_application_installed."""
+ self._setup_aiet(monkeypatch, [], available_apps)
+
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+ assert installed == aiet_runner.is_application_installed(application, "system1")
+
+ mock_executor.assert_not_called()
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "execution_params, expected_command",
+ [
+ (
+ ExecutionParams("application1", "system1", [], [], []),
+ ["aiet", "application", "run", "-n", "application1", "-s", "system1"],
+ ),
+ (
+ ExecutionParams(
+ "application1",
+ "system1",
+ ["input_file=123.txt", "size=777"],
+ ["param1=456", "param2=789"],
+ ["source1.txt:dest1.txt", "source2.txt:dest2.txt"],
+ ),
+ [
+ "aiet",
+ "application",
+ "run",
+ "-n",
+ "application1",
+ "-s",
+ "system1",
+ "-p",
+ "input_file=123.txt",
+ "-p",
+ "size=777",
+ "--system-param",
+ "param1=456",
+ "--system-param",
+ "param2=789",
+ "--deploy",
+ "source1.txt:dest1.txt",
+ "--deploy",
+ "source2.txt:dest2.txt",
+ ],
+ ),
+ ],
+ )
+ def test_run_application(
+ execution_params: ExecutionParams, expected_command: List[str]
+ ) -> None:
+ """Test method run_application."""
+ mock_executor = MagicMock()
+ mock_running_command = MagicMock()
+ mock_executor.submit.return_value = mock_running_command
+
+ aiet_runner = AIETRunner(mock_executor)
+ aiet_runner.run_application(execution_params)
+
+ mock_executor.submit.assert_called_once_with(expected_command)
+
+
+@pytest.mark.parametrize(
+ "device, system, application, backend, expected_error",
+ [
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", True),
+ "Corstone-300",
+ does_not_raise(),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U55", False),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"System Corstone-300: Cortex-M55\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM "
+ r"for the system Corstone-300: Cortex-M55\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-310: Cortex-M85+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", True),
+ "Corstone-310",
+ does_not_raise(),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-310: Cortex-M85+Ethos-U55", False),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-310",
+ pytest.raises(
+ Exception,
+ match=r"System Corstone-310: Cortex-M85\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-310: Cortex-M85+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-310",
+ pytest.raises(
+ Exception,
+ match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM "
+ r"for the system Corstone-310: Cortex-M85\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U65", True),
+ ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", True),
+ "Corstone-300",
+ does_not_raise(),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U65", False),
+ ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"System Corstone-300: Cortex-M55\+Ethos-U65 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U65", True),
+ ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM "
+ r"for the system Corstone-300: Cortex-M55\+Ethos-U65 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(
+ device_type="unknown_device", # type: ignore
+ mac=None, # type: ignore
+ memory_mode="Shared_Sram",
+ ),
+ ("some_system", False),
+ ("some_application", False),
+ "some backend",
+ pytest.raises(Exception, match="Unsupported device unknown_device"),
+ ),
+ ],
+)
+def test_estimate_performance(
+ device: DeviceInfo,
+ system: Tuple[str, bool],
+ application: Tuple[str, bool],
+ backend: str,
+ expected_error: Any,
+ test_tflite_model: Path,
+ aiet_runner: MagicMock,
+) -> None:
+ """Test getting performance estimations."""
+ system_name, system_installed = system
+ application_name, application_installed = application
+
+ aiet_runner.is_system_installed.return_value = system_installed
+ aiet_runner.is_application_installed.return_value = application_installed
+
+ mock_process = create_mock_process(
+ [
+ "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1",
+ "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2",
+ "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3",
+ "NPU ACTIVE cycles: 4",
+ "NPU IDLE cycles: 5",
+ "NPU TOTAL cycles: 6",
+ ],
+ [],
+ )
+
+ mock_generic_inference_run = RunningCommand(mock_process)
+ aiet_runner.run_application.return_value = mock_generic_inference_run
+
+ with expected_error:
+ perf_metrics = estimate_performance(
+ ModelInfo(test_tflite_model), device, backend
+ )
+
+ assert isinstance(perf_metrics, PerformanceMetrics)
+ assert perf_metrics == PerformanceMetrics(
+ npu_axi0_rd_data_beat_received=1,
+ npu_axi0_wr_data_beat_written=2,
+ npu_axi1_rd_data_beat_received=3,
+ npu_active_cycles=4,
+ npu_idle_cycles=5,
+ npu_total_cycles=6,
+ )
+
+ assert aiet_runner.is_system_installed.called_once_with(system_name)
+ assert aiet_runner.is_application_installed.called_once_with(
+ application_name, system_name
+ )
+
+
+@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+def test_estimate_performance_insufficient_data(
+ aiet_runner: MagicMock, test_tflite_model: Path, backend: str
+) -> None:
+ """Test that performance could not be estimated when not all data presented."""
+ aiet_runner.is_system_installed.return_value = True
+ aiet_runner.is_application_installed.return_value = True
+
+ no_total_cycles_output = [
+ "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1",
+ "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2",
+ "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3",
+ "NPU ACTIVE cycles: 4",
+ "NPU IDLE cycles: 5",
+ ]
+ mock_process = create_mock_process(
+ no_total_cycles_output,
+ [],
+ )
+
+ mock_generic_inference_run = RunningCommand(mock_process)
+ aiet_runner.run_application.return_value = mock_generic_inference_run
+
+ with pytest.raises(
+ Exception, match="Unable to get performance metrics, insufficient data"
+ ):
+ device = DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram")
+ estimate_performance(ModelInfo(test_tflite_model), device, backend)
+
+
+@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+def test_estimate_performance_invalid_output(
+ test_tflite_model: Path, aiet_runner: MagicMock, backend: str
+) -> None:
+ """Test estimation could not be done if inference produces unexpected output."""
+ aiet_runner.is_system_installed.return_value = True
+ aiet_runner.is_application_installed.return_value = True
+
+ mock_process = create_mock_process(
+ ["Something", "is", "wrong"], ["What a nice error!"]
+ )
+ aiet_runner.run_application.return_value = RunningCommand(mock_process)
+
+ with pytest.raises(Exception, match="Unable to get performance metrics"):
+ estimate_performance(
+ ModelInfo(test_tflite_model),
+ DeviceInfo(device_type="ethos-u55", mac=256, memory_mode="Shared_Sram"),
+ backend=backend,
+ )
+
+
+def test_get_aiet_runner() -> None:
+ """Test getting aiet runner."""
+ aiet_runner = get_aiet_runner()
+ assert isinstance(aiet_runner, AIETRunner)
+
+
+def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock:
+ """Mock underlying process."""
+ mock_process = MagicMock()
+ mock_process.poll.return_value = 0
+ type(mock_process).stdout = PropertyMock(return_value=iter(stdout))
+ type(mock_process).stderr = PropertyMock(return_value=iter(stderr))
+ return mock_process
+
+
+@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+def test_get_generic_runner(backend: str) -> None:
+ """Test function get_generic_runner()."""
+ device_info = DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram")
+
+ runner = get_generic_runner(device_info=device_info, backend=backend)
+ assert isinstance(runner, GenericInferenceRunnerEthosU)
+
+ with pytest.raises(RuntimeError):
+ get_generic_runner(device_info=device_info, backend="UNKNOWN_BACKEND")
+
+
+@pytest.mark.parametrize(
+ ("backend", "device_type"),
+ (
+ ("Corstone-300", "ethos-u55"),
+ ("Corstone-300", "ethos-u65"),
+ ("Corstone-310", "ethos-u55"),
+ ),
+)
+def test_aiet_backend_support(backend: str, device_type: str) -> None:
+ """Test AIET backend & device support."""
+ assert is_supported(backend)
+ assert is_supported(backend, device_type)
+
+ assert get_system_name(backend, device_type)
+
+ assert backend in supported_backends()
+
+
+class TestGenericInferenceRunnerEthosU:
+ """Test for the class GenericInferenceRunnerEthosU."""
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "device, backend, expected_system, expected_app",
+ [
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"),
+ "Corstone-310",
+ "Corstone-310: Cortex-M85+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Sram"),
+ "Corstone-310",
+ "Corstone-310: Cortex-M85+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55 SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55 SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u65", 256, memory_mode="Dedicated_Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ "Generic Inference Runner: Ethos-U65 Dedicated SRAM",
+ ],
+ ],
+ )
+ def test_artifact_resolver(
+ device: DeviceInfo, backend: str, expected_system: str, expected_app: str
+ ) -> None:
+ """Test artifact resolving based on the provided parameters."""
+ generic_runner = get_generic_runner(device, backend)
+ assert isinstance(generic_runner, GenericInferenceRunnerEthosU)
+
+ assert generic_runner.system_name == expected_system
+ assert generic_runner.app_name == expected_app
+
+ @staticmethod
+ def test_artifact_resolver_unsupported_backend() -> None:
+ """Test that it should be not possible to use unsupported backends."""
+ with pytest.raises(
+ RuntimeError, match="Unsupported device ethos-u65 for backend test_backend"
+ ):
+ get_generic_runner(
+ DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"), "test_backend"
+ )
+
+ @staticmethod
+ def test_artifact_resolver_unsupported_memory_mode() -> None:
+ """Test that it should be not possible to use unsupported memory modes."""
+ with pytest.raises(
+ RuntimeError, match="Unsupported memory mode test_memory_mode"
+ ):
+ get_generic_runner(
+ DeviceInfo(
+ "ethos-u65",
+ 256,
+ memory_mode="test_memory_mode", # type: ignore
+ ),
+ "Corstone-300",
+ )
+
+ @staticmethod
+ @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+ def test_inference_should_fail_if_system_not_installed(
+ aiet_runner: MagicMock, test_tflite_model: Path, backend: str
+ ) -> None:
+ """Test that inference should fail if system is not installed."""
+ aiet_runner.is_system_installed.return_value = False
+
+ generic_runner = get_generic_runner(
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend
+ )
+ with pytest.raises(
+ Exception,
+ match=r"System Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not installed",
+ ):
+ generic_runner.run(ModelInfo(test_tflite_model), [])
+
+ @staticmethod
+ @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+ def test_inference_should_fail_is_apps_not_installed(
+ aiet_runner: MagicMock, test_tflite_model: Path, backend: str
+ ) -> None:
+ """Test that inference should fail if apps are not installed."""
+ aiet_runner.is_system_installed.return_value = True
+ aiet_runner.is_application_installed.return_value = False
+
+ generic_runner = get_generic_runner(
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend
+ )
+ with pytest.raises(
+ Exception,
+ match="Application Generic Inference Runner: Ethos-U55/65 Shared SRAM"
+ r" for the system Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not "
+ r"installed",
+ ):
+ generic_runner.run(ModelInfo(test_tflite_model), [])
+
+
+@pytest.fixture(name="aiet_runner")
+def fixture_aiet_runner(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
+ """Mock AIET runner."""
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.get_aiet_runner",
+ MagicMock(return_value=aiet_runner_mock),
+ )
+ return aiet_runner_mock
diff --git a/tests/mlia/test_tools_metadata_common.py b/tests/mlia/test_tools_metadata_common.py
new file mode 100644
index 0000000..7663b83
--- /dev/null
+++ b/tests/mlia/test_tools_metadata_common.py
@@ -0,0 +1,196 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for commmon installation related functions."""
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from unittest.mock import call
+from unittest.mock import MagicMock
+from unittest.mock import PropertyMock
+
+import pytest
+
+from mlia.tools.metadata.common import DefaultInstallationManager
+from mlia.tools.metadata.common import DownloadAndInstall
+from mlia.tools.metadata.common import Installation
+from mlia.tools.metadata.common import InstallationType
+from mlia.tools.metadata.common import InstallFromPath
+
+
+def get_installation_mock(
+ name: str,
+ already_installed: bool = False,
+ could_be_installed: bool = False,
+ supported_install_type: Optional[type] = None,
+) -> MagicMock:
+ """Get mock instance for the installation."""
+ mock = MagicMock(spec=Installation)
+
+ def supports(install_type: InstallationType) -> bool:
+ if supported_install_type is None:
+ return False
+
+ return isinstance(install_type, supported_install_type)
+
+ mock.supports.side_effect = supports
+
+ props = {
+ "name": name,
+ "already_installed": already_installed,
+ "could_be_installed": could_be_installed,
+ }
+ for prop, value in props.items():
+ setattr(type(mock), prop, PropertyMock(return_value=value))
+
+ return mock
+
+
+def _already_installed_mock() -> MagicMock:
+ return get_installation_mock(
+ name="already_installed",
+ already_installed=True,
+ )
+
+
+def _ready_for_installation_mock() -> MagicMock:
+ return get_installation_mock(
+ name="ready_for_installation",
+ already_installed=False,
+ could_be_installed=True,
+ )
+
+
+def _could_be_downloaded_and_installed_mock() -> MagicMock:
+ return get_installation_mock(
+ name="could_be_downloaded_and_installed",
+ already_installed=False,
+ could_be_installed=True,
+ supported_install_type=DownloadAndInstall,
+ )
+
+
+def _could_be_installed_from_mock() -> MagicMock:
+ return get_installation_mock(
+ name="could_be_installed_from",
+ already_installed=False,
+ could_be_installed=True,
+ supported_install_type=InstallFromPath,
+ )
+
+
+def get_installation_manager(
+ noninteractive: bool,
+ installations: List[Any],
+ monkeypatch: pytest.MonkeyPatch,
+ yes_response: bool = True,
+) -> DefaultInstallationManager:
+ """Get installation manager instance."""
+ if not noninteractive:
+ monkeypatch.setattr(
+ "mlia.tools.metadata.common.yes", MagicMock(return_value=yes_response)
+ )
+
+ return DefaultInstallationManager(installations, noninteractive=noninteractive)
+
+
+def test_installation_manager_filtering() -> None:
+ """Test default installation manager."""
+ already_installed = _already_installed_mock()
+ ready_for_installation = _ready_for_installation_mock()
+ could_be_downloaded_and_installed = _could_be_downloaded_and_installed_mock()
+
+ manager = DefaultInstallationManager(
+ [
+ already_installed,
+ ready_for_installation,
+ could_be_downloaded_and_installed,
+ ]
+ )
+ assert manager.already_installed() == [already_installed]
+ assert manager.ready_for_installation() == [
+ ready_for_installation,
+ could_be_downloaded_and_installed,
+ ]
+ assert manager.could_be_downloaded_and_installed() == [
+ could_be_downloaded_and_installed
+ ]
+ assert manager.could_be_downloaded_and_installed("some_installation") == []
+
+
+@pytest.mark.parametrize("noninteractive", [True, False])
+@pytest.mark.parametrize(
+ "install_mock, eula_agreement, backend_name, expected_call",
+ [
+ [
+ _could_be_downloaded_and_installed_mock(),
+ True,
+ None,
+ [call(DownloadAndInstall(eula_agreement=True))],
+ ],
+ [
+ _could_be_downloaded_and_installed_mock(),
+ False,
+ None,
+ [call(DownloadAndInstall(eula_agreement=False))],
+ ],
+ [
+ _could_be_downloaded_and_installed_mock(),
+ False,
+ "unknown",
+ [],
+ ],
+ ],
+)
+def test_installation_manager_download_and_install(
+ install_mock: MagicMock,
+ noninteractive: bool,
+ eula_agreement: bool,
+ backend_name: Optional[str],
+ expected_call: Any,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test installation process."""
+ install_mock.reset_mock()
+
+ manager = get_installation_manager(noninteractive, [install_mock], monkeypatch)
+
+ manager.download_and_install(backend_name, eula_agreement=eula_agreement)
+ assert install_mock.install.mock_calls == expected_call
+
+
+@pytest.mark.parametrize("noninteractive", [True, False])
+@pytest.mark.parametrize(
+ "install_mock, backend_name, expected_call",
+ [
+ [
+ _could_be_installed_from_mock(),
+ None,
+ [call(InstallFromPath(Path("some_path")))],
+ ],
+ [
+ _could_be_installed_from_mock(),
+ "unknown",
+ [],
+ ],
+ [
+ _already_installed_mock(),
+ "already_installed",
+ [],
+ ],
+ ],
+)
+def test_installation_manager_install_from(
+ install_mock: MagicMock,
+ noninteractive: bool,
+ backend_name: Optional[str],
+ expected_call: Any,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test installation process."""
+ install_mock.reset_mock()
+
+ manager = get_installation_manager(noninteractive, [install_mock], monkeypatch)
+ manager.install_from(Path("some_path"), backend_name)
+
+ assert install_mock.install.mock_calls == expected_call
diff --git a/tests/mlia/test_tools_metadata_corstone.py b/tests/mlia/test_tools_metadata_corstone.py
new file mode 100644
index 0000000..2ce3610
--- /dev/null
+++ b/tests/mlia/test_tools_metadata_corstone.py
@@ -0,0 +1,419 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Corstone related installation functions.."""
+import tarfile
+from pathlib import Path
+from typing import List
+from typing import Optional
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.tools.aiet_wrapper import AIETRunner
+from mlia.tools.metadata.common import DownloadAndInstall
+from mlia.tools.metadata.common import InstallFromPath
+from mlia.tools.metadata.corstone import AIETBasedInstallation
+from mlia.tools.metadata.corstone import AIETMetadata
+from mlia.tools.metadata.corstone import BackendInfo
+from mlia.tools.metadata.corstone import BackendInstaller
+from mlia.tools.metadata.corstone import CompoundPathChecker
+from mlia.tools.metadata.corstone import Corstone300Installer
+from mlia.tools.metadata.corstone import get_corstone_installations
+from mlia.tools.metadata.corstone import PackagePathChecker
+from mlia.tools.metadata.corstone import PathChecker
+from mlia.tools.metadata.corstone import StaticPathChecker
+
+
+@pytest.fixture(name="test_mlia_resources")
+def fixture_test_mlia_resources(
+ tmp_path: Path, monkeypatch: pytest.MonkeyPatch
+) -> Path:
+ """Redirect MLIA resources resolution to the temp directory."""
+ mlia_resources = tmp_path / "resources"
+ mlia_resources.mkdir()
+
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.get_mlia_resources",
+ MagicMock(return_value=mlia_resources),
+ )
+
+ return mlia_resources
+
+
+def get_aiet_based_installation( # pylint: disable=too-many-arguments
+ aiet_runner_mock: MagicMock = MagicMock(),
+ name: str = "test_name",
+ description: str = "test_description",
+ download_artifact: Optional[MagicMock] = None,
+ path_checker: PathChecker = MagicMock(),
+ apps_resources: Optional[List[str]] = None,
+ system_config: Optional[str] = None,
+ backend_installer: BackendInstaller = MagicMock(),
+ supported_platforms: Optional[List[str]] = None,
+) -> AIETBasedInstallation:
+ """Get AIET based installation."""
+ return AIETBasedInstallation(
+ aiet_runner=aiet_runner_mock,
+ metadata=AIETMetadata(
+ name=name,
+ description=description,
+ system_config=system_config or "",
+ apps_resources=apps_resources or [],
+ fvp_dir_name="sample_dir",
+ download_artifact=download_artifact,
+ supported_platforms=supported_platforms,
+ ),
+ path_checker=path_checker,
+ backend_installer=backend_installer,
+ )
+
+
+@pytest.mark.parametrize(
+ "platform, supported_platforms, expected_result",
+ [
+ ["Linux", ["Linux"], True],
+ ["Linux", [], True],
+ ["Linux", None, True],
+ ["Windows", ["Linux"], False],
+ ],
+)
+def test_could_be_installed_depends_on_platform(
+ platform: str,
+ supported_platforms: Optional[List[str]],
+ expected_result: bool,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test that installation could not be installed on unsupported platform."""
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.platform.system", MagicMock(return_value=platform)
+ )
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.all_paths_valid", MagicMock(return_value=True)
+ )
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+
+ installation = get_aiet_based_installation(
+ aiet_runner_mock,
+ supported_platforms=supported_platforms,
+ )
+ assert installation.could_be_installed == expected_result
+
+
+def test_get_corstone_installations() -> None:
+ """Test function get_corstone_installation."""
+ installs = get_corstone_installations()
+ assert len(installs) == 2
+ assert all(isinstance(install, AIETBasedInstallation) for install in installs)
+
+
+def test_aiet_based_installation_metadata_resolving() -> None:
+ """Test AIET based installation metadata resolving."""
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(aiet_runner_mock)
+
+ assert installation.name == "test_name"
+ assert installation.description == "test_description"
+
+ aiet_runner_mock.all_installed.return_value = False
+ assert installation.already_installed is False
+
+ assert installation.could_be_installed is True
+
+
+def test_aiet_based_installation_supported_install_types(tmp_path: Path) -> None:
+ """Test supported installation types."""
+ installation_no_download_artifact = get_aiet_based_installation()
+ assert installation_no_download_artifact.supports(DownloadAndInstall()) is False
+
+ installation_with_download_artifact = get_aiet_based_installation(
+ download_artifact=MagicMock()
+ )
+ assert installation_with_download_artifact.supports(DownloadAndInstall()) is True
+
+ path_checker_mock = MagicMock(return_value=BackendInfo(tmp_path))
+ installation_can_install_from_dir = get_aiet_based_installation(
+ path_checker=path_checker_mock
+ )
+ assert installation_can_install_from_dir.supports(InstallFromPath(tmp_path)) is True
+
+ any_installation = get_aiet_based_installation()
+ assert any_installation.supports("unknown_install_type") is False # type: ignore
+
+
+def test_aiet_based_installation_install_wrong_type() -> None:
+ """Test that operation should fail if wrong install type provided."""
+ with pytest.raises(Exception, match="Unable to install wrong_install_type"):
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(aiet_runner_mock)
+
+ installation.install("wrong_install_type") # type: ignore
+
+
+def test_aiet_based_installation_install_from_path(
+ tmp_path: Path, test_mlia_resources: Path
+) -> None:
+ """Test installation from the path."""
+ system_config = test_mlia_resources / "example_config.json"
+ system_config.touch()
+
+ sample_app = test_mlia_resources / "sample_app"
+ sample_app.mkdir()
+
+ dist_dir = tmp_path / "dist"
+ dist_dir.mkdir()
+
+ path_checker_mock = MagicMock(return_value=BackendInfo(dist_dir))
+
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(
+ aiet_runner_mock=aiet_runner_mock,
+ path_checker=path_checker_mock,
+ apps_resources=[sample_app.name],
+ system_config="example_config.json",
+ )
+
+ assert installation.supports(InstallFromPath(dist_dir)) is True
+ installation.install(InstallFromPath(dist_dir))
+
+ aiet_runner_mock.install_system.assert_called_once()
+ aiet_runner_mock.install_application.assert_called_once_with(sample_app)
+
+
+@pytest.mark.parametrize("copy_source", [True, False])
+def test_aiet_based_installation_install_from_static_path(
+ tmp_path: Path, test_mlia_resources: Path, copy_source: bool
+) -> None:
+ """Test installation from the predefined path."""
+ system_config = test_mlia_resources / "example_config.json"
+ system_config.touch()
+
+ custom_system_config = test_mlia_resources / "custom_config.json"
+ custom_system_config.touch()
+
+ sample_app = test_mlia_resources / "sample_app"
+ sample_app.mkdir()
+
+ predefined_location = tmp_path / "backend"
+ predefined_location.mkdir()
+
+ predefined_location_file = predefined_location / "file.txt"
+ predefined_location_file.touch()
+
+ predefined_location_dir = predefined_location / "folder"
+ predefined_location_dir.mkdir()
+ nested_file = predefined_location_dir / "nested_file.txt"
+ nested_file.touch()
+
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+
+ def check_install_dir(install_dir: Path) -> None:
+ """Check content of the install dir."""
+ assert install_dir.is_dir()
+ files = list(install_dir.iterdir())
+
+ if copy_source:
+ assert len(files) == 3
+ assert all(install_dir / item in files for item in ["file.txt", "folder"])
+ assert (install_dir / "folder/nested_file.txt").is_file()
+ else:
+ assert len(files) == 1
+
+ assert install_dir / "custom_config.json" in files
+
+ aiet_runner_mock.install_system.side_effect = check_install_dir
+
+ installation = get_aiet_based_installation(
+ aiet_runner_mock=aiet_runner_mock,
+ path_checker=StaticPathChecker(
+ predefined_location,
+ ["file.txt"],
+ copy_source=copy_source,
+ system_config=str(custom_system_config),
+ ),
+ apps_resources=[sample_app.name],
+ system_config="example_config.json",
+ )
+
+ assert installation.supports(InstallFromPath(predefined_location)) is True
+ installation.install(InstallFromPath(predefined_location))
+
+ aiet_runner_mock.install_system.assert_called_once()
+ aiet_runner_mock.install_application.assert_called_once_with(sample_app)
+
+
+def create_sample_fvp_archive(tmp_path: Path) -> Path:
+ """Create sample FVP tar archive."""
+ fvp_archive_dir = tmp_path / "archive"
+ fvp_archive_dir.mkdir()
+
+ sample_file = fvp_archive_dir / "sample.txt"
+ sample_file.write_text("Sample file")
+
+ sample_dir = fvp_archive_dir / "sample_dir"
+ sample_dir.mkdir()
+
+ fvp_archive = tmp_path / "archive.tgz"
+ with tarfile.open(fvp_archive, "w:gz") as fvp_archive_tar:
+ fvp_archive_tar.add(fvp_archive_dir, arcname=fvp_archive_dir.name)
+
+ return fvp_archive
+
+
+def test_aiet_based_installation_download_and_install(
+ test_mlia_resources: Path, tmp_path: Path
+) -> None:
+ """Test downloading and installation process."""
+ fvp_archive = create_sample_fvp_archive(tmp_path)
+
+ system_config = test_mlia_resources / "example_config.json"
+ system_config.touch()
+
+ download_artifact_mock = MagicMock()
+ download_artifact_mock.download_to.return_value = fvp_archive
+
+ path_checker = PackagePathChecker(["archive/sample.txt"], "archive/sample_dir")
+
+ def installer(_eula_agreement: bool, dist_dir: Path) -> Path:
+ """Sample installer."""
+ return dist_dir
+
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(
+ aiet_runner_mock,
+ download_artifact=download_artifact_mock,
+ backend_installer=installer,
+ path_checker=path_checker,
+ system_config="example_config.json",
+ )
+
+ installation.install(DownloadAndInstall())
+
+ aiet_runner_mock.install_system.assert_called_once()
+
+
+@pytest.mark.parametrize(
+ "dir_content, expected_result",
+ [
+ [
+ ["models/", "file1.txt", "file2.txt"],
+ "models",
+ ],
+ [
+ ["file1.txt", "file2.txt"],
+ None,
+ ],
+ [
+ ["models/", "file2.txt"],
+ None,
+ ],
+ ],
+)
+def test_corstone_path_checker_valid_path(
+ tmp_path: Path, dir_content: List[str], expected_result: Optional[str]
+) -> None:
+ """Test Corstone path checker valid scenario."""
+ path_checker = PackagePathChecker(["file1.txt", "file2.txt"], "models")
+
+ for item in dir_content:
+ if item.endswith("/"):
+ item_dir = tmp_path / item
+ item_dir.mkdir()
+ else:
+ item_file = tmp_path / item
+ item_file.touch()
+
+ result = path_checker(tmp_path)
+ expected = (
+ None if expected_result is None else BackendInfo(tmp_path / expected_result)
+ )
+
+ assert result == expected
+
+
+@pytest.mark.parametrize("system_config", [None, "system_config"])
+@pytest.mark.parametrize("copy_source", [True, False])
+def test_static_path_checker(
+ tmp_path: Path, copy_source: bool, system_config: Optional[str]
+) -> None:
+ """Test static path checker."""
+ static_checker = StaticPathChecker(
+ tmp_path, [], copy_source=copy_source, system_config=system_config
+ )
+ assert static_checker(tmp_path) == BackendInfo(
+ tmp_path, copy_source=copy_source, system_config=system_config
+ )
+
+
+def test_static_path_checker_not_valid_path(tmp_path: Path) -> None:
+ """Test static path checker should return None if path is not valid."""
+ static_checker = StaticPathChecker(tmp_path, ["file.txt"])
+ assert static_checker(tmp_path / "backend") is None
+
+
+def test_static_path_checker_not_valid_structure(tmp_path: Path) -> None:
+ """Test static path checker should return None if files are missing."""
+ static_checker = StaticPathChecker(tmp_path, ["file.txt"])
+ assert static_checker(tmp_path) is None
+
+ missing_file = tmp_path / "file.txt"
+ missing_file.touch()
+
+ assert static_checker(tmp_path) == BackendInfo(tmp_path, copy_source=False)
+
+
+def test_compound_path_checker(tmp_path: Path) -> None:
+ """Test compound path checker."""
+ path_checker_path_valid_path = MagicMock(return_value=BackendInfo(tmp_path))
+ path_checker_path_not_valid_path = MagicMock(return_value=None)
+
+ checker = CompoundPathChecker(
+ path_checker_path_valid_path, path_checker_path_not_valid_path
+ )
+ assert checker(tmp_path) == BackendInfo(tmp_path)
+
+ checker = CompoundPathChecker(path_checker_path_not_valid_path)
+ assert checker(tmp_path) is None
+
+
+@pytest.mark.parametrize(
+ "eula_agreement, expected_command",
+ [
+ [
+ True,
+ [
+ "./FVP_Corstone_SSE-300.sh",
+ "-q",
+ "-d",
+ "corstone-300",
+ ],
+ ],
+ [
+ False,
+ [
+ "./FVP_Corstone_SSE-300.sh",
+ "-q",
+ "-d",
+ "corstone-300",
+ "--nointeractive",
+ "--i-agree-to-the-contained-eula",
+ ],
+ ],
+ ],
+)
+def test_corstone_300_installer(
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+ eula_agreement: bool,
+ expected_command: List[str],
+) -> None:
+ """Test Corstone-300 installer."""
+ command_mock = MagicMock()
+
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.subprocess.check_call", command_mock
+ )
+ installer = Corstone300Installer()
+ result = installer(eula_agreement, tmp_path)
+
+ command_mock.assert_called_once_with(expected_command)
+ assert result == tmp_path / "corstone-300"
diff --git a/tests/mlia/test_tools_vela_wrapper.py b/tests/mlia/test_tools_vela_wrapper.py
new file mode 100644
index 0000000..875d2ff
--- /dev/null
+++ b/tests/mlia/test_tools_vela_wrapper.py
@@ -0,0 +1,285 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module tools/vela_wrapper."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+from ethosu.vela.compiler_driver import TensorAllocator
+from ethosu.vela.scheduler import OptimizationStrategy
+
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.tools.vela_wrapper import estimate_performance
+from mlia.tools.vela_wrapper import generate_supported_operators_report
+from mlia.tools.vela_wrapper import NpuSupported
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+from mlia.tools.vela_wrapper import optimize_model
+from mlia.tools.vela_wrapper import OptimizedModel
+from mlia.tools.vela_wrapper import PerformanceMetrics
+from mlia.tools.vela_wrapper import supported_operators
+from mlia.tools.vela_wrapper import VelaCompiler
+from mlia.tools.vela_wrapper import VelaCompilerOptions
+from mlia.utils.proc import working_directory
+
+
+def test_default_vela_compiler() -> None:
+ """Test default Vela compiler instance."""
+ default_compiler_options = VelaCompilerOptions(accelerator_config="ethos-u55-256")
+ default_compiler = VelaCompiler(default_compiler_options)
+
+ assert default_compiler.config_files is None
+ assert default_compiler.system_config == "internal-default"
+ assert default_compiler.memory_mode == "internal-default"
+ assert default_compiler.accelerator_config == "ethos-u55-256"
+ assert default_compiler.max_block_dependency == 3
+ assert default_compiler.arena_cache_size is None
+ assert default_compiler.tensor_allocator == TensorAllocator.HillClimb
+ assert default_compiler.cpu_tensor_alignment == 16
+ assert default_compiler.optimization_strategy == OptimizationStrategy.Performance
+ assert default_compiler.output_dir is None
+
+ assert default_compiler.get_config() == {
+ "accelerator_config": "ethos-u55-256",
+ "system_config": "internal-default",
+ "core_clock": 500000000.0,
+ "axi0_port": "Sram",
+ "axi1_port": "OffChipFlash",
+ "memory_mode": "internal-default",
+ "const_mem_area": "Axi1",
+ "arena_mem_area": "Axi0",
+ "cache_mem_area": "Axi0",
+ "arena_cache_size": 4294967296,
+ "permanent_storage_mem_area": "OffChipFlash",
+ "feature_map_storage_mem_area": "Sram",
+ "fast_storage_mem_area": "Sram",
+ "memory_area": {
+ "Sram": {
+ "clock_scales": 1.0,
+ "burst_length": 32,
+ "read_latency": 32,
+ "write_latency": 32,
+ },
+ "Dram": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ "OnChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ "OffChipFlash": {
+ "clock_scales": 0.125,
+ "burst_length": 128,
+ "read_latency": 64,
+ "write_latency": 64,
+ },
+ },
+ }
+
+
+def test_vela_compiler_with_parameters(test_resources_path: Path) -> None:
+ """Test creation of Vela compiler instance with non-default params."""
+ vela_ini_path = str(test_resources_path / "vela/sample_vela.ini")
+
+ compiler_options = VelaCompilerOptions(
+ config_files=vela_ini_path,
+ system_config="Ethos_U65_High_End",
+ memory_mode="Shared_Sram",
+ accelerator_config="ethos-u65-256",
+ max_block_dependency=1,
+ arena_cache_size=10,
+ tensor_allocator="Greedy",
+ cpu_tensor_alignment=4,
+ optimization_strategy="Size",
+ output_dir="output",
+ )
+ compiler = VelaCompiler(compiler_options)
+
+ assert compiler.config_files == vela_ini_path
+ assert compiler.system_config == "Ethos_U65_High_End"
+ assert compiler.memory_mode == "Shared_Sram"
+ assert compiler.accelerator_config == "ethos-u65-256"
+ assert compiler.max_block_dependency == 1
+ assert compiler.arena_cache_size == 10
+ assert compiler.tensor_allocator == TensorAllocator.Greedy
+ assert compiler.cpu_tensor_alignment == 4
+ assert compiler.optimization_strategy == OptimizationStrategy.Size
+ assert compiler.output_dir == "output"
+
+ assert compiler.get_config() == {
+ "accelerator_config": "ethos-u65-256",
+ "system_config": "Ethos_U65_High_End",
+ "core_clock": 1000000000.0,
+ "axi0_port": "Sram",
+ "axi1_port": "Dram",
+ "memory_mode": "Shared_Sram",
+ "const_mem_area": "Axi1",
+ "arena_mem_area": "Axi0",
+ "cache_mem_area": "Axi0",
+ "arena_cache_size": 10,
+ "permanent_storage_mem_area": "Dram",
+ "feature_map_storage_mem_area": "Sram",
+ "fast_storage_mem_area": "Sram",
+ "memory_area": {
+ "Sram": {
+ "clock_scales": 1.0,
+ "burst_length": 32,
+ "read_latency": 32,
+ "write_latency": 32,
+ },
+ "Dram": {
+ "clock_scales": 0.234375,
+ "burst_length": 128,
+ "read_latency": 500,
+ "write_latency": 250,
+ },
+ "OnChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ "OffChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ },
+ }
+
+
+def test_compile_model(test_tflite_model: Path) -> None:
+ """Test model optimization."""
+ compiler = VelaCompiler(EthosUConfiguration("ethos-u55-256").compiler_options)
+
+ optimized_model = compiler.compile_model(test_tflite_model)
+ assert isinstance(optimized_model, OptimizedModel)
+
+
+def test_optimize_model(tmp_path: Path, test_tflite_model: Path) -> None:
+ """Test model optimization and saving into file."""
+ tmp_file = tmp_path / "temp.tflite"
+
+ device = EthosUConfiguration("ethos-u55-256")
+ optimize_model(test_tflite_model, device.compiler_options, tmp_file.absolute())
+
+ assert tmp_file.is_file()
+ assert tmp_file.stat().st_size > 0
+
+
+@pytest.mark.parametrize(
+ "model, expected_ops",
+ [
+ (
+ "test_model.tflite",
+ Operators(
+ ops=[
+ Operator(
+ name="sequential/conv1/Relu;sequential/conv1/BiasAdd;"
+ "sequential/conv2/Conv2D;sequential/conv1/Conv2D",
+ op_type="CONV_2D",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="sequential/conv2/Relu;sequential/conv2/BiasAdd;"
+ "sequential/conv2/Conv2D",
+ op_type="CONV_2D",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="sequential/max_pooling2d/MaxPool",
+ op_type="MAX_POOL_2D",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="sequential/flatten/Reshape",
+ op_type="RESHAPE",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="Identity",
+ op_type="FULLY_CONNECTED",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ ]
+ ),
+ )
+ ],
+)
+def test_operators(test_models_path: Path, model: str, expected_ops: Operators) -> None:
+ """Test operators function."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ operators = supported_operators(test_models_path / model, device.compiler_options)
+ for expected, actual in zip(expected_ops.ops, operators.ops):
+ # do not compare names as they could be different on each model generation
+ assert expected.op_type == actual.op_type
+ assert expected.run_on_npu == actual.run_on_npu
+
+
+def test_estimate_performance(test_tflite_model: Path) -> None:
+ """Test getting performance estimations."""
+ device = EthosUConfiguration("ethos-u55-256")
+ perf_metrics = estimate_performance(test_tflite_model, device.compiler_options)
+
+ assert isinstance(perf_metrics, PerformanceMetrics)
+
+
+def test_estimate_performance_already_optimized(
+ tmp_path: Path, test_tflite_model: Path
+) -> None:
+ """Test that performance estimation should fail for already optimized model."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ optimized_model_path = tmp_path / "optimized_model.tflite"
+
+ optimize_model(test_tflite_model, device.compiler_options, optimized_model_path)
+
+ with pytest.raises(
+ Exception, match="Unable to estimate performance for the given optimized model"
+ ):
+ estimate_performance(optimized_model_path, device.compiler_options)
+
+
+def test_generate_supported_operators_report(tmp_path: Path) -> None:
+ """Test generating supported operators report."""
+ with working_directory(tmp_path):
+ generate_supported_operators_report()
+
+ md_file = tmp_path / "SUPPORTED_OPS.md"
+ assert md_file.is_file()
+ assert md_file.stat().st_size > 0
+
+
+def test_read_invalid_model(test_tflite_invalid_model: Path) -> None:
+ """Test that reading invalid model should fail with exception."""
+ with pytest.raises(
+ Exception, match=f"Unable to read model {test_tflite_invalid_model}"
+ ):
+ device = EthosUConfiguration("ethos-u55-256")
+ estimate_performance(test_tflite_invalid_model, device.compiler_options)
+
+
+def test_compile_invalid_model(
+ test_tflite_model: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
+) -> None:
+ """Test that if model could not be compiled then correct exception raised."""
+ mock_compiler = MagicMock()
+ mock_compiler.side_effect = Exception("Bad model!")
+
+ monkeypatch.setattr("mlia.tools.vela_wrapper.compiler_driver", mock_compiler)
+
+ model_path = tmp_path / "optimized_model.tflite"
+ with pytest.raises(
+ Exception, match="Model could not be optimized with Vela compiler"
+ ):
+ device = EthosUConfiguration("ethos-u55-256")
+ optimize_model(test_tflite_model, device.compiler_options, model_path)
+
+ assert not model_path.exists()
diff --git a/tests/mlia/test_utils_console.py b/tests/mlia/test_utils_console.py
new file mode 100644
index 0000000..36975f8
--- /dev/null
+++ b/tests/mlia/test_utils_console.py
@@ -0,0 +1,100 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for console utility functions."""
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+import pytest
+
+from mlia.utils.console import apply_style
+from mlia.utils.console import create_section_header
+from mlia.utils.console import produce_table
+from mlia.utils.console import remove_ascii_codes
+
+
+@pytest.mark.parametrize(
+ "rows, headers, table_style, expected_result",
+ [
+ [[], [], "no_borders", ""],
+ [
+ [["1", "2", "3"]],
+ ["Col 1", "Col 2", "Col 3"],
+ "default",
+ """
+┌───────┬───────┬───────┐
+│ Col 1 │ Col 2 │ Col 3 │
+╞═══════╪═══════╪═══════╡
+│ 1 │ 2 │ 3 │
+└───────┴───────┴───────┘
+""".strip(),
+ ],
+ [
+ [["1", "2", "3"]],
+ ["Col 1", "Col 2", "Col 3"],
+ "nested",
+ "Col 1 Col 2 Col 3 \n \n1 2 3",
+ ],
+ [
+ [["1", "2", "3"]],
+ ["Col 1", "Col 2", "Col 3"],
+ "no_borders",
+ " Col 1 Col 2 Col 3 \n 1 2 3",
+ ],
+ ],
+)
+def test_produce_table(
+ rows: Iterable, headers: Optional[List[str]], table_style: str, expected_result: str
+) -> None:
+ """Test produce_table function."""
+ result = produce_table(rows, headers, table_style)
+ assert remove_ascii_codes(result) == expected_result
+
+
+def test_produce_table_unknown_style() -> None:
+ """Test that function should fail if unknown style provided."""
+ with pytest.raises(Exception, match="Unsupported table style unknown_style"):
+ produce_table([["1", "2", "3"]], [], "unknown_style")
+
+
+@pytest.mark.parametrize(
+ "value, expected_result",
+ [
+ ["some text", "some text"],
+ ["\033[32msome text\033[0m", "some text"],
+ ],
+)
+def test_remove_ascii_codes(value: str, expected_result: str) -> None:
+ """Test remove_ascii_codes function."""
+ assert remove_ascii_codes(value) == expected_result
+
+
+def test_apply_style() -> None:
+ """Test function apply_style."""
+ assert apply_style("some text", "green") == "[green]some text"
+
+
+@pytest.mark.parametrize(
+ "section_header, expected_result",
+ [
+ [
+ "Section header",
+ "\n--- Section header -------------------------------"
+ "------------------------------\n",
+ ],
+ [
+ "",
+ f"\n{'-' * 80}\n",
+ ],
+ ],
+)
+def test_create_section_header(section_header: str, expected_result: str) -> None:
+ """Test function test_create_section."""
+ assert create_section_header(section_header) == expected_result
+
+
+def test_create_section_header_too_long_value() -> None:
+ """Test that header could not be created for the too long section names."""
+ section_name = "section name" * 100
+ with pytest.raises(ValueError, match="Section name too long"):
+ create_section_header(section_name)
diff --git a/tests/mlia/test_utils_download.py b/tests/mlia/test_utils_download.py
new file mode 100644
index 0000000..4f8e2dc
--- /dev/null
+++ b/tests/mlia/test_utils_download.py
@@ -0,0 +1,147 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for download functionality."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Iterable
+from typing import Optional
+from unittest.mock import MagicMock
+from unittest.mock import PropertyMock
+
+import pytest
+import requests
+
+from mlia.utils.download import download
+from mlia.utils.download import DownloadArtifact
+
+
+def response_mock(
+ content_length: Optional[str], content_chunks: Iterable[bytes]
+) -> MagicMock:
+ """Mock response object."""
+ mock = MagicMock(spec=requests.Response)
+ mock.__enter__.return_value = mock
+
+ type(mock).headers = PropertyMock(return_value={"Content-Length": content_length})
+ mock.iter_content.return_value = content_chunks
+
+ return mock
+
+
+@pytest.mark.parametrize("show_progress", [True, False])
+@pytest.mark.parametrize(
+ "content_length, content_chunks, label",
+ [
+ [
+ "5",
+ [bytes(range(5))],
+ "Downloading artifact",
+ ],
+ [
+ "10",
+ [bytes(range(5)), bytes(range(5))],
+ None,
+ ],
+ [
+ None,
+ [bytes(range(5))],
+ "Downlading no size",
+ ],
+ [
+ "abc",
+ [bytes(range(5))],
+ "Downloading wrong size",
+ ],
+ ],
+)
+def test_download(
+ show_progress: bool,
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+ content_length: Optional[str],
+ content_chunks: Iterable[bytes],
+ label: Optional[str],
+) -> None:
+ """Test function download."""
+ monkeypatch.setattr(
+ "mlia.utils.download.requests.get",
+ MagicMock(return_value=response_mock(content_length, content_chunks)),
+ )
+
+ dest = tmp_path / "sample.bin"
+ download("some_url", dest, show_progress=show_progress, label=label)
+
+ assert dest.is_file()
+ assert dest.read_bytes() == bytes(
+ byte for chunk in content_chunks for byte in chunk
+ )
+
+
+@pytest.mark.parametrize(
+ "content_length, content_chunks, sha256_hash, expected_error",
+ [
+ [
+ "10",
+ [bytes(range(10))],
+ "1f825aa2f0020ef7cf91dfa30da4668d791c5d4824fc8e41354b89ec05795ab3",
+ does_not_raise(),
+ ],
+ [
+ "10",
+ [bytes(range(10))],
+ "bad_hash",
+ pytest.raises(ValueError, match="Digests do not match"),
+ ],
+ ],
+)
+def test_download_artifact_download_to(
+ monkeypatch: pytest.MonkeyPatch,
+ content_length: Optional[str],
+ content_chunks: Iterable[bytes],
+ sha256_hash: str,
+ expected_error: Any,
+ tmp_path: Path,
+) -> None:
+ """Test artifact downloading."""
+ monkeypatch.setattr(
+ "mlia.utils.download.requests.get",
+ MagicMock(return_value=response_mock(content_length, content_chunks)),
+ )
+
+ with expected_error:
+ artifact = DownloadArtifact(
+ "test_artifact",
+ "some_url",
+ "artifact_filename",
+ "1.0",
+ sha256_hash,
+ )
+
+ dest = artifact.download_to(tmp_path)
+ assert isinstance(dest, Path)
+ assert dest.name == "artifact_filename"
+
+
+def test_download_artifact_unable_to_overwrite(
+ monkeypatch: pytest.MonkeyPatch, tmp_path: Path
+) -> None:
+ """Test that download process cannot overwrite file."""
+ monkeypatch.setattr(
+ "mlia.utils.download.requests.get",
+ MagicMock(return_value=response_mock("10", [bytes(range(10))])),
+ )
+
+ artifact = DownloadArtifact(
+ "test_artifact",
+ "some_url",
+ "artifact_filename",
+ "1.0",
+ "sha256_hash",
+ )
+
+ existing_file = tmp_path / "artifact_filename"
+ existing_file.touch()
+
+ with pytest.raises(ValueError, match=f"{existing_file} already exists"):
+ artifact.download_to(tmp_path)
diff --git a/tests/mlia/test_utils_filesystem.py b/tests/mlia/test_utils_filesystem.py
new file mode 100644
index 0000000..4d8d955
--- /dev/null
+++ b/tests/mlia/test_utils_filesystem.py
@@ -0,0 +1,166 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the filesystem module."""
+import contextlib
+import json
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.utils.filesystem import all_files_exist
+from mlia.utils.filesystem import all_paths_valid
+from mlia.utils.filesystem import copy_all
+from mlia.utils.filesystem import get_mlia_resources
+from mlia.utils.filesystem import get_profile
+from mlia.utils.filesystem import get_profiles_data
+from mlia.utils.filesystem import get_profiles_file
+from mlia.utils.filesystem import get_supported_profile_names
+from mlia.utils.filesystem import get_vela_config
+from mlia.utils.filesystem import sha256
+from mlia.utils.filesystem import temp_directory
+from mlia.utils.filesystem import temp_file
+
+
+def test_get_mlia_resources() -> None:
+ """Test resources getter."""
+ assert get_mlia_resources().is_dir()
+
+
+def test_get_vela_config() -> None:
+ """Test Vela config files getter."""
+ assert get_vela_config().is_file()
+ assert get_vela_config().name == "vela.ini"
+
+
+def test_profiles_file() -> None:
+ """Test profiles file getter."""
+ assert get_profiles_file().is_file()
+ assert get_profiles_file().name == "profiles.json"
+
+
+def test_profiles_data() -> None:
+ """Test profiles data getter."""
+ assert list(get_profiles_data().keys()) == [
+ "ethos-u55-256",
+ "ethos-u55-128",
+ "ethos-u65-512",
+ ]
+
+
+def test_profiles_data_wrong_format(
+ monkeypatch: pytest.MonkeyPatch, tmp_path: Path
+) -> None:
+ """Test if profile data has wrong format."""
+ wrong_profile_data = tmp_path / "bad.json"
+ with open(wrong_profile_data, "w", encoding="utf-8") as file:
+ json.dump([], file)
+
+ monkeypatch.setattr(
+ "mlia.utils.filesystem.get_profiles_file",
+ MagicMock(return_value=wrong_profile_data),
+ )
+
+ with pytest.raises(Exception, match="Profiles data format is not valid"):
+ get_profiles_data()
+
+
+def test_get_supported_profile_names() -> None:
+ """Test profile names getter."""
+ assert list(get_supported_profile_names()) == [
+ "ethos-u55-256",
+ "ethos-u55-128",
+ "ethos-u65-512",
+ ]
+
+
+def test_get_profile() -> None:
+ """Test getting profile data."""
+ assert get_profile("ethos-u55-256") == {
+ "target": "ethos-u55",
+ "mac": 256,
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "memory_mode": "Shared_Sram",
+ }
+
+ with pytest.raises(Exception, match="Unable to find target profile unknown"):
+ get_profile("unknown")
+
+
+@pytest.mark.parametrize("raise_exception", [True, False])
+def test_temp_file(raise_exception: bool) -> None:
+ """Test temp_file context manager."""
+ with contextlib.suppress(Exception):
+ with temp_file() as tmp_path:
+ assert tmp_path.is_file()
+
+ if raise_exception:
+ raise Exception("Error!")
+
+ assert not tmp_path.exists()
+
+
+def test_sha256(tmp_path: Path) -> None:
+ """Test getting sha256 hash."""
+ sample = tmp_path / "sample.txt"
+
+ with open(sample, "w", encoding="utf-8") as file:
+ file.write("123")
+
+ expected_hash = "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"
+ assert sha256(sample) == expected_hash
+
+
+def test_temp_dir_context_manager() -> None:
+ """Test context manager for temporary directories."""
+ with temp_directory() as tmpdir:
+ assert isinstance(tmpdir, Path)
+ assert tmpdir.is_dir()
+
+ assert not tmpdir.exists()
+
+
+def test_all_files_exist(tmp_path: Path) -> None:
+ """Test function all_files_exist."""
+ sample1 = tmp_path / "sample1.txt"
+ sample1.touch()
+
+ sample2 = tmp_path / "sample2.txt"
+ sample2.touch()
+
+ sample3 = tmp_path / "sample3.txt"
+
+ assert all_files_exist([sample1, sample2]) is True
+ assert all_files_exist([sample1, sample2, sample3]) is False
+
+
+def test_all_paths_valid(tmp_path: Path) -> None:
+ """Test function all_paths_valid."""
+ sample = tmp_path / "sample.txt"
+ sample.touch()
+
+ sample_dir = tmp_path / "sample_dir"
+ sample_dir.mkdir()
+
+ unknown = tmp_path / "unknown.txt"
+
+ assert all_paths_valid([sample, sample_dir]) is True
+ assert all_paths_valid([sample, sample_dir, unknown]) is False
+
+
+def test_copy_all(tmp_path: Path) -> None:
+ """Test function copy_all."""
+ sample = tmp_path / "sample1.txt"
+ sample.touch()
+
+ sample_dir = tmp_path / "sample_dir"
+ sample_dir.mkdir()
+
+ sample_nested_file = sample_dir / "sample_nested.txt"
+ sample_nested_file.touch()
+
+ dest_dir = tmp_path / "dest"
+ copy_all(sample, sample_dir, dest=dest_dir)
+
+ assert (dest_dir / sample.name).is_file()
+ assert (dest_dir / sample_nested_file.name).is_file()
diff --git a/tests/mlia/test_utils_logging.py b/tests/mlia/test_utils_logging.py
new file mode 100644
index 0000000..75ebceb
--- /dev/null
+++ b/tests/mlia/test_utils_logging.py
@@ -0,0 +1,63 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the logging utility functions."""
+import logging
+import sys
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Optional
+
+import pytest
+
+from mlia.cli.logging import create_log_handler
+
+
+@pytest.mark.parametrize(
+ "file_path, stream, log_filter, delay, expected_error, expected_class",
+ [
+ (
+ "test.log",
+ None,
+ None,
+ True,
+ does_not_raise(),
+ logging.FileHandler,
+ ),
+ (
+ None,
+ sys.stdout,
+ None,
+ None,
+ does_not_raise(),
+ logging.StreamHandler,
+ ),
+ (
+ None,
+ None,
+ None,
+ None,
+ pytest.raises(Exception, match="Unable to create logging handler"),
+ None,
+ ),
+ ],
+)
+def test_create_log_handler(
+ file_path: Optional[Path],
+ stream: Optional[Any],
+ log_filter: Optional[logging.Filter],
+ delay: bool,
+ expected_error: Any,
+ expected_class: type,
+) -> None:
+ """Test function test_create_log_handler."""
+ with expected_error:
+ handler = create_log_handler(
+ file_path=file_path,
+ stream=stream,
+ log_level=logging.INFO,
+ log_format="%(name)s - %(message)s",
+ log_filter=log_filter,
+ delay=delay,
+ )
+ assert isinstance(handler, expected_class)
diff --git a/tests/mlia/test_utils_misc.py b/tests/mlia/test_utils_misc.py
new file mode 100644
index 0000000..011d09e
--- /dev/null
+++ b/tests/mlia/test_utils_misc.py
@@ -0,0 +1,25 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for misc util functions."""
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.utils.misc import yes
+
+
+@pytest.mark.parametrize(
+ "response, expected_result",
+ [
+ ["Y", True],
+ ["y", True],
+ ["N", False],
+ ["n", False],
+ ],
+)
+def test_yes(
+ monkeypatch: pytest.MonkeyPatch, expected_result: bool, response: str
+) -> None:
+ """Test yes function."""
+ monkeypatch.setattr("builtins.input", MagicMock(return_value=response))
+ assert yes("some_prompt") == expected_result
diff --git a/tests/mlia/test_utils_proc.py b/tests/mlia/test_utils_proc.py
new file mode 100644
index 0000000..8316ca5
--- /dev/null
+++ b/tests/mlia/test_utils_proc.py
@@ -0,0 +1,149 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module utils/proc."""
+import signal
+import subprocess
+import time
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.utils.proc import CommandExecutor
+from mlia.utils.proc import working_directory
+
+
+class TestCommandExecutor:
+ """Tests for class CommandExecutor."""
+
+ @staticmethod
+ def test_execute() -> None:
+ """Test command execution."""
+ executor = CommandExecutor()
+
+ retcode, stdout, stderr = executor.execute(["echo", "hello world!"])
+ assert retcode == 0
+ assert stdout.decode().strip() == "hello world!"
+ assert stderr.decode() == ""
+
+ @staticmethod
+ def test_submit() -> None:
+ """Test command submittion."""
+ executor = CommandExecutor()
+
+ running_command = executor.submit(["sleep", "10"])
+ assert running_command.is_alive() is True
+ assert running_command.exit_code() is None
+
+ running_command.kill()
+ for _ in range(3):
+ time.sleep(0.5)
+ if not running_command.is_alive():
+ break
+
+ assert running_command.is_alive() is False
+ assert running_command.exit_code() == -9
+
+ with pytest.raises(subprocess.CalledProcessError):
+ executor.execute(["sleep", "-1"])
+
+ @staticmethod
+ @pytest.mark.parametrize("wait", [True, False])
+ def test_stop(wait: bool) -> None:
+ """Test command termination."""
+ executor = CommandExecutor()
+
+ running_command = executor.submit(["sleep", "10"])
+ running_command.stop(wait=wait)
+
+ if wait:
+ assert running_command.is_alive() is False
+
+ @staticmethod
+ def test_unable_to_stop(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test when command could not be stopped."""
+ running_command_mock = MagicMock()
+ running_command_mock.poll.return_value = None
+
+ monkeypatch.setattr(
+ "mlia.utils.proc.subprocess.Popen",
+ MagicMock(return_value=running_command_mock),
+ )
+
+ with pytest.raises(Exception, match="Unable to stop running command"):
+ executor = CommandExecutor()
+ running_command = executor.submit(["sleep", "10"])
+
+ running_command.stop(num_of_attempts=1, interval=0.1)
+
+ running_command_mock.send_signal.assert_called_once_with(signal.SIGINT)
+
+ @staticmethod
+ def test_stop_after_several_attempts(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test when command could be stopped after several attempts."""
+ running_command_mock = MagicMock()
+ running_command_mock.poll.side_effect = [None, 0]
+
+ monkeypatch.setattr(
+ "mlia.utils.proc.subprocess.Popen",
+ MagicMock(return_value=running_command_mock),
+ )
+
+ executor = CommandExecutor()
+ running_command = executor.submit(["sleep", "10"])
+
+ running_command.stop(num_of_attempts=1, interval=0.1)
+ running_command_mock.send_signal.assert_called_once_with(signal.SIGINT)
+
+ @staticmethod
+ def test_send_signal() -> None:
+ """Test sending signal."""
+ executor = CommandExecutor()
+ running_command = executor.submit(["sleep", "10"])
+ running_command.send_signal(signal.SIGINT)
+
+ # wait a bit for a signal processing
+ time.sleep(1)
+
+ assert running_command.is_alive() is False
+ assert running_command.exit_code() == -2
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "redirect_output, expected_output", [[True, "hello\n"], [False, ""]]
+ )
+ def test_wait(
+ capsys: pytest.CaptureFixture, redirect_output: bool, expected_output: str
+ ) -> None:
+ """Test wait completion functionality."""
+ executor = CommandExecutor()
+
+ running_command = executor.submit(["echo", "hello"])
+ running_command.wait(redirect_output=redirect_output)
+
+ out, _ = capsys.readouterr()
+ assert out == expected_output
+
+
+@pytest.mark.parametrize(
+ "should_exist, create_dir",
+ [
+ [True, False],
+ [False, True],
+ ],
+)
+def test_working_directory_context_manager(
+ tmp_path: Path, should_exist: bool, create_dir: bool
+) -> None:
+ """Test working_directory context manager."""
+ prev_wd = Path.cwd()
+
+ working_dir = tmp_path / "work_dir"
+ if should_exist:
+ working_dir.mkdir()
+
+ with working_directory(working_dir, create_dir=create_dir) as current_working_dir:
+ assert current_working_dir.is_dir()
+ assert Path.cwd() == current_working_dir
+
+ assert Path.cwd() == prev_wd
diff --git a/tests/mlia/test_utils_types.py b/tests/mlia/test_utils_types.py
new file mode 100644
index 0000000..4909efe
--- /dev/null
+++ b/tests/mlia/test_utils_types.py
@@ -0,0 +1,77 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the types related utility functions."""
+from typing import Any
+from typing import Iterable
+from typing import Optional
+
+import pytest
+
+from mlia.utils.types import is_list_of
+from mlia.utils.types import is_number
+from mlia.utils.types import only_one_selected
+from mlia.utils.types import parse_int
+
+
+@pytest.mark.parametrize(
+ "value, expected_result",
+ [
+ ["", False],
+ ["abc", False],
+ ["123", True],
+ ["123.1", True],
+ ["-123", True],
+ ["-123.1", True],
+ ["0", True],
+ ["1.e10", True],
+ ],
+)
+def test_is_number(value: str, expected_result: bool) -> None:
+ """Test function is_number."""
+ assert is_number(value) == expected_result
+
+
+@pytest.mark.parametrize(
+ "data, cls, elem_num, expected_result",
+ [
+ [(1, 2), int, 2, True],
+ [[1, 2], int, 2, True],
+ [[1, 2], int, 3, False],
+ [["1", "2", "3"], str, None, True],
+ [["1", "2", "3"], int, None, False],
+ ],
+)
+def test_is_list(
+ data: Any, cls: type, elem_num: Optional[int], expected_result: bool
+) -> None:
+ """Test function is_list."""
+ assert is_list_of(data, cls, elem_num) == expected_result
+
+
+@pytest.mark.parametrize(
+ "options, expected_result",
+ [
+ [[True], True],
+ [[False], False],
+ [[True, True, False, False], False],
+ ],
+)
+def test_only_one_selected(options: Iterable[bool], expected_result: bool) -> None:
+ """Test function only_one_selected."""
+ assert only_one_selected(*options) == expected_result
+
+
+@pytest.mark.parametrize(
+ "value, default, expected_int",
+ [
+ ["1", None, 1],
+ ["abc", 123, 123],
+ [None, None, None],
+ [None, 11, 11],
+ ],
+)
+def test_parse_int(
+ value: Any, default: Optional[int], expected_int: Optional[int]
+) -> None:
+ """Test function parse_int."""
+ assert parse_int(value, default) == expected_int
diff --git a/tests/mlia/utils/__init__.py b/tests/mlia/utils/__init__.py
new file mode 100644
index 0000000..27166ef
--- /dev/null
+++ b/tests/mlia/utils/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test utils module."""
diff --git a/tests/mlia/utils/common.py b/tests/mlia/utils/common.py
new file mode 100644
index 0000000..4313cde
--- /dev/null
+++ b/tests/mlia/utils/common.py
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common test utils module."""
+from typing import Tuple
+
+import numpy as np
+import tensorflow as tf
+
+
+def get_dataset() -> Tuple[np.array, np.array]:
+ """Return sample dataset."""
+ mnist = tf.keras.datasets.mnist
+ (x_train, y_train), _ = mnist.load_data()
+ x_train = x_train / 255.0
+
+ # Use subset of 60000 examples to keep unit test speed fast.
+ x_train = x_train[0:1]
+ y_train = y_train[0:1]
+
+ return x_train, y_train
+
+
+def train_model(model: tf.keras.Model) -> None:
+ """Train model using sample dataset."""
+ num_epochs = 1
+
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
+
+ x_train, y_train = get_dataset()
+
+ model.fit(x_train, y_train, epochs=num_epochs)
diff --git a/tests/mlia/utils/logging.py b/tests/mlia/utils/logging.py
new file mode 100644
index 0000000..d223fb2
--- /dev/null
+++ b/tests/mlia/utils/logging.py
@@ -0,0 +1,13 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils for logging."""
+import logging
+
+
+def clear_loggers() -> None:
+ """Close the log handlers."""
+ for _, logger in logging.Logger.manager.loggerDict.items():
+ if not isinstance(logger, logging.PlaceHolder):
+ for handler in logger.handlers:
+ handler.close()
+ logger.removeHandler(handler)