From c9b4089b3037b5943565d76242d3016b8776f8d2 Mon Sep 17 00:00:00 2001 From: Benjamin Klimczak Date: Tue, 28 Jun 2022 10:29:35 +0100 Subject: MLIA-546 Merge AIET into MLIA Merge the deprecated AIET interface for backend execution into MLIA: - Execute backends directly (without subprocess and the aiet CLI) - Fix issues with the unit tests - Remove src/aiet and tests/aiet - Re-factor code to replace 'aiet' with 'backend' - Adapt and improve unit tests after re-factoring - Remove dependencies that are not needed anymore (click and cloup) Change-Id: I450734c6a3f705ba9afde41862b29e797e511f7c --- setup.cfg | 4 - src/aiet/__init__.py | 7 - src/aiet/backend/__init__.py | 3 - src/aiet/backend/application.py | 187 ---- src/aiet/backend/common.py | 532 --------- src/aiet/backend/config.py | 107 -- src/aiet/backend/controller.py | 134 --- src/aiet/backend/execution.py | 859 --------------- src/aiet/backend/output_parser.py | 176 --- src/aiet/backend/protocol.py | 325 ------ src/aiet/backend/source.py | 209 ---- src/aiet/backend/system.py | 289 ----- src/aiet/backend/tool.py | 109 -- src/aiet/cli/__init__.py | 28 - src/aiet/cli/application.py | 362 ------ src/aiet/cli/common.py | 173 --- src/aiet/cli/completion.py | 72 -- src/aiet/cli/system.py | 122 --- src/aiet/cli/tool.py | 143 --- src/aiet/main.py | 13 - src/aiet/resources/applications/.gitignore | 6 - src/aiet/resources/systems/.gitignore | 6 - src/aiet/resources/tools/vela/aiet-config.json | 73 -- .../resources/tools/vela/aiet-config.json.license | 3 - src/aiet/resources/tools/vela/check_model.py | 75 -- src/aiet/resources/tools/vela/run_vela.py | 65 -- src/aiet/resources/tools/vela/vela.ini | 53 - src/aiet/utils/__init__.py | 3 - src/aiet/utils/fs.py | 116 -- src/aiet/utils/helpers.py | 17 - src/aiet/utils/proc.py | 283 ----- src/mlia/backend/__init__.py | 3 + src/mlia/backend/application.py | 187 ++++ src/mlia/backend/common.py | 532 +++++++++ src/mlia/backend/config.py | 93 ++ src/mlia/backend/controller.py | 134 +++ src/mlia/backend/execution.py | 779 +++++++++++++ src/mlia/backend/fs.py | 115 ++ src/mlia/backend/manager.py | 447 ++++++++ src/mlia/backend/output_parser.py | 176 +++ src/mlia/backend/proc.py | 283 +++++ src/mlia/backend/protocol.py | 325 ++++++ src/mlia/backend/source.py | 209 ++++ src/mlia/backend/system.py | 289 +++++ src/mlia/cli/config.py | 6 +- src/mlia/devices/ethosu/performance.py | 37 +- .../resources/aiet/applications/APPLICATIONS.txt | 5 +- src/mlia/resources/aiet/systems/SYSTEMS.txt | 3 +- .../resources/backends/applications/.gitignore | 6 + src/mlia/resources/backends/systems/.gitignore | 6 + src/mlia/tools/aiet_wrapper.py | 435 -------- src/mlia/tools/metadata/corstone.py | 61 +- src/mlia/utils/proc.py | 20 +- tests/aiet/__init__.py | 3 - tests/aiet/conftest.py | 139 --- tests/aiet/test_backend_application.py | 452 -------- tests/aiet/test_backend_common.py | 486 --------- tests/aiet/test_backend_controller.py | 160 --- tests/aiet/test_backend_execution.py | 526 --------- tests/aiet/test_backend_output_parser.py | 152 --- tests/aiet/test_backend_protocol.py | 231 ---- tests/aiet/test_backend_source.py | 199 ---- tests/aiet/test_backend_system.py | 536 --------- tests/aiet/test_backend_tool.py | 60 - tests/aiet/test_check_model.py | 162 --- tests/aiet/test_cli.py | 37 - tests/aiet/test_cli_application.py | 1153 -------------------- tests/aiet/test_cli_common.py | 37 - tests/aiet/test_cli_system.py | 240 ---- tests/aiet/test_cli_tool.py | 333 ------ tests/aiet/test_main.py | 16 - tests/aiet/test_resources/application_config.json | 96 -- .../test_resources/application_config.json.license | 3 - .../applications/application1/aiet-config.json | 30 - .../application1/aiet-config.json.license | 3 - .../applications/application2/aiet-config.json | 30 - .../application2/aiet-config.json.license | 3 - .../applications/application3/readme.txt | 4 - .../applications/application4/aiet-config.json | 35 - .../application4/aiet-config.json.license | 3 - .../applications/application4/hello_app.txt | 4 - .../applications/application5/aiet-config.json | 160 --- .../application5/aiet-config.json.license | 3 - tests/aiet/test_resources/applications/readme.txt | 4 - tests/aiet/test_resources/hello_world.json | 54 - tests/aiet/test_resources/hello_world.json.license | 3 - tests/aiet/test_resources/scripts/test_backend_run | 8 - .../scripts/test_backend_run_script.sh | 8 - .../systems/system1/aiet-config.json | 35 - .../systems/system1/aiet-config.json.license | 3 - .../systems/system1/system_artifact/dummy.txt | 2 - .../systems/system2/aiet-config.json | 32 - .../systems/system2/aiet-config.json.license | 3 - .../aiet/test_resources/systems/system3/readme.txt | 4 - .../systems/system4/aiet-config.json | 19 - .../systems/system4/aiet-config.json.license | 3 - .../test_resources/tools/tool1/aiet-config.json | 30 - .../tools/tool1/aiet-config.json.license | 3 - .../test_resources/tools/tool2/aiet-config.json | 26 - .../tools/tool2/aiet-config.json.license | 3 - .../application_with_empty_config/aiet-config.json | 1 - .../aiet-config.json.license | 3 - .../application_with_valid_config/aiet-config.json | 35 - .../aiet-config.json.license | 3 - .../aiet-config.json | 2 - .../aiet-config.json.license | 3 - .../aiet-config.json | 30 - .../aiet-config.json.license | 3 - .../aiet-config.json | 35 - .../aiet-config.json.license | 3 - .../system_with_empty_config/aiet-config.json | 1 - .../aiet-config.json.license | 3 - .../system_with_valid_config/aiet-config.json | 16 - .../aiet-config.json.license | 3 - tests/aiet/test_run_vela_script.py | 152 --- tests/aiet/test_utils_fs.py | 168 --- tests/aiet/test_utils_helpers.py | 27 - tests/aiet/test_utils_proc.py | 272 ----- tests/mlia/conftest.py | 91 ++ tests/mlia/test_backend_application.py | 460 ++++++++ tests/mlia/test_backend_common.py | 486 +++++++++ tests/mlia/test_backend_controller.py | 160 +++ tests/mlia/test_backend_execution.py | 518 +++++++++ tests/mlia/test_backend_fs.py | 168 +++ tests/mlia/test_backend_manager.py | 788 +++++++++++++ tests/mlia/test_backend_output_parser.py | 152 +++ tests/mlia/test_backend_proc.py | 272 +++++ tests/mlia/test_backend_protocol.py | 231 ++++ tests/mlia/test_backend_source.py | 203 ++++ tests/mlia/test_backend_system.py | 541 +++++++++ tests/mlia/test_cli_logging.py | 10 +- tests/mlia/test_devices_ethosu_performance.py | 2 +- tests/mlia/test_resources/application_config.json | 96 ++ .../test_resources/application_config.json.license | 3 + .../applications/application1/aiet-config.json | 30 + .../application1/aiet-config.json.license | 3 + .../applications/application2/aiet-config.json | 30 + .../application2/aiet-config.json.license | 3 + .../backends/applications/application3/readme.txt | 4 + .../applications/application4/aiet-config.json | 36 + .../application4/aiet-config.json.license | 3 + .../applications/application4/hello_app.txt | 4 + .../applications/application5/aiet-config.json | 160 +++ .../application5/aiet-config.json.license | 3 + .../applications/application6/aiet-config.json | 42 + .../application6/aiet-config.json.license | 3 + .../backends/applications/readme.txt | 4 + .../backends/systems/system1/aiet-config.json | 35 + .../systems/system1/aiet-config.json.license | 3 + .../systems/system1/system_artifact/dummy.txt | 2 + .../backends/systems/system2/aiet-config.json | 32 + .../systems/system2/aiet-config.json.license | 3 + .../backends/systems/system3/readme.txt | 4 + .../backends/systems/system4/aiet-config.json | 19 + .../systems/system4/aiet-config.json.license | 3 + .../backends/systems/system6/aiet-config.json | 34 + .../systems/system6/aiet-config.json.license | 3 + tests/mlia/test_resources/hello_world.json | 54 + tests/mlia/test_resources/hello_world.json.license | 3 + tests/mlia/test_resources/scripts/test_backend_run | 8 + .../scripts/test_backend_run_script.sh | 8 + .../application_with_empty_config/aiet-config.json | 1 + .../aiet-config.json.license | 3 + .../application_with_valid_config/aiet-config.json | 35 + .../aiet-config.json.license | 3 + .../aiet-config.json | 2 + .../aiet-config.json.license | 3 + .../aiet-config.json | 30 + .../aiet-config.json.license | 3 + .../aiet-config.json | 35 + .../aiet-config.json.license | 3 + .../system_with_empty_config/aiet-config.json | 1 + .../aiet-config.json.license | 3 + .../system_with_valid_config/aiet-config.json | 16 + .../aiet-config.json.license | 3 + tests/mlia/test_tools_aiet_wrapper.py | 760 ------------- tests/mlia/test_tools_metadata_corstone.py | 90 +- 177 files changed, 8542 insertions(+), 12167 deletions(-) delete mode 100644 src/aiet/__init__.py delete mode 100644 src/aiet/backend/__init__.py delete mode 100644 src/aiet/backend/application.py delete mode 100644 src/aiet/backend/common.py delete mode 100644 src/aiet/backend/config.py delete mode 100644 src/aiet/backend/controller.py delete mode 100644 src/aiet/backend/execution.py delete mode 100644 src/aiet/backend/output_parser.py delete mode 100644 src/aiet/backend/protocol.py delete mode 100644 src/aiet/backend/source.py delete mode 100644 src/aiet/backend/system.py delete mode 100644 src/aiet/backend/tool.py delete mode 100644 src/aiet/cli/__init__.py delete mode 100644 src/aiet/cli/application.py delete mode 100644 src/aiet/cli/common.py delete mode 100644 src/aiet/cli/completion.py delete mode 100644 src/aiet/cli/system.py delete mode 100644 src/aiet/cli/tool.py delete mode 100644 src/aiet/main.py delete mode 100644 src/aiet/resources/applications/.gitignore delete mode 100644 src/aiet/resources/systems/.gitignore delete mode 100644 src/aiet/resources/tools/vela/aiet-config.json delete mode 100644 src/aiet/resources/tools/vela/aiet-config.json.license delete mode 100644 src/aiet/resources/tools/vela/check_model.py delete mode 100644 src/aiet/resources/tools/vela/run_vela.py delete mode 100644 src/aiet/resources/tools/vela/vela.ini delete mode 100644 src/aiet/utils/__init__.py delete mode 100644 src/aiet/utils/fs.py delete mode 100644 src/aiet/utils/helpers.py delete mode 100644 src/aiet/utils/proc.py create mode 100644 src/mlia/backend/__init__.py create mode 100644 src/mlia/backend/application.py create mode 100644 src/mlia/backend/common.py create mode 100644 src/mlia/backend/config.py create mode 100644 src/mlia/backend/controller.py create mode 100644 src/mlia/backend/execution.py create mode 100644 src/mlia/backend/fs.py create mode 100644 src/mlia/backend/manager.py create mode 100644 src/mlia/backend/output_parser.py create mode 100644 src/mlia/backend/proc.py create mode 100644 src/mlia/backend/protocol.py create mode 100644 src/mlia/backend/source.py create mode 100644 src/mlia/backend/system.py create mode 100644 src/mlia/resources/backends/applications/.gitignore create mode 100644 src/mlia/resources/backends/systems/.gitignore delete mode 100644 src/mlia/tools/aiet_wrapper.py delete mode 100644 tests/aiet/__init__.py delete mode 100644 tests/aiet/conftest.py delete mode 100644 tests/aiet/test_backend_application.py delete mode 100644 tests/aiet/test_backend_common.py delete mode 100644 tests/aiet/test_backend_controller.py delete mode 100644 tests/aiet/test_backend_execution.py delete mode 100644 tests/aiet/test_backend_output_parser.py delete mode 100644 tests/aiet/test_backend_protocol.py delete mode 100644 tests/aiet/test_backend_source.py delete mode 100644 tests/aiet/test_backend_system.py delete mode 100644 tests/aiet/test_backend_tool.py delete mode 100644 tests/aiet/test_check_model.py delete mode 100644 tests/aiet/test_cli.py delete mode 100644 tests/aiet/test_cli_application.py delete mode 100644 tests/aiet/test_cli_common.py delete mode 100644 tests/aiet/test_cli_system.py delete mode 100644 tests/aiet/test_cli_tool.py delete mode 100644 tests/aiet/test_main.py delete mode 100644 tests/aiet/test_resources/application_config.json delete mode 100644 tests/aiet/test_resources/application_config.json.license delete mode 100644 tests/aiet/test_resources/applications/application1/aiet-config.json delete mode 100644 tests/aiet/test_resources/applications/application1/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/applications/application2/aiet-config.json delete mode 100644 tests/aiet/test_resources/applications/application2/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/applications/application3/readme.txt delete mode 100644 tests/aiet/test_resources/applications/application4/aiet-config.json delete mode 100644 tests/aiet/test_resources/applications/application4/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/applications/application4/hello_app.txt delete mode 100644 tests/aiet/test_resources/applications/application5/aiet-config.json delete mode 100644 tests/aiet/test_resources/applications/application5/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/applications/readme.txt delete mode 100644 tests/aiet/test_resources/hello_world.json delete mode 100644 tests/aiet/test_resources/hello_world.json.license delete mode 100755 tests/aiet/test_resources/scripts/test_backend_run delete mode 100644 tests/aiet/test_resources/scripts/test_backend_run_script.sh delete mode 100644 tests/aiet/test_resources/systems/system1/aiet-config.json delete mode 100644 tests/aiet/test_resources/systems/system1/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt delete mode 100644 tests/aiet/test_resources/systems/system2/aiet-config.json delete mode 100644 tests/aiet/test_resources/systems/system2/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/systems/system3/readme.txt delete mode 100644 tests/aiet/test_resources/systems/system4/aiet-config.json delete mode 100644 tests/aiet/test_resources/systems/system4/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/tools/tool1/aiet-config.json delete mode 100644 tests/aiet/test_resources/tools/tool1/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/tools/tool2/aiet-config.json delete mode 100644 tests/aiet/test_resources/tools/tool2/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json delete mode 100644 tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json delete mode 100644 tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json delete mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json delete mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json delete mode 100644 tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json delete mode 100644 tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license delete mode 100644 tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json delete mode 100644 tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license delete mode 100644 tests/aiet/test_run_vela_script.py delete mode 100644 tests/aiet/test_utils_fs.py delete mode 100644 tests/aiet/test_utils_helpers.py delete mode 100644 tests/aiet/test_utils_proc.py create mode 100644 tests/mlia/test_backend_application.py create mode 100644 tests/mlia/test_backend_common.py create mode 100644 tests/mlia/test_backend_controller.py create mode 100644 tests/mlia/test_backend_execution.py create mode 100644 tests/mlia/test_backend_fs.py create mode 100644 tests/mlia/test_backend_manager.py create mode 100644 tests/mlia/test_backend_output_parser.py create mode 100644 tests/mlia/test_backend_proc.py create mode 100644 tests/mlia/test_backend_protocol.py create mode 100644 tests/mlia/test_backend_source.py create mode 100644 tests/mlia/test_backend_system.py create mode 100644 tests/mlia/test_resources/application_config.json create mode 100644 tests/mlia/test_resources/application_config.json.license create mode 100644 tests/mlia/test_resources/backends/applications/application1/aiet-config.json create mode 100644 tests/mlia/test_resources/backends/applications/application1/aiet-config.json.license create mode 100644 tests/mlia/test_resources/backends/applications/application2/aiet-config.json create mode 100644 tests/mlia/test_resources/backends/applications/application2/aiet-config.json.license create mode 100644 tests/mlia/test_resources/backends/applications/application3/readme.txt create mode 100644 tests/mlia/test_resources/backends/applications/application4/aiet-config.json create mode 100644 tests/mlia/test_resources/backends/applications/application4/aiet-config.json.license create mode 100644 tests/mlia/test_resources/backends/applications/application4/hello_app.txt create mode 100644 tests/mlia/test_resources/backends/applications/application5/aiet-config.json create mode 100644 tests/mlia/test_resources/backends/applications/application5/aiet-config.json.license create mode 100644 tests/mlia/test_resources/backends/applications/application6/aiet-config.json create mode 100644 tests/mlia/test_resources/backends/applications/application6/aiet-config.json.license create mode 100644 tests/mlia/test_resources/backends/applications/readme.txt create mode 100644 tests/mlia/test_resources/backends/systems/system1/aiet-config.json create mode 100644 tests/mlia/test_resources/backends/systems/system1/aiet-config.json.license create mode 100644 tests/mlia/test_resources/backends/systems/system1/system_artifact/dummy.txt create mode 100644 tests/mlia/test_resources/backends/systems/system2/aiet-config.json create mode 100644 tests/mlia/test_resources/backends/systems/system2/aiet-config.json.license create mode 100644 tests/mlia/test_resources/backends/systems/system3/readme.txt create mode 100644 tests/mlia/test_resources/backends/systems/system4/aiet-config.json create mode 100644 tests/mlia/test_resources/backends/systems/system4/aiet-config.json.license create mode 100644 tests/mlia/test_resources/backends/systems/system6/aiet-config.json create mode 100644 tests/mlia/test_resources/backends/systems/system6/aiet-config.json.license create mode 100644 tests/mlia/test_resources/hello_world.json create mode 100644 tests/mlia/test_resources/hello_world.json.license create mode 100755 tests/mlia/test_resources/scripts/test_backend_run create mode 100644 tests/mlia/test_resources/scripts/test_backend_run_script.sh create mode 100644 tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json create mode 100644 tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json.license create mode 100644 tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json create mode 100644 tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json.license create mode 100644 tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json create mode 100644 tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license create mode 100644 tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json create mode 100644 tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license create mode 100644 tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json create mode 100644 tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license create mode 100644 tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json create mode 100644 tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json.license create mode 100644 tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json create mode 100644 tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json.license delete mode 100644 tests/mlia/test_tools_aiet_wrapper.py diff --git a/setup.cfg b/setup.cfg index 9391fa2..3021043 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,12 +34,10 @@ install_requires = ethos-u-vela~=3.3.0 requests rich - click sh paramiko filelock psutil - cloup>=0.12.0 [options.packages.find] where = src @@ -47,8 +45,6 @@ where = src [options.entry_points] console_scripts = mlia=mlia.cli.main:main - aiet=aiet.main:main - run_vela=aiet.resources.tools.vela.run_vela:main [options.extras_require] dev = diff --git a/src/aiet/__init__.py b/src/aiet/__init__.py deleted file mode 100644 index 49304b1..0000000 --- a/src/aiet/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Init of aiet.""" -import pkg_resources - - -__version__ = pkg_resources.get_distribution("mlia").version diff --git a/src/aiet/backend/__init__.py b/src/aiet/backend/__init__.py deleted file mode 100644 index 3d60372..0000000 --- a/src/aiet/backend/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Backend module.""" diff --git a/src/aiet/backend/application.py b/src/aiet/backend/application.py deleted file mode 100644 index f6ef815..0000000 --- a/src/aiet/backend/application.py +++ /dev/null @@ -1,187 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Application backend module.""" -import re -from pathlib import Path -from typing import Any -from typing import cast -from typing import Dict -from typing import List -from typing import Optional - -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import DataPaths -from aiet.backend.common import get_backend_configs -from aiet.backend.common import get_backend_directories -from aiet.backend.common import load_application_or_tool_configs -from aiet.backend.common import load_config -from aiet.backend.common import remove_backend -from aiet.backend.config import ApplicationConfig -from aiet.backend.config import ExtendedApplicationConfig -from aiet.backend.source import create_destination_and_install -from aiet.backend.source import get_source -from aiet.utils.fs import get_resources - - -def get_available_application_directory_names() -> List[str]: - """Return a list of directory names for all available applications.""" - return [entry.name for entry in get_backend_directories("applications")] - - -def get_available_applications() -> List["Application"]: - """Return a list with all available applications.""" - available_applications = [] - for config_json in get_backend_configs("applications"): - config_entries = cast(List[ExtendedApplicationConfig], load_config(config_json)) - for config_entry in config_entries: - config_entry["config_location"] = config_json.parent.absolute() - applications = load_applications(config_entry) - available_applications += applications - - return sorted(available_applications, key=lambda application: application.name) - - -def get_application( - application_name: str, system_name: Optional[str] = None -) -> List["Application"]: - """Return a list of application instances with provided name.""" - return [ - application - for application in get_available_applications() - if application.name == application_name - and (not system_name or application.can_run_on(system_name)) - ] - - -def install_application(source_path: Path) -> None: - """Install application.""" - try: - source = get_source(source_path) - config = cast(List[ExtendedApplicationConfig], source.config()) - applications_to_install = [ - s for entry in config for s in load_applications(entry) - ] - except Exception as error: - raise ConfigurationException("Unable to read application definition") from error - - if not applications_to_install: - raise ConfigurationException("No application definition found") - - available_applications = get_available_applications() - already_installed = [ - s for s in applications_to_install if s in available_applications - ] - if already_installed: - names = {application.name for application in already_installed} - raise ConfigurationException( - "Applications [{}] are already installed".format(",".join(names)) - ) - - create_destination_and_install(source, get_resources("applications")) - - -def remove_application(directory_name: str) -> None: - """Remove application directory.""" - remove_backend(directory_name, "applications") - - -def get_unique_application_names(system_name: Optional[str] = None) -> List[str]: - """Extract a list of unique application names of all application available.""" - return list( - set( - application.name - for application in get_available_applications() - if not system_name or application.can_run_on(system_name) - ) - ) - - -class Application(Backend): - """Class for representing a single application component.""" - - def __init__(self, config: ApplicationConfig) -> None: - """Construct a Application instance from a dict.""" - super().__init__(config) - - self.supported_systems = config.get("supported_systems", []) - self.deploy_data = config.get("deploy_data", []) - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Application): - return False - - return ( - super().__eq__(other) - and self.name == other.name - and set(self.supported_systems) == set(other.supported_systems) - ) - - def can_run_on(self, system_name: str) -> bool: - """Check if the application can run on the system passed as argument.""" - return system_name in self.supported_systems - - def get_deploy_data(self) -> List[DataPaths]: - """Validate and return data specified in the config file.""" - if self.config_location is None: - raise ConfigurationException( - "Unable to get application {} config location".format(self.name) - ) - - deploy_data = [] - for item in self.deploy_data: - src, dst = item - src_full_path = self.config_location / src - assert src_full_path.exists(), "{} does not exists".format(src_full_path) - deploy_data.append(DataPaths(src_full_path, dst)) - return deploy_data - - def get_details(self) -> Dict[str, Any]: - """Return dictionary with information about the Application instance.""" - output = { - "type": "application", - "name": self.name, - "description": self.description, - "supported_systems": self.supported_systems, - "commands": self._get_command_details(), - } - - return output - - def remove_unused_params(self) -> None: - """Remove unused params in commands. - - After merging default and system related configuration application - could have parameters that are not being used in commands. They - should be removed. - """ - for command in self.commands.values(): - indexes_or_aliases = [ - m - for cmd_str in command.command_strings - for m in re.findall(r"{user_params:(?P\w+)}", cmd_str) - ] - - only_aliases = all(not item.isnumeric() for item in indexes_or_aliases) - if only_aliases: - used_params = [ - param - for param in command.params - if param.alias in indexes_or_aliases - ] - command.params = used_params - - -def load_applications(config: ExtendedApplicationConfig) -> List[Application]: - """Load application. - - Application configuration could contain different parameters/commands for different - supported systems. For each supported system this function will return separate - Application instance with appropriate configuration. - """ - configs = load_application_or_tool_configs(config, ApplicationConfig) - applications = [Application(cfg) for cfg in configs] - for application in applications: - application.remove_unused_params() - return applications diff --git a/src/aiet/backend/common.py b/src/aiet/backend/common.py deleted file mode 100644 index b887ee7..0000000 --- a/src/aiet/backend/common.py +++ /dev/null @@ -1,532 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain all common functions for the backends.""" -import json -import logging -import re -from abc import ABC -from collections import Counter -from pathlib import Path -from typing import Any -from typing import Callable -from typing import cast -from typing import Dict -from typing import Final -from typing import IO -from typing import Iterable -from typing import List -from typing import Match -from typing import NamedTuple -from typing import Optional -from typing import Pattern -from typing import Tuple -from typing import Type -from typing import Union - -from aiet.backend.config import BackendConfig -from aiet.backend.config import BaseBackendConfig -from aiet.backend.config import NamedExecutionConfig -from aiet.backend.config import UserParamConfig -from aiet.backend.config import UserParamsConfig -from aiet.utils.fs import get_resources -from aiet.utils.fs import remove_resource -from aiet.utils.fs import ResourceType - - -AIET_CONFIG_FILE: Final[str] = "aiet-config.json" - - -class ConfigurationException(Exception): - """Configuration exception.""" - - -def get_backend_config(dir_path: Path) -> Path: - """Get path to backendir configuration file.""" - return dir_path / AIET_CONFIG_FILE - - -def get_backend_configs(resource_type: ResourceType) -> Iterable[Path]: - """Get path to the backend configs for provided resource_type.""" - return ( - get_backend_config(entry) for entry in get_backend_directories(resource_type) - ) - - -def get_backend_directories(resource_type: ResourceType) -> Iterable[Path]: - """Get path to the backend directories for provided resource_type.""" - return ( - entry - for entry in get_resources(resource_type).iterdir() - if is_backend_directory(entry) - ) - - -def is_backend_directory(dir_path: Path) -> bool: - """Check if path is backend's configuration directory.""" - return dir_path.is_dir() and get_backend_config(dir_path).is_file() - - -def remove_backend(directory_name: str, resource_type: ResourceType) -> None: - """Remove backend with provided type and directory_name.""" - if not directory_name: - raise Exception("No directory name provided") - - remove_resource(directory_name, resource_type) - - -def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig: - """Return a loaded json file.""" - if config is None: - raise Exception("Unable to read config") - - if isinstance(config, Path): - with config.open() as json_file: - return cast(BackendConfig, json.load(json_file)) - - return cast(BackendConfig, json.load(config)) - - -def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]: - """Split the parameter string in name and optional value. - - It manages the following cases: - --param=1 -> --param, 1 - --param 1 -> --param, 1 - --flag -> --flag, None - """ - data = re.split(" |=", parameter) - if len(data) == 1: - param_name = data[0] - param_value = None - else: - param_name = " ".join(data[0:-1]) - param_value = data[-1] - return param_name, param_value - - -class DataPaths(NamedTuple): - """DataPaths class.""" - - src: Path - dst: str - - -class Backend(ABC): - """Backend class.""" - - # pylint: disable=too-many-instance-attributes - - def __init__(self, config: BaseBackendConfig): - """Initialize backend.""" - name = config.get("name") - if not name: - raise ConfigurationException("Name is empty") - - self.name = name - self.description = config.get("description", "") - self.config_location = config.get("config_location") - self.variables = config.get("variables", {}) - self.build_dir = config.get("build_dir") - self.lock = config.get("lock", False) - if self.build_dir: - self.build_dir = self._substitute_variables(self.build_dir) - self.annotations = config.get("annotations", {}) - - self._parse_commands_and_params(config) - - def validate_parameter(self, command_name: str, parameter: str) -> bool: - """Validate the parameter string against the application configuration. - - We take the parameter string, extract the parameter name/value and - check them against the current configuration. - """ - param_name, param_value = parse_raw_parameter(parameter) - valid_param_name = valid_param_value = False - - command = self.commands.get(command_name) - if not command: - raise AttributeError("Unknown command: '{}'".format(command_name)) - - # Iterate over all available parameters until we have a match. - for param in command.params: - if self._same_parameter(param_name, param): - valid_param_name = True - # This is a non-empty list - if param.values: - # We check if the value is allowed in the configuration - valid_param_value = param_value in param.values - else: - # In this case we don't validate the value and accept - # whatever we have set. - valid_param_value = True - break - - return valid_param_name and valid_param_value - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Backend): - return False - - return ( - self.name == other.name - and self.description == other.description - and self.commands == other.commands - ) - - def __repr__(self) -> str: - """Represent the Backend instance by its name.""" - return self.name - - def _parse_commands_and_params(self, config: BaseBackendConfig) -> None: - """Parse commands and user parameters.""" - self.commands: Dict[str, Command] = {} - - commands = config.get("commands") - if commands: - params = config.get("user_params") - - for command_name in commands.keys(): - command_params = self._parse_params(params, command_name) - command_strings = [ - self._substitute_variables(cmd) - for cmd in commands.get(command_name, []) - ] - self.commands[command_name] = Command(command_strings, command_params) - - def _substitute_variables(self, str_val: str) -> str: - """Substitute variables in string. - - Variables is being substituted at backend's creation stage because - they could contain references to other params which will be - resolved later. - """ - if not str_val: - return str_val - - var_pattern: Final[Pattern] = re.compile(r"{variables:(?P\w+)}") - - def var_value(match: Match) -> str: - var_name = match["var_name"] - if var_name not in self.variables: - raise ConfigurationException("Unknown variable {}".format(var_name)) - - return self.variables[var_name] - - return var_pattern.sub(var_value, str_val) # type: ignore - - @classmethod - def _parse_params( - cls, params: Optional[UserParamsConfig], command: str - ) -> List["Param"]: - if not params: - return [] - - return [cls._parse_param(p) for p in params.get(command, [])] - - @classmethod - def _parse_param(cls, param: UserParamConfig) -> "Param": - """Parse a single parameter.""" - name = param.get("name") - if name is not None and not name: - raise ConfigurationException("Parameter has an empty 'name' attribute.") - values = param.get("values", None) - default_value = param.get("default_value", None) - description = param.get("description", "") - alias = param.get("alias") - - return Param( - name=name, - description=description, - values=values, - default_value=default_value, - alias=alias, - ) - - def _get_command_details(self) -> Dict: - command_details = { - command_name: command.get_details() - for command_name, command in self.commands.items() - } - return command_details - - def _get_user_param_value( - self, user_params: List[str], param: "Param" - ) -> Optional[str]: - """Get the user-specified value of a parameter.""" - for user_param in user_params: - user_param_name, user_param_value = parse_raw_parameter(user_param) - if user_param_name == param.name: - warn_message = ( - "The direct use of parameter name is deprecated" - " and might be removed in the future.\n" - f"Please use alias '{param.alias}' instead of " - "'{user_param_name}' to provide the parameter." - ) - logging.warning(warn_message) - - if self._same_parameter(user_param_name, param): - return user_param_value - - return None - - @staticmethod - def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool: - """Compare user parameter name with param name or alias.""" - # Strip the "=" sign in the param_name. This is needed just for - # comparison with the parameters passed by the user. - # The equal sign needs to be honoured when re-building the - # parameter back. - param_name = None if not param.name else param.name.rstrip("=") - return user_param_name_or_alias in [param_name, param.alias] - - def resolved_parameters( - self, command_name: str, user_params: List[str] - ) -> List[Tuple[Optional[str], "Param"]]: - """Return list of parameters with values.""" - result: List[Tuple[Optional[str], "Param"]] = [] - command = self.commands.get(command_name) - if not command: - return result - - for param in command.params: - value = self._get_user_param_value(user_params, param) - if not value: - value = param.default_value - result.append((value, param)) - - return result - - def build_command( - self, - command_name: str, - user_params: List[str], - param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str], - ) -> List[str]: - """ - Return a list of executable command strings. - - Given a command and associated parameters, returns a list of executable command - strings. - """ - command = self.commands.get(command_name) - if not command: - raise ConfigurationException( - "Command '{}' could not be found.".format(command_name) - ) - - commands_to_run = [] - - params_values = self.resolved_parameters(command_name, user_params) - for cmd_str in command.command_strings: - cmd_str = resolve_all_parameters( - cmd_str, param_resolver, command_name, params_values - ) - commands_to_run.append(cmd_str) - - return commands_to_run - - -class Param: - """Class for representing a generic application parameter.""" - - def __init__( # pylint: disable=too-many-arguments - self, - name: Optional[str], - description: str, - values: Optional[List[str]] = None, - default_value: Optional[str] = None, - alias: Optional[str] = None, - ) -> None: - """Construct a Param instance.""" - if not name and not alias: - raise ConfigurationException( - "Either name, alias or both must be set to identify a parameter." - ) - self.name = name - self.values = values - self.description = description - self.default_value = default_value - self.alias = alias - - def get_details(self) -> Dict: - """Return a dictionary with all relevant information of a Param.""" - return {key: value for key, value in self.__dict__.items() if value} - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Param): - return False - - return ( - self.name == other.name - and self.values == other.values - and self.default_value == other.default_value - and self.description == other.description - ) - - -class Command: - """Class for representing a command.""" - - def __init__( - self, command_strings: List[str], params: Optional[List[Param]] = None - ) -> None: - """Construct a Command instance.""" - self.command_strings = command_strings - - if params: - self.params = params - else: - self.params = [] - - self._validate() - - def _validate(self) -> None: - """Validate command.""" - if not self.params: - return - - aliases = [param.alias for param in self.params if param.alias is not None] - repeated_aliases = [ - alias for alias, count in Counter(aliases).items() if count > 1 - ] - - if repeated_aliases: - raise ConfigurationException( - "Non unique aliases {}".format(", ".join(repeated_aliases)) - ) - - both_name_and_alias = [ - param.name - for param in self.params - if param.name in aliases and param.name != param.alias - ] - if both_name_and_alias: - raise ConfigurationException( - "Aliases {} could not be used as parameter name".format( - ", ".join(both_name_and_alias) - ) - ) - - def get_details(self) -> Dict: - """Return a dictionary with all relevant information of a Command.""" - output = { - "command_strings": self.command_strings, - "user_params": [param.get_details() for param in self.params], - } - return output - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Command): - return False - - return ( - self.command_strings == other.command_strings - and self.params == other.params - ) - - -def resolve_all_parameters( - str_val: str, - param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str], - command_name: Optional[str] = None, - params_values: Optional[List[Tuple[Optional[str], Param]]] = None, -) -> str: - """Resolve all parameters in the string.""" - if not str_val: - return str_val - - param_pattern: Final[Pattern] = re.compile(r"{(?P[\w.:]+)}") - while param_pattern.findall(str_val): - str_val = param_pattern.sub( - lambda m: param_resolver( - m["param_name"], command_name or "", params_values or [] - ), - str_val, - ) - return str_val - - -def load_application_or_tool_configs( - config: Any, - config_type: Type[Any], - is_system_required: bool = True, -) -> Any: - """Get one config for each system supported by the application/tool. - - The configuration could contain different parameters/commands for different - supported systems. For each supported system this function will return separate - config with appropriate configuration. - """ - merged_configs = [] - supported_systems: Optional[List[NamedExecutionConfig]] = config.get( - "supported_systems" - ) - if not supported_systems: - if is_system_required: - raise ConfigurationException("No supported systems definition provided") - # Create an empty system to be used in the parsing below - supported_systems = [cast(NamedExecutionConfig, {})] - - default_user_params = config.get("user_params", {}) - - def merge_config(system: NamedExecutionConfig) -> Any: - system_name = system.get("name") - if not system_name and is_system_required: - raise ConfigurationException( - "Unable to read supported system definition, name is missed" - ) - - merged_config = config_type(**config) - merged_config["supported_systems"] = [system_name] if system_name else [] - # merge default configuration and specific to the system - merged_config["commands"] = { - **config.get("commands", {}), - **system.get("commands", {}), - } - - params = {} - tool_user_params = system.get("user_params", {}) - command_names = tool_user_params.keys() | default_user_params.keys() - for command_name in command_names: - if command_name not in merged_config["commands"]: - continue - - params_default = default_user_params.get(command_name, []) - params_tool = tool_user_params.get(command_name, []) - if not params_default or not params_tool: - params[command_name] = params_tool or params_default - if params_default and params_tool: - if any(not p.get("alias") for p in params_default): - raise ConfigurationException( - "Default parameters for command {} should have aliases".format( - command_name - ) - ) - if any(not p.get("alias") for p in params_tool): - raise ConfigurationException( - "{} parameters for command {} should have aliases".format( - system_name, command_name - ) - ) - - merged_by_alias = { - **{p.get("alias"): p for p in params_default}, - **{p.get("alias"): p for p in params_tool}, - } - params[command_name] = list(merged_by_alias.values()) - - merged_config["user_params"] = params - merged_config["build_dir"] = system.get("build_dir", config.get("build_dir")) - merged_config["lock"] = system.get("lock", config.get("lock", False)) - merged_config["variables"] = { - **config.get("variables", {}), - **system.get("variables", {}), - } - return merged_config - - merged_configs = [merge_config(system) for system in supported_systems] - - return merged_configs diff --git a/src/aiet/backend/config.py b/src/aiet/backend/config.py deleted file mode 100644 index dd42012..0000000 --- a/src/aiet/backend/config.py +++ /dev/null @@ -1,107 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain definition of backend configuration.""" -from pathlib import Path -from typing import Dict -from typing import List -from typing import Literal -from typing import Optional -from typing import Tuple -from typing import TypedDict -from typing import Union - - -class UserParamConfig(TypedDict, total=False): - """User parameter configuration.""" - - name: Optional[str] - default_value: str - values: List[str] - description: str - alias: str - - -UserParamsConfig = Dict[str, List[UserParamConfig]] - - -class ExecutionConfig(TypedDict, total=False): - """Execution configuration.""" - - commands: Dict[str, List[str]] - user_params: UserParamsConfig - build_dir: str - variables: Dict[str, str] - lock: bool - - -class NamedExecutionConfig(ExecutionConfig): - """Execution configuration with name.""" - - name: str - - -class BaseBackendConfig(ExecutionConfig, total=False): - """Base backend configuration.""" - - name: str - description: str - config_location: Path - annotations: Dict[str, Union[str, List[str]]] - - -class ApplicationConfig(BaseBackendConfig, total=False): - """Application configuration.""" - - supported_systems: List[str] - deploy_data: List[Tuple[str, str]] - - -class ExtendedApplicationConfig(BaseBackendConfig, total=False): - """Extended application configuration.""" - - supported_systems: List[NamedExecutionConfig] - deploy_data: List[Tuple[str, str]] - - -class ProtocolConfig(TypedDict, total=False): - """Protocol config.""" - - protocol: Literal["local", "ssh"] - - -class SSHConfig(ProtocolConfig, total=False): - """SSH configuration.""" - - username: str - password: str - hostname: str - port: str - - -class LocalProtocolConfig(ProtocolConfig, total=False): - """Local protocol config.""" - - -class SystemConfig(BaseBackendConfig, total=False): - """System configuration.""" - - data_transfer: Union[SSHConfig, LocalProtocolConfig] - reporting: Dict[str, Dict] - - -class ToolConfig(BaseBackendConfig, total=False): - """Tool configuration.""" - - supported_systems: List[str] - - -class ExtendedToolConfig(BaseBackendConfig, total=False): - """Extended tool configuration.""" - - supported_systems: List[NamedExecutionConfig] - - -BackendItemConfig = Union[ApplicationConfig, SystemConfig, ToolConfig] -BackendConfig = Union[ - List[ExtendedApplicationConfig], List[SystemConfig], List[ToolConfig] -] diff --git a/src/aiet/backend/controller.py b/src/aiet/backend/controller.py deleted file mode 100644 index 2650902..0000000 --- a/src/aiet/backend/controller.py +++ /dev/null @@ -1,134 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Controller backend module.""" -import time -from pathlib import Path -from typing import List -from typing import Optional -from typing import Tuple - -import psutil -import sh - -from aiet.backend.common import ConfigurationException -from aiet.utils.fs import read_file_as_string -from aiet.utils.proc import execute_command -from aiet.utils.proc import get_stdout_stderr_paths -from aiet.utils.proc import read_process_info -from aiet.utils.proc import save_process_info -from aiet.utils.proc import terminate_command -from aiet.utils.proc import terminate_external_process - - -class SystemController: - """System controller class.""" - - def __init__(self) -> None: - """Create new instance of service controller.""" - self.cmd: Optional[sh.RunningCommand] = None - self.out_path: Optional[Path] = None - self.err_path: Optional[Path] = None - - def before_start(self) -> None: - """Run actions before system start.""" - - def after_start(self) -> None: - """Run actions after system start.""" - - def start(self, commands: List[str], cwd: Path) -> None: - """Start system.""" - if not isinstance(cwd, Path) or not cwd.is_dir(): - raise ConfigurationException("Wrong working directory {}".format(cwd)) - - if len(commands) != 1: - raise ConfigurationException("System should have only one command to run") - - startup_command = commands[0] - if not startup_command: - raise ConfigurationException("No startup command provided") - - self.before_start() - - self.out_path, self.err_path = get_stdout_stderr_paths(startup_command) - - self.cmd = execute_command( - startup_command, - cwd, - bg=True, - out=str(self.out_path), - err=str(self.err_path), - ) - - self.after_start() - - def stop( - self, wait: bool = False, wait_period: float = 0.5, number_of_attempts: int = 20 - ) -> None: - """Stop system.""" - if self.cmd is not None and self.is_running(): - terminate_command(self.cmd, wait, wait_period, number_of_attempts) - - def is_running(self) -> bool: - """Check if underlying process is running.""" - return self.cmd is not None and self.cmd.is_alive() - - def get_output(self) -> Tuple[str, str]: - """Return application output.""" - if self.cmd is None or self.out_path is None or self.err_path is None: - return ("", "") - - return (read_file_as_string(self.out_path), read_file_as_string(self.err_path)) - - -class SystemControllerSingleInstance(SystemController): - """System controller with support of system's single instance.""" - - def __init__(self, pid_file_path: Optional[Path] = None) -> None: - """Create new instance of the service controller.""" - super().__init__() - self.pid_file_path = pid_file_path - - def before_start(self) -> None: - """Run actions before system start.""" - self._check_if_previous_instance_is_running() - - def after_start(self) -> None: - """Run actions after system start.""" - self._save_process_info() - - def _check_if_previous_instance_is_running(self) -> None: - """Check if another instance of the system is running.""" - process_info = read_process_info(self._pid_file()) - - for item in process_info: - try: - process = psutil.Process(item.pid) - same_process = ( - process.name() == item.name - and process.exe() == item.executable - and process.cwd() == item.cwd - ) - if same_process: - print( - "Stopping previous instance of the system [{}]".format(item.pid) - ) - terminate_external_process(process) - except psutil.NoSuchProcess: - pass - - def _save_process_info(self, wait_period: float = 2) -> None: - """Save information about system's processes.""" - if self.cmd is None or not self.is_running(): - return - - # give some time for the system to start - time.sleep(wait_period) - - save_process_info(self.cmd.process.pid, self._pid_file()) - - def _pid_file(self) -> Path: - """Return path to file which is used for saving process info.""" - if not self.pid_file_path: - raise Exception("No pid file path presented") - - return self.pid_file_path diff --git a/src/aiet/backend/execution.py b/src/aiet/backend/execution.py deleted file mode 100644 index 1653ee2..0000000 --- a/src/aiet/backend/execution.py +++ /dev/null @@ -1,859 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Application execution module.""" -import itertools -import json -import random -import re -import string -import sys -import time -import warnings -from collections import defaultdict -from contextlib import contextmanager -from contextlib import ExitStack -from pathlib import Path -from typing import Any -from typing import Callable -from typing import cast -from typing import ContextManager -from typing import Dict -from typing import Generator -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypedDict -from typing import Union - -from filelock import FileLock -from filelock import Timeout - -from aiet.backend.application import Application -from aiet.backend.application import get_application -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import DataPaths -from aiet.backend.common import Param -from aiet.backend.common import parse_raw_parameter -from aiet.backend.common import resolve_all_parameters -from aiet.backend.output_parser import Base64OutputParser -from aiet.backend.output_parser import OutputParser -from aiet.backend.output_parser import RegexOutputParser -from aiet.backend.system import ControlledSystem -from aiet.backend.system import get_system -from aiet.backend.system import StandaloneSystem -from aiet.backend.system import System -from aiet.backend.tool import get_tool -from aiet.backend.tool import Tool -from aiet.utils.fs import recreate_directory -from aiet.utils.fs import remove_directory -from aiet.utils.fs import valid_for_filename -from aiet.utils.proc import run_and_wait - - -class AnotherInstanceIsRunningException(Exception): - """Concurrent execution error.""" - - -class ConnectionException(Exception): - """Connection exception.""" - - -class ExecutionParams(TypedDict, total=False): - """Execution parameters.""" - - disable_locking: bool - unique_build_dir: bool - - -class ExecutionContext: - """Command execution context.""" - - # pylint: disable=too-many-arguments,too-many-instance-attributes - def __init__( - self, - app: Union[Application, Tool], - app_params: List[str], - system: Optional[System], - system_params: List[str], - custom_deploy_data: Optional[List[DataPaths]] = None, - execution_params: Optional[ExecutionParams] = None, - report_file: Optional[Path] = None, - ): - """Init execution context.""" - self.app = app - self.app_params = app_params - self.custom_deploy_data = custom_deploy_data or [] - self.system = system - self.system_params = system_params - self.execution_params = execution_params or ExecutionParams() - self.report_file = report_file - - self.reporter: Optional[Reporter] - if self.report_file: - # Create reporter with output parsers - parsers: List[OutputParser] = [] - if system and system.reporting: - # Add RegexOutputParser, if it is configured in the system - parsers.append(RegexOutputParser("system", system.reporting["regex"])) - # Add Base64 parser for applications - parsers.append(Base64OutputParser("application")) - self.reporter = Reporter(parsers=parsers) - else: - self.reporter = None # No reporter needed. - - self.param_resolver = ParamResolver(self) - self._resolved_build_dir: Optional[Path] = None - - @property - def is_deploy_needed(self) -> bool: - """Check if application requires data deployment.""" - if isinstance(self.app, Application): - return ( - len(self.app.get_deploy_data()) > 0 or len(self.custom_deploy_data) > 0 - ) - return False - - @property - def is_locking_required(self) -> bool: - """Return true if any form of locking required.""" - return not self._disable_locking() and ( - self.app.lock or (self.system is not None and self.system.lock) - ) - - @property - def is_build_required(self) -> bool: - """Return true if application build required.""" - return "build" in self.app.commands - - @property - def is_unique_build_dir_required(self) -> bool: - """Return true if unique build dir required.""" - return self.execution_params.get("unique_build_dir", False) - - def build_dir(self) -> Path: - """Return resolved application build dir.""" - if self._resolved_build_dir is not None: - return self._resolved_build_dir - - if ( - not isinstance(self.app.config_location, Path) - or not self.app.config_location.is_dir() - ): - raise ConfigurationException( - "Application {} has wrong config location".format(self.app.name) - ) - - _build_dir = self.app.build_dir - if _build_dir: - _build_dir = resolve_all_parameters(_build_dir, self.param_resolver) - - if not _build_dir: - raise ConfigurationException( - "No build directory defined for the app {}".format(self.app.name) - ) - - if self.is_unique_build_dir_required: - random_suffix = "".join( - random.choices(string.ascii_lowercase + string.digits, k=7) - ) - _build_dir = "{}_{}".format(_build_dir, random_suffix) - - self._resolved_build_dir = self.app.config_location / _build_dir - return self._resolved_build_dir - - def _disable_locking(self) -> bool: - """Return true if locking should be disabled.""" - return self.execution_params.get("disable_locking", False) - - -class ParamResolver: - """Parameter resolver.""" - - def __init__(self, context: ExecutionContext): - """Init parameter resolver.""" - self.ctx = context - - @staticmethod - def resolve_user_params( - cmd_name: Optional[str], - index_or_alias: str, - resolved_params: Optional[List[Tuple[Optional[str], Param]]], - ) -> str: - """Resolve user params.""" - if not cmd_name or resolved_params is None: - raise ConfigurationException("Unable to resolve user params") - - param_value: Optional[str] = None - param: Optional[Param] = None - - if index_or_alias.isnumeric(): - i = int(index_or_alias) - if i not in range(len(resolved_params)): - raise ConfigurationException( - "Invalid index {} for user params of command {}".format(i, cmd_name) - ) - param_value, param = resolved_params[i] - else: - for val, par in resolved_params: - if par.alias == index_or_alias: - param_value, param = val, par - break - - if param is None: - raise ConfigurationException( - "No user parameter for command '{}' with alias '{}'.".format( - cmd_name, index_or_alias - ) - ) - - if param_value: - # We need to handle to cases of parameters here: - # 1) Optional parameters (non-positional with a name and value) - # 2) Positional parameters (value only, no name needed) - # Default to empty strings for positional arguments - param_name = "" - separator = "" - if param.name is not None: - # A valid param name means we have an optional/non-positional argument: - # The separator is an empty string in case the param_name - # has an equal sign as we have to honour it. - # If the parameter doesn't end with an equal sign then a - # space character is injected to split the parameter name - # and its value - param_name = param.name - separator = "" if param.name.endswith("=") else " " - - return "{param_name}{separator}{param_value}".format( - param_name=param_name, - separator=separator, - param_value=param_value, - ) - - if param.name is None: - raise ConfigurationException( - "Missing user parameter with alias '{}' for command '{}'.".format( - index_or_alias, cmd_name - ) - ) - - return param.name # flag: just return the parameter name - - def resolve_commands_and_params( - self, backend_type: str, cmd_name: str, return_params: bool, index_or_alias: str - ) -> str: - """Resolve command or command's param value.""" - if backend_type == "system": - backend = cast(Backend, self.ctx.system) - backend_params = self.ctx.system_params - else: # Application or Tool backend - backend = cast(Backend, self.ctx.app) - backend_params = self.ctx.app_params - - if cmd_name not in backend.commands: - raise ConfigurationException("Command {} not found".format(cmd_name)) - - if return_params: - params = backend.resolved_parameters(cmd_name, backend_params) - if index_or_alias.isnumeric(): - i = int(index_or_alias) - if i not in range(len(params)): - raise ConfigurationException( - "Invalid parameter index {} for command {}".format(i, cmd_name) - ) - - param_value = params[i][0] - else: - param_value = None - for value, param in params: - if param.alias == index_or_alias: - param_value = value - break - - if not param_value: - raise ConfigurationException( - ( - "No value for parameter with index or alias {} of command {}" - ).format(index_or_alias, cmd_name) - ) - return param_value - - if not index_or_alias.isnumeric(): - raise ConfigurationException("Bad command index {}".format(index_or_alias)) - - i = int(index_or_alias) - commands = backend.build_command(cmd_name, backend_params, self.param_resolver) - if i not in range(len(commands)): - raise ConfigurationException( - "Invalid index {} for command {}".format(i, cmd_name) - ) - - return commands[i] - - def resolve_variables(self, backend_type: str, var_name: str) -> str: - """Resolve variable value.""" - if backend_type == "system": - backend = cast(Backend, self.ctx.system) - else: # Application or Tool backend - backend = cast(Backend, self.ctx.app) - - if var_name not in backend.variables: - raise ConfigurationException("Unknown variable {}".format(var_name)) - - return backend.variables[var_name] - - def param_matcher( - self, - param_name: str, - cmd_name: Optional[str], - resolved_params: Optional[List[Tuple[Optional[str], Param]]], - ) -> str: - """Regexp to resolve a param from the param_name.""" - # this pattern supports parameter names like "application.commands.run:0" and - # "system.commands.run.params:0" - # Note: 'software' is included for backward compatibility. - commands_and_params_match = re.match( - r"(?Papplication|software|tool|system)[.]commands[.]" - r"(?P\w+)" - r"(?P[.]params|)[:]" - r"(?P\w+)", - param_name, - ) - - if commands_and_params_match: - backend_type, cmd_name, return_params, index_or_alias = ( - commands_and_params_match["type"], - commands_and_params_match["name"], - commands_and_params_match["params"], - commands_and_params_match["index_or_alias"], - ) - return self.resolve_commands_and_params( - backend_type, cmd_name, bool(return_params), index_or_alias - ) - - # Note: 'software' is included for backward compatibility. - variables_match = re.match( - r"(?Papplication|software|tool|system)[.]variables:(?P\w+)", - param_name, - ) - if variables_match: - backend_type, var_name = ( - variables_match["type"], - variables_match["var_name"], - ) - return self.resolve_variables(backend_type, var_name) - - user_params_match = re.match(r"user_params:(?P\w+)", param_name) - if user_params_match: - index_or_alias = user_params_match["index_or_alias"] - return self.resolve_user_params(cmd_name, index_or_alias, resolved_params) - - raise ConfigurationException( - "Unable to resolve parameter {}".format(param_name) - ) - - def param_resolver( - self, - param_name: str, - cmd_name: Optional[str] = None, - resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, - ) -> str: - """Resolve parameter value based on current execution context.""" - # Note: 'software.*' is included for backward compatibility. - resolved_param = None - if param_name in ["application.name", "tool.name", "software.name"]: - resolved_param = self.ctx.app.name - elif param_name in [ - "application.description", - "tool.description", - "software.description", - ]: - resolved_param = self.ctx.app.description - elif self.ctx.app.config_location and ( - param_name - in ["application.config_dir", "tool.config_dir", "software.config_dir"] - ): - resolved_param = str(self.ctx.app.config_location.absolute()) - elif self.ctx.app.build_dir and ( - param_name - in ["application.build_dir", "tool.build_dir", "software.build_dir"] - ): - resolved_param = str(self.ctx.build_dir().absolute()) - elif self.ctx.system is not None: - if param_name == "system.name": - resolved_param = self.ctx.system.name - elif param_name == "system.description": - resolved_param = self.ctx.system.description - elif param_name == "system.config_dir" and self.ctx.system.config_location: - resolved_param = str(self.ctx.system.config_location.absolute()) - - if not resolved_param: - resolved_param = self.param_matcher(param_name, cmd_name, resolved_params) - return resolved_param - - def __call__( - self, - param_name: str, - cmd_name: Optional[str] = None, - resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, - ) -> str: - """Resolve provided parameter.""" - return self.param_resolver(param_name, cmd_name, resolved_params) - - -class Reporter: - """Report metrics from the simulation output.""" - - def __init__(self, parsers: Optional[List[OutputParser]] = None) -> None: - """Create an empty reporter (i.e. no parsers registered).""" - self.parsers: List[OutputParser] = parsers if parsers is not None else [] - self._report: Dict[str, Any] = defaultdict(lambda: defaultdict(dict)) - - def parse(self, output: bytearray) -> None: - """Parse output and append parsed metrics to internal report dict.""" - for parser in self.parsers: - # Merge metrics from different parsers (do not overwrite) - self._report[parser.name]["metrics"].update(parser(output)) - - def get_filtered_output(self, output: bytearray) -> bytearray: - """Filter the output according to each parser.""" - for parser in self.parsers: - output = parser.filter_out_parsed_content(output) - return output - - def report(self, ctx: ExecutionContext) -> Dict[str, Any]: - """Add static simulation info to parsed data and return the report.""" - report: Dict[str, Any] = defaultdict(dict) - # Add static simulation info - report.update(self._static_info(ctx)) - # Add metrics parsed from the output - for key, val in self._report.items(): - report[key].update(val) - return report - - @staticmethod - def save(report: Dict[str, Any], report_file: Path) -> None: - """Save the report to a JSON file.""" - with open(report_file, "w", encoding="utf-8") as file: - json.dump(report, file, indent=4) - - @staticmethod - def _compute_all_params(cli_params: List[str], backend: Backend) -> Dict[str, str]: - """ - Build a dict of all parameters, {name:value}. - - Param values taken from command line if specified, defaults otherwise. - """ - # map of params passed from the cli ["p1=v1","p2=v2"] -> {"p1":"v1", "p2":"v2"} - app_params_map = dict(parse_raw_parameter(expr) for expr in cli_params) - - # a map of params declared in the application, with values taken from the CLI, - # defaults otherwise - all_params = { - (p.alias or p.name): app_params_map.get( - cast(str, p.name), cast(str, p.default_value) - ) - for cmd in backend.commands.values() - for p in cmd.params - } - return cast(Dict[str, str], all_params) - - @staticmethod - def _static_info(ctx: ExecutionContext) -> Dict[str, Any]: - """Extract static simulation information from the context.""" - if ctx.system is None: - raise ValueError("No system available to report.") - - info = { - "system": { - "name": ctx.system.name, - "params": Reporter._compute_all_params(ctx.system_params, ctx.system), - }, - "application": { - "name": ctx.app.name, - "params": Reporter._compute_all_params(ctx.app_params, ctx.app), - }, - } - return info - - -def validate_parameters( - backend: Backend, command_names: List[str], params: List[str] -) -> None: - """Check parameters passed to backend.""" - for param in params: - acceptable = any( - backend.validate_parameter(command_name, param) - for command_name in command_names - if command_name in backend.commands - ) - - if not acceptable: - backend_type = "System" if isinstance(backend, System) else "Application" - raise ValueError( - "{} parameter '{}' not valid for command '{}'".format( - backend_type, param, " or ".join(command_names) - ) - ) - - -def get_application_by_name_and_system( - application_name: str, system_name: str -) -> Application: - """Get application.""" - applications = get_application(application_name, system_name) - if not applications: - raise ValueError( - "Application '{}' doesn't support the system '{}'".format( - application_name, system_name - ) - ) - - if len(applications) != 1: - raise ValueError( - "Error during getting application {} for the system {}".format( - application_name, system_name - ) - ) - - return applications[0] - - -def get_application_and_system( - application_name: str, system_name: str -) -> Tuple[Application, System]: - """Return application and system by provided names.""" - system = get_system(system_name) - if not system: - raise ValueError("System {} is not found".format(system_name)) - - application = get_application_by_name_and_system(application_name, system_name) - - return application, system - - -def execute_application_command( # pylint: disable=too-many-arguments - command_name: str, - application_name: str, - application_params: List[str], - system_name: str, - system_params: List[str], - custom_deploy_data: List[DataPaths], -) -> None: - """Execute application command. - - .. deprecated:: 21.12 - """ - warnings.warn( - "Use 'run_application()' instead. Use of 'execute_application_command()' is " - "deprecated and might be removed in a future release.", - DeprecationWarning, - ) - - if command_name not in ["build", "run"]: - raise ConfigurationException("Unsupported command {}".format(command_name)) - - application, system = get_application_and_system(application_name, system_name) - validate_parameters(application, [command_name], application_params) - validate_parameters(system, [command_name], system_params) - - ctx = ExecutionContext( - app=application, - app_params=application_params, - system=system, - system_params=system_params, - custom_deploy_data=custom_deploy_data, - ) - - if command_name == "run": - execute_application_command_run(ctx) - else: - execute_application_command_build(ctx) - - -# pylint: disable=too-many-arguments -def run_application( - application_name: str, - application_params: List[str], - system_name: str, - system_params: List[str], - custom_deploy_data: List[DataPaths], - report_file: Optional[Path] = None, -) -> None: - """Run application on the provided system.""" - application, system = get_application_and_system(application_name, system_name) - validate_parameters(application, ["build", "run"], application_params) - validate_parameters(system, ["build", "run"], system_params) - - execution_params = ExecutionParams() - if isinstance(system, StandaloneSystem): - execution_params["disable_locking"] = True - execution_params["unique_build_dir"] = True - - ctx = ExecutionContext( - app=application, - app_params=application_params, - system=system, - system_params=system_params, - custom_deploy_data=custom_deploy_data, - execution_params=execution_params, - report_file=report_file, - ) - - with build_dir_manager(ctx): - if ctx.is_build_required: - execute_application_command_build(ctx) - - execute_application_command_run(ctx) - - -def execute_application_command_build(ctx: ExecutionContext) -> None: - """Execute application command 'build'.""" - with ExitStack() as context_stack: - for manager in get_context_managers("build", ctx): - context_stack.enter_context(manager(ctx)) - - build_dir = ctx.build_dir() - recreate_directory(build_dir) - - build_commands = ctx.app.build_command( - "build", ctx.app_params, ctx.param_resolver - ) - execute_commands_locally(build_commands, build_dir) - - -def execute_commands_locally(commands: List[str], cwd: Path) -> None: - """Execute list of commands locally.""" - for command in commands: - print("Running: {}".format(command)) - run_and_wait( - command, cwd, terminate_on_error=True, out=sys.stdout, err=sys.stderr - ) - - -def execute_application_command_run(ctx: ExecutionContext) -> None: - """Execute application command.""" - assert ctx.system is not None, "System must be provided." - if ctx.is_deploy_needed and not ctx.system.supports_deploy: - raise ConfigurationException( - "System {} does not support data deploy".format(ctx.system.name) - ) - - with ExitStack() as context_stack: - for manager in get_context_managers("run", ctx): - context_stack.enter_context(manager(ctx)) - - print("Generating commands to execute") - commands_to_run = build_run_commands(ctx) - - if ctx.system.connectable: - establish_connection(ctx) - - if ctx.system.supports_deploy: - deploy_data(ctx) - - for command in commands_to_run: - print("Running: {}".format(command)) - exit_code, std_output, std_err = ctx.system.run(command) - - if exit_code != 0: - print("Application exited with exit code {}".format(exit_code)) - - if ctx.reporter: - ctx.reporter.parse(std_output) - std_output = ctx.reporter.get_filtered_output(std_output) - - print(std_output.decode("utf8"), end="") - print(std_err.decode("utf8"), end="") - - if ctx.reporter: - report = ctx.reporter.report(ctx) - ctx.reporter.save(report, cast(Path, ctx.report_file)) - - -def establish_connection( - ctx: ExecutionContext, retries: int = 90, interval: float = 15.0 -) -> None: - """Establish connection with the system.""" - assert ctx.system is not None, "System is required." - host, port = ctx.system.connection_details() - print( - "Trying to establish connection with '{}:{}' - " - "{} retries every {} seconds ".format(host, port, retries, interval), - end="", - ) - - try: - for _ in range(retries): - print(".", end="", flush=True) - - if ctx.system.establish_connection(): - break - - if isinstance(ctx.system, ControlledSystem) and not ctx.system.is_running(): - print( - "\n\n---------- {} execution failed ----------".format( - ctx.system.name - ) - ) - stdout, stderr = ctx.system.get_output() - print(stdout) - print(stderr) - - raise Exception("System is not running") - - wait(interval) - else: - raise ConnectionException("Couldn't connect to '{}:{}'.".format(host, port)) - finally: - print() - - -def wait(interval: float) -> None: - """Wait for a period of time.""" - time.sleep(interval) - - -def deploy_data(ctx: ExecutionContext) -> None: - """Deploy data to the system.""" - if isinstance(ctx.app, Application): - # Only application can deploy data (tools can not) - assert ctx.system is not None, "System is required." - for item in itertools.chain(ctx.app.get_deploy_data(), ctx.custom_deploy_data): - print("Deploying {} onto {}".format(item.src, item.dst)) - ctx.system.deploy(item.src, item.dst) - - -def build_run_commands(ctx: ExecutionContext) -> List[str]: - """Build commands to run application.""" - if isinstance(ctx.system, StandaloneSystem): - return ctx.system.build_command("run", ctx.system_params, ctx.param_resolver) - - return ctx.app.build_command("run", ctx.app_params, ctx.param_resolver) - - -@contextmanager -def controlled_system_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Context manager used for system initialisation before run.""" - system = cast(ControlledSystem, ctx.system) - commands = system.build_command("run", ctx.system_params, ctx.param_resolver) - pid_file_path: Optional[Path] = None - if ctx.is_locking_required: - file_lock_path = get_file_lock_path(ctx) - pid_file_path = file_lock_path.parent / "{}.pid".format(file_lock_path.stem) - - system.start(commands, ctx.is_locking_required, pid_file_path) - try: - yield - finally: - print("Shutting down sequence...") - print("Stopping {}... (It could take few seconds)".format(system.name)) - system.stop(wait=True) - print("{} stopped successfully.".format(system.name)) - - -@contextmanager -def lock_execution_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Lock execution manager.""" - file_lock_path = get_file_lock_path(ctx) - file_lock = FileLock(str(file_lock_path)) - - try: - file_lock.acquire(timeout=1) - except Timeout as error: - raise AnotherInstanceIsRunningException() from error - - try: - yield - finally: - file_lock.release() - - -def get_file_lock_path(ctx: ExecutionContext, lock_dir: Path = Path("/tmp")) -> Path: - """Get file lock path.""" - lock_modules = [] - if ctx.app.lock: - lock_modules.append(ctx.app.name) - if ctx.system is not None and ctx.system.lock: - lock_modules.append(ctx.system.name) - lock_filename = "" - if lock_modules: - lock_filename = "_".join(["middleware"] + lock_modules) + ".lock" - - if lock_filename: - lock_filename = resolve_all_parameters(lock_filename, ctx.param_resolver) - lock_filename = valid_for_filename(lock_filename) - - if not lock_filename: - raise ConfigurationException("No filename for lock provided") - - if not isinstance(lock_dir, Path) or not lock_dir.is_dir(): - raise ConfigurationException( - "Invalid directory {} for lock files provided".format(lock_dir) - ) - - return lock_dir / lock_filename - - -@contextmanager -def build_dir_manager(ctx: ExecutionContext) -> Generator[None, None, None]: - """Build directory manager.""" - try: - yield - finally: - if ( - ctx.is_build_required - and ctx.is_unique_build_dir_required - and ctx.build_dir().is_dir() - ): - remove_directory(ctx.build_dir()) - - -def get_context_managers( - command_name: str, ctx: ExecutionContext -) -> Sequence[Callable[[ExecutionContext], ContextManager[None]]]: - """Get context manager for the system.""" - managers = [] - - if ctx.is_locking_required: - managers.append(lock_execution_manager) - - if command_name == "run": - if isinstance(ctx.system, ControlledSystem): - managers.append(controlled_system_manager) - - return managers - - -def get_tool_by_system(tool_name: str, system_name: Optional[str]) -> Tool: - """Return tool (optionally by provided system name.""" - tools = get_tool(tool_name, system_name) - if not tools: - raise ConfigurationException( - "Tool '{}' not found or doesn't support the system '{}'".format( - tool_name, system_name - ) - ) - if len(tools) != 1: - raise ConfigurationException( - "Please specify the system for tool {}.".format(tool_name) - ) - tool = tools[0] - - return tool - - -def execute_tool_command( - tool_name: str, - tool_params: List[str], - system_name: Optional[str] = None, -) -> None: - """Execute the tool command locally calling the 'run' command.""" - tool = get_tool_by_system(tool_name, system_name) - ctx = ExecutionContext( - app=tool, app_params=tool_params, system=None, system_params=[] - ) - commands = tool.build_command("run", tool_params, ctx.param_resolver) - - execute_commands_locally(commands, Path.cwd()) diff --git a/src/aiet/backend/output_parser.py b/src/aiet/backend/output_parser.py deleted file mode 100644 index 111772a..0000000 --- a/src/aiet/backend/output_parser.py +++ /dev/null @@ -1,176 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Definition of output parsers (including base class OutputParser).""" -import base64 -import json -import re -from abc import ABC -from abc import abstractmethod -from typing import Any -from typing import Dict -from typing import Union - - -class OutputParser(ABC): - """Abstract base class for output parsers.""" - - def __init__(self, name: str) -> None: - """Set up the name of the parser.""" - super().__init__() - self.name = name - - @abstractmethod - def __call__(self, output: bytearray) -> Dict[str, Any]: - """Parse the output and return a map of names to metrics.""" - return {} - - # pylint: disable=no-self-use - def filter_out_parsed_content(self, output: bytearray) -> bytearray: - """ - Filter out the parsed content from the output. - - Does nothing by default. Can be overridden in subclasses. - """ - return output - - -class RegexOutputParser(OutputParser): - """Parser of standard output data using regular expressions.""" - - _TYPE_MAP = {"str": str, "float": float, "int": int} - - def __init__( - self, - name: str, - regex_config: Dict[str, Dict[str, str]], - ) -> None: - """ - Set up the parser with the regular expressions. - - The regex_config is mapping from a name to a dict with keys 'pattern' - and 'type': - - The 'pattern' holds the regular expression that must contain exactly - one capturing parenthesis - - The 'type' can be one of ['str', 'float', 'int']. - - Example: - ``` - {"Metric1": {"pattern": ".*= *(.*)", "type": "str"}} - ``` - - The different regular expressions from the config are combined using - non-capturing parenthesis, i.e. regular expressions must not overlap - if more than one match per line is expected. - """ - super().__init__(name) - - self._verify_config(regex_config) - self._regex_cfg = regex_config - - # Compile regular expression to match in the output - self._regex = re.compile( - "|".join("(?:{0})".format(x["pattern"]) for x in self._regex_cfg.values()) - ) - - def __call__(self, output: bytearray) -> Dict[str, Union[str, float, int]]: - """ - Parse the output and return a map of names to metrics. - - Example: - Assuming a regex_config as used as example in `__init__()` and the - following output: - ``` - Simulation finished: - SIMULATION_STATUS = SUCCESS - Simulation DONE - ``` - Then calling the parser should return the following dict: - ``` - { - "Metric1": "SUCCESS" - } - ``` - """ - metrics = {} - output_str = output.decode("utf-8") - results = self._regex.findall(output_str) - for line_result in results: - for idx, (name, cfg) in enumerate(self._regex_cfg.items()): - # The result(s) returned by findall() are either a single string - # or a tuple (depending on the number of groups etc.) - result = ( - line_result if isinstance(line_result, str) else line_result[idx] - ) - if result: - mapped_result = self._TYPE_MAP[cfg["type"]](result) - metrics[name] = mapped_result - return metrics - - def _verify_config(self, regex_config: Dict[str, Dict[str, str]]) -> None: - """Make sure we have a valid regex_config. - - I.e. - - Exactly one capturing parenthesis per pattern - - Correct types - """ - for name, cfg in regex_config.items(): - # Check that there is one capturing group defined in the pattern. - regex = re.compile(cfg["pattern"]) - if regex.groups != 1: - raise ValueError( - f"Pattern for metric '{name}' must have exactly one " - f"capturing parenthesis, but it has {regex.groups}." - ) - # Check if type is supported - if not cfg["type"] in self._TYPE_MAP: - raise TypeError( - f"Type '{cfg['type']}' for metric '{name}' is not " - f"supported. Choose from: {list(self._TYPE_MAP.keys())}." - ) - - -class Base64OutputParser(OutputParser): - """ - Parser to extract base64-encoded JSON from tagged standard output. - - Example of the tagged output: - ``` - # Encoded JSON: {"test": 1234} - eyJ0ZXN0IjogMTIzNH0 - ``` - """ - - TAG_NAME = "metrics" - - def __init__(self, name: str) -> None: - """Set up the regular expression to extract tagged strings.""" - super().__init__(name) - self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)") - - def __call__(self, output: bytearray) -> Dict[str, Any]: - """ - Parse the output and return a map of index (as string) to decoded JSON. - - Example: - Using the tagged output from the class docs the parser should return - the following dict: - ``` - { - "0": {"test": 1234} - } - ``` - """ - metrics = {} - output_str = output.decode("utf-8") - results = self._regex.findall(output_str) - for idx, result_base64 in enumerate(results): - result_json = base64.b64decode(result_base64, validate=True) - result = json.loads(result_json) - metrics[str(idx)] = result - - return metrics - - def filter_out_parsed_content(self, output: bytearray) -> bytearray: - """Filter out base64-encoded content from the output.""" - output_str = self._regex.sub("", output.decode("utf-8")) - return bytearray(output_str.encode("utf-8")) diff --git a/src/aiet/backend/protocol.py b/src/aiet/backend/protocol.py deleted file mode 100644 index c621436..0000000 --- a/src/aiet/backend/protocol.py +++ /dev/null @@ -1,325 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain protocol related classes and functions.""" -from abc import ABC -from abc import abstractmethod -from contextlib import closing -from pathlib import Path -from typing import Any -from typing import cast -from typing import Iterable -from typing import Optional -from typing import Tuple -from typing import Union - -import paramiko - -from aiet.backend.common import ConfigurationException -from aiet.backend.config import LocalProtocolConfig -from aiet.backend.config import SSHConfig -from aiet.utils.proc import run_and_wait - - -# Redirect all paramiko thread exceptions to a file otherwise these will be -# printed to stderr. -paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO) - - -class SSHConnectionException(Exception): - """SSH connection exception.""" - - -class SupportsClose(ABC): - """Class indicates support of close operation.""" - - @abstractmethod - def close(self) -> None: - """Close protocol session.""" - - -class SupportsDeploy(ABC): - """Class indicates support of deploy operation.""" - - @abstractmethod - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Abstract method for deploy data.""" - - -class SupportsConnection(ABC): - """Class indicates that protocol uses network connections.""" - - @abstractmethod - def establish_connection(self) -> bool: - """Establish connection with underlying system.""" - - @abstractmethod - def connection_details(self) -> Tuple[str, int]: - """Return connection details (host, port).""" - - -class Protocol(ABC): - """Abstract class for representing the protocol.""" - - def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: - """Initialize the class using a dict.""" - self.__dict__.update(iterable, **kwargs) - self._validate() - - @abstractmethod - def _validate(self) -> None: - """Abstract method for config data validation.""" - - @abstractmethod - def run( - self, command: str, retry: bool = False - ) -> Tuple[int, bytearray, bytearray]: - """ - Abstract method for running commands. - - Returns a tuple: (exit_code, stdout, stderr) - """ - - -class CustomSFTPClient(paramiko.SFTPClient): - """Class for creating a custom sftp client.""" - - def put_dir(self, source: Path, target: str) -> None: - """Upload the source directory to the target path. - - The target directory needs to exists and the last directory of the - source will be created under the target with all its content. - """ - # Create the target directory - self._mkdir(target, ignore_existing=True) - # Create the last directory in the source on the target - self._mkdir("{}/{}".format(target, source.name), ignore_existing=True) - # Go through the whole content of source - for item in sorted(source.glob("**/*")): - relative_path = item.relative_to(source.parent) - remote_target = target / relative_path - if item.is_file(): - self.put(str(item), str(remote_target)) - else: - self._mkdir(str(remote_target), ignore_existing=True) - - def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None: - """Extend mkdir functionality. - - This version adds an option to not fail if the folder exists. - """ - try: - super().mkdir(path, mode) - except IOError as error: - if ignore_existing: - pass - else: - raise error - - -class LocalProtocol(Protocol): - """Class for local protocol.""" - - protocol: str - cwd: Path - - def run( - self, command: str, retry: bool = False - ) -> Tuple[int, bytearray, bytearray]: - """ - Run command locally. - - Returns a tuple: (exit_code, stdout, stderr) - """ - if not isinstance(self.cwd, Path) or not self.cwd.is_dir(): - raise ConfigurationException("Wrong working directory {}".format(self.cwd)) - - stdout = bytearray() - stderr = bytearray() - - return run_and_wait( - command, self.cwd, terminate_on_error=True, out=stdout, err=stderr - ) - - def _validate(self) -> None: - """Validate protocol configuration.""" - assert hasattr(self, "protocol") and self.protocol == "local" - assert hasattr(self, "cwd") - - -class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection): - """Class for SSH protocol.""" - - protocol: str - username: str - password: str - hostname: str - port: int - - def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: - """Initialize the class using a dict.""" - super().__init__(iterable, **kwargs) - # Internal state to store if the system is connectable. It will be set - # to true at the first connection instance - self.client: Optional[paramiko.client.SSHClient] = None - self.port = int(self.port) - - def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: - """ - Run command over SSH. - - Returns a tuple: (exit_code, stdout, stderr) - """ - transport = self._get_transport() - with closing(transport.open_session()) as channel: - # Enable shell's .profile settings and execute command - channel.exec_command("bash -l -c '{}'".format(command)) - exit_status = -1 - stdout = bytearray() - stderr = bytearray() - while True: - if channel.exit_status_ready(): - exit_status = channel.recv_exit_status() - # Call it one last time to read any leftover in the channel - self._recv_stdout_err(channel, stdout, stderr) - break - self._recv_stdout_err(channel, stdout, stderr) - - return exit_status, stdout, stderr - - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Deploy src to remote dst over SSH. - - src and dst should be path to a file or directory. - """ - transport = self._get_transport() - sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport)) - - with closing(sftp): - if src.is_dir(): - sftp.put_dir(src, dst) - elif src.is_file(): - sftp.put(str(src), dst) - else: - raise Exception("Deploy error: file type not supported") - - # After the deployment of files, sync the remote filesystem to flush - # buffers to hard disk - self.run("sync") - - def close(self) -> None: - """Close protocol session.""" - if self.client is not None: - print("Try syncing remote file system...") - # Before stopping the system, we try to run sync to make sure all - # data are flushed on disk. - self.run("sync", retry=False) - self._close_client(self.client) - - def establish_connection(self) -> bool: - """Establish connection with underlying system.""" - if self.client is not None: - return True - - self.client = self._connect() - return self.client is not None - - def _get_transport(self) -> paramiko.transport.Transport: - """Get transport.""" - self.establish_connection() - - if self.client is None: - raise SSHConnectionException( - "Couldn't connect to '{}:{}'.".format(self.hostname, self.port) - ) - - transport = self.client.get_transport() - if not transport: - raise Exception("Unable to get transport") - - return transport - - def connection_details(self) -> Tuple[str, int]: - """Return connection details of underlying system.""" - return (self.hostname, self.port) - - def _connect(self) -> Optional[paramiko.client.SSHClient]: - """Try to establish connection.""" - client: Optional[paramiko.client.SSHClient] = None - try: - client = paramiko.client.SSHClient() - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect( - self.hostname, - self.port, - self.username, - self.password, - # next parameters should be set to False to disable authentication - # using ssh keys - allow_agent=False, - look_for_keys=False, - ) - return client - except ( - # OSError raised on first attempt to connect when running inside Docker - OSError, - paramiko.ssh_exception.NoValidConnectionsError, - paramiko.ssh_exception.SSHException, - ): - # even if connection is not established socket could be still - # open, it should be closed - self._close_client(client) - - return None - - @staticmethod - def _close_client(client: Optional[paramiko.client.SSHClient]) -> None: - """Close ssh client.""" - try: - if client is not None: - client.close() - except Exception: # pylint: disable=broad-except - pass - - @classmethod - def _recv_stdout_err( - cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray - ) -> None: - """Read from channel to stdout/stder.""" - chunk_size = 512 - if channel.recv_ready(): - stdout_chunk = channel.recv(chunk_size) - stdout.extend(stdout_chunk) - if channel.recv_stderr_ready(): - stderr_chunk = channel.recv_stderr(chunk_size) - stderr.extend(stderr_chunk) - - def _validate(self) -> None: - """Check if there are all the info for establishing the connection.""" - assert hasattr(self, "protocol") and self.protocol == "ssh" - assert hasattr(self, "username") - assert hasattr(self, "password") - assert hasattr(self, "hostname") - assert hasattr(self, "port") - - -class ProtocolFactory: - """Factory class to return the appropriate Protocol class.""" - - @staticmethod - def get_protocol( - config: Optional[Union[SSHConfig, LocalProtocolConfig]], - **kwargs: Union[str, Path, None] - ) -> Union[SSHProtocol, LocalProtocol]: - """Return the right protocol instance based on the config.""" - if not config: - raise ValueError("No protocol config provided") - - protocol = config["protocol"] - if protocol == "ssh": - return SSHProtocol(config) - - if protocol == "local": - cwd = kwargs.get("cwd") - return LocalProtocol(config, cwd=cwd) - - raise ValueError("Protocol not supported: '{}'".format(protocol)) diff --git a/src/aiet/backend/source.py b/src/aiet/backend/source.py deleted file mode 100644 index dec175a..0000000 --- a/src/aiet/backend/source.py +++ /dev/null @@ -1,209 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Contain source related classes and functions.""" -import os -import shutil -import tarfile -from abc import ABC -from abc import abstractmethod -from pathlib import Path -from tarfile import TarFile -from typing import Optional -from typing import Union - -from aiet.backend.common import AIET_CONFIG_FILE -from aiet.backend.common import ConfigurationException -from aiet.backend.common import get_backend_config -from aiet.backend.common import is_backend_directory -from aiet.backend.common import load_config -from aiet.backend.config import BackendConfig -from aiet.utils.fs import copy_directory_content - - -class Source(ABC): - """Source class.""" - - @abstractmethod - def name(self) -> Optional[str]: - """Get source name.""" - - @abstractmethod - def config(self) -> Optional[BackendConfig]: - """Get configuration file content.""" - - @abstractmethod - def install_into(self, destination: Path) -> None: - """Install source into destination directory.""" - - @abstractmethod - def create_destination(self) -> bool: - """Return True if destination folder should be created before installation.""" - - -class DirectorySource(Source): - """DirectorySource class.""" - - def __init__(self, directory_path: Path) -> None: - """Create the DirectorySource instance.""" - assert isinstance(directory_path, Path) - self.directory_path = directory_path - - def name(self) -> str: - """Return name of source.""" - return self.directory_path.name - - def config(self) -> Optional[BackendConfig]: - """Return configuration file content.""" - if not is_backend_directory(self.directory_path): - raise ConfigurationException("No configuration file found") - - config_file = get_backend_config(self.directory_path) - return load_config(config_file) - - def install_into(self, destination: Path) -> None: - """Install source into destination directory.""" - if not destination.is_dir(): - raise ConfigurationException("Wrong destination {}".format(destination)) - - if not self.directory_path.is_dir(): - raise ConfigurationException( - "Directory {} does not exist".format(self.directory_path) - ) - - copy_directory_content(self.directory_path, destination) - - def create_destination(self) -> bool: - """Return True if destination folder should be created before installation.""" - return True - - -class TarArchiveSource(Source): - """TarArchiveSource class.""" - - def __init__(self, archive_path: Path) -> None: - """Create the TarArchiveSource class.""" - assert isinstance(archive_path, Path) - self.archive_path = archive_path - self._config: Optional[BackendConfig] = None - self._has_top_level_folder: Optional[bool] = None - self._name: Optional[str] = None - - def _read_archive_content(self) -> None: - """Read various information about archive.""" - # get source name from archive name (everything without extensions) - extensions = "".join(self.archive_path.suffixes) - self._name = self.archive_path.name.rstrip(extensions) - - if not self.archive_path.exists(): - return - - with self._open(self.archive_path) as archive: - try: - config_entry = archive.getmember(AIET_CONFIG_FILE) - self._has_top_level_folder = False - except KeyError as error_no_config: - try: - archive_entries = archive.getnames() - entries_common_prefix = os.path.commonprefix(archive_entries) - top_level_dir = entries_common_prefix.rstrip("/") - - if not top_level_dir: - raise RuntimeError( - "Archive has no top level directory" - ) from error_no_config - - config_path = "{}/{}".format(top_level_dir, AIET_CONFIG_FILE) - - config_entry = archive.getmember(config_path) - self._has_top_level_folder = True - self._name = top_level_dir - except (KeyError, RuntimeError) as error_no_root_dir_or_config: - raise ConfigurationException( - "No configuration file found" - ) from error_no_root_dir_or_config - - content = archive.extractfile(config_entry) - self._config = load_config(content) - - def config(self) -> Optional[BackendConfig]: - """Return configuration file content.""" - if self._config is None: - self._read_archive_content() - - return self._config - - def name(self) -> Optional[str]: - """Return name of the source.""" - if self._name is None: - self._read_archive_content() - - return self._name - - def create_destination(self) -> bool: - """Return True if destination folder must be created before installation.""" - if self._has_top_level_folder is None: - self._read_archive_content() - - return not self._has_top_level_folder - - def install_into(self, destination: Path) -> None: - """Install source into destination directory.""" - if not destination.is_dir(): - raise ConfigurationException("Wrong destination {}".format(destination)) - - with self._open(self.archive_path) as archive: - archive.extractall(destination) - - def _open(self, archive_path: Path) -> TarFile: - """Open archive file.""" - if not archive_path.is_file(): - raise ConfigurationException("File {} does not exist".format(archive_path)) - - if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"): - mode = "r:gz" - else: - raise ConfigurationException( - "Unsupported archive type {}".format(archive_path) - ) - - # The returned TarFile object can be used as a context manager (using - # 'with') by the calling instance. - return tarfile.open( # pylint: disable=consider-using-with - self.archive_path, mode=mode - ) - - -def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]: - """Return appropriate source instance based on provided source path.""" - if source_path.is_file(): - return TarArchiveSource(source_path) - - if source_path.is_dir(): - return DirectorySource(source_path) - - raise ConfigurationException("Unable to read {}".format(source_path)) - - -def create_destination_and_install(source: Source, resource_path: Path) -> None: - """Create destination directory and install source. - - This function is used for actual installation of system/backend New - directory will be created inside :resource_path: if needed If for example - archive contains top level folder then no need to create new directory - """ - destination = resource_path - create_destination = source.create_destination() - - if create_destination: - name = source.name() - if not name: - raise ConfigurationException("Unable to get source name") - - destination = resource_path / name - destination.mkdir() - try: - source.install_into(destination) - except Exception as error: - if create_destination: - shutil.rmtree(destination) - raise error diff --git a/src/aiet/backend/system.py b/src/aiet/backend/system.py deleted file mode 100644 index 48f1bb1..0000000 --- a/src/aiet/backend/system.py +++ /dev/null @@ -1,289 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""System backend module.""" -from pathlib import Path -from typing import Any -from typing import cast -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import get_backend_configs -from aiet.backend.common import get_backend_directories -from aiet.backend.common import load_config -from aiet.backend.common import remove_backend -from aiet.backend.config import SystemConfig -from aiet.backend.controller import SystemController -from aiet.backend.controller import SystemControllerSingleInstance -from aiet.backend.protocol import ProtocolFactory -from aiet.backend.protocol import SupportsClose -from aiet.backend.protocol import SupportsConnection -from aiet.backend.protocol import SupportsDeploy -from aiet.backend.source import create_destination_and_install -from aiet.backend.source import get_source -from aiet.utils.fs import get_resources - - -def get_available_systems_directory_names() -> List[str]: - """Return a list of directory names for all avialable systems.""" - return [entry.name for entry in get_backend_directories("systems")] - - -def get_available_systems() -> List["System"]: - """Return a list with all available systems.""" - available_systems = [] - for config_json in get_backend_configs("systems"): - config_entries = cast(List[SystemConfig], (load_config(config_json))) - for config_entry in config_entries: - config_entry["config_location"] = config_json.parent.absolute() - system = load_system(config_entry) - available_systems.append(system) - - return sorted(available_systems, key=lambda system: system.name) - - -def get_system(system_name: str) -> Optional["System"]: - """Return a system instance with the same name passed as argument.""" - available_systems = get_available_systems() - for system in available_systems: - if system_name == system.name: - return system - return None - - -def install_system(source_path: Path) -> None: - """Install new system.""" - try: - source = get_source(source_path) - config = cast(List[SystemConfig], source.config()) - systems_to_install = [load_system(entry) for entry in config] - except Exception as error: - raise ConfigurationException("Unable to read system definition") from error - - if not systems_to_install: - raise ConfigurationException("No system definition found") - - available_systems = get_available_systems() - already_installed = [s for s in systems_to_install if s in available_systems] - if already_installed: - names = [system.name for system in already_installed] - raise ConfigurationException( - "Systems [{}] are already installed".format(",".join(names)) - ) - - create_destination_and_install(source, get_resources("systems")) - - -def remove_system(directory_name: str) -> None: - """Remove system.""" - remove_backend(directory_name, "systems") - - -class System(Backend): - """System class.""" - - def __init__(self, config: SystemConfig) -> None: - """Construct the System class using the dictionary passed.""" - super().__init__(config) - - self._setup_data_transfer(config) - self._setup_reporting(config) - - def _setup_data_transfer(self, config: SystemConfig) -> None: - data_transfer_config = config.get("data_transfer") - protocol = ProtocolFactory().get_protocol( - data_transfer_config, cwd=self.config_location - ) - self.protocol = protocol - - def _setup_reporting(self, config: SystemConfig) -> None: - self.reporting = config.get("reporting") - - def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: - """ - Run command on the system. - - Returns a tuple: (exit_code, stdout, stderr) - """ - return self.protocol.run(command, retry) - - def deploy(self, src: Path, dst: str, retry: bool = True) -> None: - """Deploy files to the system.""" - if isinstance(self.protocol, SupportsDeploy): - self.protocol.deploy(src, dst, retry) - - @property - def supports_deploy(self) -> bool: - """Check if protocol supports deploy operation.""" - return isinstance(self.protocol, SupportsDeploy) - - @property - def connectable(self) -> bool: - """Check if protocol supports connection.""" - return isinstance(self.protocol, SupportsConnection) - - def establish_connection(self) -> bool: - """Establish connection with the system.""" - if not isinstance(self.protocol, SupportsConnection): - raise ConfigurationException( - "System {} does not support connections".format(self.name) - ) - - return self.protocol.establish_connection() - - def connection_details(self) -> Tuple[str, int]: - """Return connection details.""" - if not isinstance(self.protocol, SupportsConnection): - raise ConfigurationException( - "System {} does not support connections".format(self.name) - ) - - return self.protocol.connection_details() - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, System): - return False - - return super().__eq__(other) and self.name == other.name - - def get_details(self) -> Dict[str, Any]: - """Return a dictionary with all relevant information of a System.""" - output = { - "type": "system", - "name": self.name, - "description": self.description, - "data_transfer_protocol": self.protocol.protocol, - "commands": self._get_command_details(), - "annotations": self.annotations, - } - - return output - - -class StandaloneSystem(System): - """StandaloneSystem class.""" - - -def get_controller( - single_instance: bool, pid_file_path: Optional[Path] = None -) -> SystemController: - """Get system controller.""" - if single_instance: - return SystemControllerSingleInstance(pid_file_path) - - return SystemController() - - -class ControlledSystem(System): - """ControlledSystem class.""" - - def __init__(self, config: SystemConfig): - """Construct the ControlledSystem class using the dictionary passed.""" - super().__init__(config) - self.controller: Optional[SystemController] = None - - def start( - self, - commands: List[str], - single_instance: bool = True, - pid_file_path: Optional[Path] = None, - ) -> None: - """Launch the system.""" - if ( - not isinstance(self.config_location, Path) - or not self.config_location.is_dir() - ): - raise ConfigurationException( - "System {} has wrong config location".format(self.name) - ) - - self.controller = get_controller(single_instance, pid_file_path) - self.controller.start(commands, self.config_location) - - def is_running(self) -> bool: - """Check if system is running.""" - if not self.controller: - return False - - return self.controller.is_running() - - def get_output(self) -> Tuple[str, str]: - """Return system output.""" - if not self.controller: - return "", "" - - return self.controller.get_output() - - def stop(self, wait: bool = False) -> None: - """Stop the system.""" - if not self.controller: - raise Exception("System has not been started") - - if isinstance(self.protocol, SupportsClose): - try: - self.protocol.close() - except Exception as error: # pylint: disable=broad-except - print(error) - self.controller.stop(wait) - - -def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]: - """Load system based on it's execution type.""" - data_transfer = config.get("data_transfer", {}) - protocol = data_transfer.get("protocol") - populate_shared_params(config) - - if protocol == "ssh": - return ControlledSystem(config) - - if protocol == "local": - return StandaloneSystem(config) - - raise ConfigurationException( - "Unsupported execution type for protocol {}".format(protocol) - ) - - -def populate_shared_params(config: SystemConfig) -> None: - """Populate command parameters with shared parameters.""" - user_params = config.get("user_params") - if not user_params or "shared" not in user_params: - return - - shared_user_params = user_params["shared"] - if not shared_user_params: - return - - only_aliases = all(p.get("alias") for p in shared_user_params) - if not only_aliases: - raise ConfigurationException("All shared parameters should have aliases") - - commands = config.get("commands", {}) - for cmd_name in ["build", "run"]: - command = commands.get(cmd_name) - if command is None: - commands[cmd_name] = [] - cmd_user_params = user_params.get(cmd_name) - if not cmd_user_params: - cmd_user_params = shared_user_params - else: - only_aliases = all(p.get("alias") for p in cmd_user_params) - if not only_aliases: - raise ConfigurationException( - "All parameters for command {} should have aliases".format(cmd_name) - ) - merged_by_alias = { - **{p.get("alias"): p for p in shared_user_params}, - **{p.get("alias"): p for p in cmd_user_params}, - } - cmd_user_params = list(merged_by_alias.values()) - - user_params[cmd_name] = cmd_user_params - - config["commands"] = commands - del user_params["shared"] diff --git a/src/aiet/backend/tool.py b/src/aiet/backend/tool.py deleted file mode 100644 index d643665..0000000 --- a/src/aiet/backend/tool.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tool backend module.""" -from typing import Any -from typing import cast -from typing import Dict -from typing import List -from typing import Optional - -from aiet.backend.common import Backend -from aiet.backend.common import ConfigurationException -from aiet.backend.common import get_backend_configs -from aiet.backend.common import get_backend_directories -from aiet.backend.common import load_application_or_tool_configs -from aiet.backend.common import load_config -from aiet.backend.config import ExtendedToolConfig -from aiet.backend.config import ToolConfig - - -def get_available_tool_directory_names() -> List[str]: - """Return a list of directory names for all available tools.""" - return [entry.name for entry in get_backend_directories("tools")] - - -def get_available_tools() -> List["Tool"]: - """Return a list with all available tools.""" - available_tools = [] - for config_json in get_backend_configs("tools"): - config_entries = cast(List[ExtendedToolConfig], load_config(config_json)) - for config_entry in config_entries: - config_entry["config_location"] = config_json.parent.absolute() - tools = load_tools(config_entry) - available_tools += tools - - return sorted(available_tools, key=lambda tool: tool.name) - - -def get_tool(tool_name: str, system_name: Optional[str] = None) -> List["Tool"]: - """Return a tool instance with the same name passed as argument.""" - return [ - tool - for tool in get_available_tools() - if tool.name == tool_name and (not system_name or tool.can_run_on(system_name)) - ] - - -def get_unique_tool_names(system_name: Optional[str] = None) -> List[str]: - """Extract a list of unique tool names of all tools available.""" - return list( - set( - tool.name - for tool in get_available_tools() - if not system_name or tool.can_run_on(system_name) - ) - ) - - -class Tool(Backend): - """Class for representing a single tool component.""" - - def __init__(self, config: ToolConfig) -> None: - """Construct a Tool instance from a dict.""" - super().__init__(config) - - self.supported_systems = config.get("supported_systems", []) - - if "run" not in self.commands: - raise ConfigurationException("A Tool must have a 'run' command.") - - def __eq__(self, other: object) -> bool: - """Overload operator ==.""" - if not isinstance(other, Tool): - return False - - return ( - super().__eq__(other) - and self.name == other.name - and set(self.supported_systems) == set(other.supported_systems) - ) - - def can_run_on(self, system_name: str) -> bool: - """Check if the tool can run on the system passed as argument.""" - return system_name in self.supported_systems - - def get_details(self) -> Dict[str, Any]: - """Return dictionary with all relevant information of the Tool instance.""" - output = { - "type": "tool", - "name": self.name, - "description": self.description, - "supported_systems": self.supported_systems, - "commands": self._get_command_details(), - } - - return output - - -def load_tools(config: ExtendedToolConfig) -> List[Tool]: - """Load tool. - - Tool configuration could contain different parameters/commands for different - supported systems. For each supported system this function will return separate - Tool instance with appropriate configuration. - """ - configs = load_application_or_tool_configs( - config, ToolConfig, is_system_required=False - ) - tools = [Tool(cfg) for cfg in configs] - return tools diff --git a/src/aiet/cli/__init__.py b/src/aiet/cli/__init__.py deleted file mode 100644 index bcd17c3..0000000 --- a/src/aiet/cli/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module to mange the CLI interface.""" -import click - -from aiet import __version__ -from aiet.cli.application import application_cmd -from aiet.cli.completion import completion_cmd -from aiet.cli.system import system_cmd -from aiet.cli.tool import tool_cmd -from aiet.utils.helpers import set_verbosity - - -@click.group() -@click.version_option(__version__) -@click.option( - "-v", "--verbose", default=0, count=True, callback=set_verbosity, expose_value=False -) -@click.pass_context -def cli(ctx: click.Context) -> None: # pylint: disable=unused-argument - """AIET: AI Evaluation Toolkit.""" - # Unused arguments must be present here in definition to pass click context. - - -cli.add_command(application_cmd) -cli.add_command(system_cmd) -cli.add_command(tool_cmd) -cli.add_command(completion_cmd) diff --git a/src/aiet/cli/application.py b/src/aiet/cli/application.py deleted file mode 100644 index 59b652d..0000000 --- a/src/aiet/cli/application.py +++ /dev/null @@ -1,362 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-FileCopyrightText: Copyright (c) 2021, Gianluca Gippetto. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause -"""Module to manage the CLI interface of applications.""" -import json -import logging -import re -from pathlib import Path -from typing import Any -from typing import IO -from typing import List -from typing import Optional -from typing import Tuple - -import click -import cloup - -from aiet.backend.application import get_application -from aiet.backend.application import get_available_application_directory_names -from aiet.backend.application import get_unique_application_names -from aiet.backend.application import install_application -from aiet.backend.application import remove_application -from aiet.backend.common import DataPaths -from aiet.backend.execution import execute_application_command -from aiet.backend.execution import run_application -from aiet.backend.system import get_available_systems -from aiet.cli.common import get_format -from aiet.cli.common import middleware_exception_handler -from aiet.cli.common import middleware_signal_handler -from aiet.cli.common import print_command_details -from aiet.cli.common import set_format - - -@click.group(name="application") -@click.option( - "-f", - "--format", - "format_", - type=click.Choice(["cli", "json"]), - default="cli", - show_default=True, -) -@click.pass_context -def application_cmd(ctx: click.Context, format_: str) -> None: - """Sub command to manage applications.""" - set_format(ctx, format_) - - -@application_cmd.command(name="list") -@click.pass_context -@click.option( - "-s", - "--system", - "system_name", - type=click.Choice([s.name for s in get_available_systems()]), - required=False, -) -def list_cmd(ctx: click.Context, system_name: str) -> None: - """List all available applications.""" - unique_application_names = get_unique_application_names(system_name) - unique_application_names.sort() - if get_format(ctx) == "json": - data = {"type": "application", "available": unique_application_names} - print(json.dumps(data)) - else: - print("Available applications:\n") - print(*unique_application_names, sep="\n") - - -@application_cmd.command(name="details") -@click.option( - "-n", - "--name", - "application_name", - type=click.Choice(get_unique_application_names()), - required=True, -) -@click.option( - "-s", - "--system", - "system_name", - type=click.Choice([s.name for s in get_available_systems()]), - required=False, -) -@click.pass_context -def details_cmd(ctx: click.Context, application_name: str, system_name: str) -> None: - """Details of a specific application.""" - applications = get_application(application_name, system_name) - if not applications: - raise click.UsageError( - "Application '{}' doesn't support the system '{}'".format( - application_name, system_name - ) - ) - - if get_format(ctx) == "json": - applications_details = [s.get_details() for s in applications] - print(json.dumps(applications_details)) - else: - for application in applications: - application_details = application.get_details() - application_details_template = ( - 'Application "{name}" details\nDescription: {description}' - ) - - print( - application_details_template.format( - name=application_details["name"], - description=application_details["description"], - ) - ) - - print( - "\nSupported systems: {}".format( - ", ".join(application_details["supported_systems"]) - ) - ) - - command_details = application_details["commands"] - - for command, details in command_details.items(): - print("\n{} commands:".format(command)) - print_command_details(details) - - -# pylint: disable=too-many-arguments -@application_cmd.command(name="execute") -@click.option( - "-n", - "--name", - "application_name", - type=click.Choice(get_unique_application_names()), - required=True, -) -@click.option( - "-s", - "--system", - "system_name", - type=click.Choice([s.name for s in get_available_systems()]), - required=True, -) -@click.option( - "-c", - "--command", - "command_name", - type=click.Choice(["build", "run"]), - required=True, -) -@click.option("-p", "--param", "application_params", multiple=True) -@click.option("--system-param", "system_params", multiple=True) -@click.option("-d", "--deploy", "deploy_params", multiple=True) -@middleware_signal_handler -@middleware_exception_handler -def execute_cmd( - application_name: str, - system_name: str, - command_name: str, - application_params: List[str], - system_params: List[str], - deploy_params: List[str], -) -> None: - """Execute application commands. DEPRECATED! Use 'aiet application run' instead.""" - logging.warning( - "Please use 'aiet application run' instead. Use of 'aiet application " - "execute' is deprecated and might be removed in a future release." - ) - - custom_deploy_data = get_custom_deploy_data(command_name, deploy_params) - - execute_application_command( - command_name, - application_name, - application_params, - system_name, - system_params, - custom_deploy_data, - ) - - -@cloup.command(name="run") -@cloup.option( - "-n", - "--name", - "application_name", - type=click.Choice(get_unique_application_names()), -) -@cloup.option( - "-s", - "--system", - "system_name", - type=click.Choice([s.name for s in get_available_systems()]), -) -@cloup.option("-p", "--param", "application_params", multiple=True) -@cloup.option("--system-param", "system_params", multiple=True) -@cloup.option("-d", "--deploy", "deploy_params", multiple=True) -@click.option( - "-r", - "--report", - "report_file", - type=Path, - help="Create a report file in JSON format containing metrics parsed from " - "the simulation output as specified in the aiet-config.json.", -) -@cloup.option( - "--config", - "config_file", - type=click.File("r"), - help="Read options from a config file rather than from the command line. " - "The config file is a json file.", -) -@cloup.constraint( - cloup.constraints.If( - cloup.constraints.conditions.Not( - cloup.constraints.conditions.IsSet("config_file") - ), - then=cloup.constraints.require_all, - ), - ["system_name", "application_name"], -) -@cloup.constraint( - cloup.constraints.If("config_file", then=cloup.constraints.accept_none), - [ - "system_name", - "application_name", - "application_params", - "system_params", - "deploy_params", - ], -) -@middleware_signal_handler -@middleware_exception_handler -def run_cmd( - application_name: str, - system_name: str, - application_params: List[str], - system_params: List[str], - deploy_params: List[str], - report_file: Optional[Path], - config_file: Optional[IO[str]], -) -> None: - """Execute application commands.""" - if config_file: - payload_data = json.load(config_file) - ( - system_name, - application_name, - application_params, - system_params, - deploy_params, - report_file, - ) = parse_payload_run_config(payload_data) - - custom_deploy_data = get_custom_deploy_data("run", deploy_params) - - run_application( - application_name, - application_params, - system_name, - system_params, - custom_deploy_data, - report_file, - ) - - -application_cmd.add_command(run_cmd) - - -def parse_payload_run_config( - payload_data: dict, -) -> Tuple[str, str, List[str], List[str], List[str], Optional[Path]]: - """Parse the payload into a tuple.""" - system_id = payload_data.get("id") - arguments: Optional[Any] = payload_data.get("arguments") - - if not isinstance(system_id, str): - raise click.ClickException("invalid payload json: no system 'id'") - if not isinstance(arguments, dict): - raise click.ClickException("invalid payload json: no arguments object") - - application_name = arguments.pop("application", None) - if not isinstance(application_name, str): - raise click.ClickException("invalid payload json: no application_id") - - report_path = arguments.pop("report_path", None) - - application_params = [] - system_params = [] - deploy_params = [] - - for (param_key, value) in arguments.items(): - (par, _) = re.subn("^application/", "", param_key) - (par, found_sys_param) = re.subn("^system/", "", par) - (par, found_deploy_param) = re.subn("^deploy/", "", par) - - param_expr = par + "=" + value - if found_sys_param: - system_params.append(param_expr) - elif found_deploy_param: - deploy_params.append(par) - else: - application_params.append(param_expr) - - return ( - system_id, - application_name, - application_params, - system_params, - deploy_params, - report_path, - ) - - -def get_custom_deploy_data( - command_name: str, deploy_params: List[str] -) -> List[DataPaths]: - """Get custom deploy data information.""" - custom_deploy_data: List[DataPaths] = [] - if not deploy_params: - return custom_deploy_data - - for param in deploy_params: - parts = param.split(":") - if not len(parts) == 2 or any(not part.strip() for part in parts): - raise click.ClickException( - "Invalid deploy parameter '{}' for command {}".format( - param, command_name - ) - ) - data_path = DataPaths(Path(parts[0]), parts[1]) - if not data_path.src.exists(): - raise click.ClickException("Path {} does not exist".format(data_path.src)) - custom_deploy_data.append(data_path) - - return custom_deploy_data - - -@application_cmd.command(name="install") -@click.option( - "-s", - "--source", - "source", - required=True, - help="Path to the directory or archive with application definition", -) -def install_cmd(source: str) -> None: - """Install new application.""" - source_path = Path(source) - install_application(source_path) - - -@application_cmd.command(name="remove") -@click.option( - "-d", - "--directory_name", - "directory_name", - type=click.Choice(get_available_application_directory_names()), - required=True, - help="Name of the directory with application", -) -def remove_cmd(directory_name: str) -> None: - """Remove application.""" - remove_application(directory_name) diff --git a/src/aiet/cli/common.py b/src/aiet/cli/common.py deleted file mode 100644 index 1d157b6..0000000 --- a/src/aiet/cli/common.py +++ /dev/null @@ -1,173 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Common functions for cli module.""" -import enum -import logging -from functools import wraps -from signal import SIG_IGN -from signal import SIGINT -from signal import signal as signal_handler -from signal import SIGTERM -from typing import Any -from typing import Callable -from typing import cast -from typing import Dict - -from click import ClickException -from click import Context -from click import UsageError - -from aiet.backend.common import ConfigurationException -from aiet.backend.execution import AnotherInstanceIsRunningException -from aiet.backend.execution import ConnectionException -from aiet.backend.protocol import SSHConnectionException -from aiet.utils.proc import CommandFailedException - - -class MiddlewareExitCode(enum.IntEnum): - """Middleware exit codes.""" - - SUCCESS = 0 - # exit codes 1 and 2 are used by click - SHUTDOWN_REQUESTED = 3 - BACKEND_ERROR = 4 - CONCURRENT_ERROR = 5 - CONNECTION_ERROR = 6 - CONFIGURATION_ERROR = 7 - MODEL_OPTIMISED_ERROR = 8 - INVALID_TFLITE_FILE_ERROR = 9 - - -class CustomClickException(ClickException): - """Custom click exception.""" - - def show(self, file: Any = None) -> None: - """Override show method.""" - super().show(file) - - logging.debug("Execution failed with following exception: ", exc_info=self) - - -class MiddlewareShutdownException(CustomClickException): - """Exception indicates that user requested middleware shutdown.""" - - exit_code = int(MiddlewareExitCode.SHUTDOWN_REQUESTED) - - -class BackendException(CustomClickException): - """Exception indicates that command failed.""" - - exit_code = int(MiddlewareExitCode.BACKEND_ERROR) - - -class ConcurrentErrorException(CustomClickException): - """Exception indicates concurrent execution error.""" - - exit_code = int(MiddlewareExitCode.CONCURRENT_ERROR) - - -class BackendConnectionException(CustomClickException): - """Exception indicates that connection could not be established.""" - - exit_code = int(MiddlewareExitCode.CONNECTION_ERROR) - - -class BackendConfigurationException(CustomClickException): - """Exception indicates some configuration issue.""" - - exit_code = int(MiddlewareExitCode.CONFIGURATION_ERROR) - - -class ModelOptimisedException(CustomClickException): - """Exception indicates input file has previously been Vela optimised.""" - - exit_code = int(MiddlewareExitCode.MODEL_OPTIMISED_ERROR) - - -class InvalidTFLiteFileError(CustomClickException): - """Exception indicates input TFLite file is misformatted.""" - - exit_code = int(MiddlewareExitCode.INVALID_TFLITE_FILE_ERROR) - - -def print_command_details(command: Dict) -> None: - """Print command details including parameters.""" - command_strings = command["command_strings"] - print("Commands: {}".format(command_strings)) - user_params = command["user_params"] - for i, param in enumerate(user_params, 1): - print("User parameter #{}".format(i)) - print("\tName: {}".format(param.get("name", "-"))) - print("\tDescription: {}".format(param["description"])) - print("\tPossible values: {}".format(param.get("values", "-"))) - print("\tDefault value: {}".format(param.get("default_value", "-"))) - print("\tAlias: {}".format(param.get("alias", "-"))) - - -def raise_exception_at_signal( - signum: int, frame: Any # pylint: disable=unused-argument -) -> None: - """Handle signals.""" - # Disable both SIGINT and SIGTERM signals. Further SIGINT and SIGTERM - # signals will be ignored as we allow a graceful shutdown. - # Unused arguments must be present here in definition as used in signal handler - # callback - - signal_handler(SIGINT, SIG_IGN) - signal_handler(SIGTERM, SIG_IGN) - raise MiddlewareShutdownException("Middleware shutdown requested") - - -def middleware_exception_handler(func: Callable) -> Callable: - """Handle backend exceptions decorator.""" - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - try: - return func(*args, **kwargs) - except (MiddlewareShutdownException, UsageError, ClickException) as error: - # click should take care of these exceptions - raise error - except ValueError as error: - raise ClickException(str(error)) from error - except AnotherInstanceIsRunningException as error: - raise ConcurrentErrorException( - "Another instance of the system is running" - ) from error - except (SSHConnectionException, ConnectionException) as error: - raise BackendConnectionException(str(error)) from error - except ConfigurationException as error: - raise BackendConfigurationException(str(error)) from error - except (CommandFailedException, Exception) as error: - raise BackendException( - "Execution failed. Please check output for the details." - ) from error - - return wrapper - - -def middleware_signal_handler(func: Callable) -> Callable: - """Handle signals decorator.""" - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - # Set up signal handlers for SIGINT (ctrl-c) and SIGTERM (kill command) - # The handler ignores further signals and it raises an exception - signal_handler(SIGINT, raise_exception_at_signal) - signal_handler(SIGTERM, raise_exception_at_signal) - - return func(*args, **kwargs) - - return wrapper - - -def set_format(ctx: Context, format_: str) -> None: - """Save format in click context.""" - ctx_obj = ctx.ensure_object(dict) - ctx_obj["format"] = format_ - - -def get_format(ctx: Context) -> str: - """Get format from click context.""" - ctx_obj = cast(Dict[str, str], ctx.ensure_object(dict)) - return ctx_obj["format"] diff --git a/src/aiet/cli/completion.py b/src/aiet/cli/completion.py deleted file mode 100644 index 71f054f..0000000 --- a/src/aiet/cli/completion.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -""" -Add auto completion to different shells with these helpers. - -See: https://click.palletsprojects.com/en/8.0.x/shell-completion/ -""" -import click - - -def _get_package_name() -> str: - return __name__.split(".", maxsplit=1)[0] - - -# aiet completion bash -@click.group(name="completion") -def completion_cmd() -> None: - """Enable auto completion for your shell.""" - - -@completion_cmd.command(name="bash") -def bash_cmd() -> None: - """ - Enable auto completion for bash. - - Use this command to activate completion in the current bash: - - eval "`aiet completion bash`" - - Use this command to add auto completion to bash globally, if you have aiet - installed globally (requires starting a new shell afterwards): - - aiet completion bash >> ~/.bashrc - """ - package_name = _get_package_name() - print(f'eval "$(_{package_name.upper()}_COMPLETE=bash_source {package_name})"') - - -@completion_cmd.command(name="zsh") -def zsh_cmd() -> None: - """ - Enable auto completion for zsh. - - Use this command to activate completion in the current zsh: - - eval "`aiet completion zsh`" - - Use this command to add auto completion to zsh globally, if you have aiet - installed globally (requires starting a new shell afterwards): - - aiet completion zsh >> ~/.zshrc - """ - package_name = _get_package_name() - print(f'eval "$(_{package_name.upper()}_COMPLETE=zsh_source {package_name})"') - - -@completion_cmd.command(name="fish") -def fish_cmd() -> None: - """ - Enable auto completion for fish. - - Use this command to activate completion in the current fish: - - eval "`aiet completion fish`" - - Use this command to add auto completion to fish globally, if you have aiet - installed globally (requires starting a new shell afterwards): - - aiet completion fish >> ~/.config/fish/completions/aiet.fish - """ - package_name = _get_package_name() - print(f'eval "(env _{package_name.upper()}_COMPLETE=fish_source {package_name})"') diff --git a/src/aiet/cli/system.py b/src/aiet/cli/system.py deleted file mode 100644 index f1f7637..0000000 --- a/src/aiet/cli/system.py +++ /dev/null @@ -1,122 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module to manage the CLI interface of systems.""" -import json -from pathlib import Path -from typing import cast - -import click - -from aiet.backend.application import get_available_applications -from aiet.backend.system import get_available_systems -from aiet.backend.system import get_available_systems_directory_names -from aiet.backend.system import get_system -from aiet.backend.system import install_system -from aiet.backend.system import remove_system -from aiet.backend.system import System -from aiet.cli.common import get_format -from aiet.cli.common import print_command_details -from aiet.cli.common import set_format - - -@click.group(name="system") -@click.option( - "-f", - "--format", - "format_", - type=click.Choice(["cli", "json"]), - default="cli", - show_default=True, -) -@click.pass_context -def system_cmd(ctx: click.Context, format_: str) -> None: - """Sub command to manage systems.""" - set_format(ctx, format_) - - -@system_cmd.command(name="list") -@click.pass_context -def list_cmd(ctx: click.Context) -> None: - """List all available systems.""" - available_systems = get_available_systems() - system_names = [system.name for system in available_systems] - if get_format(ctx) == "json": - data = {"type": "system", "available": system_names} - print(json.dumps(data)) - else: - print("Available systems:\n") - print(*system_names, sep="\n") - - -@system_cmd.command(name="details") -@click.option( - "-n", - "--name", - "system_name", - type=click.Choice([s.name for s in get_available_systems()]), - required=True, -) -@click.pass_context -def details_cmd(ctx: click.Context, system_name: str) -> None: - """Details of a specific system.""" - system = cast(System, get_system(system_name)) - applications = [ - s.name for s in get_available_applications() if s.can_run_on(system.name) - ] - system_details = system.get_details() - if get_format(ctx) == "json": - system_details["available_application"] = applications - print(json.dumps(system_details)) - else: - system_details_template = ( - 'System "{name}" details\n' - "Description: {description}\n" - "Data Transfer Protocol: {protocol}\n" - "Available Applications: {available_application}" - ) - print( - system_details_template.format( - name=system_details["name"], - description=system_details["description"], - protocol=system_details["data_transfer_protocol"], - available_application=", ".join(applications), - ) - ) - - if system_details["annotations"]: - print("Annotations:") - for ann_name, ann_value in system_details["annotations"].items(): - print("\t{}: {}".format(ann_name, ann_value)) - - command_details = system_details["commands"] - for command, details in command_details.items(): - print("\n{} commands:".format(command)) - print_command_details(details) - - -@system_cmd.command(name="install") -@click.option( - "-s", - "--source", - "source", - required=True, - help="Path to the directory or archive with system definition", -) -def install_cmd(source: str) -> None: - """Install new system.""" - source_path = Path(source) - install_system(source_path) - - -@system_cmd.command(name="remove") -@click.option( - "-d", - "--directory_name", - "directory_name", - type=click.Choice(get_available_systems_directory_names()), - required=True, - help="Name of the directory with system", -) -def remove_cmd(directory_name: str) -> None: - """Remove system by given name.""" - remove_system(directory_name) diff --git a/src/aiet/cli/tool.py b/src/aiet/cli/tool.py deleted file mode 100644 index 2c80821..0000000 --- a/src/aiet/cli/tool.py +++ /dev/null @@ -1,143 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module to manage the CLI interface of tools.""" -import json -from typing import Any -from typing import List -from typing import Optional - -import click - -from aiet.backend.execution import execute_tool_command -from aiet.backend.tool import get_tool -from aiet.backend.tool import get_unique_tool_names -from aiet.cli.common import get_format -from aiet.cli.common import middleware_exception_handler -from aiet.cli.common import middleware_signal_handler -from aiet.cli.common import print_command_details -from aiet.cli.common import set_format - - -@click.group(name="tool") -@click.option( - "-f", - "--format", - "format_", - type=click.Choice(["cli", "json"]), - default="cli", - show_default=True, -) -@click.pass_context -def tool_cmd(ctx: click.Context, format_: str) -> None: - """Sub command to manage tools.""" - set_format(ctx, format_) - - -@tool_cmd.command(name="list") -@click.pass_context -def list_cmd(ctx: click.Context) -> None: - """List all available tools.""" - # raise NotImplementedError("TODO") - tool_names = get_unique_tool_names() - tool_names.sort() - if get_format(ctx) == "json": - data = {"type": "tool", "available": tool_names} - print(json.dumps(data)) - else: - print("Available tools:\n") - print(*tool_names, sep="\n") - - -def validate_system( - ctx: click.Context, - _: click.Parameter, # param is not used - value: Any, -) -> Any: - """Validate provided system name depending on the the tool name.""" - tool_name = ctx.params["tool_name"] - tools = get_tool(tool_name, value) - if not tools: - supported_systems = [tool.supported_systems[0] for tool in get_tool(tool_name)] - raise click.BadParameter( - message="'{}' is not one of {}.".format( - value, - ", ".join("'{}'".format(system) for system in supported_systems), - ), - ctx=ctx, - ) - return value - - -@tool_cmd.command(name="details") -@click.option( - "-n", - "--name", - "tool_name", - type=click.Choice(get_unique_tool_names()), - required=True, -) -@click.option( - "-s", - "--system", - "system_name", - callback=validate_system, - required=False, -) -@click.pass_context -@middleware_signal_handler -@middleware_exception_handler -def details_cmd(ctx: click.Context, tool_name: str, system_name: Optional[str]) -> None: - """Details of a specific tool.""" - tools = get_tool(tool_name, system_name) - if get_format(ctx) == "json": - tools_details = [s.get_details() for s in tools] - print(json.dumps(tools_details)) - else: - for tool in tools: - tool_details = tool.get_details() - tool_details_template = 'Tool "{name}" details\nDescription: {description}' - - print( - tool_details_template.format( - name=tool_details["name"], - description=tool_details["description"], - ) - ) - - print( - "\nSupported systems: {}".format( - ", ".join(tool_details["supported_systems"]) - ) - ) - - command_details = tool_details["commands"] - - for command, details in command_details.items(): - print("\n{} commands:".format(command)) - print_command_details(details) - - -# pylint: disable=too-many-arguments -@tool_cmd.command(name="execute") -@click.option( - "-n", - "--name", - "tool_name", - type=click.Choice(get_unique_tool_names()), - required=True, -) -@click.option("-p", "--param", "tool_params", multiple=True) -@click.option( - "-s", - "--system", - "system_name", - callback=validate_system, - required=False, -) -@middleware_signal_handler -@middleware_exception_handler -def execute_cmd( - tool_name: str, tool_params: List[str], system_name: Optional[str] -) -> None: - """Execute tool commands.""" - execute_tool_command(tool_name, tool_params, system_name) diff --git a/src/aiet/main.py b/src/aiet/main.py deleted file mode 100644 index 6898ad9..0000000 --- a/src/aiet/main.py +++ /dev/null @@ -1,13 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Entry point module of AIET.""" -from aiet.cli import cli - - -def main() -> None: - """Entry point of aiet application.""" - cli() # pylint: disable=no-value-for-parameter - - -if __name__ == "__main__": - main() diff --git a/src/aiet/resources/applications/.gitignore b/src/aiet/resources/applications/.gitignore deleted file mode 100644 index 0226166..0000000 --- a/src/aiet/resources/applications/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/src/aiet/resources/systems/.gitignore b/src/aiet/resources/systems/.gitignore deleted file mode 100644 index 0226166..0000000 --- a/src/aiet/resources/systems/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/src/aiet/resources/tools/vela/aiet-config.json b/src/aiet/resources/tools/vela/aiet-config.json deleted file mode 100644 index c12f291..0000000 --- a/src/aiet/resources/tools/vela/aiet-config.json +++ /dev/null @@ -1,73 +0,0 @@ -[ - { - "name": "vela", - "description": "Neural network model compiler for Arm Ethos-U NPUs", - "supported_systems": [ - { - "name": "Corstone-300: Cortex-M55+Ethos-U55" - }, - { - "name": "Corstone-310: Cortex-M85+Ethos-U55" - }, - { - "name": "Corstone-300: Cortex-M55+Ethos-U65", - "variables": { - "accelerator_config_prefix": "ethos-u65", - "system_config": "Ethos_U65_High_End", - "shared_sram": "U65_Shared_Sram" - }, - "user_params": { - "run": [ - { - "description": "MACs per cycle", - "values": [ - "256", - "512" - ], - "default_value": "512", - "alias": "mac" - } - ] - } - } - ], - "variables": { - "accelerator_config_prefix": "ethos-u55", - "system_config": "Ethos_U55_High_End_Embedded", - "shared_sram": "U55_Shared_Sram" - }, - "commands": { - "run": [ - "run_vela {user_params:input} {user_params:output} --config {tool.config_dir}/vela.ini --accelerator-config {variables:accelerator_config_prefix}-{user_params:mac} --system-config {variables:system_config} --memory-mode {variables:shared_sram} --optimise Performance" - ] - }, - "user_params": { - "run": [ - { - "description": "MACs per cycle", - "values": [ - "32", - "64", - "128", - "256" - ], - "default_value": "128", - "alias": "mac" - }, - { - "name": "--input-model", - "description": "Path to the TFLite model", - "values": [], - "alias": "input" - }, - { - "name": "--output-model", - "description": "Path to the output model file of the vela-optimisation step. The vela output is saved in the parent directory.", - "values": [], - "default_value": "output_model.tflite", - "alias": "output" - } - ] - } - } -] diff --git a/src/aiet/resources/tools/vela/aiet-config.json.license b/src/aiet/resources/tools/vela/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/src/aiet/resources/tools/vela/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/src/aiet/resources/tools/vela/check_model.py b/src/aiet/resources/tools/vela/check_model.py deleted file mode 100644 index 7c700b1..0000000 --- a/src/aiet/resources/tools/vela/check_model.py +++ /dev/null @@ -1,75 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Check if a TFLite model file is Vela-optimised.""" -import struct -from pathlib import Path - -from ethosu.vela.tflite.Model import Model - -from aiet.cli.common import InvalidTFLiteFileError -from aiet.cli.common import ModelOptimisedException -from aiet.utils.fs import read_file_as_bytearray - - -def get_model_from_file(input_model_file: Path) -> Model: - """Generate Model instance from TFLite file using flatc generated code.""" - buffer = read_file_as_bytearray(input_model_file) - try: - model = Model.GetRootAsModel(buffer, 0) - except (TypeError, RuntimeError, struct.error) as tflite_error: - raise InvalidTFLiteFileError( - f"Error reading in model from {input_model_file}." - ) from tflite_error - return model - - -def is_vela_optimised(tflite_model: Model) -> bool: - """Return True if 'ethos-u' custom operator found in the Model.""" - operators = get_operators_from_model(tflite_model) - - custom_codes = get_custom_codes_from_operators(operators) - - return check_custom_codes_for_ethosu(custom_codes) - - -def get_operators_from_model(tflite_model: Model) -> list: - """Return list of the unique operator codes used in the Model.""" - return [ - tflite_model.OperatorCodes(index) - for index in range(tflite_model.OperatorCodesLength()) - ] - - -def get_custom_codes_from_operators(operators: list) -> list: - """Return list of each operator's CustomCode() strings, if they exist.""" - return [ - operator.CustomCode() - for operator in operators - if operator.CustomCode() is not None - ] - - -def check_custom_codes_for_ethosu(custom_codes: list) -> bool: - """Check for existence of ethos-u string in the custom codes.""" - return any( - custom_code_name.decode("utf-8") == "ethos-u" - for custom_code_name in custom_codes - ) - - -def check_model(tflite_file_name: str) -> None: - """Raise an exception if model in given file is Vela optimised.""" - tflite_path = Path(tflite_file_name) - - tflite_model = get_model_from_file(tflite_path) - - if is_vela_optimised(tflite_model): - raise ModelOptimisedException( - f"TFLite model in {tflite_file_name} is already " - f"vela optimised ('ethos-u' custom op detected)." - ) - - print( - f"TFLite model in {tflite_file_name} is not vela optimised " - f"('ethos-u' custom op not detected)." - ) diff --git a/src/aiet/resources/tools/vela/run_vela.py b/src/aiet/resources/tools/vela/run_vela.py deleted file mode 100644 index 2c1b0be..0000000 --- a/src/aiet/resources/tools/vela/run_vela.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Wrapper to only run Vela when the input is not already optimised.""" -import shutil -import subprocess -from pathlib import Path -from typing import Tuple - -import click - -from aiet.cli.common import ModelOptimisedException -from aiet.resources.tools.vela.check_model import check_model - - -def vela_output_model_path(input_model: str, output_dir: str) -> Path: - """Construct the path to the Vela output file.""" - in_path = Path(input_model) - tflite_vela = Path(output_dir) / f"{in_path.stem}_vela{in_path.suffix}" - return tflite_vela - - -def execute_vela(vela_args: Tuple, output_dir: Path, input_model: str) -> None: - """Execute vela as external call.""" - cmd = ["vela"] + list(vela_args) - cmd += ["--output-dir", str(output_dir)] # Re-add parsed out_dir to arguments - cmd += [input_model] - subprocess.run(cmd, check=True) - - -@click.command(context_settings=dict(ignore_unknown_options=True)) -@click.option( - "--input-model", - "-i", - type=click.Path(exists=True, file_okay=True, readable=True), - required=True, -) -@click.option("--output-model", "-o", type=click.Path(), required=True) -# Collect the remaining arguments to be directly forwarded to Vela -@click.argument("vela-args", nargs=-1, type=click.UNPROCESSED) -def run_vela(input_model: str, output_model: str, vela_args: Tuple) -> None: - """Check input, run Vela (if needed) and copy optimised file to destination.""" - output_dir = Path(output_model).parent - try: - check_model(input_model) # raises an exception if already Vela-optimised - execute_vela(vela_args, output_dir, input_model) - print("Vela optimisation complete.") - src_model = vela_output_model_path(input_model, str(output_dir)) - except ModelOptimisedException as ex: - # Input already optimized: copy input file to destination path and return - print(f"Input already vela-optimised.\n{ex}") - src_model = Path(input_model) - except subprocess.CalledProcessError as ex: - print(ex) - raise SystemExit(ex.returncode) from ex - - try: - shutil.copyfile(src_model, output_model) - except (shutil.SameFileError, OSError) as ex: - print(ex) - raise SystemExit(ex.errno) from ex - - -def main() -> None: - """Entry point of check_model application.""" - run_vela() # pylint: disable=no-value-for-parameter diff --git a/src/aiet/resources/tools/vela/vela.ini b/src/aiet/resources/tools/vela/vela.ini deleted file mode 100644 index 5996553..0000000 --- a/src/aiet/resources/tools/vela/vela.ini +++ /dev/null @@ -1,53 +0,0 @@ -; SPDX-FileCopyrightText: Copyright 2021-2022, Arm Limited and/or its affiliates. -; SPDX-License-Identifier: Apache-2.0 - -; ----------------------------------------------------------------------------- -; Vela configuration file - -; ----------------------------------------------------------------------------- -; System Configuration - -; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s) -[System_Config.Ethos_U55_High_End_Embedded] -core_clock=500e6 -axi0_port=Sram -axi1_port=OffChipFlash -Sram_clock_scale=1.0 -Sram_burst_length=32 -Sram_read_latency=32 -Sram_write_latency=32 -OffChipFlash_clock_scale=0.125 -OffChipFlash_burst_length=128 -OffChipFlash_read_latency=64 -OffChipFlash_write_latency=64 - -; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s) -[System_Config.Ethos_U65_High_End] -core_clock=1e9 -axi0_port=Sram -axi1_port=Dram -Sram_clock_scale=1.0 -Sram_burst_length=32 -Sram_read_latency=32 -Sram_write_latency=32 -Dram_clock_scale=0.234375 -Dram_burst_length=128 -Dram_read_latency=500 -Dram_write_latency=250 - -; ----------------------------------------------------------------------------- -; Memory Mode - -; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software -; The non-SRAM memory is assumed to be read-only -[Memory_Mode.U55_Shared_Sram] -const_mem_area=Axi1 -arena_mem_area=Axi0 -cache_mem_area=Axi0 -arena_cache_size=4194304 - -[Memory_Mode.U65_Shared_Sram] -const_mem_area=Axi1 -arena_mem_area=Axi0 -cache_mem_area=Axi0 -arena_cache_size=2097152 diff --git a/src/aiet/utils/__init__.py b/src/aiet/utils/__init__.py deleted file mode 100644 index fc7ef7c..0000000 --- a/src/aiet/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""This module contains all utils shared across aiet project.""" diff --git a/src/aiet/utils/fs.py b/src/aiet/utils/fs.py deleted file mode 100644 index ea99a69..0000000 --- a/src/aiet/utils/fs.py +++ /dev/null @@ -1,116 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module to host all file system related functions.""" -import importlib.resources as pkg_resources -import re -import shutil -from pathlib import Path -from typing import Any -from typing import Literal -from typing import Optional - -ResourceType = Literal["applications", "systems", "tools"] - - -def get_aiet_resources() -> Path: - """Get resources folder path.""" - with pkg_resources.path("aiet", "__init__.py") as init_path: - project_root = init_path.parent - return project_root / "resources" - - -def get_resources(name: ResourceType) -> Path: - """Return the absolute path of the specified resource. - - It uses importlib to return resources packaged with MANIFEST.in. - """ - if not name: - raise ResourceWarning("Resource name is not provided") - - resource_path = get_aiet_resources() / name - if resource_path.is_dir(): - return resource_path - - raise ResourceWarning("Resource '{}' not found.".format(name)) - - -def copy_directory_content(source: Path, destination: Path) -> None: - """Copy content of the source directory into destination directory.""" - for item in source.iterdir(): - src = source / item.name - dest = destination / item.name - - if src.is_dir(): - shutil.copytree(src, dest) - else: - shutil.copy2(src, dest) - - -def remove_resource(resource_directory: str, resource_type: ResourceType) -> None: - """Remove resource data.""" - resources = get_resources(resource_type) - - resource_location = resources / resource_directory - if not resource_location.exists(): - raise Exception("Resource {} does not exist".format(resource_directory)) - - if not resource_location.is_dir(): - raise Exception("Wrong resource {}".format(resource_directory)) - - shutil.rmtree(resource_location) - - -def remove_directory(directory_path: Optional[Path]) -> None: - """Remove directory.""" - if not directory_path or not directory_path.is_dir(): - raise Exception("No directory path provided") - - shutil.rmtree(directory_path) - - -def recreate_directory(directory_path: Optional[Path]) -> None: - """Recreate directory.""" - if not directory_path: - raise Exception("No directory path provided") - - if directory_path.exists() and not directory_path.is_dir(): - raise Exception( - "Path {} does exist and it is not a directory".format(str(directory_path)) - ) - - if directory_path.is_dir(): - remove_directory(directory_path) - - directory_path.mkdir() - - -def read_file(file_path: Path, mode: Optional[str] = None) -> Any: - """Read file as string or bytearray.""" - if file_path.is_file(): - if mode is not None: - # Ignore pylint warning because mode can be 'binary' as well which - # is not compatible with specifying encodings. - with open(file_path, mode) as file: # pylint: disable=unspecified-encoding - return file.read() - else: - with open(file_path, encoding="utf-8") as file: - return file.read() - - if mode == "rb": - return b"" - return "" - - -def read_file_as_string(file_path: Path) -> str: - """Read file as string.""" - return str(read_file(file_path)) - - -def read_file_as_bytearray(file_path: Path) -> bytearray: - """Read a file as bytearray.""" - return bytearray(read_file(file_path, mode="rb")) - - -def valid_for_filename(value: str, replacement: str = "") -> str: - """Replace non alpha numeric characters.""" - return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII) diff --git a/src/aiet/utils/helpers.py b/src/aiet/utils/helpers.py deleted file mode 100644 index 6d3cd22..0000000 --- a/src/aiet/utils/helpers.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Helpers functions.""" -import logging -from typing import Any - - -def set_verbosity( - ctx: Any, option: Any, verbosity: Any # pylint: disable=unused-argument -) -> None: - """Set the logging level according to the verbosity.""" - # Unused arguments must be present here in definition as these are required in - # function definition when set as a callback - if verbosity == 1: - logging.getLogger().setLevel(logging.INFO) - elif verbosity > 1: - logging.getLogger().setLevel(logging.DEBUG) diff --git a/src/aiet/utils/proc.py b/src/aiet/utils/proc.py deleted file mode 100644 index b6f4357..0000000 --- a/src/aiet/utils/proc.py +++ /dev/null @@ -1,283 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Processes module. - -This module contains all classes and functions for dealing with Linux -processes. -""" -import csv -import datetime -import logging -import shlex -import signal -import time -from pathlib import Path -from typing import Any -from typing import List -from typing import NamedTuple -from typing import Optional -from typing import Tuple - -import psutil -from sh import Command -from sh import CommandNotFound -from sh import ErrorReturnCode -from sh import RunningCommand - -from aiet.utils.fs import valid_for_filename - - -class CommandFailedException(Exception): - """Exception for failed command execution.""" - - -class ShellCommand: - """Wrapper class for shell commands.""" - - def __init__(self, base_log_path: str = "/tmp") -> None: - """Initialise the class. - - base_log_path: it is the base directory where logs will be stored - """ - self.base_log_path = base_log_path - - def run( - self, - cmd: str, - *args: str, - _cwd: Optional[Path] = None, - _tee: bool = True, - _bg: bool = True, - _out: Any = None, - _err: Any = None, - _search_paths: Optional[List[Path]] = None - ) -> RunningCommand: - """Run the shell command with the given arguments. - - There are special arguments that modify the behaviour of the process. - _cwd: current working directory - _tee: it redirects the stdout both to console and file - _bg: if True, it runs the process in background and the command is not - blocking. - _out: use this object for stdout redirect, - _err: use this object for stderr redirect, - _search_paths: If presented used for searching executable - """ - try: - kwargs = {} - if _cwd: - kwargs["_cwd"] = str(_cwd) - command = Command(cmd, _search_paths).bake(args, **kwargs) - except CommandNotFound as error: - logging.error("Command '%s' not found", error.args[0]) - raise error - - out, err = _out, _err - if not _out and not _err: - out, err = [ - str(item) - for item in self.get_stdout_stderr_paths(self.base_log_path, cmd) - ] - - return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False) - - @classmethod - def get_stdout_stderr_paths(cls, base_log_path: str, cmd: str) -> Tuple[Path, Path]: - """Construct and returns the paths of stdout/stderr files.""" - timestamp = datetime.datetime.now().timestamp() - base_path = Path(base_log_path) - base = base_path / "{}_{}".format(valid_for_filename(cmd, "_"), timestamp) - stdout = base.with_suffix(".out") - stderr = base.with_suffix(".err") - try: - stdout.touch() - stderr.touch() - except FileNotFoundError as error: - logging.error("File not found: %s", error.filename) - raise error - return stdout, stderr - - -def parse_command(command: str, shell: str = "bash") -> List[str]: - """Parse command.""" - cmd, *args = shlex.split(command, posix=True) - - if is_shell_script(cmd): - args = [cmd] + args - cmd = shell - - return [cmd] + args - - -def get_stdout_stderr_paths( - command: str, base_log_path: str = "/tmp" -) -> Tuple[Path, Path]: - """Construct and returns the paths of stdout/stderr files.""" - cmd, *_ = parse_command(command) - - return ShellCommand.get_stdout_stderr_paths(base_log_path, cmd) - - -def execute_command( # pylint: disable=invalid-name - command: str, - cwd: Path, - bg: bool = False, - shell: str = "bash", - out: Any = None, - err: Any = None, -) -> RunningCommand: - """Execute shell command.""" - cmd, *args = parse_command(command, shell) - - search_paths = None - if cmd != shell and (cwd / cmd).is_file(): - search_paths = [cwd] - - return ShellCommand().run( - cmd, *args, _cwd=cwd, _bg=bg, _search_paths=search_paths, _out=out, _err=err - ) - - -def is_shell_script(cmd: str) -> bool: - """Check if command is shell script.""" - return cmd.endswith(".sh") - - -def run_and_wait( - command: str, - cwd: Path, - terminate_on_error: bool = False, - out: Any = None, - err: Any = None, -) -> Tuple[int, bytearray, bytearray]: - """ - Run command and wait while it is executing. - - Returns a tuple: (exit_code, stdout, stderr) - """ - running_cmd: Optional[RunningCommand] = None - try: - running_cmd = execute_command(command, cwd, bg=True, out=out, err=err) - return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr - except ErrorReturnCode as cmd_failed: - raise CommandFailedException() from cmd_failed - except Exception as error: - is_running = running_cmd is not None and running_cmd.is_alive() - if terminate_on_error and is_running: - print("Terminating ...") - terminate_command(running_cmd) - - raise error - - -def terminate_command( - running_cmd: RunningCommand, - wait: bool = True, - wait_period: float = 0.5, - number_of_attempts: int = 20, -) -> None: - """Terminate running command.""" - try: - running_cmd.process.signal_group(signal.SIGINT) - if wait: - for _ in range(number_of_attempts): - time.sleep(wait_period) - if not running_cmd.is_alive(): - return - print( - "Unable to terminate process {}. Sending SIGTERM...".format( - running_cmd.process.pid - ) - ) - running_cmd.process.signal_group(signal.SIGTERM) - except ProcessLookupError: - pass - - -def terminate_external_process( - process: psutil.Process, - wait_period: float = 0.5, - number_of_attempts: int = 20, - wait_for_termination: float = 5.0, -) -> None: - """Terminate external process.""" - try: - process.terminate() - for _ in range(number_of_attempts): - if not process.is_running(): - return - time.sleep(wait_period) - - if process.is_running(): - process.terminate() - time.sleep(wait_for_termination) - except psutil.Error: - print("Unable to terminate process") - - -class ProcessInfo(NamedTuple): - """Process information.""" - - name: str - executable: str - cwd: str - pid: int - - -def save_process_info(pid: int, pid_file: Path) -> None: - """Save process information to file.""" - try: - parent = psutil.Process(pid) - children = parent.children(recursive=True) - family = [parent] + children - - with open(pid_file, "w", encoding="utf-8") as file: - csv_writer = csv.writer(file) - for member in family: - process_info = ProcessInfo( - name=member.name(), - executable=member.exe(), - cwd=member.cwd(), - pid=member.pid, - ) - csv_writer.writerow(process_info) - except psutil.NoSuchProcess: - # if process does not exist or finishes before - # function call then nothing could be saved - # just ignore this exception and exit - pass - - -def read_process_info(pid_file: Path) -> List[ProcessInfo]: - """Read information about previous system processes.""" - if not pid_file.is_file(): - return [] - - result = [] - with open(pid_file, encoding="utf-8") as file: - csv_reader = csv.reader(file) - for row in csv_reader: - name, executable, cwd, pid = row - result.append( - ProcessInfo(name=name, executable=executable, cwd=cwd, pid=int(pid)) - ) - - return result - - -def print_command_stdout(command: RunningCommand) -> None: - """Print the stdout of a command. - - The command has 2 states: running and done. - If the command is running, the output is taken by the running process. - If the command has ended its execution, the stdout is taken from stdout - property - """ - if command.is_alive(): - while True: - try: - print(command.next(), end="") - except StopIteration: - break - else: - print(command.stdout) diff --git a/src/mlia/backend/__init__.py b/src/mlia/backend/__init__.py new file mode 100644 index 0000000..3d60372 --- /dev/null +++ b/src/mlia/backend/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Backend module.""" diff --git a/src/mlia/backend/application.py b/src/mlia/backend/application.py new file mode 100644 index 0000000..eb85212 --- /dev/null +++ b/src/mlia/backend/application.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Application backend module.""" +import re +from pathlib import Path +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional + +from mlia.backend.common import Backend +from mlia.backend.common import ConfigurationException +from mlia.backend.common import DataPaths +from mlia.backend.common import get_backend_configs +from mlia.backend.common import get_backend_directories +from mlia.backend.common import load_application_or_tool_configs +from mlia.backend.common import load_config +from mlia.backend.common import remove_backend +from mlia.backend.config import ApplicationConfig +from mlia.backend.config import ExtendedApplicationConfig +from mlia.backend.fs import get_backends_path +from mlia.backend.source import create_destination_and_install +from mlia.backend.source import get_source + + +def get_available_application_directory_names() -> List[str]: + """Return a list of directory names for all available applications.""" + return [entry.name for entry in get_backend_directories("applications")] + + +def get_available_applications() -> List["Application"]: + """Return a list with all available applications.""" + available_applications = [] + for config_json in get_backend_configs("applications"): + config_entries = cast(List[ExtendedApplicationConfig], load_config(config_json)) + for config_entry in config_entries: + config_entry["config_location"] = config_json.parent.absolute() + applications = load_applications(config_entry) + available_applications += applications + + return sorted(available_applications, key=lambda application: application.name) + + +def get_application( + application_name: str, system_name: Optional[str] = None +) -> List["Application"]: + """Return a list of application instances with provided name.""" + return [ + application + for application in get_available_applications() + if application.name == application_name + and (not system_name or application.can_run_on(system_name)) + ] + + +def install_application(source_path: Path) -> None: + """Install application.""" + try: + source = get_source(source_path) + config = cast(List[ExtendedApplicationConfig], source.config()) + applications_to_install = [ + s for entry in config for s in load_applications(entry) + ] + except Exception as error: + raise ConfigurationException("Unable to read application definition") from error + + if not applications_to_install: + raise ConfigurationException("No application definition found") + + available_applications = get_available_applications() + already_installed = [ + s for s in applications_to_install if s in available_applications + ] + if already_installed: + names = {application.name for application in already_installed} + raise ConfigurationException( + "Applications [{}] are already installed".format(",".join(names)) + ) + + create_destination_and_install(source, get_backends_path("applications")) + + +def remove_application(directory_name: str) -> None: + """Remove application directory.""" + remove_backend(directory_name, "applications") + + +def get_unique_application_names(system_name: Optional[str] = None) -> List[str]: + """Extract a list of unique application names of all application available.""" + return list( + set( + application.name + for application in get_available_applications() + if not system_name or application.can_run_on(system_name) + ) + ) + + +class Application(Backend): + """Class for representing a single application component.""" + + def __init__(self, config: ApplicationConfig) -> None: + """Construct a Application instance from a dict.""" + super().__init__(config) + + self.supported_systems = config.get("supported_systems", []) + self.deploy_data = config.get("deploy_data", []) + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Application): + return False + + return ( + super().__eq__(other) + and self.name == other.name + and set(self.supported_systems) == set(other.supported_systems) + ) + + def can_run_on(self, system_name: str) -> bool: + """Check if the application can run on the system passed as argument.""" + return system_name in self.supported_systems + + def get_deploy_data(self) -> List[DataPaths]: + """Validate and return data specified in the config file.""" + if self.config_location is None: + raise ConfigurationException( + "Unable to get application {} config location".format(self.name) + ) + + deploy_data = [] + for item in self.deploy_data: + src, dst = item + src_full_path = self.config_location / src + assert src_full_path.exists(), "{} does not exists".format(src_full_path) + deploy_data.append(DataPaths(src_full_path, dst)) + return deploy_data + + def get_details(self) -> Dict[str, Any]: + """Return dictionary with information about the Application instance.""" + output = { + "type": "application", + "name": self.name, + "description": self.description, + "supported_systems": self.supported_systems, + "commands": self._get_command_details(), + } + + return output + + def remove_unused_params(self) -> None: + """Remove unused params in commands. + + After merging default and system related configuration application + could have parameters that are not being used in commands. They + should be removed. + """ + for command in self.commands.values(): + indexes_or_aliases = [ + m + for cmd_str in command.command_strings + for m in re.findall(r"{user_params:(?P\w+)}", cmd_str) + ] + + only_aliases = all(not item.isnumeric() for item in indexes_or_aliases) + if only_aliases: + used_params = [ + param + for param in command.params + if param.alias in indexes_or_aliases + ] + command.params = used_params + + +def load_applications(config: ExtendedApplicationConfig) -> List[Application]: + """Load application. + + Application configuration could contain different parameters/commands for different + supported systems. For each supported system this function will return separate + Application instance with appropriate configuration. + """ + configs = load_application_or_tool_configs(config, ApplicationConfig) + applications = [Application(cfg) for cfg in configs] + for application in applications: + application.remove_unused_params() + return applications diff --git a/src/mlia/backend/common.py b/src/mlia/backend/common.py new file mode 100644 index 0000000..2bbb9d3 --- /dev/null +++ b/src/mlia/backend/common.py @@ -0,0 +1,532 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain all common functions for the backends.""" +import json +import logging +import re +from abc import ABC +from collections import Counter +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Final +from typing import IO +from typing import Iterable +from typing import List +from typing import Match +from typing import NamedTuple +from typing import Optional +from typing import Pattern +from typing import Tuple +from typing import Type +from typing import Union + +from mlia.backend.config import BackendConfig +from mlia.backend.config import BaseBackendConfig +from mlia.backend.config import NamedExecutionConfig +from mlia.backend.config import UserParamConfig +from mlia.backend.config import UserParamsConfig +from mlia.backend.fs import get_backends_path +from mlia.backend.fs import remove_resource +from mlia.backend.fs import ResourceType + + +BACKEND_CONFIG_FILE: Final[str] = "aiet-config.json" + + +class ConfigurationException(Exception): + """Configuration exception.""" + + +def get_backend_config(dir_path: Path) -> Path: + """Get path to backendir configuration file.""" + return dir_path / BACKEND_CONFIG_FILE + + +def get_backend_configs(resource_type: ResourceType) -> Iterable[Path]: + """Get path to the backend configs for provided resource_type.""" + return ( + get_backend_config(entry) for entry in get_backend_directories(resource_type) + ) + + +def get_backend_directories(resource_type: ResourceType) -> Iterable[Path]: + """Get path to the backend directories for provided resource_type.""" + return ( + entry + for entry in get_backends_path(resource_type).iterdir() + if is_backend_directory(entry) + ) + + +def is_backend_directory(dir_path: Path) -> bool: + """Check if path is backend's configuration directory.""" + return dir_path.is_dir() and get_backend_config(dir_path).is_file() + + +def remove_backend(directory_name: str, resource_type: ResourceType) -> None: + """Remove backend with provided type and directory_name.""" + if not directory_name: + raise Exception("No directory name provided") + + remove_resource(directory_name, resource_type) + + +def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig: + """Return a loaded json file.""" + if config is None: + raise Exception("Unable to read config") + + if isinstance(config, Path): + with config.open() as json_file: + return cast(BackendConfig, json.load(json_file)) + + return cast(BackendConfig, json.load(config)) + + +def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]: + """Split the parameter string in name and optional value. + + It manages the following cases: + --param=1 -> --param, 1 + --param 1 -> --param, 1 + --flag -> --flag, None + """ + data = re.split(" |=", parameter) + if len(data) == 1: + param_name = data[0] + param_value = None + else: + param_name = " ".join(data[0:-1]) + param_value = data[-1] + return param_name, param_value + + +class DataPaths(NamedTuple): + """DataPaths class.""" + + src: Path + dst: str + + +class Backend(ABC): + """Backend class.""" + + # pylint: disable=too-many-instance-attributes + + def __init__(self, config: BaseBackendConfig): + """Initialize backend.""" + name = config.get("name") + if not name: + raise ConfigurationException("Name is empty") + + self.name = name + self.description = config.get("description", "") + self.config_location = config.get("config_location") + self.variables = config.get("variables", {}) + self.build_dir = config.get("build_dir") + self.lock = config.get("lock", False) + if self.build_dir: + self.build_dir = self._substitute_variables(self.build_dir) + self.annotations = config.get("annotations", {}) + + self._parse_commands_and_params(config) + + def validate_parameter(self, command_name: str, parameter: str) -> bool: + """Validate the parameter string against the application configuration. + + We take the parameter string, extract the parameter name/value and + check them against the current configuration. + """ + param_name, param_value = parse_raw_parameter(parameter) + valid_param_name = valid_param_value = False + + command = self.commands.get(command_name) + if not command: + raise AttributeError("Unknown command: '{}'".format(command_name)) + + # Iterate over all available parameters until we have a match. + for param in command.params: + if self._same_parameter(param_name, param): + valid_param_name = True + # This is a non-empty list + if param.values: + # We check if the value is allowed in the configuration + valid_param_value = param_value in param.values + else: + # In this case we don't validate the value and accept + # whatever we have set. + valid_param_value = True + break + + return valid_param_name and valid_param_value + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Backend): + return False + + return ( + self.name == other.name + and self.description == other.description + and self.commands == other.commands + ) + + def __repr__(self) -> str: + """Represent the Backend instance by its name.""" + return self.name + + def _parse_commands_and_params(self, config: BaseBackendConfig) -> None: + """Parse commands and user parameters.""" + self.commands: Dict[str, Command] = {} + + commands = config.get("commands") + if commands: + params = config.get("user_params") + + for command_name in commands.keys(): + command_params = self._parse_params(params, command_name) + command_strings = [ + self._substitute_variables(cmd) + for cmd in commands.get(command_name, []) + ] + self.commands[command_name] = Command(command_strings, command_params) + + def _substitute_variables(self, str_val: str) -> str: + """Substitute variables in string. + + Variables is being substituted at backend's creation stage because + they could contain references to other params which will be + resolved later. + """ + if not str_val: + return str_val + + var_pattern: Final[Pattern] = re.compile(r"{variables:(?P\w+)}") + + def var_value(match: Match) -> str: + var_name = match["var_name"] + if var_name not in self.variables: + raise ConfigurationException("Unknown variable {}".format(var_name)) + + return self.variables[var_name] + + return var_pattern.sub(var_value, str_val) # type: ignore + + @classmethod + def _parse_params( + cls, params: Optional[UserParamsConfig], command: str + ) -> List["Param"]: + if not params: + return [] + + return [cls._parse_param(p) for p in params.get(command, [])] + + @classmethod + def _parse_param(cls, param: UserParamConfig) -> "Param": + """Parse a single parameter.""" + name = param.get("name") + if name is not None and not name: + raise ConfigurationException("Parameter has an empty 'name' attribute.") + values = param.get("values", None) + default_value = param.get("default_value", None) + description = param.get("description", "") + alias = param.get("alias") + + return Param( + name=name, + description=description, + values=values, + default_value=default_value, + alias=alias, + ) + + def _get_command_details(self) -> Dict: + command_details = { + command_name: command.get_details() + for command_name, command in self.commands.items() + } + return command_details + + def _get_user_param_value( + self, user_params: List[str], param: "Param" + ) -> Optional[str]: + """Get the user-specified value of a parameter.""" + for user_param in user_params: + user_param_name, user_param_value = parse_raw_parameter(user_param) + if user_param_name == param.name: + warn_message = ( + "The direct use of parameter name is deprecated" + " and might be removed in the future.\n" + f"Please use alias '{param.alias}' instead of " + "'{user_param_name}' to provide the parameter." + ) + logging.warning(warn_message) + + if self._same_parameter(user_param_name, param): + return user_param_value + + return None + + @staticmethod + def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool: + """Compare user parameter name with param name or alias.""" + # Strip the "=" sign in the param_name. This is needed just for + # comparison with the parameters passed by the user. + # The equal sign needs to be honoured when re-building the + # parameter back. + param_name = None if not param.name else param.name.rstrip("=") + return user_param_name_or_alias in [param_name, param.alias] + + def resolved_parameters( + self, command_name: str, user_params: List[str] + ) -> List[Tuple[Optional[str], "Param"]]: + """Return list of parameters with values.""" + result: List[Tuple[Optional[str], "Param"]] = [] + command = self.commands.get(command_name) + if not command: + return result + + for param in command.params: + value = self._get_user_param_value(user_params, param) + if not value: + value = param.default_value + result.append((value, param)) + + return result + + def build_command( + self, + command_name: str, + user_params: List[str], + param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str], + ) -> List[str]: + """ + Return a list of executable command strings. + + Given a command and associated parameters, returns a list of executable command + strings. + """ + command = self.commands.get(command_name) + if not command: + raise ConfigurationException( + "Command '{}' could not be found.".format(command_name) + ) + + commands_to_run = [] + + params_values = self.resolved_parameters(command_name, user_params) + for cmd_str in command.command_strings: + cmd_str = resolve_all_parameters( + cmd_str, param_resolver, command_name, params_values + ) + commands_to_run.append(cmd_str) + + return commands_to_run + + +class Param: + """Class for representing a generic application parameter.""" + + def __init__( # pylint: disable=too-many-arguments + self, + name: Optional[str], + description: str, + values: Optional[List[str]] = None, + default_value: Optional[str] = None, + alias: Optional[str] = None, + ) -> None: + """Construct a Param instance.""" + if not name and not alias: + raise ConfigurationException( + "Either name, alias or both must be set to identify a parameter." + ) + self.name = name + self.values = values + self.description = description + self.default_value = default_value + self.alias = alias + + def get_details(self) -> Dict: + """Return a dictionary with all relevant information of a Param.""" + return {key: value for key, value in self.__dict__.items() if value} + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Param): + return False + + return ( + self.name == other.name + and self.values == other.values + and self.default_value == other.default_value + and self.description == other.description + ) + + +class Command: + """Class for representing a command.""" + + def __init__( + self, command_strings: List[str], params: Optional[List[Param]] = None + ) -> None: + """Construct a Command instance.""" + self.command_strings = command_strings + + if params: + self.params = params + else: + self.params = [] + + self._validate() + + def _validate(self) -> None: + """Validate command.""" + if not self.params: + return + + aliases = [param.alias for param in self.params if param.alias is not None] + repeated_aliases = [ + alias for alias, count in Counter(aliases).items() if count > 1 + ] + + if repeated_aliases: + raise ConfigurationException( + "Non unique aliases {}".format(", ".join(repeated_aliases)) + ) + + both_name_and_alias = [ + param.name + for param in self.params + if param.name in aliases and param.name != param.alias + ] + if both_name_and_alias: + raise ConfigurationException( + "Aliases {} could not be used as parameter name".format( + ", ".join(both_name_and_alias) + ) + ) + + def get_details(self) -> Dict: + """Return a dictionary with all relevant information of a Command.""" + output = { + "command_strings": self.command_strings, + "user_params": [param.get_details() for param in self.params], + } + return output + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, Command): + return False + + return ( + self.command_strings == other.command_strings + and self.params == other.params + ) + + +def resolve_all_parameters( + str_val: str, + param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str], + command_name: Optional[str] = None, + params_values: Optional[List[Tuple[Optional[str], Param]]] = None, +) -> str: + """Resolve all parameters in the string.""" + if not str_val: + return str_val + + param_pattern: Final[Pattern] = re.compile(r"{(?P[\w.:]+)}") + while param_pattern.findall(str_val): + str_val = param_pattern.sub( + lambda m: param_resolver( + m["param_name"], command_name or "", params_values or [] + ), + str_val, + ) + return str_val + + +def load_application_or_tool_configs( + config: Any, + config_type: Type[Any], + is_system_required: bool = True, +) -> Any: + """Get one config for each system supported by the application/tool. + + The configuration could contain different parameters/commands for different + supported systems. For each supported system this function will return separate + config with appropriate configuration. + """ + merged_configs = [] + supported_systems: Optional[List[NamedExecutionConfig]] = config.get( + "supported_systems" + ) + if not supported_systems: + if is_system_required: + raise ConfigurationException("No supported systems definition provided") + # Create an empty system to be used in the parsing below + supported_systems = [cast(NamedExecutionConfig, {})] + + default_user_params = config.get("user_params", {}) + + def merge_config(system: NamedExecutionConfig) -> Any: + system_name = system.get("name") + if not system_name and is_system_required: + raise ConfigurationException( + "Unable to read supported system definition, name is missed" + ) + + merged_config = config_type(**config) + merged_config["supported_systems"] = [system_name] if system_name else [] + # merge default configuration and specific to the system + merged_config["commands"] = { + **config.get("commands", {}), + **system.get("commands", {}), + } + + params = {} + tool_user_params = system.get("user_params", {}) + command_names = tool_user_params.keys() | default_user_params.keys() + for command_name in command_names: + if command_name not in merged_config["commands"]: + continue + + params_default = default_user_params.get(command_name, []) + params_tool = tool_user_params.get(command_name, []) + if not params_default or not params_tool: + params[command_name] = params_tool or params_default + if params_default and params_tool: + if any(not p.get("alias") for p in params_default): + raise ConfigurationException( + "Default parameters for command {} should have aliases".format( + command_name + ) + ) + if any(not p.get("alias") for p in params_tool): + raise ConfigurationException( + "{} parameters for command {} should have aliases".format( + system_name, command_name + ) + ) + + merged_by_alias = { + **{p.get("alias"): p for p in params_default}, + **{p.get("alias"): p for p in params_tool}, + } + params[command_name] = list(merged_by_alias.values()) + + merged_config["user_params"] = params + merged_config["build_dir"] = system.get("build_dir", config.get("build_dir")) + merged_config["lock"] = system.get("lock", config.get("lock", False)) + merged_config["variables"] = { + **config.get("variables", {}), + **system.get("variables", {}), + } + return merged_config + + merged_configs = [merge_config(system) for system in supported_systems] + + return merged_configs diff --git a/src/mlia/backend/config.py b/src/mlia/backend/config.py new file mode 100644 index 0000000..657adef --- /dev/null +++ b/src/mlia/backend/config.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain definition of backend configuration.""" +from pathlib import Path +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple +from typing import TypedDict +from typing import Union + + +class UserParamConfig(TypedDict, total=False): + """User parameter configuration.""" + + name: Optional[str] + default_value: str + values: List[str] + description: str + alias: str + + +UserParamsConfig = Dict[str, List[UserParamConfig]] + + +class ExecutionConfig(TypedDict, total=False): + """Execution configuration.""" + + commands: Dict[str, List[str]] + user_params: UserParamsConfig + build_dir: str + variables: Dict[str, str] + lock: bool + + +class NamedExecutionConfig(ExecutionConfig): + """Execution configuration with name.""" + + name: str + + +class BaseBackendConfig(ExecutionConfig, total=False): + """Base backend configuration.""" + + name: str + description: str + config_location: Path + annotations: Dict[str, Union[str, List[str]]] + + +class ApplicationConfig(BaseBackendConfig, total=False): + """Application configuration.""" + + supported_systems: List[str] + deploy_data: List[Tuple[str, str]] + + +class ExtendedApplicationConfig(BaseBackendConfig, total=False): + """Extended application configuration.""" + + supported_systems: List[NamedExecutionConfig] + deploy_data: List[Tuple[str, str]] + + +class ProtocolConfig(TypedDict, total=False): + """Protocol config.""" + + protocol: Literal["local", "ssh"] + + +class SSHConfig(ProtocolConfig, total=False): + """SSH configuration.""" + + username: str + password: str + hostname: str + port: str + + +class LocalProtocolConfig(ProtocolConfig, total=False): + """Local protocol config.""" + + +class SystemConfig(BaseBackendConfig, total=False): + """System configuration.""" + + data_transfer: Union[SSHConfig, LocalProtocolConfig] + reporting: Dict[str, Dict] + + +BackendItemConfig = Union[ApplicationConfig, SystemConfig] +BackendConfig = Union[List[ExtendedApplicationConfig], List[SystemConfig]] diff --git a/src/mlia/backend/controller.py b/src/mlia/backend/controller.py new file mode 100644 index 0000000..f1b68a9 --- /dev/null +++ b/src/mlia/backend/controller.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Controller backend module.""" +import time +from pathlib import Path +from typing import List +from typing import Optional +from typing import Tuple + +import psutil +import sh + +from mlia.backend.common import ConfigurationException +from mlia.backend.fs import read_file_as_string +from mlia.backend.proc import execute_command +from mlia.backend.proc import get_stdout_stderr_paths +from mlia.backend.proc import read_process_info +from mlia.backend.proc import save_process_info +from mlia.backend.proc import terminate_command +from mlia.backend.proc import terminate_external_process + + +class SystemController: + """System controller class.""" + + def __init__(self) -> None: + """Create new instance of service controller.""" + self.cmd: Optional[sh.RunningCommand] = None + self.out_path: Optional[Path] = None + self.err_path: Optional[Path] = None + + def before_start(self) -> None: + """Run actions before system start.""" + + def after_start(self) -> None: + """Run actions after system start.""" + + def start(self, commands: List[str], cwd: Path) -> None: + """Start system.""" + if not isinstance(cwd, Path) or not cwd.is_dir(): + raise ConfigurationException("Wrong working directory {}".format(cwd)) + + if len(commands) != 1: + raise ConfigurationException("System should have only one command to run") + + startup_command = commands[0] + if not startup_command: + raise ConfigurationException("No startup command provided") + + self.before_start() + + self.out_path, self.err_path = get_stdout_stderr_paths(startup_command) + + self.cmd = execute_command( + startup_command, + cwd, + bg=True, + out=str(self.out_path), + err=str(self.err_path), + ) + + self.after_start() + + def stop( + self, wait: bool = False, wait_period: float = 0.5, number_of_attempts: int = 20 + ) -> None: + """Stop system.""" + if self.cmd is not None and self.is_running(): + terminate_command(self.cmd, wait, wait_period, number_of_attempts) + + def is_running(self) -> bool: + """Check if underlying process is running.""" + return self.cmd is not None and self.cmd.is_alive() + + def get_output(self) -> Tuple[str, str]: + """Return application output.""" + if self.cmd is None or self.out_path is None or self.err_path is None: + return ("", "") + + return (read_file_as_string(self.out_path), read_file_as_string(self.err_path)) + + +class SystemControllerSingleInstance(SystemController): + """System controller with support of system's single instance.""" + + def __init__(self, pid_file_path: Optional[Path] = None) -> None: + """Create new instance of the service controller.""" + super().__init__() + self.pid_file_path = pid_file_path + + def before_start(self) -> None: + """Run actions before system start.""" + self._check_if_previous_instance_is_running() + + def after_start(self) -> None: + """Run actions after system start.""" + self._save_process_info() + + def _check_if_previous_instance_is_running(self) -> None: + """Check if another instance of the system is running.""" + process_info = read_process_info(self._pid_file()) + + for item in process_info: + try: + process = psutil.Process(item.pid) + same_process = ( + process.name() == item.name + and process.exe() == item.executable + and process.cwd() == item.cwd + ) + if same_process: + print( + "Stopping previous instance of the system [{}]".format(item.pid) + ) + terminate_external_process(process) + except psutil.NoSuchProcess: + pass + + def _save_process_info(self, wait_period: float = 2) -> None: + """Save information about system's processes.""" + if self.cmd is None or not self.is_running(): + return + + # give some time for the system to start + time.sleep(wait_period) + + save_process_info(self.cmd.process.pid, self._pid_file()) + + def _pid_file(self) -> Path: + """Return path to file which is used for saving process info.""" + if not self.pid_file_path: + raise Exception("No pid file path presented") + + return self.pid_file_path diff --git a/src/mlia/backend/execution.py b/src/mlia/backend/execution.py new file mode 100644 index 0000000..749ccdb --- /dev/null +++ b/src/mlia/backend/execution.py @@ -0,0 +1,779 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Application execution module.""" +import itertools +import json +import random +import re +import string +import sys +import time +from collections import defaultdict +from contextlib import contextmanager +from contextlib import ExitStack +from pathlib import Path +from typing import Any +from typing import Callable +from typing import cast +from typing import ContextManager +from typing import Dict +from typing import Generator +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TypedDict + +from filelock import FileLock +from filelock import Timeout + +from mlia.backend.application import Application +from mlia.backend.application import get_application +from mlia.backend.common import Backend +from mlia.backend.common import ConfigurationException +from mlia.backend.common import DataPaths +from mlia.backend.common import Param +from mlia.backend.common import parse_raw_parameter +from mlia.backend.common import resolve_all_parameters +from mlia.backend.fs import recreate_directory +from mlia.backend.fs import remove_directory +from mlia.backend.fs import valid_for_filename +from mlia.backend.output_parser import Base64OutputParser +from mlia.backend.output_parser import OutputParser +from mlia.backend.output_parser import RegexOutputParser +from mlia.backend.proc import run_and_wait +from mlia.backend.system import ControlledSystem +from mlia.backend.system import get_system +from mlia.backend.system import StandaloneSystem +from mlia.backend.system import System + + +class AnotherInstanceIsRunningException(Exception): + """Concurrent execution error.""" + + +class ConnectionException(Exception): + """Connection exception.""" + + +class ExecutionParams(TypedDict, total=False): + """Execution parameters.""" + + disable_locking: bool + unique_build_dir: bool + + +class ExecutionContext: + """Command execution context.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + app: Application, + app_params: List[str], + system: Optional[System], + system_params: List[str], + custom_deploy_data: Optional[List[DataPaths]] = None, + execution_params: Optional[ExecutionParams] = None, + report_file: Optional[Path] = None, + ): + """Init execution context.""" + self.app = app + self.app_params = app_params + self.custom_deploy_data = custom_deploy_data or [] + self.system = system + self.system_params = system_params + self.execution_params = execution_params or ExecutionParams() + self.report_file = report_file + + self.reporter: Optional[Reporter] + if self.report_file: + # Create reporter with output parsers + parsers: List[OutputParser] = [] + if system and system.reporting: + # Add RegexOutputParser, if it is configured in the system + parsers.append(RegexOutputParser("system", system.reporting["regex"])) + # Add Base64 parser for applications + parsers.append(Base64OutputParser("application")) + self.reporter = Reporter(parsers=parsers) + else: + self.reporter = None # No reporter needed. + + self.param_resolver = ParamResolver(self) + self._resolved_build_dir: Optional[Path] = None + + self.stdout: Optional[bytearray] = None + self.stderr: Optional[bytearray] = None + + @property + def is_deploy_needed(self) -> bool: + """Check if application requires data deployment.""" + return len(self.app.get_deploy_data()) > 0 or len(self.custom_deploy_data) > 0 + + @property + def is_locking_required(self) -> bool: + """Return true if any form of locking required.""" + return not self._disable_locking() and ( + self.app.lock or (self.system is not None and self.system.lock) + ) + + @property + def is_build_required(self) -> bool: + """Return true if application build required.""" + return "build" in self.app.commands + + @property + def is_unique_build_dir_required(self) -> bool: + """Return true if unique build dir required.""" + return self.execution_params.get("unique_build_dir", False) + + def build_dir(self) -> Path: + """Return resolved application build dir.""" + if self._resolved_build_dir is not None: + return self._resolved_build_dir + + if ( + not isinstance(self.app.config_location, Path) + or not self.app.config_location.is_dir() + ): + raise ConfigurationException( + "Application {} has wrong config location".format(self.app.name) + ) + + _build_dir = self.app.build_dir + if _build_dir: + _build_dir = resolve_all_parameters(_build_dir, self.param_resolver) + + if not _build_dir: + raise ConfigurationException( + "No build directory defined for the app {}".format(self.app.name) + ) + + if self.is_unique_build_dir_required: + random_suffix = "".join( + random.choices(string.ascii_lowercase + string.digits, k=7) + ) + _build_dir = "{}_{}".format(_build_dir, random_suffix) + + self._resolved_build_dir = self.app.config_location / _build_dir + return self._resolved_build_dir + + def _disable_locking(self) -> bool: + """Return true if locking should be disabled.""" + return self.execution_params.get("disable_locking", False) + + +class ParamResolver: + """Parameter resolver.""" + + def __init__(self, context: ExecutionContext): + """Init parameter resolver.""" + self.ctx = context + + @staticmethod + def resolve_user_params( + cmd_name: Optional[str], + index_or_alias: str, + resolved_params: Optional[List[Tuple[Optional[str], Param]]], + ) -> str: + """Resolve user params.""" + if not cmd_name or resolved_params is None: + raise ConfigurationException("Unable to resolve user params") + + param_value: Optional[str] = None + param: Optional[Param] = None + + if index_or_alias.isnumeric(): + i = int(index_or_alias) + if i not in range(len(resolved_params)): + raise ConfigurationException( + "Invalid index {} for user params of command {}".format(i, cmd_name) + ) + param_value, param = resolved_params[i] + else: + for val, par in resolved_params: + if par.alias == index_or_alias: + param_value, param = val, par + break + + if param is None: + raise ConfigurationException( + "No user parameter for command '{}' with alias '{}'.".format( + cmd_name, index_or_alias + ) + ) + + if param_value: + # We need to handle to cases of parameters here: + # 1) Optional parameters (non-positional with a name and value) + # 2) Positional parameters (value only, no name needed) + # Default to empty strings for positional arguments + param_name = "" + separator = "" + if param.name is not None: + # A valid param name means we have an optional/non-positional argument: + # The separator is an empty string in case the param_name + # has an equal sign as we have to honour it. + # If the parameter doesn't end with an equal sign then a + # space character is injected to split the parameter name + # and its value + param_name = param.name + separator = "" if param.name.endswith("=") else " " + + return "{param_name}{separator}{param_value}".format( + param_name=param_name, + separator=separator, + param_value=param_value, + ) + + if param.name is None: + raise ConfigurationException( + "Missing user parameter with alias '{}' for command '{}'.".format( + index_or_alias, cmd_name + ) + ) + + return param.name # flag: just return the parameter name + + def resolve_commands_and_params( + self, backend_type: str, cmd_name: str, return_params: bool, index_or_alias: str + ) -> str: + """Resolve command or command's param value.""" + if backend_type == "system": + backend = cast(Backend, self.ctx.system) + backend_params = self.ctx.system_params + else: # Application or Tool backend + backend = cast(Backend, self.ctx.app) + backend_params = self.ctx.app_params + + if cmd_name not in backend.commands: + raise ConfigurationException("Command {} not found".format(cmd_name)) + + if return_params: + params = backend.resolved_parameters(cmd_name, backend_params) + if index_or_alias.isnumeric(): + i = int(index_or_alias) + if i not in range(len(params)): + raise ConfigurationException( + "Invalid parameter index {} for command {}".format(i, cmd_name) + ) + + param_value = params[i][0] + else: + param_value = None + for value, param in params: + if param.alias == index_or_alias: + param_value = value + break + + if not param_value: + raise ConfigurationException( + ( + "No value for parameter with index or alias {} of command {}" + ).format(index_or_alias, cmd_name) + ) + return param_value + + if not index_or_alias.isnumeric(): + raise ConfigurationException("Bad command index {}".format(index_or_alias)) + + i = int(index_or_alias) + commands = backend.build_command(cmd_name, backend_params, self.param_resolver) + if i not in range(len(commands)): + raise ConfigurationException( + "Invalid index {} for command {}".format(i, cmd_name) + ) + + return commands[i] + + def resolve_variables(self, backend_type: str, var_name: str) -> str: + """Resolve variable value.""" + if backend_type == "system": + backend = cast(Backend, self.ctx.system) + else: # Application or Tool backend + backend = cast(Backend, self.ctx.app) + + if var_name not in backend.variables: + raise ConfigurationException("Unknown variable {}".format(var_name)) + + return backend.variables[var_name] + + def param_matcher( + self, + param_name: str, + cmd_name: Optional[str], + resolved_params: Optional[List[Tuple[Optional[str], Param]]], + ) -> str: + """Regexp to resolve a param from the param_name.""" + # this pattern supports parameter names like "application.commands.run:0" and + # "system.commands.run.params:0" + # Note: 'software' is included for backward compatibility. + commands_and_params_match = re.match( + r"(?Papplication|software|tool|system)[.]commands[.]" + r"(?P\w+)" + r"(?P[.]params|)[:]" + r"(?P\w+)", + param_name, + ) + + if commands_and_params_match: + backend_type, cmd_name, return_params, index_or_alias = ( + commands_and_params_match["type"], + commands_and_params_match["name"], + commands_and_params_match["params"], + commands_and_params_match["index_or_alias"], + ) + return self.resolve_commands_and_params( + backend_type, cmd_name, bool(return_params), index_or_alias + ) + + # Note: 'software' is included for backward compatibility. + variables_match = re.match( + r"(?Papplication|software|tool|system)[.]variables:(?P\w+)", + param_name, + ) + if variables_match: + backend_type, var_name = ( + variables_match["type"], + variables_match["var_name"], + ) + return self.resolve_variables(backend_type, var_name) + + user_params_match = re.match(r"user_params:(?P\w+)", param_name) + if user_params_match: + index_or_alias = user_params_match["index_or_alias"] + return self.resolve_user_params(cmd_name, index_or_alias, resolved_params) + + raise ConfigurationException( + "Unable to resolve parameter {}".format(param_name) + ) + + def param_resolver( + self, + param_name: str, + cmd_name: Optional[str] = None, + resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, + ) -> str: + """Resolve parameter value based on current execution context.""" + # Note: 'software.*' is included for backward compatibility. + resolved_param = None + if param_name in ["application.name", "tool.name", "software.name"]: + resolved_param = self.ctx.app.name + elif param_name in [ + "application.description", + "tool.description", + "software.description", + ]: + resolved_param = self.ctx.app.description + elif self.ctx.app.config_location and ( + param_name + in ["application.config_dir", "tool.config_dir", "software.config_dir"] + ): + resolved_param = str(self.ctx.app.config_location.absolute()) + elif self.ctx.app.build_dir and ( + param_name + in ["application.build_dir", "tool.build_dir", "software.build_dir"] + ): + resolved_param = str(self.ctx.build_dir().absolute()) + elif self.ctx.system is not None: + if param_name == "system.name": + resolved_param = self.ctx.system.name + elif param_name == "system.description": + resolved_param = self.ctx.system.description + elif param_name == "system.config_dir" and self.ctx.system.config_location: + resolved_param = str(self.ctx.system.config_location.absolute()) + + if not resolved_param: + resolved_param = self.param_matcher(param_name, cmd_name, resolved_params) + return resolved_param + + def __call__( + self, + param_name: str, + cmd_name: Optional[str] = None, + resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None, + ) -> str: + """Resolve provided parameter.""" + return self.param_resolver(param_name, cmd_name, resolved_params) + + +class Reporter: + """Report metrics from the simulation output.""" + + def __init__(self, parsers: Optional[List[OutputParser]] = None) -> None: + """Create an empty reporter (i.e. no parsers registered).""" + self.parsers: List[OutputParser] = parsers if parsers is not None else [] + self._report: Dict[str, Any] = defaultdict(lambda: defaultdict(dict)) + + def parse(self, output: bytearray) -> None: + """Parse output and append parsed metrics to internal report dict.""" + for parser in self.parsers: + # Merge metrics from different parsers (do not overwrite) + self._report[parser.name]["metrics"].update(parser(output)) + + def get_filtered_output(self, output: bytearray) -> bytearray: + """Filter the output according to each parser.""" + for parser in self.parsers: + output = parser.filter_out_parsed_content(output) + return output + + def report(self, ctx: ExecutionContext) -> Dict[str, Any]: + """Add static simulation info to parsed data and return the report.""" + report: Dict[str, Any] = defaultdict(dict) + # Add static simulation info + report.update(self._static_info(ctx)) + # Add metrics parsed from the output + for key, val in self._report.items(): + report[key].update(val) + return report + + @staticmethod + def save(report: Dict[str, Any], report_file: Path) -> None: + """Save the report to a JSON file.""" + with open(report_file, "w", encoding="utf-8") as file: + json.dump(report, file, indent=4) + + @staticmethod + def _compute_all_params(cli_params: List[str], backend: Backend) -> Dict[str, str]: + """ + Build a dict of all parameters, {name:value}. + + Param values taken from command line if specified, defaults otherwise. + """ + # map of params passed from the cli ["p1=v1","p2=v2"] -> {"p1":"v1", "p2":"v2"} + app_params_map = dict(parse_raw_parameter(expr) for expr in cli_params) + + # a map of params declared in the application, with values taken from the CLI, + # defaults otherwise + all_params = { + (p.alias or p.name): app_params_map.get( + cast(str, p.name), cast(str, p.default_value) + ) + for cmd in backend.commands.values() + for p in cmd.params + } + return cast(Dict[str, str], all_params) + + @staticmethod + def _static_info(ctx: ExecutionContext) -> Dict[str, Any]: + """Extract static simulation information from the context.""" + if ctx.system is None: + raise ValueError("No system available to report.") + + info = { + "system": { + "name": ctx.system.name, + "params": Reporter._compute_all_params(ctx.system_params, ctx.system), + }, + "application": { + "name": ctx.app.name, + "params": Reporter._compute_all_params(ctx.app_params, ctx.app), + }, + } + return info + + +def validate_parameters( + backend: Backend, command_names: List[str], params: List[str] +) -> None: + """Check parameters passed to backend.""" + for param in params: + acceptable = any( + backend.validate_parameter(command_name, param) + for command_name in command_names + if command_name in backend.commands + ) + + if not acceptable: + backend_type = "System" if isinstance(backend, System) else "Application" + raise ValueError( + "{} parameter '{}' not valid for command '{}'".format( + backend_type, param, " or ".join(command_names) + ) + ) + + +def get_application_by_name_and_system( + application_name: str, system_name: str +) -> Application: + """Get application.""" + applications = get_application(application_name, system_name) + if not applications: + raise ValueError( + "Application '{}' doesn't support the system '{}'".format( + application_name, system_name + ) + ) + + if len(applications) != 1: + raise ValueError( + "Error during getting application {} for the system {}".format( + application_name, system_name + ) + ) + + return applications[0] + + +def get_application_and_system( + application_name: str, system_name: str +) -> Tuple[Application, System]: + """Return application and system by provided names.""" + system = get_system(system_name) + if not system: + raise ValueError("System {} is not found".format(system_name)) + + application = get_application_by_name_and_system(application_name, system_name) + + return application, system + + +# pylint: disable=too-many-arguments +def run_application( + application_name: str, + application_params: List[str], + system_name: str, + system_params: List[str], + custom_deploy_data: List[DataPaths], + report_file: Optional[Path] = None, +) -> ExecutionContext: + """Run application on the provided system.""" + application, system = get_application_and_system(application_name, system_name) + validate_parameters(application, ["build", "run"], application_params) + validate_parameters(system, ["build", "run"], system_params) + + execution_params = ExecutionParams() + if isinstance(system, StandaloneSystem): + execution_params["disable_locking"] = True + execution_params["unique_build_dir"] = True + + ctx = ExecutionContext( + app=application, + app_params=application_params, + system=system, + system_params=system_params, + custom_deploy_data=custom_deploy_data, + execution_params=execution_params, + report_file=report_file, + ) + + with build_dir_manager(ctx): + if ctx.is_build_required: + execute_application_command_build(ctx) + + execute_application_command_run(ctx) + + return ctx + + +def execute_application_command_build(ctx: ExecutionContext) -> None: + """Execute application command 'build'.""" + with ExitStack() as context_stack: + for manager in get_context_managers("build", ctx): + context_stack.enter_context(manager(ctx)) + + build_dir = ctx.build_dir() + recreate_directory(build_dir) + + build_commands = ctx.app.build_command( + "build", ctx.app_params, ctx.param_resolver + ) + execute_commands_locally(build_commands, build_dir) + + +def execute_commands_locally(commands: List[str], cwd: Path) -> None: + """Execute list of commands locally.""" + for command in commands: + print("Running: {}".format(command)) + run_and_wait( + command, cwd, terminate_on_error=True, out=sys.stdout, err=sys.stderr + ) + + +def execute_application_command_run(ctx: ExecutionContext) -> None: + """Execute application command.""" + assert ctx.system is not None, "System must be provided." + if ctx.is_deploy_needed and not ctx.system.supports_deploy: + raise ConfigurationException( + "System {} does not support data deploy".format(ctx.system.name) + ) + + with ExitStack() as context_stack: + for manager in get_context_managers("run", ctx): + context_stack.enter_context(manager(ctx)) + + print("Generating commands to execute") + commands_to_run = build_run_commands(ctx) + + if ctx.system.connectable: + establish_connection(ctx) + + if ctx.system.supports_deploy: + deploy_data(ctx) + + for command in commands_to_run: + print("Running: {}".format(command)) + exit_code, ctx.stdout, ctx.stderr = ctx.system.run(command) + + if exit_code != 0: + print("Application exited with exit code {}".format(exit_code)) + + if ctx.reporter: + ctx.reporter.parse(ctx.stdout) + ctx.stdout = ctx.reporter.get_filtered_output(ctx.stdout) + + if ctx.reporter: + report = ctx.reporter.report(ctx) + ctx.reporter.save(report, cast(Path, ctx.report_file)) + + +def establish_connection( + ctx: ExecutionContext, retries: int = 90, interval: float = 15.0 +) -> None: + """Establish connection with the system.""" + assert ctx.system is not None, "System is required." + host, port = ctx.system.connection_details() + print( + "Trying to establish connection with '{}:{}' - " + "{} retries every {} seconds ".format(host, port, retries, interval), + end="", + ) + + try: + for _ in range(retries): + print(".", end="", flush=True) + + if ctx.system.establish_connection(): + break + + if isinstance(ctx.system, ControlledSystem) and not ctx.system.is_running(): + print( + "\n\n---------- {} execution failed ----------".format( + ctx.system.name + ) + ) + stdout, stderr = ctx.system.get_output() + print(stdout) + print(stderr) + + raise Exception("System is not running") + + wait(interval) + else: + raise ConnectionException("Couldn't connect to '{}:{}'.".format(host, port)) + finally: + print() + + +def wait(interval: float) -> None: + """Wait for a period of time.""" + time.sleep(interval) + + +def deploy_data(ctx: ExecutionContext) -> None: + """Deploy data to the system.""" + assert ctx.system is not None, "System is required." + for item in itertools.chain(ctx.app.get_deploy_data(), ctx.custom_deploy_data): + print("Deploying {} onto {}".format(item.src, item.dst)) + ctx.system.deploy(item.src, item.dst) + + +def build_run_commands(ctx: ExecutionContext) -> List[str]: + """Build commands to run application.""" + if isinstance(ctx.system, StandaloneSystem): + return ctx.system.build_command("run", ctx.system_params, ctx.param_resolver) + + return ctx.app.build_command("run", ctx.app_params, ctx.param_resolver) + + +@contextmanager +def controlled_system_manager(ctx: ExecutionContext) -> Generator[None, None, None]: + """Context manager used for system initialisation before run.""" + system = cast(ControlledSystem, ctx.system) + commands = system.build_command("run", ctx.system_params, ctx.param_resolver) + pid_file_path: Optional[Path] = None + if ctx.is_locking_required: + file_lock_path = get_file_lock_path(ctx) + pid_file_path = file_lock_path.parent / "{}.pid".format(file_lock_path.stem) + + system.start(commands, ctx.is_locking_required, pid_file_path) + try: + yield + finally: + print("Shutting down sequence...") + print("Stopping {}... (It could take few seconds)".format(system.name)) + system.stop(wait=True) + print("{} stopped successfully.".format(system.name)) + + +@contextmanager +def lock_execution_manager(ctx: ExecutionContext) -> Generator[None, None, None]: + """Lock execution manager.""" + file_lock_path = get_file_lock_path(ctx) + file_lock = FileLock(str(file_lock_path)) + + try: + file_lock.acquire(timeout=1) + except Timeout as error: + raise AnotherInstanceIsRunningException() from error + + try: + yield + finally: + file_lock.release() + + +def get_file_lock_path(ctx: ExecutionContext, lock_dir: Path = Path("/tmp")) -> Path: + """Get file lock path.""" + lock_modules = [] + if ctx.app.lock: + lock_modules.append(ctx.app.name) + if ctx.system is not None and ctx.system.lock: + lock_modules.append(ctx.system.name) + lock_filename = "" + if lock_modules: + lock_filename = "_".join(["middleware"] + lock_modules) + ".lock" + + if lock_filename: + lock_filename = resolve_all_parameters(lock_filename, ctx.param_resolver) + lock_filename = valid_for_filename(lock_filename) + + if not lock_filename: + raise ConfigurationException("No filename for lock provided") + + if not isinstance(lock_dir, Path) or not lock_dir.is_dir(): + raise ConfigurationException( + "Invalid directory {} for lock files provided".format(lock_dir) + ) + + return lock_dir / lock_filename + + +@contextmanager +def build_dir_manager(ctx: ExecutionContext) -> Generator[None, None, None]: + """Build directory manager.""" + try: + yield + finally: + if ( + ctx.is_build_required + and ctx.is_unique_build_dir_required + and ctx.build_dir().is_dir() + ): + remove_directory(ctx.build_dir()) + + +def get_context_managers( + command_name: str, ctx: ExecutionContext +) -> Sequence[Callable[[ExecutionContext], ContextManager[None]]]: + """Get context manager for the system.""" + managers = [] + + if ctx.is_locking_required: + managers.append(lock_execution_manager) + + if command_name == "run": + if isinstance(ctx.system, ControlledSystem): + managers.append(controlled_system_manager) + + return managers diff --git a/src/mlia/backend/fs.py b/src/mlia/backend/fs.py new file mode 100644 index 0000000..9979fcb --- /dev/null +++ b/src/mlia/backend/fs.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module to host all file system related functions.""" +import re +import shutil +from pathlib import Path +from typing import Any +from typing import Literal +from typing import Optional + +from mlia.utils.filesystem import get_mlia_resources + +ResourceType = Literal["applications", "systems"] + + +def get_backend_resources() -> Path: + """Get backend resources folder path.""" + return get_mlia_resources() / "backends" + + +def get_backends_path(name: ResourceType) -> Path: + """Return the absolute path of the specified resource. + + It uses importlib to return resources packaged with MANIFEST.in. + """ + if not name: + raise ResourceWarning("Resource name is not provided") + + resource_path = get_backend_resources() / name + if resource_path.is_dir(): + return resource_path + + raise ResourceWarning("Resource '{}' not found.".format(name)) + + +def copy_directory_content(source: Path, destination: Path) -> None: + """Copy content of the source directory into destination directory.""" + for item in source.iterdir(): + src = source / item.name + dest = destination / item.name + + if src.is_dir(): + shutil.copytree(src, dest) + else: + shutil.copy2(src, dest) + + +def remove_resource(resource_directory: str, resource_type: ResourceType) -> None: + """Remove resource data.""" + resources = get_backends_path(resource_type) + + resource_location = resources / resource_directory + if not resource_location.exists(): + raise Exception("Resource {} does not exist".format(resource_directory)) + + if not resource_location.is_dir(): + raise Exception("Wrong resource {}".format(resource_directory)) + + shutil.rmtree(resource_location) + + +def remove_directory(directory_path: Optional[Path]) -> None: + """Remove directory.""" + if not directory_path or not directory_path.is_dir(): + raise Exception("No directory path provided") + + shutil.rmtree(directory_path) + + +def recreate_directory(directory_path: Optional[Path]) -> None: + """Recreate directory.""" + if not directory_path: + raise Exception("No directory path provided") + + if directory_path.exists() and not directory_path.is_dir(): + raise Exception( + "Path {} does exist and it is not a directory".format(str(directory_path)) + ) + + if directory_path.is_dir(): + remove_directory(directory_path) + + directory_path.mkdir() + + +def read_file(file_path: Path, mode: Optional[str] = None) -> Any: + """Read file as string or bytearray.""" + if file_path.is_file(): + if mode is not None: + # Ignore pylint warning because mode can be 'binary' as well which + # is not compatible with specifying encodings. + with open(file_path, mode) as file: # pylint: disable=unspecified-encoding + return file.read() + else: + with open(file_path, encoding="utf-8") as file: + return file.read() + + if mode == "rb": + return b"" + return "" + + +def read_file_as_string(file_path: Path) -> str: + """Read file as string.""" + return str(read_file(file_path)) + + +def read_file_as_bytearray(file_path: Path) -> bytearray: + """Read a file as bytearray.""" + return bytearray(read_file(file_path, mode="rb")) + + +def valid_for_filename(value: str, replacement: str = "") -> str: + """Replace non alpha numeric characters.""" + return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII) diff --git a/src/mlia/backend/manager.py b/src/mlia/backend/manager.py new file mode 100644 index 0000000..3a1016c --- /dev/null +++ b/src/mlia/backend/manager.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Module for backend integration.""" +import logging +import re +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple + +from mlia.backend.application import get_available_applications +from mlia.backend.application import install_application +from mlia.backend.common import DataPaths +from mlia.backend.execution import ExecutionContext +from mlia.backend.execution import run_application +from mlia.backend.system import get_available_systems +from mlia.backend.system import install_system +from mlia.utils.proc import OutputConsumer +from mlia.utils.proc import RunningCommand + + +logger = logging.getLogger(__name__) + +# Mapping backend -> device_type -> system_name +_SUPPORTED_SYSTEMS = { + "Corstone-300": { + "ethos-u55": "Corstone-300: Cortex-M55+Ethos-U55", + "ethos-u65": "Corstone-300: Cortex-M55+Ethos-U65", + }, + "Corstone-310": { + "ethos-u55": "Corstone-310: Cortex-M85+Ethos-U55", + }, +} + +# Mapping system_name -> memory_mode -> application +_SYSTEM_TO_APP_MAP = { + "Corstone-300: Cortex-M55+Ethos-U55": { + "Sram": "Generic Inference Runner: Ethos-U55 SRAM", + "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + }, + "Corstone-300: Cortex-M55+Ethos-U65": { + "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + "Dedicated_Sram": "Generic Inference Runner: Ethos-U65 Dedicated SRAM", + }, + "Corstone-310: Cortex-M85+Ethos-U55": { + "Sram": "Generic Inference Runner: Ethos-U55 SRAM", + "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + }, +} + + +def get_system_name(backend: str, device_type: str) -> str: + """Get the system name for the given backend and device type.""" + return _SUPPORTED_SYSTEMS[backend][device_type] + + +def is_supported(backend: str, device_type: Optional[str] = None) -> bool: + """Check if the backend (and optionally device type) is supported.""" + if device_type is None: + return backend in _SUPPORTED_SYSTEMS + + try: + get_system_name(backend, device_type) + return True + except KeyError: + return False + + +def supported_backends() -> List[str]: + """Get a list of all backends supported by the backend manager.""" + return list(_SUPPORTED_SYSTEMS.keys()) + + +def get_all_system_names(backend: str) -> List[str]: + """Get all systems supported by the backend.""" + return list(_SUPPORTED_SYSTEMS.get(backend, {}).values()) + + +def get_all_application_names(backend: str) -> List[str]: + """Get all applications supported by the backend.""" + app_set = { + app + for sys in get_all_system_names(backend) + for app in _SYSTEM_TO_APP_MAP[sys].values() + } + return list(app_set) + + +@dataclass +class DeviceInfo: + """Device information.""" + + device_type: Literal["ethos-u55", "ethos-u65"] + mac: int + memory_mode: Literal["Sram", "Shared_Sram", "Dedicated_Sram"] + + +@dataclass +class ModelInfo: + """Model info.""" + + model_path: Path + + +@dataclass +class PerformanceMetrics: + """Performance metrics parsed from generic inference output.""" + + npu_active_cycles: int + npu_idle_cycles: int + npu_total_cycles: int + npu_axi0_rd_data_beat_received: int + npu_axi0_wr_data_beat_written: int + npu_axi1_rd_data_beat_received: int + + +@dataclass +class ExecutionParams: + """Application execution params.""" + + application: str + system: str + application_params: List[str] + system_params: List[str] + deploy_params: List[str] + + +class LogWriter(OutputConsumer): + """Redirect output to the logger.""" + + def feed(self, line: str) -> None: + """Process line from the output.""" + logger.debug(line.strip()) + + +class GenericInferenceOutputParser(OutputConsumer): + """Generic inference app output parser.""" + + PATTERNS = { + name: tuple(re.compile(pattern, re.IGNORECASE) for pattern in patterns) + for name, patterns in ( + ( + "npu_active_cycles", + ( + r"NPU ACTIVE cycles: (?P\d+)", + r"NPU ACTIVE: (?P\d+) cycles", + ), + ), + ( + "npu_idle_cycles", + ( + r"NPU IDLE cycles: (?P\d+)", + r"NPU IDLE: (?P\d+) cycles", + ), + ), + ( + "npu_total_cycles", + ( + r"NPU TOTAL cycles: (?P\d+)", + r"NPU TOTAL: (?P\d+) cycles", + ), + ), + ( + "npu_axi0_rd_data_beat_received", + ( + r"NPU AXI0_RD_DATA_BEAT_RECEIVED beats: (?P\d+)", + r"NPU AXI0_RD_DATA_BEAT_RECEIVED: (?P\d+) beats", + ), + ), + ( + "npu_axi0_wr_data_beat_written", + ( + r"NPU AXI0_WR_DATA_BEAT_WRITTEN beats: (?P\d+)", + r"NPU AXI0_WR_DATA_BEAT_WRITTEN: (?P\d+) beats", + ), + ), + ( + "npu_axi1_rd_data_beat_received", + ( + r"NPU AXI1_RD_DATA_BEAT_RECEIVED beats: (?P\d+)", + r"NPU AXI1_RD_DATA_BEAT_RECEIVED: (?P\d+) beats", + ), + ), + ) + } + + def __init__(self) -> None: + """Init generic inference output parser instance.""" + self.result: Dict = {} + + def feed(self, line: str) -> None: + """Feed new line to the parser.""" + for name, patterns in self.PATTERNS.items(): + for pattern in patterns: + match = pattern.search(line) + + if match: + self.result[name] = int(match["value"]) + return + + def is_ready(self) -> bool: + """Return true if all expected data has been parsed.""" + return self.result.keys() == self.PATTERNS.keys() + + def missed_keys(self) -> List[str]: + """Return list of the keys that have not been found in the output.""" + return sorted(self.PATTERNS.keys() - self.result.keys()) + + +class BackendRunner: + """Backend runner.""" + + def __init__(self) -> None: + """Init BackendRunner instance.""" + + @staticmethod + def get_installed_systems() -> List[str]: + """Get list of the installed systems.""" + return [system.name for system in get_available_systems()] + + @staticmethod + def get_installed_applications(system: Optional[str] = None) -> List[str]: + """Get list of the installed application.""" + return [ + app.name + for app in get_available_applications() + if system is None or app.can_run_on(system) + ] + + def is_application_installed(self, application: str, system: str) -> bool: + """Return true if requested application installed.""" + return application in self.get_installed_applications(system) + + def is_system_installed(self, system: str) -> bool: + """Return true if requested system installed.""" + return system in self.get_installed_systems() + + def systems_installed(self, systems: List[str]) -> bool: + """Check if all provided systems are installed.""" + if not systems: + return False + + installed_systems = self.get_installed_systems() + return all(system in installed_systems for system in systems) + + def applications_installed(self, applications: List[str]) -> bool: + """Check if all provided applications are installed.""" + if not applications: + return False + + installed_apps = self.get_installed_applications() + return all(app in installed_apps for app in applications) + + def all_installed(self, systems: List[str], apps: List[str]) -> bool: + """Check if all provided artifacts are installed.""" + return self.systems_installed(systems) and self.applications_installed(apps) + + @staticmethod + def install_system(system_path: Path) -> None: + """Install system.""" + install_system(system_path) + + @staticmethod + def install_application(app_path: Path) -> None: + """Install application.""" + install_application(app_path) + + @staticmethod + def run_application(execution_params: ExecutionParams) -> ExecutionContext: + """Run requested application.""" + + def to_data_paths(paths: str) -> DataPaths: + """Split input into two and create new DataPaths object.""" + src, dst = paths.split(sep=":", maxsplit=1) + return DataPaths(Path(src), dst) + + deploy_data_paths = [ + to_data_paths(paths) for paths in execution_params.deploy_params + ] + + ctx = run_application( + execution_params.application, + execution_params.application_params, + execution_params.system, + execution_params.system_params, + deploy_data_paths, + ) + + return ctx + + @staticmethod + def _params(name: str, params: List[str]) -> List[str]: + return [p for item in [(name, param) for param in params] for p in item] + + +class GenericInferenceRunner(ABC): + """Abstract class for generic inference runner.""" + + def __init__(self, backend_runner: BackendRunner): + """Init generic inference runner instance.""" + self.backend_runner = backend_runner + self.running_inference: Optional[RunningCommand] = None + + def run( + self, model_info: ModelInfo, output_consumers: List[OutputConsumer] + ) -> None: + """Run generic inference for the provided device/model.""" + execution_params = self.get_execution_params(model_info) + + ctx = self.backend_runner.run_application(execution_params) + if ctx.stdout is not None: + self.consume_output(ctx.stdout, output_consumers) + + def stop(self) -> None: + """Stop running inference.""" + if self.running_inference is None: + return + + self.running_inference.stop() + + @abstractmethod + def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: + """Get execution params for the provided model.""" + + def __enter__(self) -> "GenericInferenceRunner": + """Enter context.""" + return self + + def __exit__(self, *_args: Any) -> None: + """Exit context.""" + self.stop() + + def check_system_and_application(self, system_name: str, app_name: str) -> None: + """Check if requested system and application installed.""" + if not self.backend_runner.is_system_installed(system_name): + raise Exception(f"System {system_name} is not installed") + + if not self.backend_runner.is_application_installed(app_name, system_name): + raise Exception( + f"Application {app_name} for the system {system_name} " + "is not installed" + ) + + @staticmethod + def consume_output(output: bytearray, consumers: List[OutputConsumer]) -> None: + """Pass program's output to the consumers.""" + for line in output.decode("utf8").splitlines(): + for consumer in consumers: + consumer.feed(line) + + +class GenericInferenceRunnerEthosU(GenericInferenceRunner): + """Generic inference runner on U55/65.""" + + def __init__( + self, backend_runner: BackendRunner, device_info: DeviceInfo, backend: str + ) -> None: + """Init generic inference runner instance.""" + super().__init__(backend_runner) + + system_name, app_name = self.resolve_system_and_app(device_info, backend) + self.system_name = system_name + self.app_name = app_name + self.device_info = device_info + + @staticmethod + def resolve_system_and_app( + device_info: DeviceInfo, backend: str + ) -> Tuple[str, str]: + """Find appropriate system and application for the provided device/backend.""" + try: + system_name = get_system_name(backend, device_info.device_type) + except KeyError as ex: + raise RuntimeError( + f"Unsupported device {device_info.device_type} " + f"for backend {backend}" + ) from ex + + if system_name not in _SYSTEM_TO_APP_MAP: + raise RuntimeError(f"System {system_name} is not installed") + + try: + app_name = _SYSTEM_TO_APP_MAP[system_name][device_info.memory_mode] + except KeyError as err: + raise RuntimeError( + f"Unsupported memory mode {device_info.memory_mode}" + ) from err + + return system_name, app_name + + def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: + """Get execution params for Ethos-U55/65.""" + self.check_system_and_application(self.system_name, self.app_name) + + system_params = [ + f"mac={self.device_info.mac}", + f"input_file={model_info.model_path.absolute()}", + ] + + return ExecutionParams( + self.app_name, + self.system_name, + [], + system_params, + [], + ) + + +def get_generic_runner(device_info: DeviceInfo, backend: str) -> GenericInferenceRunner: + """Get generic runner for provided device and backend.""" + backend_runner = get_backend_runner() + return GenericInferenceRunnerEthosU(backend_runner, device_info, backend) + + +def estimate_performance( + model_info: ModelInfo, device_info: DeviceInfo, backend: str +) -> PerformanceMetrics: + """Get performance estimations.""" + with get_generic_runner(device_info, backend) as generic_runner: + output_parser = GenericInferenceOutputParser() + output_consumers = [output_parser, LogWriter()] + + generic_runner.run(model_info, output_consumers) + + if not output_parser.is_ready(): + missed_data = ",".join(output_parser.missed_keys()) + logger.debug( + "Unable to get performance metrics, missed data %s", missed_data + ) + raise Exception("Unable to get performance metrics, insufficient data") + + return PerformanceMetrics(**output_parser.result) + + +def get_backend_runner() -> BackendRunner: + """ + Return BackendRunner instance. + + Note: This is needed for the unit tests. + """ + return BackendRunner() diff --git a/src/mlia/backend/output_parser.py b/src/mlia/backend/output_parser.py new file mode 100644 index 0000000..111772a --- /dev/null +++ b/src/mlia/backend/output_parser.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Definition of output parsers (including base class OutputParser).""" +import base64 +import json +import re +from abc import ABC +from abc import abstractmethod +from typing import Any +from typing import Dict +from typing import Union + + +class OutputParser(ABC): + """Abstract base class for output parsers.""" + + def __init__(self, name: str) -> None: + """Set up the name of the parser.""" + super().__init__() + self.name = name + + @abstractmethod + def __call__(self, output: bytearray) -> Dict[str, Any]: + """Parse the output and return a map of names to metrics.""" + return {} + + # pylint: disable=no-self-use + def filter_out_parsed_content(self, output: bytearray) -> bytearray: + """ + Filter out the parsed content from the output. + + Does nothing by default. Can be overridden in subclasses. + """ + return output + + +class RegexOutputParser(OutputParser): + """Parser of standard output data using regular expressions.""" + + _TYPE_MAP = {"str": str, "float": float, "int": int} + + def __init__( + self, + name: str, + regex_config: Dict[str, Dict[str, str]], + ) -> None: + """ + Set up the parser with the regular expressions. + + The regex_config is mapping from a name to a dict with keys 'pattern' + and 'type': + - The 'pattern' holds the regular expression that must contain exactly + one capturing parenthesis + - The 'type' can be one of ['str', 'float', 'int']. + + Example: + ``` + {"Metric1": {"pattern": ".*= *(.*)", "type": "str"}} + ``` + + The different regular expressions from the config are combined using + non-capturing parenthesis, i.e. regular expressions must not overlap + if more than one match per line is expected. + """ + super().__init__(name) + + self._verify_config(regex_config) + self._regex_cfg = regex_config + + # Compile regular expression to match in the output + self._regex = re.compile( + "|".join("(?:{0})".format(x["pattern"]) for x in self._regex_cfg.values()) + ) + + def __call__(self, output: bytearray) -> Dict[str, Union[str, float, int]]: + """ + Parse the output and return a map of names to metrics. + + Example: + Assuming a regex_config as used as example in `__init__()` and the + following output: + ``` + Simulation finished: + SIMULATION_STATUS = SUCCESS + Simulation DONE + ``` + Then calling the parser should return the following dict: + ``` + { + "Metric1": "SUCCESS" + } + ``` + """ + metrics = {} + output_str = output.decode("utf-8") + results = self._regex.findall(output_str) + for line_result in results: + for idx, (name, cfg) in enumerate(self._regex_cfg.items()): + # The result(s) returned by findall() are either a single string + # or a tuple (depending on the number of groups etc.) + result = ( + line_result if isinstance(line_result, str) else line_result[idx] + ) + if result: + mapped_result = self._TYPE_MAP[cfg["type"]](result) + metrics[name] = mapped_result + return metrics + + def _verify_config(self, regex_config: Dict[str, Dict[str, str]]) -> None: + """Make sure we have a valid regex_config. + + I.e. + - Exactly one capturing parenthesis per pattern + - Correct types + """ + for name, cfg in regex_config.items(): + # Check that there is one capturing group defined in the pattern. + regex = re.compile(cfg["pattern"]) + if regex.groups != 1: + raise ValueError( + f"Pattern for metric '{name}' must have exactly one " + f"capturing parenthesis, but it has {regex.groups}." + ) + # Check if type is supported + if not cfg["type"] in self._TYPE_MAP: + raise TypeError( + f"Type '{cfg['type']}' for metric '{name}' is not " + f"supported. Choose from: {list(self._TYPE_MAP.keys())}." + ) + + +class Base64OutputParser(OutputParser): + """ + Parser to extract base64-encoded JSON from tagged standard output. + + Example of the tagged output: + ``` + # Encoded JSON: {"test": 1234} + eyJ0ZXN0IjogMTIzNH0 + ``` + """ + + TAG_NAME = "metrics" + + def __init__(self, name: str) -> None: + """Set up the regular expression to extract tagged strings.""" + super().__init__(name) + self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)") + + def __call__(self, output: bytearray) -> Dict[str, Any]: + """ + Parse the output and return a map of index (as string) to decoded JSON. + + Example: + Using the tagged output from the class docs the parser should return + the following dict: + ``` + { + "0": {"test": 1234} + } + ``` + """ + metrics = {} + output_str = output.decode("utf-8") + results = self._regex.findall(output_str) + for idx, result_base64 in enumerate(results): + result_json = base64.b64decode(result_base64, validate=True) + result = json.loads(result_json) + metrics[str(idx)] = result + + return metrics + + def filter_out_parsed_content(self, output: bytearray) -> bytearray: + """Filter out base64-encoded content from the output.""" + output_str = self._regex.sub("", output.decode("utf-8")) + return bytearray(output_str.encode("utf-8")) diff --git a/src/mlia/backend/proc.py b/src/mlia/backend/proc.py new file mode 100644 index 0000000..90ff414 --- /dev/null +++ b/src/mlia/backend/proc.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Processes module. + +This module contains all classes and functions for dealing with Linux +processes. +""" +import csv +import datetime +import logging +import shlex +import signal +import time +from pathlib import Path +from typing import Any +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Tuple + +import psutil +from sh import Command +from sh import CommandNotFound +from sh import ErrorReturnCode +from sh import RunningCommand + +from mlia.backend.fs import valid_for_filename + + +class CommandFailedException(Exception): + """Exception for failed command execution.""" + + +class ShellCommand: + """Wrapper class for shell commands.""" + + def __init__(self, base_log_path: str = "/tmp") -> None: + """Initialise the class. + + base_log_path: it is the base directory where logs will be stored + """ + self.base_log_path = base_log_path + + def run( + self, + cmd: str, + *args: str, + _cwd: Optional[Path] = None, + _tee: bool = True, + _bg: bool = True, + _out: Any = None, + _err: Any = None, + _search_paths: Optional[List[Path]] = None + ) -> RunningCommand: + """Run the shell command with the given arguments. + + There are special arguments that modify the behaviour of the process. + _cwd: current working directory + _tee: it redirects the stdout both to console and file + _bg: if True, it runs the process in background and the command is not + blocking. + _out: use this object for stdout redirect, + _err: use this object for stderr redirect, + _search_paths: If presented used for searching executable + """ + try: + kwargs = {} + if _cwd: + kwargs["_cwd"] = str(_cwd) + command = Command(cmd, _search_paths).bake(args, **kwargs) + except CommandNotFound as error: + logging.error("Command '%s' not found", error.args[0]) + raise error + + out, err = _out, _err + if not _out and not _err: + out, err = [ + str(item) + for item in self.get_stdout_stderr_paths(self.base_log_path, cmd) + ] + + return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False) + + @classmethod + def get_stdout_stderr_paths(cls, base_log_path: str, cmd: str) -> Tuple[Path, Path]: + """Construct and returns the paths of stdout/stderr files.""" + timestamp = datetime.datetime.now().timestamp() + base_path = Path(base_log_path) + base = base_path / "{}_{}".format(valid_for_filename(cmd, "_"), timestamp) + stdout = base.with_suffix(".out") + stderr = base.with_suffix(".err") + try: + stdout.touch() + stderr.touch() + except FileNotFoundError as error: + logging.error("File not found: %s", error.filename) + raise error + return stdout, stderr + + +def parse_command(command: str, shell: str = "bash") -> List[str]: + """Parse command.""" + cmd, *args = shlex.split(command, posix=True) + + if is_shell_script(cmd): + args = [cmd] + args + cmd = shell + + return [cmd] + args + + +def get_stdout_stderr_paths( + command: str, base_log_path: str = "/tmp" +) -> Tuple[Path, Path]: + """Construct and returns the paths of stdout/stderr files.""" + cmd, *_ = parse_command(command) + + return ShellCommand.get_stdout_stderr_paths(base_log_path, cmd) + + +def execute_command( # pylint: disable=invalid-name + command: str, + cwd: Path, + bg: bool = False, + shell: str = "bash", + out: Any = None, + err: Any = None, +) -> RunningCommand: + """Execute shell command.""" + cmd, *args = parse_command(command, shell) + + search_paths = None + if cmd != shell and (cwd / cmd).is_file(): + search_paths = [cwd] + + return ShellCommand().run( + cmd, *args, _cwd=cwd, _bg=bg, _search_paths=search_paths, _out=out, _err=err + ) + + +def is_shell_script(cmd: str) -> bool: + """Check if command is shell script.""" + return cmd.endswith(".sh") + + +def run_and_wait( + command: str, + cwd: Path, + terminate_on_error: bool = False, + out: Any = None, + err: Any = None, +) -> Tuple[int, bytearray, bytearray]: + """ + Run command and wait while it is executing. + + Returns a tuple: (exit_code, stdout, stderr) + """ + running_cmd: Optional[RunningCommand] = None + try: + running_cmd = execute_command(command, cwd, bg=True, out=out, err=err) + return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr + except ErrorReturnCode as cmd_failed: + raise CommandFailedException() from cmd_failed + except Exception as error: + is_running = running_cmd is not None and running_cmd.is_alive() + if terminate_on_error and is_running: + print("Terminating ...") + terminate_command(running_cmd) + + raise error + + +def terminate_command( + running_cmd: RunningCommand, + wait: bool = True, + wait_period: float = 0.5, + number_of_attempts: int = 20, +) -> None: + """Terminate running command.""" + try: + running_cmd.process.signal_group(signal.SIGINT) + if wait: + for _ in range(number_of_attempts): + time.sleep(wait_period) + if not running_cmd.is_alive(): + return + print( + "Unable to terminate process {}. Sending SIGTERM...".format( + running_cmd.process.pid + ) + ) + running_cmd.process.signal_group(signal.SIGTERM) + except ProcessLookupError: + pass + + +def terminate_external_process( + process: psutil.Process, + wait_period: float = 0.5, + number_of_attempts: int = 20, + wait_for_termination: float = 5.0, +) -> None: + """Terminate external process.""" + try: + process.terminate() + for _ in range(number_of_attempts): + if not process.is_running(): + return + time.sleep(wait_period) + + if process.is_running(): + process.terminate() + time.sleep(wait_for_termination) + except psutil.Error: + print("Unable to terminate process") + + +class ProcessInfo(NamedTuple): + """Process information.""" + + name: str + executable: str + cwd: str + pid: int + + +def save_process_info(pid: int, pid_file: Path) -> None: + """Save process information to file.""" + try: + parent = psutil.Process(pid) + children = parent.children(recursive=True) + family = [parent] + children + + with open(pid_file, "w", encoding="utf-8") as file: + csv_writer = csv.writer(file) + for member in family: + process_info = ProcessInfo( + name=member.name(), + executable=member.exe(), + cwd=member.cwd(), + pid=member.pid, + ) + csv_writer.writerow(process_info) + except psutil.NoSuchProcess: + # if process does not exist or finishes before + # function call then nothing could be saved + # just ignore this exception and exit + pass + + +def read_process_info(pid_file: Path) -> List[ProcessInfo]: + """Read information about previous system processes.""" + if not pid_file.is_file(): + return [] + + result = [] + with open(pid_file, encoding="utf-8") as file: + csv_reader = csv.reader(file) + for row in csv_reader: + name, executable, cwd, pid = row + result.append( + ProcessInfo(name=name, executable=executable, cwd=cwd, pid=int(pid)) + ) + + return result + + +def print_command_stdout(command: RunningCommand) -> None: + """Print the stdout of a command. + + The command has 2 states: running and done. + If the command is running, the output is taken by the running process. + If the command has ended its execution, the stdout is taken from stdout + property + """ + if command.is_alive(): + while True: + try: + print(command.next(), end="") + except StopIteration: + break + else: + print(command.stdout) diff --git a/src/mlia/backend/protocol.py b/src/mlia/backend/protocol.py new file mode 100644 index 0000000..ebfe69a --- /dev/null +++ b/src/mlia/backend/protocol.py @@ -0,0 +1,325 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain protocol related classes and functions.""" +from abc import ABC +from abc import abstractmethod +from contextlib import closing +from pathlib import Path +from typing import Any +from typing import cast +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import Union + +import paramiko + +from mlia.backend.common import ConfigurationException +from mlia.backend.config import LocalProtocolConfig +from mlia.backend.config import SSHConfig +from mlia.backend.proc import run_and_wait + + +# Redirect all paramiko thread exceptions to a file otherwise these will be +# printed to stderr. +paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO) + + +class SSHConnectionException(Exception): + """SSH connection exception.""" + + +class SupportsClose(ABC): + """Class indicates support of close operation.""" + + @abstractmethod + def close(self) -> None: + """Close protocol session.""" + + +class SupportsDeploy(ABC): + """Class indicates support of deploy operation.""" + + @abstractmethod + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Abstract method for deploy data.""" + + +class SupportsConnection(ABC): + """Class indicates that protocol uses network connections.""" + + @abstractmethod + def establish_connection(self) -> bool: + """Establish connection with underlying system.""" + + @abstractmethod + def connection_details(self) -> Tuple[str, int]: + """Return connection details (host, port).""" + + +class Protocol(ABC): + """Abstract class for representing the protocol.""" + + def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: + """Initialize the class using a dict.""" + self.__dict__.update(iterable, **kwargs) + self._validate() + + @abstractmethod + def _validate(self) -> None: + """Abstract method for config data validation.""" + + @abstractmethod + def run( + self, command: str, retry: bool = False + ) -> Tuple[int, bytearray, bytearray]: + """ + Abstract method for running commands. + + Returns a tuple: (exit_code, stdout, stderr) + """ + + +class CustomSFTPClient(paramiko.SFTPClient): + """Class for creating a custom sftp client.""" + + def put_dir(self, source: Path, target: str) -> None: + """Upload the source directory to the target path. + + The target directory needs to exists and the last directory of the + source will be created under the target with all its content. + """ + # Create the target directory + self._mkdir(target, ignore_existing=True) + # Create the last directory in the source on the target + self._mkdir("{}/{}".format(target, source.name), ignore_existing=True) + # Go through the whole content of source + for item in sorted(source.glob("**/*")): + relative_path = item.relative_to(source.parent) + remote_target = target / relative_path + if item.is_file(): + self.put(str(item), str(remote_target)) + else: + self._mkdir(str(remote_target), ignore_existing=True) + + def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None: + """Extend mkdir functionality. + + This version adds an option to not fail if the folder exists. + """ + try: + super().mkdir(path, mode) + except IOError as error: + if ignore_existing: + pass + else: + raise error + + +class LocalProtocol(Protocol): + """Class for local protocol.""" + + protocol: str + cwd: Path + + def run( + self, command: str, retry: bool = False + ) -> Tuple[int, bytearray, bytearray]: + """ + Run command locally. + + Returns a tuple: (exit_code, stdout, stderr) + """ + if not isinstance(self.cwd, Path) or not self.cwd.is_dir(): + raise ConfigurationException("Wrong working directory {}".format(self.cwd)) + + stdout = bytearray() + stderr = bytearray() + + return run_and_wait( + command, self.cwd, terminate_on_error=True, out=stdout, err=stderr + ) + + def _validate(self) -> None: + """Validate protocol configuration.""" + assert hasattr(self, "protocol") and self.protocol == "local" + assert hasattr(self, "cwd") + + +class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection): + """Class for SSH protocol.""" + + protocol: str + username: str + password: str + hostname: str + port: int + + def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None: + """Initialize the class using a dict.""" + super().__init__(iterable, **kwargs) + # Internal state to store if the system is connectable. It will be set + # to true at the first connection instance + self.client: Optional[paramiko.client.SSHClient] = None + self.port = int(self.port) + + def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: + """ + Run command over SSH. + + Returns a tuple: (exit_code, stdout, stderr) + """ + transport = self._get_transport() + with closing(transport.open_session()) as channel: + # Enable shell's .profile settings and execute command + channel.exec_command("bash -l -c '{}'".format(command)) + exit_status = -1 + stdout = bytearray() + stderr = bytearray() + while True: + if channel.exit_status_ready(): + exit_status = channel.recv_exit_status() + # Call it one last time to read any leftover in the channel + self._recv_stdout_err(channel, stdout, stderr) + break + self._recv_stdout_err(channel, stdout, stderr) + + return exit_status, stdout, stderr + + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Deploy src to remote dst over SSH. + + src and dst should be path to a file or directory. + """ + transport = self._get_transport() + sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport)) + + with closing(sftp): + if src.is_dir(): + sftp.put_dir(src, dst) + elif src.is_file(): + sftp.put(str(src), dst) + else: + raise Exception("Deploy error: file type not supported") + + # After the deployment of files, sync the remote filesystem to flush + # buffers to hard disk + self.run("sync") + + def close(self) -> None: + """Close protocol session.""" + if self.client is not None: + print("Try syncing remote file system...") + # Before stopping the system, we try to run sync to make sure all + # data are flushed on disk. + self.run("sync", retry=False) + self._close_client(self.client) + + def establish_connection(self) -> bool: + """Establish connection with underlying system.""" + if self.client is not None: + return True + + self.client = self._connect() + return self.client is not None + + def _get_transport(self) -> paramiko.transport.Transport: + """Get transport.""" + self.establish_connection() + + if self.client is None: + raise SSHConnectionException( + "Couldn't connect to '{}:{}'.".format(self.hostname, self.port) + ) + + transport = self.client.get_transport() + if not transport: + raise Exception("Unable to get transport") + + return transport + + def connection_details(self) -> Tuple[str, int]: + """Return connection details of underlying system.""" + return (self.hostname, self.port) + + def _connect(self) -> Optional[paramiko.client.SSHClient]: + """Try to establish connection.""" + client: Optional[paramiko.client.SSHClient] = None + try: + client = paramiko.client.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect( + self.hostname, + self.port, + self.username, + self.password, + # next parameters should be set to False to disable authentication + # using ssh keys + allow_agent=False, + look_for_keys=False, + ) + return client + except ( + # OSError raised on first attempt to connect when running inside Docker + OSError, + paramiko.ssh_exception.NoValidConnectionsError, + paramiko.ssh_exception.SSHException, + ): + # even if connection is not established socket could be still + # open, it should be closed + self._close_client(client) + + return None + + @staticmethod + def _close_client(client: Optional[paramiko.client.SSHClient]) -> None: + """Close ssh client.""" + try: + if client is not None: + client.close() + except Exception: # pylint: disable=broad-except + pass + + @classmethod + def _recv_stdout_err( + cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray + ) -> None: + """Read from channel to stdout/stder.""" + chunk_size = 512 + if channel.recv_ready(): + stdout_chunk = channel.recv(chunk_size) + stdout.extend(stdout_chunk) + if channel.recv_stderr_ready(): + stderr_chunk = channel.recv_stderr(chunk_size) + stderr.extend(stderr_chunk) + + def _validate(self) -> None: + """Check if there are all the info for establishing the connection.""" + assert hasattr(self, "protocol") and self.protocol == "ssh" + assert hasattr(self, "username") + assert hasattr(self, "password") + assert hasattr(self, "hostname") + assert hasattr(self, "port") + + +class ProtocolFactory: + """Factory class to return the appropriate Protocol class.""" + + @staticmethod + def get_protocol( + config: Optional[Union[SSHConfig, LocalProtocolConfig]], + **kwargs: Union[str, Path, None] + ) -> Union[SSHProtocol, LocalProtocol]: + """Return the right protocol instance based on the config.""" + if not config: + raise ValueError("No protocol config provided") + + protocol = config["protocol"] + if protocol == "ssh": + return SSHProtocol(config) + + if protocol == "local": + cwd = kwargs.get("cwd") + return LocalProtocol(config, cwd=cwd) + + raise ValueError("Protocol not supported: '{}'".format(protocol)) diff --git a/src/mlia/backend/source.py b/src/mlia/backend/source.py new file mode 100644 index 0000000..dcf6835 --- /dev/null +++ b/src/mlia/backend/source.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Contain source related classes and functions.""" +import os +import shutil +import tarfile +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from tarfile import TarFile +from typing import Optional +from typing import Union + +from mlia.backend.common import BACKEND_CONFIG_FILE +from mlia.backend.common import ConfigurationException +from mlia.backend.common import get_backend_config +from mlia.backend.common import is_backend_directory +from mlia.backend.common import load_config +from mlia.backend.config import BackendConfig +from mlia.backend.fs import copy_directory_content + + +class Source(ABC): + """Source class.""" + + @abstractmethod + def name(self) -> Optional[str]: + """Get source name.""" + + @abstractmethod + def config(self) -> Optional[BackendConfig]: + """Get configuration file content.""" + + @abstractmethod + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + + @abstractmethod + def create_destination(self) -> bool: + """Return True if destination folder should be created before installation.""" + + +class DirectorySource(Source): + """DirectorySource class.""" + + def __init__(self, directory_path: Path) -> None: + """Create the DirectorySource instance.""" + assert isinstance(directory_path, Path) + self.directory_path = directory_path + + def name(self) -> str: + """Return name of source.""" + return self.directory_path.name + + def config(self) -> Optional[BackendConfig]: + """Return configuration file content.""" + if not is_backend_directory(self.directory_path): + raise ConfigurationException("No configuration file found") + + config_file = get_backend_config(self.directory_path) + return load_config(config_file) + + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + if not destination.is_dir(): + raise ConfigurationException("Wrong destination {}".format(destination)) + + if not self.directory_path.is_dir(): + raise ConfigurationException( + "Directory {} does not exist".format(self.directory_path) + ) + + copy_directory_content(self.directory_path, destination) + + def create_destination(self) -> bool: + """Return True if destination folder should be created before installation.""" + return True + + +class TarArchiveSource(Source): + """TarArchiveSource class.""" + + def __init__(self, archive_path: Path) -> None: + """Create the TarArchiveSource class.""" + assert isinstance(archive_path, Path) + self.archive_path = archive_path + self._config: Optional[BackendConfig] = None + self._has_top_level_folder: Optional[bool] = None + self._name: Optional[str] = None + + def _read_archive_content(self) -> None: + """Read various information about archive.""" + # get source name from archive name (everything without extensions) + extensions = "".join(self.archive_path.suffixes) + self._name = self.archive_path.name.rstrip(extensions) + + if not self.archive_path.exists(): + return + + with self._open(self.archive_path) as archive: + try: + config_entry = archive.getmember(BACKEND_CONFIG_FILE) + self._has_top_level_folder = False + except KeyError as error_no_config: + try: + archive_entries = archive.getnames() + entries_common_prefix = os.path.commonprefix(archive_entries) + top_level_dir = entries_common_prefix.rstrip("/") + + if not top_level_dir: + raise RuntimeError( + "Archive has no top level directory" + ) from error_no_config + + config_path = "{}/{}".format(top_level_dir, BACKEND_CONFIG_FILE) + + config_entry = archive.getmember(config_path) + self._has_top_level_folder = True + self._name = top_level_dir + except (KeyError, RuntimeError) as error_no_root_dir_or_config: + raise ConfigurationException( + "No configuration file found" + ) from error_no_root_dir_or_config + + content = archive.extractfile(config_entry) + self._config = load_config(content) + + def config(self) -> Optional[BackendConfig]: + """Return configuration file content.""" + if self._config is None: + self._read_archive_content() + + return self._config + + def name(self) -> Optional[str]: + """Return name of the source.""" + if self._name is None: + self._read_archive_content() + + return self._name + + def create_destination(self) -> bool: + """Return True if destination folder must be created before installation.""" + if self._has_top_level_folder is None: + self._read_archive_content() + + return not self._has_top_level_folder + + def install_into(self, destination: Path) -> None: + """Install source into destination directory.""" + if not destination.is_dir(): + raise ConfigurationException("Wrong destination {}".format(destination)) + + with self._open(self.archive_path) as archive: + archive.extractall(destination) + + def _open(self, archive_path: Path) -> TarFile: + """Open archive file.""" + if not archive_path.is_file(): + raise ConfigurationException("File {} does not exist".format(archive_path)) + + if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"): + mode = "r:gz" + else: + raise ConfigurationException( + "Unsupported archive type {}".format(archive_path) + ) + + # The returned TarFile object can be used as a context manager (using + # 'with') by the calling instance. + return tarfile.open( # pylint: disable=consider-using-with + self.archive_path, mode=mode + ) + + +def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]: + """Return appropriate source instance based on provided source path.""" + if source_path.is_file(): + return TarArchiveSource(source_path) + + if source_path.is_dir(): + return DirectorySource(source_path) + + raise ConfigurationException("Unable to read {}".format(source_path)) + + +def create_destination_and_install(source: Source, resource_path: Path) -> None: + """Create destination directory and install source. + + This function is used for actual installation of system/backend New + directory will be created inside :resource_path: if needed If for example + archive contains top level folder then no need to create new directory + """ + destination = resource_path + create_destination = source.create_destination() + + if create_destination: + name = source.name() + if not name: + raise ConfigurationException("Unable to get source name") + + destination = resource_path / name + destination.mkdir() + try: + source.install_into(destination) + except Exception as error: + if create_destination: + shutil.rmtree(destination) + raise error diff --git a/src/mlia/backend/system.py b/src/mlia/backend/system.py new file mode 100644 index 0000000..469083e --- /dev/null +++ b/src/mlia/backend/system.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""System backend module.""" +from pathlib import Path +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from mlia.backend.common import Backend +from mlia.backend.common import ConfigurationException +from mlia.backend.common import get_backend_configs +from mlia.backend.common import get_backend_directories +from mlia.backend.common import load_config +from mlia.backend.common import remove_backend +from mlia.backend.config import SystemConfig +from mlia.backend.controller import SystemController +from mlia.backend.controller import SystemControllerSingleInstance +from mlia.backend.fs import get_backends_path +from mlia.backend.protocol import ProtocolFactory +from mlia.backend.protocol import SupportsClose +from mlia.backend.protocol import SupportsConnection +from mlia.backend.protocol import SupportsDeploy +from mlia.backend.source import create_destination_and_install +from mlia.backend.source import get_source + + +def get_available_systems_directory_names() -> List[str]: + """Return a list of directory names for all avialable systems.""" + return [entry.name for entry in get_backend_directories("systems")] + + +def get_available_systems() -> List["System"]: + """Return a list with all available systems.""" + available_systems = [] + for config_json in get_backend_configs("systems"): + config_entries = cast(List[SystemConfig], (load_config(config_json))) + for config_entry in config_entries: + config_entry["config_location"] = config_json.parent.absolute() + system = load_system(config_entry) + available_systems.append(system) + + return sorted(available_systems, key=lambda system: system.name) + + +def get_system(system_name: str) -> Optional["System"]: + """Return a system instance with the same name passed as argument.""" + available_systems = get_available_systems() + for system in available_systems: + if system_name == system.name: + return system + return None + + +def install_system(source_path: Path) -> None: + """Install new system.""" + try: + source = get_source(source_path) + config = cast(List[SystemConfig], source.config()) + systems_to_install = [load_system(entry) for entry in config] + except Exception as error: + raise ConfigurationException("Unable to read system definition") from error + + if not systems_to_install: + raise ConfigurationException("No system definition found") + + available_systems = get_available_systems() + already_installed = [s for s in systems_to_install if s in available_systems] + if already_installed: + names = [system.name for system in already_installed] + raise ConfigurationException( + "Systems [{}] are already installed".format(",".join(names)) + ) + + create_destination_and_install(source, get_backends_path("systems")) + + +def remove_system(directory_name: str) -> None: + """Remove system.""" + remove_backend(directory_name, "systems") + + +class System(Backend): + """System class.""" + + def __init__(self, config: SystemConfig) -> None: + """Construct the System class using the dictionary passed.""" + super().__init__(config) + + self._setup_data_transfer(config) + self._setup_reporting(config) + + def _setup_data_transfer(self, config: SystemConfig) -> None: + data_transfer_config = config.get("data_transfer") + protocol = ProtocolFactory().get_protocol( + data_transfer_config, cwd=self.config_location + ) + self.protocol = protocol + + def _setup_reporting(self, config: SystemConfig) -> None: + self.reporting = config.get("reporting") + + def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]: + """ + Run command on the system. + + Returns a tuple: (exit_code, stdout, stderr) + """ + return self.protocol.run(command, retry) + + def deploy(self, src: Path, dst: str, retry: bool = True) -> None: + """Deploy files to the system.""" + if isinstance(self.protocol, SupportsDeploy): + self.protocol.deploy(src, dst, retry) + + @property + def supports_deploy(self) -> bool: + """Check if protocol supports deploy operation.""" + return isinstance(self.protocol, SupportsDeploy) + + @property + def connectable(self) -> bool: + """Check if protocol supports connection.""" + return isinstance(self.protocol, SupportsConnection) + + def establish_connection(self) -> bool: + """Establish connection with the system.""" + if not isinstance(self.protocol, SupportsConnection): + raise ConfigurationException( + "System {} does not support connections".format(self.name) + ) + + return self.protocol.establish_connection() + + def connection_details(self) -> Tuple[str, int]: + """Return connection details.""" + if not isinstance(self.protocol, SupportsConnection): + raise ConfigurationException( + "System {} does not support connections".format(self.name) + ) + + return self.protocol.connection_details() + + def __eq__(self, other: object) -> bool: + """Overload operator ==.""" + if not isinstance(other, System): + return False + + return super().__eq__(other) and self.name == other.name + + def get_details(self) -> Dict[str, Any]: + """Return a dictionary with all relevant information of a System.""" + output = { + "type": "system", + "name": self.name, + "description": self.description, + "data_transfer_protocol": self.protocol.protocol, + "commands": self._get_command_details(), + "annotations": self.annotations, + } + + return output + + +class StandaloneSystem(System): + """StandaloneSystem class.""" + + +def get_controller( + single_instance: bool, pid_file_path: Optional[Path] = None +) -> SystemController: + """Get system controller.""" + if single_instance: + return SystemControllerSingleInstance(pid_file_path) + + return SystemController() + + +class ControlledSystem(System): + """ControlledSystem class.""" + + def __init__(self, config: SystemConfig): + """Construct the ControlledSystem class using the dictionary passed.""" + super().__init__(config) + self.controller: Optional[SystemController] = None + + def start( + self, + commands: List[str], + single_instance: bool = True, + pid_file_path: Optional[Path] = None, + ) -> None: + """Launch the system.""" + if ( + not isinstance(self.config_location, Path) + or not self.config_location.is_dir() + ): + raise ConfigurationException( + "System {} has wrong config location".format(self.name) + ) + + self.controller = get_controller(single_instance, pid_file_path) + self.controller.start(commands, self.config_location) + + def is_running(self) -> bool: + """Check if system is running.""" + if not self.controller: + return False + + return self.controller.is_running() + + def get_output(self) -> Tuple[str, str]: + """Return system output.""" + if not self.controller: + return "", "" + + return self.controller.get_output() + + def stop(self, wait: bool = False) -> None: + """Stop the system.""" + if not self.controller: + raise Exception("System has not been started") + + if isinstance(self.protocol, SupportsClose): + try: + self.protocol.close() + except Exception as error: # pylint: disable=broad-except + print(error) + self.controller.stop(wait) + + +def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]: + """Load system based on it's execution type.""" + data_transfer = config.get("data_transfer", {}) + protocol = data_transfer.get("protocol") + populate_shared_params(config) + + if protocol == "ssh": + return ControlledSystem(config) + + if protocol == "local": + return StandaloneSystem(config) + + raise ConfigurationException( + "Unsupported execution type for protocol {}".format(protocol) + ) + + +def populate_shared_params(config: SystemConfig) -> None: + """Populate command parameters with shared parameters.""" + user_params = config.get("user_params") + if not user_params or "shared" not in user_params: + return + + shared_user_params = user_params["shared"] + if not shared_user_params: + return + + only_aliases = all(p.get("alias") for p in shared_user_params) + if not only_aliases: + raise ConfigurationException("All shared parameters should have aliases") + + commands = config.get("commands", {}) + for cmd_name in ["build", "run"]: + command = commands.get(cmd_name) + if command is None: + commands[cmd_name] = [] + cmd_user_params = user_params.get(cmd_name) + if not cmd_user_params: + cmd_user_params = shared_user_params + else: + only_aliases = all(p.get("alias") for p in cmd_user_params) + if not only_aliases: + raise ConfigurationException( + "All parameters for command {} should have aliases".format(cmd_name) + ) + merged_by_alias = { + **{p.get("alias"): p for p in shared_user_params}, + **{p.get("alias"): p for p in cmd_user_params}, + } + cmd_user_params = list(merged_by_alias.values()) + + user_params[cmd_name] = cmd_user_params + + config["commands"] = commands + del user_params["shared"] diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py index 838b051..a673230 100644 --- a/src/mlia/cli/config.py +++ b/src/mlia/cli/config.py @@ -5,7 +5,7 @@ import logging from functools import lru_cache from typing import List -import mlia.tools.aiet_wrapper as aiet +import mlia.backend.manager as backend_manager from mlia.tools.metadata.common import DefaultInstallationManager from mlia.tools.metadata.common import InstallationManager from mlia.tools.metadata.corstone import get_corstone_installations @@ -25,12 +25,12 @@ def get_available_backends() -> List[str]: """Return list of the available backends.""" available_backends = ["Vela"] - # Add backends using AIET + # Add backends using backend manager manager = get_installation_manager() available_backends.extend( ( backend - for backend in aiet.supported_backends() + for backend in backend_manager.supported_backends() if manager.backend_installed(backend) ) ) diff --git a/src/mlia/devices/ethosu/performance.py b/src/mlia/devices/ethosu/performance.py index b0718a5..a73045a 100644 --- a/src/mlia/devices/ethosu/performance.py +++ b/src/mlia/devices/ethosu/performance.py @@ -10,7 +10,7 @@ from typing import Optional from typing import Tuple from typing import Union -import mlia.tools.aiet_wrapper as aiet +import mlia.backend.manager as backend_manager import mlia.tools.vela_wrapper as vela from mlia.core.context import Context from mlia.core.performance import PerformanceEstimator @@ -147,15 +147,15 @@ class VelaPerformanceEstimator( return memory_usage -class AIETPerformanceEstimator( +class CorstonePerformanceEstimator( PerformanceEstimator[Union[Path, ModelConfiguration], NPUCycles] ): - """AIET based performance estimator.""" + """Corstone-based performance estimator.""" def __init__( self, context: Context, device: EthosUConfiguration, backend: str ) -> None: - """Init AIET based performance estimator.""" + """Init Corstone-based performance estimator.""" self.context = context self.device = device self.backend = backend @@ -179,24 +179,24 @@ class AIETPerformanceEstimator( model_path, self.device.compiler_options, optimized_model_path ) - model_info = aiet.ModelInfo(model_path=optimized_model_path) - device_info = aiet.DeviceInfo( + model_info = backend_manager.ModelInfo(model_path=optimized_model_path) + device_info = backend_manager.DeviceInfo( device_type=self.device.target, # type: ignore mac=self.device.mac, memory_mode=self.device.compiler_options.memory_mode, # type: ignore ) - aiet_perf_metrics = aiet.estimate_performance( + corstone_perf_metrics = backend_manager.estimate_performance( model_info, device_info, self.backend ) npu_cycles = NPUCycles( - aiet_perf_metrics.npu_active_cycles, - aiet_perf_metrics.npu_idle_cycles, - aiet_perf_metrics.npu_total_cycles, - aiet_perf_metrics.npu_axi0_rd_data_beat_received, - aiet_perf_metrics.npu_axi0_wr_data_beat_written, - aiet_perf_metrics.npu_axi1_rd_data_beat_received, + corstone_perf_metrics.npu_active_cycles, + corstone_perf_metrics.npu_idle_cycles, + corstone_perf_metrics.npu_total_cycles, + corstone_perf_metrics.npu_axi0_rd_data_beat_received, + corstone_perf_metrics.npu_axi0_wr_data_beat_written, + corstone_perf_metrics.npu_axi1_rd_data_beat_received, ) logger.info("Done\n") @@ -220,10 +220,11 @@ class EthosUPerformanceEstimator( if backends is None: backends = ["Vela"] # Only Vela is always available as default for backend in backends: - if backend != "Vela" and not aiet.is_supported(backend): + if backend != "Vela" and not backend_manager.is_supported(backend): raise ValueError( f"Unsupported backend '{backend}'. " - f"Only 'Vela' and {aiet.supported_backends()} are supported." + f"Only 'Vela' and {backend_manager.supported_backends()} " + "are supported." ) self.backends = set(backends) @@ -242,11 +243,11 @@ class EthosUPerformanceEstimator( if backend == "Vela": vela_estimator = VelaPerformanceEstimator(self.context, self.device) memory_usage = vela_estimator.estimate(tflite_model) - elif backend in aiet.supported_backends(): - aiet_estimator = AIETPerformanceEstimator( + elif backend in backend_manager.supported_backends(): + corstone_estimator = CorstonePerformanceEstimator( self.context, self.device, backend ) - npu_cycles = aiet_estimator.estimate(tflite_model) + npu_cycles = corstone_estimator.estimate(tflite_model) else: logger.warning( "Backend '%s' is not supported for Ethos-U performance " diff --git a/src/mlia/resources/aiet/applications/APPLICATIONS.txt b/src/mlia/resources/aiet/applications/APPLICATIONS.txt index 09127f8..a702e19 100644 --- a/src/mlia/resources/aiet/applications/APPLICATIONS.txt +++ b/src/mlia/resources/aiet/applications/APPLICATIONS.txt @@ -1,6 +1,7 @@ SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. SPDX-License-Identifier: Apache-2.0 -This directory contains the Generic Inference Runner application packages for AIET +This directory contains the application packages for the Generic Inference +Runner. -Each package should contain its own aiet-config.json file +Each package should contain its own aiet-config.json file. diff --git a/src/mlia/resources/aiet/systems/SYSTEMS.txt b/src/mlia/resources/aiet/systems/SYSTEMS.txt index bc27e73..3861769 100644 --- a/src/mlia/resources/aiet/systems/SYSTEMS.txt +++ b/src/mlia/resources/aiet/systems/SYSTEMS.txt @@ -1,8 +1,7 @@ SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. SPDX-License-Identifier: Apache-2.0 -This directory contains the configuration files of the systems for the AIET -middleware. +This directory contains the configuration files of the system backends. Supported systems: diff --git a/src/mlia/resources/backends/applications/.gitignore b/src/mlia/resources/backends/applications/.gitignore new file mode 100644 index 0000000..0226166 --- /dev/null +++ b/src/mlia/resources/backends/applications/.gitignore @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/src/mlia/resources/backends/systems/.gitignore b/src/mlia/resources/backends/systems/.gitignore new file mode 100644 index 0000000..0226166 --- /dev/null +++ b/src/mlia/resources/backends/systems/.gitignore @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/src/mlia/tools/aiet_wrapper.py b/src/mlia/tools/aiet_wrapper.py deleted file mode 100644 index 73e82ee..0000000 --- a/src/mlia/tools/aiet_wrapper.py +++ /dev/null @@ -1,435 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module for AIET integration.""" -import logging -import re -from abc import ABC -from abc import abstractmethod -from dataclasses import dataclass -from pathlib import Path -from typing import Any -from typing import Dict -from typing import List -from typing import Literal -from typing import Optional -from typing import Tuple - -from aiet.backend.application import get_available_applications -from aiet.backend.application import install_application -from aiet.backend.system import get_available_systems -from aiet.backend.system import install_system -from mlia.utils.proc import CommandExecutor -from mlia.utils.proc import OutputConsumer -from mlia.utils.proc import RunningCommand - - -logger = logging.getLogger(__name__) - -# Mapping backend -> device_type -> system_name -_SUPPORTED_SYSTEMS = { - "Corstone-300": { - "ethos-u55": "Corstone-300: Cortex-M55+Ethos-U55", - "ethos-u65": "Corstone-300: Cortex-M55+Ethos-U65", - }, - "Corstone-310": { - "ethos-u55": "Corstone-310: Cortex-M85+Ethos-U55", - }, -} - -# Mapping system_name -> memory_mode -> application -_SYSTEM_TO_APP_MAP = { - "Corstone-300: Cortex-M55+Ethos-U55": { - "Sram": "Generic Inference Runner: Ethos-U55 SRAM", - "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", - }, - "Corstone-300: Cortex-M55+Ethos-U65": { - "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", - "Dedicated_Sram": "Generic Inference Runner: Ethos-U65 Dedicated SRAM", - }, - "Corstone-310: Cortex-M85+Ethos-U55": { - "Sram": "Generic Inference Runner: Ethos-U55 SRAM", - "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM", - }, -} - - -def get_system_name(backend: str, device_type: str) -> str: - """Get the AIET system name for the given backend and device type.""" - return _SUPPORTED_SYSTEMS[backend][device_type] - - -def is_supported(backend: str, device_type: Optional[str] = None) -> bool: - """Check if the backend (and optionally device type) is supported.""" - if device_type is None: - return backend in _SUPPORTED_SYSTEMS - - try: - get_system_name(backend, device_type) - return True - except KeyError: - return False - - -def supported_backends() -> List[str]: - """Get a list of all backends supported by the AIET wrapper.""" - return list(_SUPPORTED_SYSTEMS.keys()) - - -def get_all_system_names(backend: str) -> List[str]: - """Get all systems supported by the backend.""" - return list(_SUPPORTED_SYSTEMS.get(backend, {}).values()) - - -def get_all_application_names(backend: str) -> List[str]: - """Get all applications supported by the backend.""" - app_set = { - app - for sys in get_all_system_names(backend) - for app in _SYSTEM_TO_APP_MAP[sys].values() - } - return list(app_set) - - -@dataclass -class DeviceInfo: - """Device information.""" - - device_type: Literal["ethos-u55", "ethos-u65"] - mac: int - memory_mode: Literal["Sram", "Shared_Sram", "Dedicated_Sram"] - - -@dataclass -class ModelInfo: - """Model info.""" - - model_path: Path - - -@dataclass -class PerformanceMetrics: - """Performance metrics parsed from generic inference output.""" - - npu_active_cycles: int - npu_idle_cycles: int - npu_total_cycles: int - npu_axi0_rd_data_beat_received: int - npu_axi0_wr_data_beat_written: int - npu_axi1_rd_data_beat_received: int - - -@dataclass -class ExecutionParams: - """Application execution params.""" - - application: str - system: str - application_params: List[str] - system_params: List[str] - deploy_params: List[str] - - -class AIETLogWriter(OutputConsumer): - """Redirect AIET command output to the logger.""" - - def feed(self, line: str) -> None: - """Process line from the output.""" - logger.debug(line.strip()) - - -class GenericInferenceOutputParser(OutputConsumer): - """Generic inference app output parser.""" - - PATTERNS = { - name: tuple(re.compile(pattern, re.IGNORECASE) for pattern in patterns) - for name, patterns in ( - ( - "npu_active_cycles", - ( - r"NPU ACTIVE cycles: (?P\d+)", - r"NPU ACTIVE: (?P\d+) cycles", - ), - ), - ( - "npu_idle_cycles", - ( - r"NPU IDLE cycles: (?P\d+)", - r"NPU IDLE: (?P\d+) cycles", - ), - ), - ( - "npu_total_cycles", - ( - r"NPU TOTAL cycles: (?P\d+)", - r"NPU TOTAL: (?P\d+) cycles", - ), - ), - ( - "npu_axi0_rd_data_beat_received", - ( - r"NPU AXI0_RD_DATA_BEAT_RECEIVED beats: (?P\d+)", - r"NPU AXI0_RD_DATA_BEAT_RECEIVED: (?P\d+) beats", - ), - ), - ( - "npu_axi0_wr_data_beat_written", - ( - r"NPU AXI0_WR_DATA_BEAT_WRITTEN beats: (?P\d+)", - r"NPU AXI0_WR_DATA_BEAT_WRITTEN: (?P\d+) beats", - ), - ), - ( - "npu_axi1_rd_data_beat_received", - ( - r"NPU AXI1_RD_DATA_BEAT_RECEIVED beats: (?P\d+)", - r"NPU AXI1_RD_DATA_BEAT_RECEIVED: (?P\d+) beats", - ), - ), - ) - } - - def __init__(self) -> None: - """Init generic inference output parser instance.""" - self.result: Dict = {} - - def feed(self, line: str) -> None: - """Feed new line to the parser.""" - for name, patterns in self.PATTERNS.items(): - for pattern in patterns: - match = pattern.search(line) - - if match: - self.result[name] = int(match["value"]) - return - - def is_ready(self) -> bool: - """Return true if all expected data has been parsed.""" - return self.result.keys() == self.PATTERNS.keys() - - def missed_keys(self) -> List[str]: - """Return list of the keys that have not been found in the output.""" - return sorted(self.PATTERNS.keys() - self.result.keys()) - - -class AIETRunner: - """AIET runner.""" - - def __init__(self, executor: CommandExecutor) -> None: - """Init AIET runner instance.""" - self.executor = executor - - @staticmethod - def get_installed_systems() -> List[str]: - """Get list of the installed systems.""" - return [system.name for system in get_available_systems()] - - @staticmethod - def get_installed_applications(system: Optional[str] = None) -> List[str]: - """Get list of the installed application.""" - return [ - app.name - for app in get_available_applications() - if system is None or app.can_run_on(system) - ] - - def is_application_installed(self, application: str, system: str) -> bool: - """Return true if requested application installed.""" - return application in self.get_installed_applications(system) - - def is_system_installed(self, system: str) -> bool: - """Return true if requested system installed.""" - return system in self.get_installed_systems() - - def systems_installed(self, systems: List[str]) -> bool: - """Check if all provided systems are installed.""" - if not systems: - return False - - installed_systems = self.get_installed_systems() - return all(system in installed_systems for system in systems) - - def applications_installed(self, applications: List[str]) -> bool: - """Check if all provided applications are installed.""" - if not applications: - return False - - installed_apps = self.get_installed_applications() - return all(app in installed_apps for app in applications) - - def all_installed(self, systems: List[str], apps: List[str]) -> bool: - """Check if all provided artifacts are installed.""" - return self.systems_installed(systems) and self.applications_installed(apps) - - @staticmethod - def install_system(system_path: Path) -> None: - """Install system.""" - install_system(system_path) - - @staticmethod - def install_application(app_path: Path) -> None: - """Install application.""" - install_application(app_path) - - def run_application(self, execution_params: ExecutionParams) -> RunningCommand: - """Run requested application.""" - command = [ - "aiet", - "application", - "run", - "-n", - execution_params.application, - "-s", - execution_params.system, - *self._params("-p", execution_params.application_params), - *self._params("--system-param", execution_params.system_params), - *self._params("--deploy", execution_params.deploy_params), - ] - - return self._submit(command) - - @staticmethod - def _params(name: str, params: List[str]) -> List[str]: - return [p for item in [(name, param) for param in params] for p in item] - - def _submit(self, command: List[str]) -> RunningCommand: - """Submit command for the execution.""" - logger.debug("Submit command %s", " ".join(command)) - return self.executor.submit(command) - - -class GenericInferenceRunner(ABC): - """Abstract class for generic inference runner.""" - - def __init__(self, aiet_runner: AIETRunner): - """Init generic inference runner instance.""" - self.aiet_runner = aiet_runner - self.running_inference: Optional[RunningCommand] = None - - def run( - self, model_info: ModelInfo, output_consumers: List[OutputConsumer] - ) -> None: - """Run generic inference for the provided device/model.""" - execution_params = self.get_execution_params(model_info) - - self.running_inference = self.aiet_runner.run_application(execution_params) - self.running_inference.output_consumers = output_consumers - self.running_inference.consume_output() - - def stop(self) -> None: - """Stop running inference.""" - if self.running_inference is None: - return - - self.running_inference.stop() - - @abstractmethod - def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: - """Get execution params for the provided model.""" - - def __enter__(self) -> "GenericInferenceRunner": - """Enter context.""" - return self - - def __exit__(self, *_args: Any) -> None: - """Exit context.""" - self.stop() - - def check_system_and_application(self, system_name: str, app_name: str) -> None: - """Check if requested system and application installed.""" - if not self.aiet_runner.is_system_installed(system_name): - raise Exception(f"System {system_name} is not installed") - - if not self.aiet_runner.is_application_installed(app_name, system_name): - raise Exception( - f"Application {app_name} for the system {system_name} " - "is not installed" - ) - - -class GenericInferenceRunnerEthosU(GenericInferenceRunner): - """Generic inference runner on U55/65.""" - - def __init__( - self, aiet_runner: AIETRunner, device_info: DeviceInfo, backend: str - ) -> None: - """Init generic inference runner instance.""" - super().__init__(aiet_runner) - - system_name, app_name = self.resolve_system_and_app(device_info, backend) - self.system_name = system_name - self.app_name = app_name - self.device_info = device_info - - @staticmethod - def resolve_system_and_app( - device_info: DeviceInfo, backend: str - ) -> Tuple[str, str]: - """Find appropriate system and application for the provided device/backend.""" - try: - system_name = get_system_name(backend, device_info.device_type) - except KeyError as ex: - raise RuntimeError( - f"Unsupported device {device_info.device_type} " - f"for backend {backend}" - ) from ex - - if system_name not in _SYSTEM_TO_APP_MAP: - raise RuntimeError(f"System {system_name} is not installed") - - try: - app_name = _SYSTEM_TO_APP_MAP[system_name][device_info.memory_mode] - except KeyError as err: - raise RuntimeError( - f"Unsupported memory mode {device_info.memory_mode}" - ) from err - - return system_name, app_name - - def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams: - """Get execution params for Ethos-U55/65.""" - self.check_system_and_application(self.system_name, self.app_name) - - system_params = [ - f"mac={self.device_info.mac}", - f"input_file={model_info.model_path.absolute()}", - ] - - return ExecutionParams( - self.app_name, - self.system_name, - [], - system_params, - [], - ) - - -def get_generic_runner(device_info: DeviceInfo, backend: str) -> GenericInferenceRunner: - """Get generic runner for provided device and backend.""" - aiet_runner = get_aiet_runner() - return GenericInferenceRunnerEthosU(aiet_runner, device_info, backend) - - -def estimate_performance( - model_info: ModelInfo, device_info: DeviceInfo, backend: str -) -> PerformanceMetrics: - """Get performance estimations.""" - with get_generic_runner(device_info, backend) as generic_runner: - output_parser = GenericInferenceOutputParser() - output_consumers = [output_parser, AIETLogWriter()] - - generic_runner.run(model_info, output_consumers) - - if not output_parser.is_ready(): - missed_data = ",".join(output_parser.missed_keys()) - logger.debug( - "Unable to get performance metrics, missed data %s", missed_data - ) - raise Exception("Unable to get performance metrics, insufficient data") - - return PerformanceMetrics(**output_parser.result) - - -def get_aiet_runner() -> AIETRunner: - """Return AIET runner.""" - executor = CommandExecutor() - return AIETRunner(executor) diff --git a/src/mlia/tools/metadata/corstone.py b/src/mlia/tools/metadata/corstone.py index 7a9d113..a92f81c 100644 --- a/src/mlia/tools/metadata/corstone.py +++ b/src/mlia/tools/metadata/corstone.py @@ -12,7 +12,8 @@ from typing import Iterable from typing import List from typing import Optional -import mlia.tools.aiet_wrapper as aiet +import mlia.backend.manager as backend_manager +from mlia.backend.fs import get_backend_resources from mlia.tools.metadata.common import DownloadAndInstall from mlia.tools.metadata.common import Installation from mlia.tools.metadata.common import InstallationType @@ -41,8 +42,8 @@ PathChecker = Callable[[Path], Optional[BackendInfo]] BackendInstaller = Callable[[bool, Path], Path] -class AIETMetadata: - """AIET installation metadata.""" +class BackendMetadata: + """Backend installation metadata.""" def __init__( self, @@ -55,7 +56,7 @@ class AIETMetadata: supported_platforms: Optional[List[str]] = None, ) -> None: """ - Initialize AIETMetaData. + Initialize BackendMetadata. Members expected_systems and expected_apps are filled automatically. """ @@ -67,15 +68,15 @@ class AIETMetadata: self.download_artifact = download_artifact self.supported_platforms = supported_platforms - self.expected_systems = aiet.get_all_system_names(name) - self.expected_apps = aiet.get_all_application_names(name) + self.expected_systems = backend_manager.get_all_system_names(name) + self.expected_apps = backend_manager.get_all_application_names(name) @property def expected_resources(self) -> Iterable[Path]: """Return list of expected resources.""" resources = [self.system_config, *self.apps_resources] - return (get_mlia_resources() / resource for resource in resources) + return (get_backend_resources() / resource for resource in resources) @property def supported_platform(self) -> bool: @@ -86,49 +87,49 @@ class AIETMetadata: return platform.system() in self.supported_platforms -class AIETBasedInstallation(Installation): - """Backend installation based on AIET functionality.""" +class BackendInstallation(Installation): + """Backend installation.""" def __init__( self, - aiet_runner: aiet.AIETRunner, - metadata: AIETMetadata, + backend_runner: backend_manager.BackendRunner, + metadata: BackendMetadata, path_checker: PathChecker, backend_installer: Optional[BackendInstaller], ) -> None: - """Init the tool installation.""" - self.aiet_runner = aiet_runner + """Init the backend installation.""" + self.backend_runner = backend_runner self.metadata = metadata self.path_checker = path_checker self.backend_installer = backend_installer @property def name(self) -> str: - """Return name of the tool.""" + """Return name of the backend.""" return self.metadata.name @property def description(self) -> str: - """Return description of the tool.""" + """Return description of the backend.""" return self.metadata.description @property def already_installed(self) -> bool: - """Return true if tool already installed.""" - return self.aiet_runner.all_installed( + """Return true if backend already installed.""" + return self.backend_runner.all_installed( self.metadata.expected_systems, self.metadata.expected_apps ) @property def could_be_installed(self) -> bool: - """Return true if tool could be installed.""" + """Return true if backend could be installed.""" if not self.metadata.supported_platform: return False return all_paths_valid(self.metadata.expected_resources) def supports(self, install_type: InstallationType) -> bool: - """Return true if tools supported type of the installation.""" + """Return true if backends supported type of the installation.""" if isinstance(install_type, DownloadAndInstall): return self.metadata.download_artifact is not None @@ -138,7 +139,7 @@ class AIETBasedInstallation(Installation): return False # type: ignore def install(self, install_type: InstallationType) -> None: - """Install the tool.""" + """Install the backend.""" if isinstance(install_type, DownloadAndInstall): download_artifact = self.metadata.download_artifact assert download_artifact is not None, "No artifact provided" @@ -153,7 +154,7 @@ class AIETBasedInstallation(Installation): raise Exception(f"Unable to install {install_type}") def install_from(self, backend_info: BackendInfo) -> None: - """Install tool from the directory.""" + """Install backend from the directory.""" mlia_resources = get_mlia_resources() with temp_directory() as tmpdir: @@ -169,15 +170,15 @@ class AIETBasedInstallation(Installation): copy_all(*resources_to_copy, dest=fvp_dist_dir) - self.aiet_runner.install_system(fvp_dist_dir) + self.backend_runner.install_system(fvp_dist_dir) for app in self.metadata.apps_resources: - self.aiet_runner.install_application(mlia_resources / app) + self.backend_runner.install_application(mlia_resources / app) def download_and_install( self, download_artifact: DownloadArtifact, eula_agrement: bool ) -> None: - """Download and install the tool.""" + """Download and install the backend.""" with temp_directory() as tmpdir: try: downloaded_to = download_artifact.download_to(tmpdir) @@ -307,10 +308,10 @@ class Corstone300Installer: def get_corstone_300_installation() -> Installation: """Get Corstone-300 installation.""" - corstone_300 = AIETBasedInstallation( - aiet_runner=aiet.get_aiet_runner(), + corstone_300 = BackendInstallation( + backend_runner=backend_manager.BackendRunner(), # pylint: disable=line-too-long - metadata=AIETMetadata( + metadata=BackendMetadata( name="Corstone-300", description="Corstone-300 FVP", system_config="aiet/systems/corstone-300/aiet-config.json", @@ -356,10 +357,10 @@ def get_corstone_300_installation() -> Installation: def get_corstone_310_installation() -> Installation: """Get Corstone-310 installation.""" - corstone_310 = AIETBasedInstallation( - aiet_runner=aiet.get_aiet_runner(), + corstone_310 = BackendInstallation( + backend_runner=backend_manager.BackendRunner(), # pylint: disable=line-too-long - metadata=AIETMetadata( + metadata=BackendMetadata( name="Corstone-310", description="Corstone-310 FVP", system_config="aiet/systems/corstone-310/aiet-config.json", diff --git a/src/mlia/utils/proc.py b/src/mlia/utils/proc.py index 39aca43..18a4305 100644 --- a/src/mlia/utils/proc.py +++ b/src/mlia/utils/proc.py @@ -8,7 +8,6 @@ import time from abc import ABC from abc import abstractmethod from contextlib import contextmanager -from contextlib import suppress from pathlib import Path from typing import Any from typing import Generator @@ -23,7 +22,7 @@ class OutputConsumer(ABC): @abstractmethod def feed(self, line: str) -> None: - """Feed new line to the consumerr.""" + """Feed new line to the consumer.""" class RunningCommand: @@ -32,7 +31,7 @@ class RunningCommand: def __init__(self, process: subprocess.Popen) -> None: """Init running command instance.""" self.process = process - self._output_consumers: Optional[List[OutputConsumer]] = None + self.output_consumers: List[OutputConsumer] = [] def is_alive(self) -> bool: """Return true if process is still alive.""" @@ -57,25 +56,14 @@ class RunningCommand: """Send signal to the process.""" self.process.send_signal(signal_num) - @property - def output_consumers(self) -> Optional[List[OutputConsumer]]: - """Property output_consumers.""" - return self._output_consumers - - @output_consumers.setter - def output_consumers(self, output_consumers: List[OutputConsumer]) -> None: - """Set output consumers.""" - self._output_consumers = output_consumers - def consume_output(self) -> None: """Pass program's output to the consumers.""" - if self.process is None or self.output_consumers is None: + if self.process is None or not self.output_consumers: return for line in self.stdout(): for consumer in self.output_consumers: - with suppress(): - consumer.feed(line) + consumer.feed(line) def stop( self, wait: bool = True, num_of_attempts: int = 5, interval: float = 0.5 diff --git a/tests/aiet/__init__.py b/tests/aiet/__init__.py deleted file mode 100644 index 873a7df..0000000 --- a/tests/aiet/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""AIET tests module.""" diff --git a/tests/aiet/conftest.py b/tests/aiet/conftest.py deleted file mode 100644 index cab3dc2..0000000 --- a/tests/aiet/conftest.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=redefined-outer-name -"""conftest for pytest.""" -import shutil -import tarfile -from pathlib import Path -from typing import Any - -import pytest -from click.testing import CliRunner - -from aiet.backend.common import get_backend_configs - - -@pytest.fixture(scope="session") -def test_systems_path(test_resources_path: Path) -> Path: - """Return test systems path in a pytest fixture.""" - return test_resources_path / "systems" - - -@pytest.fixture(scope="session") -def test_applications_path(test_resources_path: Path) -> Path: - """Return test applications path in a pytest fixture.""" - return test_resources_path / "applications" - - -@pytest.fixture(scope="session") -def test_tools_path(test_resources_path: Path) -> Path: - """Return test tools path in a pytest fixture.""" - return test_resources_path / "tools" - - -@pytest.fixture(scope="session") -def test_resources_path() -> Path: - """Return test resources path in a pytest fixture.""" - current_path = Path(__file__).parent.absolute() - return current_path / "test_resources" - - -@pytest.fixture(scope="session") -def non_optimised_input_model_file(test_tflite_model: Path) -> Path: - """Provide the path to a quantized dummy model file.""" - return test_tflite_model - - -@pytest.fixture(scope="session") -def optimised_input_model_file(test_tflite_vela_model: Path) -> Path: - """Provide path to Vela-optimised dummy model file.""" - return test_tflite_vela_model - - -@pytest.fixture(scope="session") -def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path: - """Provide the path to an invalid dummy model file.""" - return test_tflite_invalid_model - - -@pytest.fixture(autouse=True) -def test_resources(monkeypatch: pytest.MonkeyPatch, test_resources_path: Path) -> Any: - """Force using test resources as middleware's repository.""" - - def get_test_resources() -> Path: - """Return path to the test resources.""" - return test_resources_path - - monkeypatch.setattr("aiet.utils.fs.get_aiet_resources", get_test_resources) - yield - - -@pytest.fixture(scope="session", autouse=True) -def add_tools(test_resources_path: Path) -> Any: - """Symlink the tools from the original resources path to the test resources path.""" - # tool_dirs = get_available_tool_directory_names() - tool_dirs = [cfg.parent for cfg in get_backend_configs("tools")] - - links = { - src_dir: (test_resources_path / "tools" / src_dir.name) for src_dir in tool_dirs - } - for src_dir, dst_dir in links.items(): - if not dst_dir.exists(): - dst_dir.symlink_to(src_dir, target_is_directory=True) - yield - # Remove symlinks - for dst_dir in links.values(): - if dst_dir.is_symlink(): - dst_dir.unlink() - - -def create_archive( - archive_name: str, source: Path, destination: Path, with_root_folder: bool = False -) -> None: - """Create archive from directory source.""" - with tarfile.open(destination / archive_name, mode="w:gz") as tar: - for item in source.iterdir(): - item_name = item.name - if with_root_folder: - item_name = f"{source.name}/{item_name}" - tar.add(item, item_name) - - -def process_directory(source: Path, destination: Path) -> None: - """Process resource directory.""" - destination.mkdir() - - for item in source.iterdir(): - if item.is_dir(): - create_archive(f"{item.name}.tar.gz", item, destination) - create_archive(f"{item.name}_dir.tar.gz", item, destination, True) - - -@pytest.fixture(scope="session", autouse=True) -def add_archives( - test_resources_path: Path, tmp_path_factory: pytest.TempPathFactory -) -> Any: - """Generate archives of the test resources.""" - tmp_path = tmp_path_factory.mktemp("archives") - - archives_path = tmp_path / "archives" - archives_path.mkdir() - - if (archives_path_link := test_resources_path / "archives").is_symlink(): - archives_path.unlink() - - archives_path_link.symlink_to(archives_path, target_is_directory=True) - - for item in ["applications", "systems"]: - process_directory(test_resources_path / item, archives_path / item) - - yield - - archives_path_link.unlink() - shutil.rmtree(tmp_path) - - -@pytest.fixture(scope="module") -def cli_runner() -> CliRunner: - """Return CliRunner instance in a pytest fixture.""" - return CliRunner() diff --git a/tests/aiet/test_backend_application.py b/tests/aiet/test_backend_application.py deleted file mode 100644 index abfab00..0000000 --- a/tests/aiet/test_backend_application.py +++ /dev/null @@ -1,452 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-self-use -"""Tests for the application backend.""" -from collections import Counter -from contextlib import ExitStack as does_not_raise -from pathlib import Path -from typing import Any -from typing import List -from unittest.mock import MagicMock - -import pytest - -from aiet.backend.application import Application -from aiet.backend.application import get_application -from aiet.backend.application import get_available_application_directory_names -from aiet.backend.application import get_available_applications -from aiet.backend.application import get_unique_application_names -from aiet.backend.application import install_application -from aiet.backend.application import load_applications -from aiet.backend.application import remove_application -from aiet.backend.common import Command -from aiet.backend.common import DataPaths -from aiet.backend.common import Param -from aiet.backend.common import UserParamConfig -from aiet.backend.config import ApplicationConfig -from aiet.backend.config import ExtendedApplicationConfig -from aiet.backend.config import NamedExecutionConfig - - -def test_get_available_application_directory_names() -> None: - """Test get_available_applicationss mocking get_resources.""" - directory_names = get_available_application_directory_names() - assert Counter(directory_names) == Counter( - ["application1", "application2", "application4", "application5"] - ) - - -def test_get_available_applications() -> None: - """Test get_available_applicationss mocking get_resources.""" - available_applications = get_available_applications() - - assert all(isinstance(s, Application) for s in available_applications) - assert all(s != 42 for s in available_applications) - assert len(available_applications) == 9 - # application_5 has multiply items with multiply supported systems - assert [str(s) for s in available_applications] == [ - "application_1", - "application_2", - "application_4", - "application_5", - "application_5", - "application_5A", - "application_5A", - "application_5B", - "application_5B", - ] - - -def test_get_unique_application_names() -> None: - """Test get_unique_application_names.""" - unique_names = get_unique_application_names() - - assert all(isinstance(s, str) for s in unique_names) - assert all(s for s in unique_names) - assert sorted(unique_names) == [ - "application_1", - "application_2", - "application_4", - "application_5", - "application_5A", - "application_5B", - ] - - -def test_get_application() -> None: - """Test get_application mocking get_resoures.""" - application = get_application("application_1") - if len(application) != 1: - pytest.fail("Unable to get application") - assert application[0].name == "application_1" - - application = get_application("unknown application") - assert len(application) == 0 - - -@pytest.mark.parametrize( - "source, call_count, expected_exception", - ( - ( - "archives/applications/application1.tar.gz", - 0, - pytest.raises( - Exception, match=r"Applications \[application_1\] are already installed" - ), - ), - ( - "various/applications/application_with_empty_config", - 0, - pytest.raises(Exception, match="No application definition found"), - ), - ( - "various/applications/application_with_wrong_config1", - 0, - pytest.raises(Exception, match="Unable to read application definition"), - ), - ( - "various/applications/application_with_wrong_config2", - 0, - pytest.raises(Exception, match="Unable to read application definition"), - ), - ( - "various/applications/application_with_wrong_config3", - 0, - pytest.raises(Exception, match="Unable to read application definition"), - ), - ("various/applications/application_with_valid_config", 1, does_not_raise()), - ( - "archives/applications/application3.tar.gz", - 0, - pytest.raises(Exception, match="Unable to read application definition"), - ), - ( - "applications/application1", - 0, - pytest.raises( - Exception, match=r"Applications \[application_1\] are already installed" - ), - ), - ( - "applications/application3", - 0, - pytest.raises(Exception, match="Unable to read application definition"), - ), - ), -) -def test_install_application( - monkeypatch: Any, - test_resources_path: Path, - source: str, - call_count: int, - expected_exception: Any, -) -> None: - """Test application install from archive.""" - mock_create_destination_and_install = MagicMock() - monkeypatch.setattr( - "aiet.backend.application.create_destination_and_install", - mock_create_destination_and_install, - ) - - with expected_exception: - install_application(test_resources_path / source) - assert mock_create_destination_and_install.call_count == call_count - - -def test_remove_application(monkeypatch: Any) -> None: - """Test application removal.""" - mock_remove_backend = MagicMock() - monkeypatch.setattr("aiet.backend.application.remove_backend", mock_remove_backend) - - remove_application("some_application_directory") - mock_remove_backend.assert_called_once() - - -def test_application_config_without_commands() -> None: - """Test application config without commands.""" - config = ApplicationConfig(name="application") - application = Application(config) - # pylint: disable=use-implicit-booleaness-not-comparison - assert application.commands == {} - - -class TestApplication: - """Test for application class methods.""" - - def test___eq__(self) -> None: - """Test overloaded __eq__ method.""" - config = ApplicationConfig( - # Application - supported_systems=["system1", "system2"], - build_dir="build_dir", - # inherited from Backend - name="name", - description="description", - commands={}, - ) - application1 = Application(config) - application2 = Application(config) # Identical - assert application1 == application2 - - application3 = Application(config) # changed - # Change one single attribute so not equal, but same Type - setattr(application3, "supported_systems", ["somewhere/else"]) - assert application1 != application3 - - # different Type - application4 = "Not the Application you are looking for" - assert application1 != application4 - - application5 = Application(config) - # supported systems could be in any order - setattr(application5, "supported_systems", ["system2", "system1"]) - assert application1 == application5 - - def test_can_run_on(self) -> None: - """Test Application can run on.""" - config = ApplicationConfig(name="application", supported_systems=["System-A"]) - - application = Application(config) - assert application.can_run_on("System-A") - assert not application.can_run_on("System-B") - - applications = get_application("application_1", "System 1") - assert len(applications) == 1 - assert applications[0].can_run_on("System 1") - - def test_get_deploy_data(self, tmp_path: Path) -> None: - """Test Application can run on.""" - src, dest = "src", "dest" - config = ApplicationConfig( - name="application", deploy_data=[(src, dest)], config_location=tmp_path - ) - src_path = tmp_path / src - src_path.mkdir() - application = Application(config) - assert application.get_deploy_data() == [DataPaths(src_path, dest)] - - def test_get_deploy_data_no_config_location(self) -> None: - """Test that getting deploy data fails if no config location provided.""" - with pytest.raises( - Exception, match="Unable to get application .* config location" - ): - Application(ApplicationConfig(name="application")).get_deploy_data() - - def test_unable_to_create_application_without_name(self) -> None: - """Test that it is not possible to create application without name.""" - with pytest.raises(Exception, match="Name is empty"): - Application(ApplicationConfig()) - - def test_application_config_without_commands(self) -> None: - """Test application config without commands.""" - config = ApplicationConfig(name="application") - application = Application(config) - # pylint: disable=use-implicit-booleaness-not-comparison - assert application.commands == {} - - @pytest.mark.parametrize( - "config, expected_params", - ( - ( - ApplicationConfig( - name="application", - commands={"command": ["cmd {user_params:0} {user_params:1}"]}, - user_params={ - "command": [ - UserParamConfig( - name="--param1", description="param1", alias="param1" - ), - UserParamConfig( - name="--param2", description="param2", alias="param2" - ), - ] - }, - ), - [Param("--param1", "param1"), Param("--param2", "param2")], - ), - ( - ApplicationConfig( - name="application", - commands={"command": ["cmd {user_params:param1} {user_params:1}"]}, - user_params={ - "command": [ - UserParamConfig( - name="--param1", description="param1", alias="param1" - ), - UserParamConfig( - name="--param2", description="param2", alias="param2" - ), - ] - }, - ), - [Param("--param1", "param1"), Param("--param2", "param2")], - ), - ( - ApplicationConfig( - name="application", - commands={"command": ["cmd {user_params:param1}"]}, - user_params={ - "command": [ - UserParamConfig( - name="--param1", description="param1", alias="param1" - ), - UserParamConfig( - name="--param2", description="param2", alias="param2" - ), - ] - }, - ), - [Param("--param1", "param1")], - ), - ), - ) - def test_remove_unused_params( - self, config: ApplicationConfig, expected_params: List[Param] - ) -> None: - """Test mod remove_unused_parameter.""" - application = Application(config) - application.remove_unused_params() - assert application.commands["command"].params == expected_params - - -@pytest.mark.parametrize( - "config, expected_error", - ( - ( - ExtendedApplicationConfig(name="application"), - pytest.raises(Exception, match="No supported systems definition provided"), - ), - ( - ExtendedApplicationConfig( - name="application", supported_systems=[NamedExecutionConfig(name="")] - ), - pytest.raises( - Exception, - match="Unable to read supported system definition, name is missed", - ), - ), - ( - ExtendedApplicationConfig( - name="application", - supported_systems=[ - NamedExecutionConfig( - name="system", - commands={"command": ["cmd"]}, - user_params={"command": [UserParamConfig(name="param")]}, - ) - ], - commands={"command": ["cmd {user_params:0}"]}, - user_params={"command": [UserParamConfig(name="param")]}, - ), - pytest.raises( - Exception, match="Default parameters for command .* should have aliases" - ), - ), - ( - ExtendedApplicationConfig( - name="application", - supported_systems=[ - NamedExecutionConfig( - name="system", - commands={"command": ["cmd"]}, - user_params={"command": [UserParamConfig(name="param")]}, - ) - ], - commands={"command": ["cmd {user_params:0}"]}, - user_params={"command": [UserParamConfig(name="param", alias="param")]}, - ), - pytest.raises( - Exception, match="system parameters for command .* should have aliases" - ), - ), - ), -) -def test_load_application_exceptional_cases( - config: ExtendedApplicationConfig, expected_error: Any -) -> None: - """Test exceptional cases for application load function.""" - with expected_error: - load_applications(config) - - -def test_load_application() -> None: - """Test application load function. - - The main purpose of this test is to test configuration for application - for different systems. All configuration should be correctly - overridden if needed. - """ - application_5 = get_application("application_5") - assert len(application_5) == 2 - - default_commands = { - "build": Command(["default build command"]), - "run": Command(["default run command"]), - } - default_variables = {"var1": "value1", "var2": "value2"} - - application_5_0 = application_5[0] - assert application_5_0.build_dir == "default_build_dir" - assert application_5_0.supported_systems == ["System 1"] - assert application_5_0.commands == default_commands - assert application_5_0.variables == default_variables - assert application_5_0.lock is False - - application_5_1 = application_5[1] - assert application_5_1.build_dir == application_5_0.build_dir - assert application_5_1.supported_systems == ["System 2"] - assert application_5_1.commands == application_5_1.commands - assert application_5_1.variables == default_variables - - application_5a = get_application("application_5A") - assert len(application_5a) == 2 - - application_5a_0 = application_5a[0] - assert application_5a_0.supported_systems == ["System 1"] - assert application_5a_0.build_dir == "build_5A" - assert application_5a_0.commands == default_commands - assert application_5a_0.variables == {"var1": "new value1", "var2": "value2"} - assert application_5a_0.lock is False - - application_5a_1 = application_5a[1] - assert application_5a_1.supported_systems == ["System 2"] - assert application_5a_1.build_dir == "build" - assert application_5a_1.commands == { - "build": Command(["default build command"]), - "run": Command(["run command on system 2"]), - } - assert application_5a_1.variables == {"var1": "value1", "var2": "new value2"} - assert application_5a_1.lock is True - - application_5b = get_application("application_5B") - assert len(application_5b) == 2 - - application_5b_0 = application_5b[0] - assert application_5b_0.build_dir == "build_5B" - assert application_5b_0.supported_systems == ["System 1"] - assert application_5b_0.commands == { - "build": Command(["default build command with value for var1 System1"], []), - "run": Command(["default run command with value for var2 System1"]), - } - assert "non_used_command" not in application_5b_0.commands - - application_5b_1 = application_5b[1] - assert application_5b_1.build_dir == "build" - assert application_5b_1.supported_systems == ["System 2"] - assert application_5b_1.commands == { - "build": Command( - [ - "build command on system 2 with value" - " for var1 System2 {user_params:param1}" - ], - [ - Param( - "--param", - "Sample command param", - ["value1", "value2", "value3"], - "value1", - ) - ], - ), - "run": Command(["run command on system 2"], []), - } diff --git a/tests/aiet/test_backend_common.py b/tests/aiet/test_backend_common.py deleted file mode 100644 index 12c30ec..0000000 --- a/tests/aiet/test_backend_common.py +++ /dev/null @@ -1,486 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-self-use,protected-access -"""Tests for the common backend module.""" -from contextlib import ExitStack as does_not_raise -from pathlib import Path -from typing import Any -from typing import cast -from typing import Dict -from typing import IO -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union -from unittest.mock import MagicMock - -import pytest - -from aiet.backend.application import Application -from aiet.backend.common import Backend -from aiet.backend.common import BaseBackendConfig -from aiet.backend.common import Command -from aiet.backend.common import ConfigurationException -from aiet.backend.common import load_config -from aiet.backend.common import Param -from aiet.backend.common import parse_raw_parameter -from aiet.backend.common import remove_backend -from aiet.backend.config import ApplicationConfig -from aiet.backend.config import UserParamConfig -from aiet.backend.execution import ExecutionContext -from aiet.backend.execution import ParamResolver -from aiet.backend.system import System - - -@pytest.mark.parametrize( - "directory_name, expected_exception", - ( - ("some_dir", does_not_raise()), - (None, pytest.raises(Exception, match="No directory name provided")), - ), -) -def test_remove_backend( - monkeypatch: Any, directory_name: str, expected_exception: Any -) -> None: - """Test remove_backend function.""" - mock_remove_resource = MagicMock() - monkeypatch.setattr("aiet.backend.common.remove_resource", mock_remove_resource) - - with expected_exception: - remove_backend(directory_name, "applications") - - -@pytest.mark.parametrize( - "filename, expected_exception", - ( - ("application_config.json", does_not_raise()), - (None, pytest.raises(Exception, match="Unable to read config")), - ), -) -def test_load_config( - filename: str, expected_exception: Any, test_resources_path: Path, monkeypatch: Any -) -> None: - """Test load_config.""" - with expected_exception: - configs: List[Optional[Union[Path, IO[bytes]]]] = ( - [None] - if not filename - else [ - # Ignore pylint warning as 'with' can't be used inside of a - # generator expression. - # pylint: disable=consider-using-with - open(test_resources_path / filename, "rb"), - test_resources_path / filename, - ] - ) - for config in configs: - json_mock = MagicMock() - monkeypatch.setattr("aiet.backend.common.json.load", json_mock) - load_config(config) - json_mock.assert_called_once() - - -class TestBackend: - """Test Backend class.""" - - def test___repr__(self) -> None: - """Test the representation of Backend instance.""" - backend = Backend( - BaseBackendConfig(name="Testing name", description="Testing description") - ) - assert str(backend) == "Testing name" - - def test__eq__(self) -> None: - """Test equality method with different cases.""" - backend1 = Backend(BaseBackendConfig(name="name", description="description")) - backend1.commands = {"command": Command(["command"])} - - backend2 = Backend(BaseBackendConfig(name="name", description="description")) - backend2.commands = {"command": Command(["command"])} - - backend3 = Backend( - BaseBackendConfig( - name="Ben", description="This is not the Backend you are looking for" - ) - ) - backend3.commands = {"wave": Command(["wave hand"])} - - backend4 = "Foo" # checking not isinstance(backend4, Backend) - - assert backend1 == backend2 - assert backend1 != backend3 - assert backend1 != backend4 - - @pytest.mark.parametrize( - "parameter, valid", - [ - ("--choice-param dummy_value_1", True), - ("--choice-param wrong_value", False), - ("--open-param something", True), - ("--wrong-param value", False), - ], - ) - def test_validate_parameter( - self, parameter: str, valid: bool, test_resources_path: Path - ) -> None: - """Test validate_parameter.""" - config = cast( - List[ApplicationConfig], - load_config(test_resources_path / "hello_world.json"), - ) - # The application configuration is a list of configurations so we need - # only the first one - # Exercise the validate_parameter test using the Application classe which - # inherits from Backend. - application = Application(config[0]) - assert application.validate_parameter("run", parameter) == valid - - def test_validate_parameter_with_invalid_command( - self, test_resources_path: Path - ) -> None: - """Test validate_parameter with an invalid command_name.""" - config = cast( - List[ApplicationConfig], - load_config(test_resources_path / "hello_world.json"), - ) - application = Application(config[0]) - with pytest.raises(AttributeError) as err: - # command foo does not exist, so raise an error - application.validate_parameter("foo", "bar") - assert "Unknown command: 'foo'" in str(err.value) - - def test_build_command(self, monkeypatch: Any) -> None: - """Test command building.""" - config = { - "name": "test", - "commands": { - "build": ["build {user_params:0} {user_params:1}"], - "run": ["run {user_params:0}"], - "post_run": ["post_run {application_params:0} on {system_params:0}"], - "some_command": ["Command with {variables:var_A}"], - "empty_command": [""], - }, - "user_params": { - "build": [ - { - "name": "choice_param_0=", - "values": [1, 2, 3], - "default_value": 1, - }, - {"name": "choice_param_1", "values": [3, 4, 5], "default_value": 3}, - {"name": "choice_param_3", "values": [6, 7, 8]}, - ], - "run": [{"name": "flag_param_0"}], - }, - "variables": {"var_A": "value for variable A"}, - } - - monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock()) - application, system = Application(config), System(config) # type: ignore - context = ExecutionContext( - app=application, - app_params=[], - system=system, - system_params=[], - custom_deploy_data=[], - ) - - param_resolver = ParamResolver(context) - - cmd = application.build_command( - "build", ["choice_param_0=2", "choice_param_1=4"], param_resolver - ) - assert cmd == ["build choice_param_0=2 choice_param_1 4"] - - cmd = application.build_command("build", ["choice_param_0=2"], param_resolver) - assert cmd == ["build choice_param_0=2 choice_param_1 3"] - - cmd = application.build_command( - "build", ["choice_param_0=2", "choice_param_3=7"], param_resolver - ) - assert cmd == ["build choice_param_0=2 choice_param_1 3"] - - with pytest.raises( - ConfigurationException, match="Command 'foo' could not be found." - ): - application.build_command("foo", [""], param_resolver) - - cmd = application.build_command("some_command", [], param_resolver) - assert cmd == ["Command with value for variable A"] - - cmd = application.build_command("empty_command", [], param_resolver) - assert cmd == [""] - - @pytest.mark.parametrize("class_", [Application, System]) - def test_build_command_unknown_variable(self, class_: type) -> None: - """Test that unable to construct backend with unknown variable.""" - with pytest.raises(Exception, match="Unknown variable var1"): - config = {"name": "test", "commands": {"run": ["run {variables:var1}"]}} - class_(config) - - @pytest.mark.parametrize( - "class_, config, expected_output", - [ - ( - Application, - { - "name": "test", - "commands": { - "build": ["build {user_params:0} {user_params:1}"], - "run": ["run {user_params:0}"], - }, - "user_params": { - "build": [ - { - "name": "choice_param_0=", - "values": ["a", "b", "c"], - "default_value": "a", - "alias": "param_1", - }, - { - "name": "choice_param_1", - "values": ["a", "b", "c"], - "default_value": "a", - "alias": "param_2", - }, - {"name": "choice_param_3", "values": ["a", "b", "c"]}, - ], - "run": [{"name": "flag_param_0"}], - }, - }, - [ - ( - "b", - Param( - name="choice_param_0=", - description="", - values=["a", "b", "c"], - default_value="a", - alias="param_1", - ), - ), - ( - "a", - Param( - name="choice_param_1", - description="", - values=["a", "b", "c"], - default_value="a", - alias="param_2", - ), - ), - ( - "c", - Param( - name="choice_param_3", - description="", - values=["a", "b", "c"], - ), - ), - ], - ), - (System, {"name": "test"}, []), - ], - ) - def test_resolved_parameters( - self, - monkeypatch: Any, - class_: type, - config: Dict, - expected_output: List[Tuple[Optional[str], Param]], - ) -> None: - """Test command building.""" - monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock()) - backend = class_(config) - - params = backend.resolved_parameters( - "build", ["choice_param_0=b", "choice_param_3=c"] - ) - assert params == expected_output - - @pytest.mark.parametrize( - ["param_name", "user_param", "expected_value"], - [ - ( - "test_name", - "test_name=1234", - "1234", - ), # optional parameter using '=' - ( - "test_name", - "test_name 1234", - "1234", - ), # optional parameter using ' ' - ("test_name", "test_name", None), # flag - (None, "test_name=1234", "1234"), # positional parameter - ], - ) - def test_resolved_user_parameters( - self, param_name: str, user_param: str, expected_value: str - ) -> None: - """Test different variants to provide user parameters.""" - # A dummy config providing one backend config - config = { - "name": "test_backend", - "commands": { - "test": ["user_param:test_param"], - }, - "user_params": { - "test": [UserParamConfig(name=param_name, alias="test_name")], - }, - } - backend = Backend(cast(BaseBackendConfig, config)) - params = backend.resolved_parameters( - command_name="test", user_params=[user_param] - ) - assert len(params) == 1 - value, param = params[0] - assert param_name == param.name - assert expected_value == value - - @pytest.mark.parametrize( - "input_param,expected", - [ - ("--param=1", ("--param", "1")), - ("--param 1", ("--param", "1")), - ("--flag", ("--flag", None)), - ], - ) - def test__parse_raw_parameter( - self, input_param: str, expected: Tuple[str, Optional[str]] - ) -> None: - """Test internal method of parsing a single raw parameter.""" - assert parse_raw_parameter(input_param) == expected - - -class TestParam: - """Test Param class.""" - - def test__eq__(self) -> None: - """Test equality method with different cases.""" - param1 = Param(name="test", description="desc", values=["values"]) - param2 = Param(name="test", description="desc", values=["values"]) - param3 = Param(name="test1", description="desc", values=["values"]) - param4 = object() - - assert param1 == param2 - assert param1 != param3 - assert param1 != param4 - - def test_get_details(self) -> None: - """Test get_details() method.""" - param1 = Param(name="test", description="desc", values=["values"]) - assert param1.get_details() == { - "name": "test", - "values": ["values"], - "description": "desc", - } - - def test_invalid(self) -> None: - """Test invalid use cases for the Param class.""" - with pytest.raises( - ConfigurationException, - match="Either name, alias or both must be set to identify a parameter.", - ): - Param(name=None, description="desc", values=["values"]) - - -class TestCommand: - """Test Command class.""" - - def test_get_details(self) -> None: - """Test get_details() method.""" - param1 = Param(name="test", description="desc", values=["values"]) - command1 = Command(command_strings=["echo test"], params=[param1]) - assert command1.get_details() == { - "command_strings": ["echo test"], - "user_params": [ - {"name": "test", "values": ["values"], "description": "desc"} - ], - } - - def test__eq__(self) -> None: - """Test equality method with different cases.""" - param1 = Param("test", "desc", ["values"]) - param2 = Param("test1", "desc1", ["values1"]) - command1 = Command(command_strings=["echo test"], params=[param1]) - command2 = Command(command_strings=["echo test"], params=[param1]) - command3 = Command(command_strings=["echo test"]) - command4 = Command(command_strings=["echo test"], params=[param2]) - command5 = object() - - assert command1 == command2 - assert command1 != command3 - assert command1 != command4 - assert command1 != command5 - - @pytest.mark.parametrize( - "params, expected_error", - [ - [[], does_not_raise()], - [[Param("param", "param description", [])], does_not_raise()], - [ - [ - Param("param", "param description", [], None, "alias"), - Param("param", "param description", [], None), - ], - does_not_raise(), - ], - [ - [ - Param("param1", "param1 description", [], None, "alias1"), - Param("param2", "param2 description", [], None, "alias2"), - ], - does_not_raise(), - ], - [ - [ - Param("param", "param description", [], None, "alias"), - Param("param", "param description", [], None, "alias"), - ], - pytest.raises(ConfigurationException, match="Non unique aliases alias"), - ], - [ - [ - Param("alias", "param description", [], None, "alias1"), - Param("param", "param description", [], None, "alias"), - ], - pytest.raises( - ConfigurationException, - match="Aliases .* could not be used as parameter name", - ), - ], - [ - [ - Param("alias", "param description", [], None, "alias"), - Param("param1", "param1 description", [], None, "alias1"), - ], - does_not_raise(), - ], - [ - [ - Param("alias", "param description", [], None, "alias"), - Param("alias", "param1 description", [], None, "alias1"), - ], - pytest.raises( - ConfigurationException, - match="Aliases .* could not be used as parameter name", - ), - ], - [ - [ - Param("param1", "param1 description", [], None, "alias1"), - Param("param2", "param2 description", [], None, "alias1"), - Param("param3", "param3 description", [], None, "alias2"), - Param("param4", "param4 description", [], None, "alias2"), - ], - pytest.raises( - ConfigurationException, match="Non unique aliases alias1, alias2" - ), - ], - ], - ) - def test_validate_params(self, params: List[Param], expected_error: Any) -> None: - """Test command validation function.""" - with expected_error: - Command([], params) diff --git a/tests/aiet/test_backend_controller.py b/tests/aiet/test_backend_controller.py deleted file mode 100644 index 8836ec5..0000000 --- a/tests/aiet/test_backend_controller.py +++ /dev/null @@ -1,160 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for system controller.""" -import csv -import os -import time -from pathlib import Path -from typing import Any - -import psutil -import pytest - -from aiet.backend.common import ConfigurationException -from aiet.backend.controller import SystemController -from aiet.backend.controller import SystemControllerSingleInstance -from aiet.utils.proc import ShellCommand - - -def get_system_controller(**kwargs: Any) -> SystemController: - """Get service controller.""" - single_instance = kwargs.get("single_instance", False) - if single_instance: - pid_file_path = kwargs.get("pid_file_path") - return SystemControllerSingleInstance(pid_file_path) - - return SystemController() - - -def test_service_controller() -> None: - """Test service controller functionality.""" - service_controller = get_system_controller() - - assert service_controller.get_output() == ("", "") - with pytest.raises(ConfigurationException, match="Wrong working directory"): - service_controller.start(["sleep 100"], Path("unknown")) - - service_controller.start(["sleep 100"], Path.cwd()) - assert service_controller.is_running() - - service_controller.stop(True) - assert not service_controller.is_running() - assert service_controller.get_output() == ("", "") - - service_controller.stop() - - with pytest.raises( - ConfigurationException, match="System should have only one command to run" - ): - service_controller.start(["sleep 100", "sleep 101"], Path.cwd()) - - with pytest.raises(ConfigurationException, match="No startup command provided"): - service_controller.start([""], Path.cwd()) - - -def test_service_controller_bad_configuration() -> None: - """Test service controller functionality for bad configuration.""" - with pytest.raises(Exception, match="No pid file path presented"): - service_controller = get_system_controller( - single_instance=True, pid_file_path=None - ) - service_controller.start(["sleep 100"], Path.cwd()) - - -def test_service_controller_writes_process_info_correctly(tmpdir: Any) -> None: - """Test that controller writes process info correctly.""" - pid_file = Path(tmpdir) / "test.pid" - - service_controller = get_system_controller( - single_instance=True, pid_file_path=Path(tmpdir) / "test.pid" - ) - - service_controller.start(["sleep 100"], Path.cwd()) - assert service_controller.is_running() - assert pid_file.is_file() - - with open(pid_file, "r", encoding="utf-8") as file: - csv_reader = csv.reader(file) - rows = list(csv_reader) - assert len(rows) == 1 - - name, *_ = rows[0] - assert name == "sleep" - - service_controller.stop() - assert pid_file.exists() - - -def test_service_controller_does_not_write_process_info_if_process_finishes( - tmpdir: Any, -) -> None: - """Test that controller does not write process info if process already finished.""" - pid_file = Path(tmpdir) / "test.pid" - service_controller = get_system_controller( - single_instance=True, pid_file_path=pid_file - ) - service_controller.is_running = lambda: False # type: ignore - service_controller.start(["echo hello"], Path.cwd()) - - assert not pid_file.exists() - - -def test_service_controller_searches_for_previous_instances_correctly( - tmpdir: Any, -) -> None: - """Test that controller searches for previous instances correctly.""" - pid_file = Path(tmpdir) / "test.pid" - command = ShellCommand().run("sleep", "100") - assert command.is_alive() - - pid = command.process.pid - process = psutil.Process(pid) - with open(pid_file, "w", encoding="utf-8") as file: - csv_writer = csv.writer(file) - csv_writer.writerow(("some_process", "some_program", "some_cwd", os.getpid())) - csv_writer.writerow((process.name(), process.exe(), process.cwd(), process.pid)) - csv_writer.writerow(("some_old_process", "not_running", "from_nowhere", 77777)) - - service_controller = get_system_controller( - single_instance=True, pid_file_path=pid_file - ) - service_controller.start(["sleep 100"], Path.cwd()) - # controller should stop this process as it is currently running and - # mentioned in pid file - assert not command.is_alive() - - service_controller.stop() - - -@pytest.mark.parametrize( - "executable", ["test_backend_run_script.sh", "test_backend_run"] -) -def test_service_controller_run_shell_script( - executable: str, test_resources_path: Path -) -> None: - """Test controller's ability to run shell scripts.""" - script_path = test_resources_path / "scripts" - - service_controller = get_system_controller() - - service_controller.start([executable], script_path) - - assert service_controller.is_running() - # give time for the command to produce output - time.sleep(2) - service_controller.stop(wait=True) - assert not service_controller.is_running() - stdout, stderr = service_controller.get_output() - assert stdout == "Hello from script\n" - assert stderr == "Oops!\n" - - -def test_service_controller_does_nothing_if_not_started(tmpdir: Any) -> None: - """Test that nothing happened if controller is not started.""" - service_controller = get_system_controller( - single_instance=True, pid_file_path=Path(tmpdir) / "test.pid" - ) - - assert not service_controller.is_running() - service_controller.stop() - assert not service_controller.is_running() diff --git a/tests/aiet/test_backend_execution.py b/tests/aiet/test_backend_execution.py deleted file mode 100644 index 8aa45f1..0000000 --- a/tests/aiet/test_backend_execution.py +++ /dev/null @@ -1,526 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-self-use -"""Test backend context module.""" -from contextlib import ExitStack as does_not_raise -from pathlib import Path -from typing import Any -from typing import Dict -from typing import Optional -from unittest import mock -from unittest.mock import MagicMock - -import pytest -from sh import CommandNotFound - -from aiet.backend.application import Application -from aiet.backend.application import get_application -from aiet.backend.common import ConfigurationException -from aiet.backend.common import DataPaths -from aiet.backend.common import UserParamConfig -from aiet.backend.config import ApplicationConfig -from aiet.backend.config import LocalProtocolConfig -from aiet.backend.config import SystemConfig -from aiet.backend.execution import deploy_data -from aiet.backend.execution import execute_commands_locally -from aiet.backend.execution import ExecutionContext -from aiet.backend.execution import get_application_and_system -from aiet.backend.execution import get_application_by_name_and_system -from aiet.backend.execution import get_file_lock_path -from aiet.backend.execution import get_tool_by_system -from aiet.backend.execution import ParamResolver -from aiet.backend.execution import Reporter -from aiet.backend.execution import wait -from aiet.backend.output_parser import OutputParser -from aiet.backend.system import get_system -from aiet.backend.system import load_system -from aiet.backend.tool import get_tool -from aiet.utils.proc import CommandFailedException - - -def test_context_param_resolver(tmpdir: Any) -> None: - """Test parameter resolving.""" - system_config_location = Path(tmpdir) / "system" - system_config_location.mkdir() - - application_config_location = Path(tmpdir) / "application" - application_config_location.mkdir() - - ctx = ExecutionContext( - app=Application( - ApplicationConfig( - name="test_application", - description="Test application", - config_location=application_config_location, - build_dir="build-{application.name}-{system.name}", - commands={ - "run": [ - "run_command1 {user_params:0}", - "run_command2 {user_params:1}", - ] - }, - variables={"var_1": "value for var_1"}, - user_params={ - "run": [ - UserParamConfig( - name="--param1", - description="Param 1", - default_value="123", - alias="param_1", - ), - UserParamConfig( - name="--param2", description="Param 2", default_value="456" - ), - UserParamConfig( - name="--param3", description="Param 3", alias="param_3" - ), - UserParamConfig( - name="--param4=", - description="Param 4", - default_value="456", - alias="param_4", - ), - UserParamConfig( - description="Param 5", - default_value="789", - alias="param_5", - ), - ] - }, - ) - ), - app_params=["--param2=789"], - system=load_system( - SystemConfig( - name="test_system", - description="Test system", - config_location=system_config_location, - build_dir="build", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={ - "build": ["build_command1 {user_params:0}"], - "run": ["run_command {application.commands.run:1}"], - }, - variables={"var_1": "value for var_1"}, - user_params={ - "build": [ - UserParamConfig( - name="--param1", description="Param 1", default_value="aaa" - ), - UserParamConfig(name="--param2", description="Param 2"), - ] - }, - ) - ), - system_params=["--param1=bbb"], - custom_deploy_data=[], - ) - - param_resolver = ParamResolver(ctx) - expected_values = { - "application.name": "test_application", - "application.description": "Test application", - "application.config_dir": str(application_config_location), - "application.build_dir": "{}/build-test_application-test_system".format( - application_config_location - ), - "application.commands.run:0": "run_command1 --param1 123", - "application.commands.run.params:0": "123", - "application.commands.run.params:param_1": "123", - "application.commands.run:1": "run_command2 --param2 789", - "application.commands.run.params:1": "789", - "application.variables:var_1": "value for var_1", - "system.name": "test_system", - "system.description": "Test system", - "system.config_dir": str(system_config_location), - "system.commands.build:0": "build_command1 --param1 bbb", - "system.commands.run:0": "run_command run_command2 --param2 789", - "system.commands.build.params:0": "bbb", - "system.variables:var_1": "value for var_1", - } - - for param, value in expected_values.items(): - assert param_resolver(param) == value - - assert ctx.build_dir() == Path( - "{}/build-test_application-test_system".format(application_config_location) - ) - - expected_errors = { - "application.variables:var_2": pytest.raises( - Exception, match="Unknown variable var_2" - ), - "application.commands.clean:0": pytest.raises( - Exception, match="Command clean not found" - ), - "application.commands.run:2": pytest.raises( - Exception, match="Invalid index 2 for command run" - ), - "application.commands.run.params:5": pytest.raises( - Exception, match="Invalid parameter index 5 for command run" - ), - "application.commands.run.params:param_2": pytest.raises( - Exception, - match="No value for parameter with index or alias param_2 of command run", - ), - "UNKNOWN": pytest.raises( - Exception, match="Unable to resolve parameter UNKNOWN" - ), - "system.commands.build.params:1": pytest.raises( - Exception, - match="No value for parameter with index or alias 1 of command build", - ), - "system.commands.build:A": pytest.raises( - Exception, match="Bad command index A" - ), - "system.variables:var_2": pytest.raises( - Exception, match="Unknown variable var_2" - ), - } - for param, error in expected_errors.items(): - with error: - param_resolver(param) - - resolved_params = ctx.app.resolved_parameters("run", []) - expected_user_params = { - "user_params:0": "--param1 123", - "user_params:param_1": "--param1 123", - "user_params:2": "--param3", - "user_params:param_3": "--param3", - "user_params:3": "--param4=456", - "user_params:param_4": "--param4=456", - "user_params:param_5": "789", - } - for param, expected_value in expected_user_params.items(): - assert param_resolver(param, "run", resolved_params) == expected_value - - with pytest.raises( - Exception, match="Invalid index 5 for user params of command run" - ): - param_resolver("user_params:5", "run", resolved_params) - - with pytest.raises( - Exception, match="No user parameter for command 'run' with alias 'param_2'." - ): - param_resolver("user_params:param_2", "run", resolved_params) - - with pytest.raises(Exception, match="Unable to resolve user params"): - param_resolver("user_params:0", "", resolved_params) - - bad_ctx = ExecutionContext( - app=Application( - ApplicationConfig( - name="test_application", - config_location=application_config_location, - build_dir="build-{user_params:0}", - ) - ), - app_params=["--param2=789"], - system=load_system( - SystemConfig( - name="test_system", - description="Test system", - config_location=system_config_location, - build_dir="build-{system.commands.run.params:123}", - data_transfer=LocalProtocolConfig(protocol="local"), - ) - ), - system_params=["--param1=bbb"], - custom_deploy_data=[], - ) - param_resolver = ParamResolver(bad_ctx) - with pytest.raises(Exception, match="Unable to resolve user params"): - bad_ctx.build_dir() - - -# pylint: disable=too-many-arguments -@pytest.mark.parametrize( - "application_name, soft_lock, sys_lock, lock_dir, expected_error, expected_path", - ( - ( - "test_application", - True, - True, - Path("/tmp"), - does_not_raise(), - Path("/tmp/middleware_test_application_test_system.lock"), - ), - ( - "$$test_application$!:", - True, - True, - Path("/tmp"), - does_not_raise(), - Path("/tmp/middleware_test_application_test_system.lock"), - ), - ( - "test_application", - True, - True, - Path("unknown"), - pytest.raises( - Exception, match="Invalid directory unknown for lock files provided" - ), - None, - ), - ( - "test_application", - False, - True, - Path("/tmp"), - does_not_raise(), - Path("/tmp/middleware_test_system.lock"), - ), - ( - "test_application", - True, - False, - Path("/tmp"), - does_not_raise(), - Path("/tmp/middleware_test_application.lock"), - ), - ( - "test_application", - False, - False, - Path("/tmp"), - pytest.raises(Exception, match="No filename for lock provided"), - None, - ), - ), -) -def test_get_file_lock_path( - application_name: str, - soft_lock: bool, - sys_lock: bool, - lock_dir: Path, - expected_error: Any, - expected_path: Path, -) -> None: - """Test get_file_lock_path function.""" - with expected_error: - ctx = ExecutionContext( - app=Application(ApplicationConfig(name=application_name, lock=soft_lock)), - app_params=[], - system=load_system( - SystemConfig( - name="test_system", - lock=sys_lock, - data_transfer=LocalProtocolConfig(protocol="local"), - ) - ), - system_params=[], - custom_deploy_data=[], - ) - path = get_file_lock_path(ctx, lock_dir) - assert path == expected_path - - -def test_get_application_by_name_and_system(monkeypatch: Any) -> None: - """Test exceptional case for get_application_by_name_and_system.""" - monkeypatch.setattr( - "aiet.backend.execution.get_application", - MagicMock(return_value=[MagicMock(), MagicMock()]), - ) - - with pytest.raises( - ValueError, - match="Error during getting application test_application for the " - "system test_system", - ): - get_application_by_name_and_system("test_application", "test_system") - - -def test_get_application_and_system(monkeypatch: Any) -> None: - """Test exceptional case for get_application_and_system.""" - monkeypatch.setattr( - "aiet.backend.execution.get_system", MagicMock(return_value=None) - ) - - with pytest.raises(ValueError, match="System test_system is not found"): - get_application_and_system("test_application", "test_system") - - -def test_wait_function(monkeypatch: Any) -> None: - """Test wait function.""" - sleep_mock = MagicMock() - monkeypatch.setattr("time.sleep", sleep_mock) - wait(0.1) - sleep_mock.assert_called_once() - - -def test_deployment_execution_context() -> None: - """Test property 'is_deploy_needed' of the ExecutionContext.""" - ctx = ExecutionContext( - app=get_application("application_1")[0], - app_params=[], - system=get_system("System 1"), - system_params=[], - ) - assert not ctx.is_deploy_needed - deploy_data(ctx) # should be a NOP - - ctx = ExecutionContext( - app=get_application("application_1")[0], - app_params=[], - system=get_system("System 1"), - system_params=[], - custom_deploy_data=[DataPaths(Path("README.md"), ".")], - ) - assert ctx.is_deploy_needed - - ctx = ExecutionContext( - app=get_application("application_1")[0], - app_params=[], - system=None, - system_params=[], - ) - assert not ctx.is_deploy_needed - with pytest.raises(AssertionError): - deploy_data(ctx) - - ctx = ExecutionContext( - app=get_tool("tool_1")[0], - app_params=[], - system=None, - system_params=[], - ) - assert not ctx.is_deploy_needed - deploy_data(ctx) # should be a NOP - - -@pytest.mark.parametrize( - ["tool_name", "system_name", "exception"], - [ - ("vela", "Corstone-300: Cortex-M55+Ethos-U65", None), - ("unknown tool", "Corstone-300: Cortex-M55+Ethos-U65", ConfigurationException), - ("vela", "unknown system", ConfigurationException), - ("vela", None, ConfigurationException), - ], -) -def test_get_tool_by_system( - tool_name: str, system_name: Optional[str], exception: Optional[Any] -) -> None: - """Test exceptions thrown by function get_tool_by_system().""" - - def test() -> None: - """Test call of get_tool_by_system().""" - tool = get_tool_by_system(tool_name, system_name) - assert tool is not None - - if exception is None: - test() - else: - with pytest.raises(exception): - test() - - -class TestExecuteCommandsLocally: - """Test execute_commands_locally() function.""" - - @pytest.mark.parametrize( - "first_command, exception, expected_output", - ( - ( - "echo 'hello'", - None, - "Running: echo 'hello'\nhello\nRunning: echo 'goodbye'\ngoodbye\n", - ), - ( - "non-existent-command", - CommandNotFound, - "Running: non-existent-command\n", - ), - ("false", CommandFailedException, "Running: false\n"), - ), - ids=( - "runs_multiple_commands", - "stops_executing_on_non_existent_command", - "stops_executing_when_command_exits_with_error_code", - ), - ) - def test_execution( - self, - first_command: str, - exception: Any, - expected_output: str, - test_resources_path: Path, - capsys: Any, - ) -> None: - """Test expected behaviour of the function.""" - commands = [first_command, "echo 'goodbye'"] - cwd = test_resources_path - if exception is None: - execute_commands_locally(commands, cwd) - else: - with pytest.raises(exception): - execute_commands_locally(commands, cwd) - - captured = capsys.readouterr() - assert captured.out == expected_output - - def test_stops_executing_on_exception( - self, monkeypatch: Any, test_resources_path: Path - ) -> None: - """Ensure commands following an error-exit-code command don't run.""" - # Mock execute_command() function - execute_command_mock = mock.MagicMock() - monkeypatch.setattr("aiet.utils.proc.execute_command", execute_command_mock) - - # Mock Command object and assign as return value to execute_command() - cmd_mock = mock.MagicMock() - execute_command_mock.return_value = cmd_mock - - # Mock the terminate_command (speed up test) - terminate_command_mock = mock.MagicMock() - monkeypatch.setattr("aiet.utils.proc.terminate_command", terminate_command_mock) - - # Mock a thrown Exception and assign to Command().exit_code - exit_code_mock = mock.PropertyMock(side_effect=Exception("Exception.")) - type(cmd_mock).exit_code = exit_code_mock - - with pytest.raises(Exception, match="Exception."): - execute_commands_locally( - ["command_1", "command_2"], cwd=test_resources_path - ) - - # Assert only "command_1" was executed - assert execute_command_mock.call_count == 1 - - -def test_reporter(tmpdir: Any) -> None: - """Test class 'Reporter'.""" - ctx = ExecutionContext( - app=get_application("application_4")[0], - app_params=["--app=TestApp"], - system=get_system("System 4"), - system_params=[], - ) - assert ctx.system is not None - - class MockParser(OutputParser): - """Mock implementation of an output parser.""" - - def __init__(self, metrics: Dict[str, Any]) -> None: - """Set up the MockParser.""" - super().__init__(name="test") - self.metrics = metrics - - def __call__(self, output: bytearray) -> Dict[str, Any]: - """Return mock metrics (ignoring the given output).""" - return self.metrics - - metrics = {"Metric": 123, "AnotherMetric": 456} - reporter = Reporter( - parsers=[MockParser(metrics={key: val}) for key, val in metrics.items()], - ) - reporter.parse(bytearray()) - report = reporter.report(ctx) - assert report["system"]["name"] == ctx.system.name - assert report["system"]["params"] == {} - assert report["application"]["name"] == ctx.app.name - assert report["application"]["params"] == {"--app": "TestApp"} - assert report["test"]["metrics"] == metrics - report_file = Path(tmpdir) / "report.json" - reporter.save(report, report_file) - assert report_file.is_file() diff --git a/tests/aiet/test_backend_output_parser.py b/tests/aiet/test_backend_output_parser.py deleted file mode 100644 index d659812..0000000 --- a/tests/aiet/test_backend_output_parser.py +++ /dev/null @@ -1,152 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for the output parsing.""" -import base64 -import json -from typing import Any -from typing import Dict - -import pytest - -from aiet.backend.output_parser import Base64OutputParser -from aiet.backend.output_parser import OutputParser -from aiet.backend.output_parser import RegexOutputParser - - -OUTPUT_MATCH_ALL = bytearray( - """ -String1: My awesome string! -String2: STRINGS_ARE_GREAT!!! -Int: 12 -Float: 3.14 -""", - encoding="utf-8", -) - -OUTPUT_NO_MATCH = bytearray( - """ -This contains no matches... -Test1234567890!"£$%^&*()_+@~{}[]/.,<>?| -""", - encoding="utf-8", -) - -OUTPUT_PARTIAL_MATCH = bytearray( - "String1: My awesome string!", - encoding="utf-8", -) - -REGEX_CONFIG = { - "FirstString": {"pattern": r"String1.*: (.*)", "type": "str"}, - "SecondString": {"pattern": r"String2.*: (.*)!!!", "type": "str"}, - "IntegerValue": {"pattern": r"Int.*: (.*)", "type": "int"}, - "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"}, -} - -EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {} - -EXPECTED_METRICS_ALL = { - "FirstString": "My awesome string!", - "SecondString": "STRINGS_ARE_GREAT", - "IntegerValue": 12, - "FloatValue": 3.14, -} - -EXPECTED_METRICS_PARTIAL = { - "FirstString": "My awesome string!", -} - - -class TestRegexOutputParser: - """Collect tests for the RegexOutputParser.""" - - @staticmethod - @pytest.mark.parametrize( - ["output", "config", "expected_metrics"], - [ - (OUTPUT_MATCH_ALL, REGEX_CONFIG, EXPECTED_METRICS_ALL), - (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL), - (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL), - ( - OUTPUT_MATCH_ALL + OUTPUT_PARTIAL_MATCH, - REGEX_CONFIG, - EXPECTED_METRICS_ALL, - ), - (OUTPUT_NO_MATCH, REGEX_CONFIG, {}), - (OUTPUT_MATCH_ALL, EMPTY_REGEX_CONFIG, {}), - (bytearray(), EMPTY_REGEX_CONFIG, {}), - (bytearray(), REGEX_CONFIG, {}), - ], - ) - def test_parsing(output: bytearray, config: Dict, expected_metrics: Dict) -> None: - """ - Make sure the RegexOutputParser yields valid results. - - I.e. return an empty dict if either the input or the config is empty and - return the parsed metrics otherwise. - """ - parser = RegexOutputParser(name="Test", regex_config=config) - assert parser.name == "Test" - assert isinstance(parser, OutputParser) - res = parser(output) - assert res == expected_metrics - - @staticmethod - def test_unsupported_type() -> None: - """An unsupported type in the regex_config must raise an exception.""" - config = {"BrokenMetric": {"pattern": "(.*)", "type": "UNSUPPORTED_TYPE"}} - with pytest.raises(TypeError): - RegexOutputParser(name="Test", regex_config=config) - - @staticmethod - @pytest.mark.parametrize( - "config", - ( - {"TooManyGroups": {"pattern": r"(\w)(\d)", "type": "str"}}, - {"NoGroups": {"pattern": r"\W", "type": "str"}}, - ), - ) - def test_invalid_pattern(config: Dict) -> None: - """Exactly one capturing parenthesis is allowed in the regex pattern.""" - with pytest.raises(ValueError): - RegexOutputParser(name="Test", regex_config=config) - - -@pytest.mark.parametrize( - "expected_metrics", - [ - EXPECTED_METRICS_ALL, - EXPECTED_METRICS_PARTIAL, - ], -) -def test_base64_output_parser(expected_metrics: Dict) -> None: - """ - Make sure the Base64OutputParser yields valid results. - - I.e. return an empty dict if either the input or the config is empty and - return the parsed metrics otherwise. - """ - parser = Base64OutputParser(name="Test") - assert parser.name == "Test" - assert isinstance(parser, OutputParser) - - def create_base64_output(expected_metrics: Dict) -> bytearray: - json_str = json.dumps(expected_metrics, indent=4) - json_b64 = base64.b64encode(json_str.encode("utf-8")) - return ( - OUTPUT_MATCH_ALL # Should not be matched by the Base64OutputParser - + f"<{Base64OutputParser.TAG_NAME}>".encode("utf-8") - + bytearray(json_b64) - + f"".encode("utf-8") - + OUTPUT_NO_MATCH # Just to add some difficulty... - ) - - output = create_base64_output(expected_metrics) - res = parser(output) - assert len(res) == 1 - assert isinstance(res, dict) - for val in res.values(): - assert val == expected_metrics - - output = parser.filter_out_parsed_content(output) - assert output == (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH) diff --git a/tests/aiet/test_backend_protocol.py b/tests/aiet/test_backend_protocol.py deleted file mode 100644 index 2103238..0000000 --- a/tests/aiet/test_backend_protocol.py +++ /dev/null @@ -1,231 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-self-use,attribute-defined-outside-init,protected-access -"""Tests for the protocol backend module.""" -from contextlib import ExitStack as does_not_raise -from pathlib import Path -from typing import Any -from unittest.mock import MagicMock - -import paramiko -import pytest - -from aiet.backend.common import ConfigurationException -from aiet.backend.config import LocalProtocolConfig -from aiet.backend.protocol import CustomSFTPClient -from aiet.backend.protocol import LocalProtocol -from aiet.backend.protocol import ProtocolFactory -from aiet.backend.protocol import SSHProtocol - - -class TestProtocolFactory: - """Test ProtocolFactory class.""" - - @pytest.mark.parametrize( - "config, expected_class, exception", - [ - ( - { - "protocol": "ssh", - "username": "user", - "password": "pass", - "hostname": "hostname", - "port": "22", - }, - SSHProtocol, - does_not_raise(), - ), - ({"protocol": "local"}, LocalProtocol, does_not_raise()), - ( - {"protocol": "something"}, - None, - pytest.raises(Exception, match="Protocol not supported"), - ), - (None, None, pytest.raises(Exception, match="No protocol config provided")), - ], - ) - def test_get_protocol( - self, config: Any, expected_class: type, exception: Any - ) -> None: - """Test get_protocol method.""" - factory = ProtocolFactory() - with exception: - protocol = factory.get_protocol(config) - assert isinstance(protocol, expected_class) - - -class TestLocalProtocol: - """Test local protocol.""" - - def test_local_protocol_run_command(self) -> None: - """Test local protocol run command.""" - config = LocalProtocolConfig(protocol="local") - protocol = LocalProtocol(config, cwd=Path("/tmp")) - ret, stdout, stderr = protocol.run("pwd") - assert ret == 0 - assert stdout.decode("utf-8").strip() == "/tmp" - assert stderr.decode("utf-8") == "" - - def test_local_protocol_run_wrong_cwd(self) -> None: - """Execution should fail if wrong working directory provided.""" - config = LocalProtocolConfig(protocol="local") - protocol = LocalProtocol(config, cwd=Path("unknown_directory")) - with pytest.raises( - ConfigurationException, match="Wrong working directory unknown_directory" - ): - protocol.run("pwd") - - -class TestSSHProtocol: - """Test SSH protocol.""" - - @pytest.fixture(autouse=True) - def setup_method(self, monkeypatch: Any) -> None: - """Set up protocol mocks.""" - self.mock_ssh_client = MagicMock(spec=paramiko.client.SSHClient) - - self.mock_ssh_channel = ( - self.mock_ssh_client.get_transport.return_value.open_session.return_value - ) - self.mock_ssh_channel.mock_add_spec(spec=paramiko.channel.Channel) - self.mock_ssh_channel.exit_status_ready.side_effect = [False, True] - self.mock_ssh_channel.recv_exit_status.return_value = True - self.mock_ssh_channel.recv_ready.side_effect = [False, True] - self.mock_ssh_channel.recv_stderr_ready.side_effect = [False, True] - - monkeypatch.setattr( - "aiet.backend.protocol.paramiko.client.SSHClient", - MagicMock(return_value=self.mock_ssh_client), - ) - - self.mock_sftp_client = MagicMock(spec=CustomSFTPClient) - monkeypatch.setattr( - "aiet.backend.protocol.CustomSFTPClient.from_transport", - MagicMock(return_value=self.mock_sftp_client), - ) - - ssh_config = { - "protocol": "ssh", - "username": "user", - "password": "pass", - "hostname": "hostname", - "port": "22", - } - self.protocol = SSHProtocol(ssh_config) - - def test_unable_create_ssh_client(self, monkeypatch: Any) -> None: - """Test that command should fail if unable to create ssh client instance.""" - monkeypatch.setattr( - "aiet.backend.protocol.paramiko.client.SSHClient", - MagicMock(side_effect=OSError("Error!")), - ) - - with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"): - self.protocol.run("command_example", retry=False) - - def test_ssh_protocol_run_command(self) -> None: - """Test that command run via ssh successfully.""" - self.protocol.run("command_example") - self.mock_ssh_channel.exec_command.assert_called_once() - - def test_ssh_protocol_run_command_connect_failed(self) -> None: - """Test that if connection is not possible then correct exception is raised.""" - self.mock_ssh_client.connect.side_effect = OSError("Unable to connect") - self.mock_ssh_client.close.side_effect = Exception("Error!") - - with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"): - self.protocol.run("command_example", retry=False) - - def test_ssh_protocol_run_command_bad_transport(self) -> None: - """Test that command should fail if unable to get transport.""" - self.mock_ssh_client.get_transport.return_value = None - - with pytest.raises(Exception, match="Unable to get transport"): - self.protocol.run("command_example", retry=False) - - def test_ssh_protocol_deploy_command_file( - self, test_applications_path: Path - ) -> None: - """Test that files could be deployed over ssh.""" - file_for_deploy = test_applications_path / "readme.txt" - dest = "/tmp/dest" - - self.protocol.deploy(file_for_deploy, dest) - self.mock_sftp_client.put.assert_called_once_with(str(file_for_deploy), dest) - - def test_ssh_protocol_deploy_command_unknown_file(self) -> None: - """Test that deploy will fail if file does not exist.""" - with pytest.raises(Exception, match="Deploy error: file type not supported"): - self.protocol.deploy(Path("unknown_file"), "/tmp/dest") - - def test_ssh_protocol_deploy_command_bad_transport(self) -> None: - """Test that deploy should fail if unable to get transport.""" - self.mock_ssh_client.get_transport.return_value = None - - with pytest.raises(Exception, match="Unable to get transport"): - self.protocol.deploy(Path("some_file"), "/tmp/dest") - - def test_ssh_protocol_deploy_command_directory( - self, test_resources_path: Path - ) -> None: - """Test that directory could be deployed over ssh.""" - directory_for_deploy = test_resources_path / "scripts" - dest = "/tmp/dest" - - self.protocol.deploy(directory_for_deploy, dest) - self.mock_sftp_client.put_dir.assert_called_once_with( - directory_for_deploy, dest - ) - - @pytest.mark.parametrize("establish_connection", (True, False)) - def test_ssh_protocol_close(self, establish_connection: bool) -> None: - """Test protocol close operation.""" - if establish_connection: - self.protocol.establish_connection() - self.protocol.close() - - call_count = 1 if establish_connection else 0 - assert self.mock_ssh_channel.exec_command.call_count == call_count - - def test_connection_details(self) -> None: - """Test getting connection details.""" - assert self.protocol.connection_details() == ("hostname", 22) - - -class TestCustomSFTPClient: - """Test CustomSFTPClient class.""" - - @pytest.fixture(autouse=True) - def setup_method(self, monkeypatch: Any) -> None: - """Set up mocks for CustomSFTPClient instance.""" - self.mock_mkdir = MagicMock() - self.mock_put = MagicMock() - monkeypatch.setattr( - "aiet.backend.protocol.paramiko.SFTPClient.__init__", - MagicMock(return_value=None), - ) - monkeypatch.setattr( - "aiet.backend.protocol.paramiko.SFTPClient.mkdir", self.mock_mkdir - ) - monkeypatch.setattr( - "aiet.backend.protocol.paramiko.SFTPClient.put", self.mock_put - ) - - self.sftp_client = CustomSFTPClient(MagicMock()) - - def test_put_dir(self, test_systems_path: Path) -> None: - """Test deploying directory to remote host.""" - directory_for_deploy = test_systems_path / "system1" - - self.sftp_client.put_dir(directory_for_deploy, "/tmp/dest") - assert self.mock_put.call_count == 3 - assert self.mock_mkdir.call_count == 3 - - def test_mkdir(self) -> None: - """Test creating directory on remote host.""" - self.mock_mkdir.side_effect = IOError("Cannot create directory") - - self.sftp_client._mkdir("new_directory", ignore_existing=True) - - with pytest.raises(IOError, match="Cannot create directory"): - self.sftp_client._mkdir("new_directory", ignore_existing=False) diff --git a/tests/aiet/test_backend_source.py b/tests/aiet/test_backend_source.py deleted file mode 100644 index 13b2c6d..0000000 --- a/tests/aiet/test_backend_source.py +++ /dev/null @@ -1,199 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-self-use -"""Tests for the source backend module.""" -from collections import Counter -from contextlib import ExitStack as does_not_raise -from pathlib import Path -from typing import Any -from unittest.mock import MagicMock -from unittest.mock import patch - -import pytest - -from aiet.backend.common import ConfigurationException -from aiet.backend.source import create_destination_and_install -from aiet.backend.source import DirectorySource -from aiet.backend.source import get_source -from aiet.backend.source import TarArchiveSource - - -def test_create_destination_and_install(test_systems_path: Path, tmpdir: Any) -> None: - """Test create_destination_and_install function.""" - system_directory = test_systems_path / "system1" - - dir_source = DirectorySource(system_directory) - resources = Path(tmpdir) - create_destination_and_install(dir_source, resources) - assert (resources / "system1").is_dir() - - -@patch("aiet.backend.source.DirectorySource.create_destination", return_value=False) -def test_create_destination_and_install_if_dest_creation_not_required( - mock_ds_create_destination: Any, tmpdir: Any -) -> None: - """Test create_destination_and_install function.""" - dir_source = DirectorySource(Path("unknown")) - resources = Path(tmpdir) - with pytest.raises(Exception): - create_destination_and_install(dir_source, resources) - - mock_ds_create_destination.assert_called_once() - - -def test_create_destination_and_install_if_installation_fails(tmpdir: Any) -> None: - """Test create_destination_and_install function if installation fails.""" - dir_source = DirectorySource(Path("unknown")) - resources = Path(tmpdir) - with pytest.raises(Exception, match="Directory .* does not exist"): - create_destination_and_install(dir_source, resources) - assert not (resources / "unknown").exists() - assert resources.exists() - - -def test_create_destination_and_install_if_name_is_empty() -> None: - """Test create_destination_and_install function fails if source name is empty.""" - source = MagicMock() - source.create_destination.return_value = True - source.name.return_value = None - - with pytest.raises(Exception, match="Unable to get source name"): - create_destination_and_install(source, Path("some_path")) - - source.install_into.assert_not_called() - - -@pytest.mark.parametrize( - "source_path, expected_class, expected_error", - [ - (Path("applications/application1/"), DirectorySource, does_not_raise()), - ( - Path("archives/applications/application1.tar.gz"), - TarArchiveSource, - does_not_raise(), - ), - ( - Path("doesnt/exist"), - None, - pytest.raises( - ConfigurationException, match="Unable to read .*doesnt/exist" - ), - ), - ], -) -def test_get_source( - source_path: Path, - expected_class: Any, - expected_error: Any, - test_resources_path: Path, -) -> None: - """Test get_source function.""" - with expected_error: - full_source_path = test_resources_path / source_path - source = get_source(full_source_path) - assert isinstance(source, expected_class) - - -class TestDirectorySource: - """Test DirectorySource class.""" - - @pytest.mark.parametrize( - "directory, name", - [ - (Path("/some/path/some_system"), "some_system"), - (Path("some_system"), "some_system"), - ], - ) - def test_name(self, directory: Path, name: str) -> None: - """Test getting source name.""" - assert DirectorySource(directory).name() == name - - def test_install_into(self, test_systems_path: Path, tmpdir: Any) -> None: - """Test install directory into destination.""" - system_directory = test_systems_path / "system1" - - dir_source = DirectorySource(system_directory) - with pytest.raises(Exception, match="Wrong destination .*"): - dir_source.install_into(Path("unknown_destination")) - - tmpdir_path = Path(tmpdir) - dir_source.install_into(tmpdir_path) - source_files = [f.name for f in system_directory.iterdir()] - dest_files = [f.name for f in tmpdir_path.iterdir()] - assert Counter(source_files) == Counter(dest_files) - - def test_install_into_unknown_source_directory(self, tmpdir: Any) -> None: - """Test install system from unknown directory.""" - with pytest.raises(Exception, match="Directory .* does not exist"): - DirectorySource(Path("unknown_directory")).install_into(Path(tmpdir)) - - -class TestTarArchiveSource: - """Test TarArchiveSource class.""" - - @pytest.mark.parametrize( - "archive, name", - [ - (Path("some_archive.tgz"), "some_archive"), - (Path("some_archive.tar.gz"), "some_archive"), - (Path("some_archive"), "some_archive"), - ("archives/systems/system1.tar.gz", "system1"), - ("archives/systems/system1_dir.tar.gz", "system1"), - ], - ) - def test_name(self, test_resources_path: Path, archive: Path, name: str) -> None: - """Test getting source name.""" - assert TarArchiveSource(test_resources_path / archive).name() == name - - def test_install_into(self, test_resources_path: Path, tmpdir: Any) -> None: - """Test install archive into destination.""" - system_archive = test_resources_path / "archives/systems/system1.tar.gz" - - tar_source = TarArchiveSource(system_archive) - with pytest.raises(Exception, match="Wrong destination .*"): - tar_source.install_into(Path("unknown_destination")) - - tmpdir_path = Path(tmpdir) - tar_source.install_into(tmpdir_path) - source_files = [ - "aiet-config.json.license", - "aiet-config.json", - "system_artifact", - ] - dest_files = [f.name for f in tmpdir_path.iterdir()] - assert Counter(source_files) == Counter(dest_files) - - def test_install_into_unknown_source_archive(self, tmpdir: Any) -> None: - """Test install unknown source archive.""" - with pytest.raises(Exception, match="File .* does not exist"): - TarArchiveSource(Path("unknown.tar.gz")).install_into(Path(tmpdir)) - - def test_install_into_unsupported_source_archive(self, tmpdir: Any) -> None: - """Test install unsupported file type.""" - plain_text_file = Path(tmpdir) / "test_file" - plain_text_file.write_text("Not a system config") - - with pytest.raises(Exception, match="Unsupported archive type .*"): - TarArchiveSource(plain_text_file).install_into(Path(tmpdir)) - - def test_lazy_property_init(self, test_resources_path: Path) -> None: - """Test that class properties initialized correctly.""" - system_archive = test_resources_path / "archives/systems/system1.tar.gz" - - tar_source = TarArchiveSource(system_archive) - assert tar_source.name() == "system1" - assert tar_source.config() is not None - assert tar_source.create_destination() - - tar_source = TarArchiveSource(system_archive) - assert tar_source.config() is not None - assert tar_source.create_destination() - assert tar_source.name() == "system1" - - def test_create_destination_property(self, test_resources_path: Path) -> None: - """Test create_destination property filled correctly for different archives.""" - system_archive1 = test_resources_path / "archives/systems/system1.tar.gz" - system_archive2 = test_resources_path / "archives/systems/system1_dir.tar.gz" - - assert TarArchiveSource(system_archive1).create_destination() - assert not TarArchiveSource(system_archive2).create_destination() diff --git a/tests/aiet/test_backend_system.py b/tests/aiet/test_backend_system.py deleted file mode 100644 index a581547..0000000 --- a/tests/aiet/test_backend_system.py +++ /dev/null @@ -1,536 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for system backend.""" -from contextlib import ExitStack as does_not_raise -from pathlib import Path -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from unittest.mock import MagicMock - -import pytest - -from aiet.backend.common import Command -from aiet.backend.common import ConfigurationException -from aiet.backend.common import Param -from aiet.backend.common import UserParamConfig -from aiet.backend.config import LocalProtocolConfig -from aiet.backend.config import ProtocolConfig -from aiet.backend.config import SSHConfig -from aiet.backend.config import SystemConfig -from aiet.backend.controller import SystemController -from aiet.backend.controller import SystemControllerSingleInstance -from aiet.backend.protocol import LocalProtocol -from aiet.backend.protocol import SSHProtocol -from aiet.backend.protocol import SupportsClose -from aiet.backend.protocol import SupportsDeploy -from aiet.backend.system import ControlledSystem -from aiet.backend.system import get_available_systems -from aiet.backend.system import get_controller -from aiet.backend.system import get_system -from aiet.backend.system import install_system -from aiet.backend.system import load_system -from aiet.backend.system import remove_system -from aiet.backend.system import StandaloneSystem -from aiet.backend.system import System - - -def dummy_resolver( - values: Optional[Dict[str, str]] = None -) -> Callable[[str, str, List[Tuple[Optional[str], Param]]], str]: - """Return dummy parameter resolver implementation.""" - # pylint: disable=unused-argument - def resolver( - param: str, cmd: str, param_values: List[Tuple[Optional[str], Param]] - ) -> str: - """Implement dummy parameter resolver.""" - return values.get(param, "") if values else "" - - return resolver - - -def test_get_available_systems() -> None: - """Test get_available_systems mocking get_resources.""" - available_systems = get_available_systems() - assert all(isinstance(s, System) for s in available_systems) - assert len(available_systems) == 3 - assert [str(s) for s in available_systems] == ["System 1", "System 2", "System 4"] - - -def test_get_system() -> None: - """Test get_system.""" - system1 = get_system("System 1") - assert isinstance(system1, ControlledSystem) - assert system1.connectable is True - assert system1.connection_details() == ("localhost", 8021) - assert system1.name == "System 1" - - system2 = get_system("System 2") - # check that comparison with object of another type returns false - assert system1 != 42 - assert system1 != system2 - - system = get_system("Unknown system") - assert system is None - - -@pytest.mark.parametrize( - "source, call_count, exception_type", - ( - ( - "archives/systems/system1.tar.gz", - 0, - pytest.raises(Exception, match="Systems .* are already installed"), - ), - ( - "archives/systems/system3.tar.gz", - 0, - pytest.raises(Exception, match="Unable to read system definition"), - ), - ( - "systems/system1", - 0, - pytest.raises(Exception, match="Systems .* are already installed"), - ), - ( - "systems/system3", - 0, - pytest.raises(Exception, match="Unable to read system definition"), - ), - ("unknown_path", 0, pytest.raises(Exception, match="Unable to read")), - ( - "various/systems/system_with_empty_config", - 0, - pytest.raises(Exception, match="No system definition found"), - ), - ("various/systems/system_with_valid_config", 1, does_not_raise()), - ), -) -def test_install_system( - monkeypatch: Any, - test_resources_path: Path, - source: str, - call_count: int, - exception_type: Any, -) -> None: - """Test system installation from archive.""" - mock_create_destination_and_install = MagicMock() - monkeypatch.setattr( - "aiet.backend.system.create_destination_and_install", - mock_create_destination_and_install, - ) - - with exception_type: - install_system(test_resources_path / source) - - assert mock_create_destination_and_install.call_count == call_count - - -def test_remove_system(monkeypatch: Any) -> None: - """Test system removal.""" - mock_remove_backend = MagicMock() - monkeypatch.setattr("aiet.backend.system.remove_backend", mock_remove_backend) - remove_system("some_system_dir") - mock_remove_backend.assert_called_once() - - -def test_system(monkeypatch: Any) -> None: - """Test the System class.""" - config = SystemConfig(name="System 1") - monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock()) - system = System(config) - assert str(system) == "System 1" - assert system.name == "System 1" - - -def test_system_with_empty_parameter_name() -> None: - """Test that configuration fails if parameter name is empty.""" - bad_config = SystemConfig( - name="System 1", - commands={"run": ["run"]}, - user_params={"run": [{"name": "", "values": ["1", "2", "3"]}]}, - ) - with pytest.raises(Exception, match="Parameter has an empty 'name' attribute."): - System(bad_config) - - -def test_system_standalone_run() -> None: - """Test run operation for standalone system.""" - system = get_system("System 4") - assert isinstance(system, StandaloneSystem) - - with pytest.raises( - ConfigurationException, match="System .* does not support connections" - ): - system.connection_details() - - with pytest.raises( - ConfigurationException, match="System .* does not support connections" - ): - system.establish_connection() - - assert system.connectable is False - - system.run("echo 'application run'") - - -@pytest.mark.parametrize( - "system_name, expected_value", [("System 1", True), ("System 4", False)] -) -def test_system_supports_deploy(system_name: str, expected_value: bool) -> None: - """Test system property supports_deploy.""" - system = get_system(system_name) - if system is None: - pytest.fail("Unable to get system {}".format(system_name)) - assert system.supports_deploy == expected_value - - -@pytest.mark.parametrize( - "mock_protocol", - [ - MagicMock(spec=SSHProtocol), - MagicMock( - spec=SSHProtocol, - **{"close.side_effect": ValueError("Unable to close protocol")} - ), - MagicMock(spec=LocalProtocol), - ], -) -def test_system_start_and_stop(monkeypatch: Any, mock_protocol: MagicMock) -> None: - """Test system start, run commands and stop.""" - monkeypatch.setattr( - "aiet.backend.system.ProtocolFactory.get_protocol", - MagicMock(return_value=mock_protocol), - ) - - system = get_system("System 1") - if system is None: - pytest.fail("Unable to get system") - assert isinstance(system, ControlledSystem) - - with pytest.raises(Exception, match="System has not been started"): - system.stop() - - assert not system.is_running() - assert system.get_output() == ("", "") - system.start(["sleep 10"], False) - assert system.is_running() - system.stop(wait=True) - assert not system.is_running() - assert system.get_output() == ("", "") - - if isinstance(mock_protocol, SupportsClose): - mock_protocol.close.assert_called_once() - - if isinstance(mock_protocol, SSHProtocol): - system.establish_connection() - - -def test_system_start_no_config_location() -> None: - """Test that system without config location could not start.""" - system = load_system( - SystemConfig( - name="test", - data_transfer=SSHConfig( - protocol="ssh", - username="user", - password="user", - hostname="localhost", - port="123", - ), - ) - ) - - assert isinstance(system, ControlledSystem) - with pytest.raises( - ConfigurationException, match="System test has wrong config location" - ): - system.start(["sleep 100"]) - - -@pytest.mark.parametrize( - "config, expected_class, expected_error", - [ - ( - SystemConfig( - name="test", - data_transfer=SSHConfig( - protocol="ssh", - username="user", - password="user", - hostname="localhost", - port="123", - ), - ), - ControlledSystem, - does_not_raise(), - ), - ( - SystemConfig( - name="test", data_transfer=LocalProtocolConfig(protocol="local") - ), - StandaloneSystem, - does_not_raise(), - ), - ( - SystemConfig( - name="test", - data_transfer=ProtocolConfig(protocol="cool_protocol"), # type: ignore - ), - None, - pytest.raises( - Exception, match="Unsupported execution type for protocol cool_protocol" - ), - ), - ], -) -def test_load_system( - config: SystemConfig, expected_class: type, expected_error: Any -) -> None: - """Test load_system function.""" - if not expected_class: - with expected_error: - load_system(config) - else: - system = load_system(config) - assert isinstance(system, expected_class) - - -def test_load_system_populate_shared_params() -> None: - """Test shared parameters population.""" - with pytest.raises(Exception, match="All shared parameters should have aliases"): - load_system( - SystemConfig( - name="test_system", - data_transfer=LocalProtocolConfig(protocol="local"), - user_params={ - "shared": [ - UserParamConfig( - name="--shared_param1", - description="Shared parameter", - values=["1", "2", "3"], - default_value="1", - ) - ] - }, - ) - ) - - with pytest.raises( - Exception, match="All parameters for command run should have aliases" - ): - load_system( - SystemConfig( - name="test_system", - data_transfer=LocalProtocolConfig(protocol="local"), - user_params={ - "shared": [ - UserParamConfig( - name="--shared_param1", - description="Shared parameter", - values=["1", "2", "3"], - default_value="1", - alias="shared_param1", - ) - ], - "run": [ - UserParamConfig( - name="--run_param1", - description="Run specific parameter", - values=["1", "2", "3"], - default_value="2", - ) - ], - }, - ) - ) - system0 = load_system( - SystemConfig( - name="test_system", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={"run": ["run_command"]}, - user_params={ - "shared": [], - "run": [ - UserParamConfig( - name="--run_param1", - description="Run specific parameter", - values=["1", "2", "3"], - default_value="2", - alias="run_param1", - ) - ], - }, - ) - ) - assert len(system0.commands) == 1 - run_command1 = system0.commands["run"] - assert run_command1 == Command( - ["run_command"], - [ - Param( - "--run_param1", - "Run specific parameter", - ["1", "2", "3"], - "2", - "run_param1", - ) - ], - ) - - system1 = load_system( - SystemConfig( - name="test_system", - data_transfer=LocalProtocolConfig(protocol="local"), - user_params={ - "shared": [ - UserParamConfig( - name="--shared_param1", - description="Shared parameter", - values=["1", "2", "3"], - default_value="1", - alias="shared_param1", - ) - ], - "run": [ - UserParamConfig( - name="--run_param1", - description="Run specific parameter", - values=["1", "2", "3"], - default_value="2", - alias="run_param1", - ) - ], - }, - ) - ) - assert len(system1.commands) == 2 - build_command1 = system1.commands["build"] - assert build_command1 == Command( - [], - [ - Param( - "--shared_param1", - "Shared parameter", - ["1", "2", "3"], - "1", - "shared_param1", - ) - ], - ) - - run_command1 = system1.commands["run"] - assert run_command1 == Command( - [], - [ - Param( - "--shared_param1", - "Shared parameter", - ["1", "2", "3"], - "1", - "shared_param1", - ), - Param( - "--run_param1", - "Run specific parameter", - ["1", "2", "3"], - "2", - "run_param1", - ), - ], - ) - - system2 = load_system( - SystemConfig( - name="test_system", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={"build": ["build_command"]}, - user_params={ - "shared": [ - UserParamConfig( - name="--shared_param1", - description="Shared parameter", - values=["1", "2", "3"], - default_value="1", - alias="shared_param1", - ) - ], - "run": [ - UserParamConfig( - name="--run_param1", - description="Run specific parameter", - values=["1", "2", "3"], - default_value="2", - alias="run_param1", - ) - ], - }, - ) - ) - assert len(system2.commands) == 2 - build_command2 = system2.commands["build"] - assert build_command2 == Command( - ["build_command"], - [ - Param( - "--shared_param1", - "Shared parameter", - ["1", "2", "3"], - "1", - "shared_param1", - ) - ], - ) - - run_command2 = system1.commands["run"] - assert run_command2 == Command( - [], - [ - Param( - "--shared_param1", - "Shared parameter", - ["1", "2", "3"], - "1", - "shared_param1", - ), - Param( - "--run_param1", - "Run specific parameter", - ["1", "2", "3"], - "2", - "run_param1", - ), - ], - ) - - -@pytest.mark.parametrize( - "mock_protocol, expected_call_count", - [(MagicMock(spec=SupportsDeploy), 1), (MagicMock(), 0)], -) -def test_system_deploy_data( - monkeypatch: Any, mock_protocol: MagicMock, expected_call_count: int -) -> None: - """Test deploy data functionality.""" - monkeypatch.setattr( - "aiet.backend.system.ProtocolFactory.get_protocol", - MagicMock(return_value=mock_protocol), - ) - - system = ControlledSystem(SystemConfig(name="test")) - system.deploy(Path("some_file"), "some_dest") - - assert mock_protocol.deploy.call_count == expected_call_count - - -@pytest.mark.parametrize( - "single_instance, controller_class", - ((False, SystemController), (True, SystemControllerSingleInstance)), -) -def test_get_controller(single_instance: bool, controller_class: type) -> None: - """Test function get_controller.""" - controller = get_controller(single_instance) - assert isinstance(controller, controller_class) diff --git a/tests/aiet/test_backend_tool.py b/tests/aiet/test_backend_tool.py deleted file mode 100644 index fd5960d..0000000 --- a/tests/aiet/test_backend_tool.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-self-use -"""Tests for the tool backend.""" -from collections import Counter - -import pytest - -from aiet.backend.common import ConfigurationException -from aiet.backend.config import ToolConfig -from aiet.backend.tool import get_available_tool_directory_names -from aiet.backend.tool import get_available_tools -from aiet.backend.tool import get_tool -from aiet.backend.tool import Tool - - -def test_get_available_tool_directory_names() -> None: - """Test get_available_tools mocking get_resources.""" - directory_names = get_available_tool_directory_names() - assert Counter(directory_names) == Counter(["tool1", "tool2", "vela"]) - - -def test_get_available_tools() -> None: - """Test get_available_tools mocking get_resources.""" - available_tools = get_available_tools() - expected_tool_names = sorted( - [ - "tool_1", - "tool_2", - "vela", - "vela", - "vela", - ] - ) - - assert all(isinstance(s, Tool) for s in available_tools) - assert all(s != 42 for s in available_tools) - assert any(s == available_tools[0] for s in available_tools) - assert len(available_tools) == len(expected_tool_names) - available_tool_names = sorted(str(s) for s in available_tools) - assert available_tool_names == expected_tool_names - - -def test_get_tool() -> None: - """Test get_tool mocking get_resoures.""" - tools = get_tool("tool_1") - assert len(tools) == 1 - tool = tools[0] - assert tool is not None - assert isinstance(tool, Tool) - assert tool.name == "tool_1" - - tools = get_tool("unknown tool") - assert not tools - - -def test_tool_creation() -> None: - """Test edge cases when creating a Tool instance.""" - with pytest.raises(ConfigurationException): - Tool(ToolConfig(name="test", commands={"test": []})) # no 'run' command diff --git a/tests/aiet/test_check_model.py b/tests/aiet/test_check_model.py deleted file mode 100644 index 4eafe59..0000000 --- a/tests/aiet/test_check_model.py +++ /dev/null @@ -1,162 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=redefined-outer-name,no-self-use -"""Module for testing check_model.py script.""" -from pathlib import Path -from typing import Any - -import pytest -from ethosu.vela.tflite.Model import Model -from ethosu.vela.tflite.OperatorCode import OperatorCode - -from aiet.cli.common import InvalidTFLiteFileError -from aiet.cli.common import ModelOptimisedException -from aiet.resources.tools.vela.check_model import check_custom_codes_for_ethosu -from aiet.resources.tools.vela.check_model import check_model -from aiet.resources.tools.vela.check_model import get_custom_codes_from_operators -from aiet.resources.tools.vela.check_model import get_model_from_file -from aiet.resources.tools.vela.check_model import get_operators_from_model -from aiet.resources.tools.vela.check_model import is_vela_optimised - - -@pytest.fixture(scope="session") -def optimised_tflite_model( - optimised_input_model_file: Path, -) -> Model: - """Return Model instance read from a Vela-optimised TFLite file.""" - return get_model_from_file(optimised_input_model_file) - - -@pytest.fixture(scope="session") -def non_optimised_tflite_model( - non_optimised_input_model_file: Path, -) -> Model: - """Return Model instance read from a Vela-optimised TFLite file.""" - return get_model_from_file(non_optimised_input_model_file) - - -class TestIsVelaOptimised: - """Test class for is_vela_optimised() function.""" - - def test_return_true_when_input_is_optimised( - self, - optimised_tflite_model: Model, - ) -> None: - """Verify True returned when input is optimised model.""" - output = is_vela_optimised(optimised_tflite_model) - - assert output is True - - def test_return_false_when_input_is_not_optimised( - self, - non_optimised_tflite_model: Model, - ) -> None: - """Verify False returned when input is non-optimised model.""" - output = is_vela_optimised(non_optimised_tflite_model) - - assert output is False - - -def test_get_operator_list_returns_correct_instances( - optimised_tflite_model: Model, -) -> None: - """Verify list of OperatorCode instances returned by get_operator_list().""" - operator_list = get_operators_from_model(optimised_tflite_model) - - assert all(isinstance(operator, OperatorCode) for operator in operator_list) - - -class TestGetCustomCodesFromOperators: - """Test the get_custom_codes_from_operators() function.""" - - def test_returns_empty_list_when_input_operators_have_no_custom_codes( - self, monkeypatch: Any - ) -> None: - """Verify function returns empty list when operators have no custom codes.""" - # Mock OperatorCode.CustomCode() function to return None - monkeypatch.setattr( - "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", lambda _: None - ) - - operators = [OperatorCode()] * 3 - - custom_codes = get_custom_codes_from_operators(operators) - - assert custom_codes == [] - - def test_returns_custom_codes_when_input_operators_have_custom_codes( - self, monkeypatch: Any - ) -> None: - """Verify list of bytes objects returned representing the CustomCodes.""" - # Mock OperatorCode.CustomCode() function to return a byte string - monkeypatch.setattr( - "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", - lambda _: b"custom-code", - ) - - operators = [OperatorCode()] * 3 - - custom_codes = get_custom_codes_from_operators(operators) - - assert custom_codes == [b"custom-code", b"custom-code", b"custom-code"] - - -@pytest.mark.parametrize( - "custom_codes, expected_output", - [ - ([b"ethos-u", b"something else"], True), - ([b"custom-code-1", b"custom-code-2"], False), - ], -) -def test_check_list_for_ethosu(custom_codes: list, expected_output: bool) -> None: - """Verify function detects 'ethos-u' bytes in the input list.""" - output = check_custom_codes_for_ethosu(custom_codes) - assert output is expected_output - - -class TestGetModelFromFile: - """Test the get_model_from_file() function.""" - - def test_error_raised_when_input_is_invalid_model_file( - self, - invalid_input_model_file: Path, - ) -> None: - """Verify error thrown when an invalid model file is given.""" - with pytest.raises(InvalidTFLiteFileError): - get_model_from_file(invalid_input_model_file) - - def test_model_instance_returned_when_input_is_valid_model_file( - self, - optimised_input_model_file: Path, - ) -> None: - """Verify file is read successfully and returns model instance.""" - tflite_model = get_model_from_file(optimised_input_model_file) - - assert isinstance(tflite_model, Model) - - -class TestCheckModel: - """Test the check_model() function.""" - - def test_check_model_with_non_optimised_input( - self, - non_optimised_input_model_file: Path, - ) -> None: - """Verify no error occurs for a valid input file.""" - check_model(non_optimised_input_model_file) - - def test_check_model_with_optimised_input( - self, - optimised_input_model_file: Path, - ) -> None: - """Verify that the right exception is raised with already optimised input.""" - with pytest.raises(ModelOptimisedException): - check_model(optimised_input_model_file) - - def test_check_model_with_invalid_input( - self, - invalid_input_model_file: Path, - ) -> None: - """Verify that an exception is raised with invalid input.""" - with pytest.raises(Exception): - check_model(invalid_input_model_file) diff --git a/tests/aiet/test_cli.py b/tests/aiet/test_cli.py deleted file mode 100644 index e8589fa..0000000 --- a/tests/aiet/test_cli.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module for testing CLI top command.""" -from typing import Any -from unittest.mock import ANY -from unittest.mock import MagicMock - -from click.testing import CliRunner - -from aiet.cli import cli - - -def test_cli(cli_runner: CliRunner) -> None: - """Test CLI top level command.""" - result = cli_runner.invoke(cli) - assert result.exit_code == 0 - assert "system" in cli.commands - assert "application" in cli.commands - - -def test_cli_version(cli_runner: CliRunner) -> None: - """Test version option.""" - result = cli_runner.invoke(cli, ["--version"]) - assert result.exit_code == 0 - assert "version" in result.output - - -def test_cli_verbose(cli_runner: CliRunner, monkeypatch: Any) -> None: - """Test verbose option.""" - with monkeypatch.context() as mock_context: - mock = MagicMock() - # params[1] is the verbose option and we need to replace the - # callback with a mock object - mock_context.setattr(cli.params[1], "callback", mock) - cli_runner.invoke(cli, ["-vvvv"]) - # 4 is the number -v called earlier - mock.assert_called_once_with(ANY, ANY, 4) diff --git a/tests/aiet/test_cli_application.py b/tests/aiet/test_cli_application.py deleted file mode 100644 index f1ccc44..0000000 --- a/tests/aiet/test_cli_application.py +++ /dev/null @@ -1,1153 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=attribute-defined-outside-init,no-member,line-too-long,too-many-arguments,too-many-locals,redefined-outer-name,too-many-lines -"""Module for testing CLI application subcommand.""" -import base64 -import json -import re -import time -from contextlib import contextmanager -from contextlib import ExitStack -from pathlib import Path -from typing import Any -from typing import Generator -from typing import IO -from typing import List -from typing import Optional -from typing import TypedDict -from unittest.mock import MagicMock - -import click -import pytest -from click.testing import CliRunner -from filelock import FileLock - -from aiet.backend.application import Application -from aiet.backend.config import ApplicationConfig -from aiet.backend.config import LocalProtocolConfig -from aiet.backend.config import SSHConfig -from aiet.backend.config import SystemConfig -from aiet.backend.config import UserParamConfig -from aiet.backend.output_parser import Base64OutputParser -from aiet.backend.protocol import SSHProtocol -from aiet.backend.system import load_system -from aiet.cli.application import application_cmd -from aiet.cli.application import details_cmd -from aiet.cli.application import execute_cmd -from aiet.cli.application import install_cmd -from aiet.cli.application import list_cmd -from aiet.cli.application import parse_payload_run_config -from aiet.cli.application import remove_cmd -from aiet.cli.application import run_cmd -from aiet.cli.common import MiddlewareExitCode - - -def test_application_cmd() -> None: - """Test application commands.""" - commands = ["list", "details", "install", "remove", "execute", "run"] - assert all(command in application_cmd.commands for command in commands) - - -@pytest.mark.parametrize("format_", ["json", "cli"]) -def test_application_cmd_context(cli_runner: CliRunner, format_: str) -> None: - """Test setting command context parameters.""" - result = cli_runner.invoke(application_cmd, ["--format", format_]) - # command should fail if no subcommand provided - assert result.exit_code == 2 - - result = cli_runner.invoke(application_cmd, ["--format", format_, "list"]) - assert result.exit_code == 0 - - -@pytest.mark.parametrize( - "format_, system_name, expected_output", - [ - ( - "json", - None, - '{"type": "application", "available": ["application_1", "application_2"]}\n', - ), - ( - "json", - "system_1", - '{"type": "application", "available": ["application_1"]}\n', - ), - ("cli", None, "Available applications:\n\napplication_1\napplication_2\n"), - ("cli", "system_1", "Available applications:\n\napplication_1\n"), - ], -) -def test_list_cmd( - cli_runner: CliRunner, - monkeypatch: Any, - format_: str, - system_name: str, - expected_output: str, -) -> None: - """Test available applications commands.""" - # Mock some applications - mock_application_1 = MagicMock(spec=Application) - mock_application_1.name = "application_1" - mock_application_1.can_run_on.return_value = system_name == "system_1" - mock_application_2 = MagicMock(spec=Application) - mock_application_2.name = "application_2" - mock_application_2.can_run_on.return_value = system_name == "system_2" - - # Monkey patch the call get_available_applications - mock_available_applications = MagicMock() - mock_available_applications.return_value = [mock_application_1, mock_application_2] - - monkeypatch.setattr( - "aiet.backend.application.get_available_applications", - mock_available_applications, - ) - - obj = {"format": format_} - args = [] - if system_name: - list_cmd.params[0].type = click.Choice([system_name]) - args = ["--system", system_name] - result = cli_runner.invoke(list_cmd, obj=obj, args=args) - assert result.output == expected_output - - -def get_test_application() -> Application: - """Return test system details.""" - config = ApplicationConfig( - name="application", - description="test", - build_dir="", - supported_systems=[], - deploy_data=[], - user_params={}, - commands={ - "clean": ["clean"], - "build": ["build"], - "run": ["run"], - "post_run": ["post_run"], - }, - ) - - return Application(config) - - -def get_details_cmd_json_output() -> str: - """Get JSON output for details command.""" - json_output = """ -[ - { - "type": "application", - "name": "application", - "description": "test", - "supported_systems": [], - "commands": { - "clean": { - "command_strings": [ - "clean" - ], - "user_params": [] - }, - "build": { - "command_strings": [ - "build" - ], - "user_params": [] - }, - "run": { - "command_strings": [ - "run" - ], - "user_params": [] - }, - "post_run": { - "command_strings": [ - "post_run" - ], - "user_params": [] - } - } - } -]""" - return json.dumps(json.loads(json_output)) + "\n" - - -def get_details_cmd_console_output() -> str: - """Get console output for details command.""" - return ( - 'Application "application" details' - + "\nDescription: test" - + "\n\nSupported systems: " - + "\n\nclean commands:" - + "\nCommands: ['clean']" - + "\n\nbuild commands:" - + "\nCommands: ['build']" - + "\n\nrun commands:" - + "\nCommands: ['run']" - + "\n\npost_run commands:" - + "\nCommands: ['post_run']" - + "\n" - ) - - -@pytest.mark.parametrize( - "application_name,format_, expected_output", - [ - ("application", "json", get_details_cmd_json_output()), - ("application", "cli", get_details_cmd_console_output()), - ], -) -def test_details_cmd( - cli_runner: CliRunner, - monkeypatch: Any, - application_name: str, - format_: str, - expected_output: str, -) -> None: - """Test application details command.""" - monkeypatch.setattr( - "aiet.cli.application.get_application", - MagicMock(return_value=[get_test_application()]), - ) - - details_cmd.params[0].type = click.Choice(["application"]) - result = cli_runner.invoke( - details_cmd, obj={"format": format_}, args=["--name", application_name] - ) - assert result.exception is None - assert result.output == expected_output - - -def test_details_cmd_wrong_system(cli_runner: CliRunner, monkeypatch: Any) -> None: - """Test details command fails if application is not supported by the system.""" - monkeypatch.setattr( - "aiet.backend.execution.get_application", MagicMock(return_value=[]) - ) - - details_cmd.params[0].type = click.Choice(["application"]) - details_cmd.params[1].type = click.Choice(["system"]) - result = cli_runner.invoke( - details_cmd, args=["--name", "application", "--system", "system"] - ) - assert result.exit_code == 2 - assert ( - "Application 'application' doesn't support the system 'system'" in result.stdout - ) - - -def test_install_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: - """Test install application command.""" - mock_install_application = MagicMock() - monkeypatch.setattr( - "aiet.cli.application.install_application", mock_install_application - ) - - args = ["--source", "test"] - cli_runner.invoke(install_cmd, args=args) - mock_install_application.assert_called_once_with(Path("test")) - - -def test_remove_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: - """Test remove application command.""" - mock_remove_application = MagicMock() - monkeypatch.setattr( - "aiet.cli.application.remove_application", mock_remove_application - ) - remove_cmd.params[0].type = click.Choice(["test"]) - - args = ["--directory_name", "test"] - cli_runner.invoke(remove_cmd, args=args) - mock_remove_application.assert_called_once_with("test") - - -class ExecutionCase(TypedDict, total=False): - """Execution case.""" - - args: List[str] - lock_path: str - can_establish_connection: bool - establish_connection_delay: int - app_exit_code: int - exit_code: int - output: str - - -@pytest.mark.parametrize( - "application_config, system_config, executions", - [ - [ - ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - config_location=Path("wrong_location"), - commands={"build": ["echo build {application.name}"]}, - ), - SystemConfig( - name="test_system", - description="Test system", - data_transfer=LocalProtocolConfig(protocol="local"), - config_location=Path("wrong_location"), - commands={"run": ["echo run {application.name} on {system.name}"]}, - ), - [ - ExecutionCase( - args=["-c", "build"], - exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, - output="Error: Application test_application has wrong config location\n", - ) - ], - ], - [ - ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - build_dir="build", - deploy_data=[("sample_file", "/tmp/sample_file")], - commands={"build": ["echo build {application.name}"]}, - ), - SystemConfig( - name="test_system", - description="Test system", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={"run": ["echo run {application.name} on {system.name}"]}, - ), - [ - ExecutionCase( - args=["-c", "run"], - exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, - output="Error: System test_system does not support data deploy\n", - ) - ], - ], - [ - ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - commands={"build": ["echo build {application.name}"]}, - ), - SystemConfig( - name="test_system", - description="Test system", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={"run": ["echo run {application.name} on {system.name}"]}, - ), - [ - ExecutionCase( - args=["-c", "build"], - exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, - output="Error: No build directory defined for the app test_application\n", - ) - ], - ], - [ - ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["new_system"], - build_dir="build", - commands={ - "build": ["echo build {application.name} with {user_params:0}"] - }, - user_params={ - "build": [ - UserParamConfig( - name="param", - description="sample parameter", - default_value="default", - values=["val1", "val2", "val3"], - ) - ] - }, - ), - SystemConfig( - name="test_system", - description="Test system", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={"run": ["echo run {application.name} on {system.name}"]}, - ), - [ - ExecutionCase( - args=["-c", "build"], - exit_code=1, - output="Error: Application 'test_application' doesn't support the system 'test_system'\n", - ) - ], - ], - [ - ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - build_dir="build", - commands={"build": ["false"]}, - ), - SystemConfig( - name="test_system", - description="Test system", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={"run": ["echo run {application.name} on {system.name}"]}, - ), - [ - ExecutionCase( - args=["-c", "build"], - exit_code=MiddlewareExitCode.BACKEND_ERROR, - output="""Running: false -Error: Execution failed. Please check output for the details.\n""", - ) - ], - ], - [ - ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - lock=True, - build_dir="build", - commands={ - "build": ["echo build {application.name} with {user_params:0}"] - }, - user_params={ - "build": [ - UserParamConfig( - name="param", - description="sample parameter", - default_value="default", - values=["val1", "val2", "val3"], - ) - ] - }, - ), - SystemConfig( - name="test_system", - description="Test system", - lock=True, - data_transfer=LocalProtocolConfig(protocol="local"), - commands={"run": ["echo run {application.name} on {system.name}"]}, - ), - [ - ExecutionCase( - args=["-c", "build"], - exit_code=MiddlewareExitCode.SUCCESS, - output="""Running: echo build test_application with param default -build test_application with param default\n""", - ), - ExecutionCase( - args=["-c", "build"], - lock_path="/tmp/middleware_test_application_test_system.lock", - exit_code=MiddlewareExitCode.CONCURRENT_ERROR, - output="Error: Another instance of the system is running\n", - ), - ExecutionCase( - args=["-c", "build", "--param=param=val3"], - exit_code=MiddlewareExitCode.SUCCESS, - output="""Running: echo build test_application with param val3 -build test_application with param val3\n""", - ), - ExecutionCase( - args=["-c", "build", "--param=param=newval"], - exit_code=1, - output="Error: Application parameter 'param=newval' not valid for command 'build'\n", - ), - ExecutionCase( - args=["-c", "some_command"], - exit_code=MiddlewareExitCode.CONFIGURATION_ERROR, - output="Error: Unsupported command some_command\n", - ), - ExecutionCase( - args=["-c", "run"], - exit_code=MiddlewareExitCode.SUCCESS, - output="""Generating commands to execute -Running: echo run test_application on test_system -run test_application on test_system\n""", - ), - ], - ], - [ - ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - deploy_data=[("sample_file", "/tmp/sample_file")], - commands={ - "run": [ - "echo run {application.name} with {user_params:param} on {system.name}" - ] - }, - user_params={ - "run": [ - UserParamConfig( - name="param=", - description="sample parameter", - default_value="default", - values=["val1", "val2", "val3"], - alias="param", - ) - ] - }, - ), - SystemConfig( - name="test_system", - description="Test system", - lock=True, - data_transfer=SSHConfig( - protocol="ssh", - username="username", - password="password", - hostname="localhost", - port="8022", - ), - commands={"run": ["sleep 100"]}, - ), - [ - ExecutionCase( - args=["-c", "run"], - exit_code=MiddlewareExitCode.SUCCESS, - output="""Generating commands to execute -Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . -Deploying {application.config_location}/sample_file onto /tmp/sample_file -Running: echo run test_application with param=default on test_system -Shutting down sequence... -Stopping test_system... (It could take few seconds) -test_system stopped successfully.\n""", - ), - ExecutionCase( - args=["-c", "run"], - lock_path="/tmp/middleware_test_system.lock", - exit_code=MiddlewareExitCode.CONCURRENT_ERROR, - output="Error: Another instance of the system is running\n", - ), - ExecutionCase( - args=[ - "-c", - "run", - "--deploy={application.config_location}/sample_file:/tmp/sample_file", - ], - exit_code=0, - output="""Generating commands to execute -Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . -Deploying {application.config_location}/sample_file onto /tmp/sample_file -Deploying {application.config_location}/sample_file onto /tmp/sample_file -Running: echo run test_application with param=default on test_system -Shutting down sequence... -Stopping test_system... (It could take few seconds) -test_system stopped successfully.\n""", - ), - ExecutionCase( - args=["-c", "run"], - app_exit_code=1, - exit_code=0, - output="""Generating commands to execute -Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . -Deploying {application.config_location}/sample_file onto /tmp/sample_file -Running: echo run test_application with param=default on test_system -Application exited with exit code 1 -Shutting down sequence... -Stopping test_system... (It could take few seconds) -test_system stopped successfully.\n""", - ), - ExecutionCase( - args=["-c", "run"], - exit_code=MiddlewareExitCode.CONNECTION_ERROR, - can_establish_connection=False, - output="""Generating commands to execute -Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .......................................................................................... -Shutting down sequence... -Stopping test_system... (It could take few seconds) -test_system stopped successfully. -Error: Couldn't connect to 'localhost:8022'.\n""", - ), - ExecutionCase( - args=["-c", "run", "--deploy=bad_format"], - exit_code=1, - output="Error: Invalid deploy parameter 'bad_format' for command run\n", - ), - ExecutionCase( - args=["-c", "run", "--deploy=:"], - exit_code=1, - output="Error: Invalid deploy parameter ':' for command run\n", - ), - ExecutionCase( - args=["-c", "run", "--deploy= : "], - exit_code=1, - output="Error: Invalid deploy parameter ' : ' for command run\n", - ), - ExecutionCase( - args=["-c", "run", "--deploy=some_src_file:"], - exit_code=1, - output="Error: Invalid deploy parameter 'some_src_file:' for command run\n", - ), - ExecutionCase( - args=["-c", "run", "--deploy=:some_dst_file"], - exit_code=1, - output="Error: Invalid deploy parameter ':some_dst_file' for command run\n", - ), - ExecutionCase( - args=["-c", "run", "--deploy=unknown_file:/tmp/dest"], - exit_code=1, - output="Error: Path unknown_file does not exist\n", - ), - ], - ], - [ - ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - commands={ - "run": [ - "echo run {application.name} with {user_params:param} on {system.name}" - ] - }, - user_params={ - "run": [ - UserParamConfig( - name="param=", - description="sample parameter", - default_value="default", - values=["val1", "val2", "val3"], - alias="param", - ) - ] - }, - ), - SystemConfig( - name="test_system", - description="Test system", - data_transfer=SSHConfig( - protocol="ssh", - username="username", - password="password", - hostname="localhost", - port="8022", - ), - commands={"run": ["echo Unable to start system"]}, - ), - [ - ExecutionCase( - args=["-c", "run"], - exit_code=4, - can_establish_connection=False, - establish_connection_delay=1, - output="""Generating commands to execute -Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . - ----------- test_system execution failed ---------- -Unable to start system - - - -Shutting down sequence... -Stopping test_system... (It could take few seconds) -test_system stopped successfully. -Error: Execution failed. Please check output for the details.\n""", - ) - ], - ], - ], -) -def test_application_command_execution( - application_config: ApplicationConfig, - system_config: SystemConfig, - executions: List[ExecutionCase], - tmpdir: Any, - cli_runner: CliRunner, - monkeypatch: Any, -) -> None: - """Test application command execution.""" - - @contextmanager - def lock_execution(lock_path: str) -> Generator[None, None, None]: - lock = FileLock(lock_path) - lock.acquire(timeout=1) - - try: - yield - finally: - lock.release() - - def replace_vars(str_val: str) -> str: - """Replace variables.""" - application_config_location = str( - application_config["config_location"].absolute() - ) - - return str_val.replace( - "{application.config_location}", application_config_location - ) - - for execution in executions: - init_execution_test( - monkeypatch, - tmpdir, - application_config, - system_config, - can_establish_connection=execution.get("can_establish_connection", True), - establish_conection_delay=execution.get("establish_connection_delay", 0), - remote_app_exit_code=execution.get("app_exit_code", 0), - ) - - lock_path = execution.get("lock_path") - - with ExitStack() as stack: - if lock_path: - stack.enter_context(lock_execution(lock_path)) - - args = [replace_vars(arg) for arg in execution["args"]] - - result = cli_runner.invoke( - execute_cmd, - args=["-n", application_config["name"], "-s", system_config["name"]] - + args, - ) - output = replace_vars(execution["output"]) - assert result.exit_code == execution["exit_code"] - assert result.stdout == output - - -@pytest.fixture(params=[False, True], ids=["run-cli", "run-json"]) -def payload_path_or_none(request: Any, tmp_path_factory: Any) -> Optional[Path]: - """Drives tests for run command so that it executes them both to use a json file, and to use CLI.""" - if request.param: - ret: Path = tmp_path_factory.getbasetemp() / "system_config_payload_file.json" - return ret - return None - - -def write_system_payload_config( - payload_file: IO[str], - application_config: ApplicationConfig, - system_config: SystemConfig, -) -> None: - """Write a json payload file for the given test configuration.""" - payload_dict = { - "id": system_config["name"], - "arguments": { - "application": application_config["name"], - }, - } - json.dump(payload_dict, payload_file) - - -@pytest.mark.parametrize( - "application_config, system_config, executions", - [ - [ - ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - build_dir="build", - commands={ - "build": ["echo build {application.name} with {user_params:0}"] - }, - user_params={ - "build": [ - UserParamConfig( - name="param", - description="sample parameter", - default_value="default", - values=["val1", "val2", "val3"], - ) - ] - }, - ), - SystemConfig( - name="test_system", - description="Test system", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={"run": ["echo run {application.name} on {system.name}"]}, - ), - [ - ExecutionCase( - args=[], - exit_code=MiddlewareExitCode.SUCCESS, - output="""Running: echo build test_application with param default -build test_application with param default -Generating commands to execute -Running: echo run test_application on test_system -run test_application on test_system\n""", - ) - ], - ], - [ - ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - commands={ - "run": [ - "echo run {application.name} with {user_params:param} on {system.name}" - ] - }, - user_params={ - "run": [ - UserParamConfig( - name="param=", - description="sample parameter", - default_value="default", - values=["val1", "val2", "val3"], - alias="param", - ) - ] - }, - ), - SystemConfig( - name="test_system", - description="Test system", - data_transfer=SSHConfig( - protocol="ssh", - username="username", - password="password", - hostname="localhost", - port="8022", - ), - commands={"run": ["sleep 100"]}, - ), - [ - ExecutionCase( - args=[], - exit_code=MiddlewareExitCode.SUCCESS, - output="""Generating commands to execute -Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds . -Running: echo run test_application with param=default on test_system -Shutting down sequence... -Stopping test_system... (It could take few seconds) -test_system stopped successfully.\n""", - ) - ], - ], - ], -) -def test_application_run( - application_config: ApplicationConfig, - system_config: SystemConfig, - executions: List[ExecutionCase], - tmpdir: Any, - cli_runner: CliRunner, - monkeypatch: Any, - payload_path_or_none: Path, -) -> None: - """Test application command execution.""" - for execution in executions: - init_execution_test(monkeypatch, tmpdir, application_config, system_config) - - if payload_path_or_none: - with open(payload_path_or_none, "w", encoding="utf-8") as payload_file: - write_system_payload_config( - payload_file, application_config, system_config - ) - - result = cli_runner.invoke( - run_cmd, - args=["--config", str(payload_path_or_none)], - ) - else: - result = cli_runner.invoke( - run_cmd, - args=["-n", application_config["name"], "-s", system_config["name"]] - + execution["args"], - ) - - assert result.stdout == execution["output"] - assert result.exit_code == execution["exit_code"] - - -@pytest.mark.parametrize( - "cmdline,error_pattern", - [ - [ - "--config {payload} -s test_system", - "when --config is set, the following parameters should not be provided", - ], - [ - "--config {payload} -n test_application", - "when --config is set, the following parameters should not be provided", - ], - [ - "--config {payload} -p mypar:3", - "when --config is set, the following parameters should not be provided", - ], - [ - "-p mypar:3", - "when --config is not set, the following parameters are required", - ], - ["-s test_system", "when --config is not set, --name is required"], - ["-n test_application", "when --config is not set, --system is required"], - ], -) -def test_application_run_invalid_param_combinations( - cmdline: str, - error_pattern: str, - cli_runner: CliRunner, - monkeypatch: Any, - tmp_path: Any, - tmpdir: Any, -) -> None: - """Test that invalid combinations arguments result in error as expected.""" - application_config = ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - build_dir="build", - commands={"build": ["echo build {application.name} with {user_params:0}"]}, - user_params={ - "build": [ - UserParamConfig( - name="param", - description="sample parameter", - default_value="default", - values=["val1", "val2", "val3"], - ) - ] - }, - ) - system_config = SystemConfig( - name="test_system", - description="Test system", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={"run": ["echo run {application.name} on {system.name}"]}, - ) - - init_execution_test(monkeypatch, tmpdir, application_config, system_config) - - payload_file = tmp_path / "payload.json" - payload_file.write_text("dummy") - result = cli_runner.invoke( - run_cmd, - args=cmdline.format(payload=payload_file).split(), - ) - found = re.search(error_pattern, result.stdout) - assert found, f"Cannot find pattern: [{error_pattern}] in \n[\n{result.stdout}\n]" - - -@pytest.mark.parametrize( - "payload,expected", - [ - pytest.param( - {"arguments": {}}, - None, - marks=pytest.mark.xfail(reason="no system 'id''", strict=True), - ), - pytest.param( - {"id": "testsystem"}, - None, - marks=pytest.mark.xfail(reason="no arguments object", strict=True), - ), - ( - {"id": "testsystem", "arguments": {"application": "testapp"}}, - ("testsystem", "testapp", [], [], [], None), - ), - ( - { - "id": "testsystem", - "arguments": {"application": "testapp", "par1": "val1"}, - }, - ("testsystem", "testapp", ["par1=val1"], [], [], None), - ), - ( - { - "id": "testsystem", - "arguments": {"application": "testapp", "application/par1": "val1"}, - }, - ("testsystem", "testapp", ["par1=val1"], [], [], None), - ), - ( - { - "id": "testsystem", - "arguments": {"application": "testapp", "system/par1": "val1"}, - }, - ("testsystem", "testapp", [], ["par1=val1"], [], None), - ), - ( - { - "id": "testsystem", - "arguments": {"application": "testapp", "deploy/par1": "val1"}, - }, - ("testsystem", "testapp", [], [], ["par1"], None), - ), - ( - { - "id": "testsystem", - "arguments": { - "application": "testapp", - "appar1": "val1", - "application/appar2": "val2", - "system/syspar1": "val3", - "deploy/depploypar1": "val4", - "application/appar3": "val5", - "system/syspar2": "val6", - "deploy/depploypar2": "val7", - }, - }, - ( - "testsystem", - "testapp", - ["appar1=val1", "appar2=val2", "appar3=val5"], - ["syspar1=val3", "syspar2=val6"], - ["depploypar1", "depploypar2"], - None, - ), - ), - ], -) -def test_parse_payload_run_config(payload: dict, expected: tuple) -> None: - """Test parsing of the JSON payload for the run_config command.""" - assert parse_payload_run_config(payload) == expected - - -def test_application_run_report( - tmpdir: Any, - cli_runner: CliRunner, - monkeypatch: Any, -) -> None: - """Test flag '--report' of command 'application run'.""" - app_metrics = {"app_metric": 3.14} - app_metrics_b64 = base64.b64encode(json.dumps(app_metrics).encode("utf-8")) - application_config = ApplicationConfig( - name="test_application", - description="Test application", - supported_systems=["test_system"], - build_dir="build", - commands={"build": ["echo build {application.name} with {user_params:0}"]}, - user_params={ - "build": [ - UserParamConfig( - name="param", - description="sample parameter", - default_value="default", - values=["val1", "val2", "val3"], - ), - UserParamConfig( - name="p2", - description="another parameter, not overridden", - default_value="the-right-choice", - values=["the-right-choice", "the-bad-choice"], - ), - ] - }, - ) - system_config = SystemConfig( - name="test_system", - description="Test system", - data_transfer=LocalProtocolConfig(protocol="local"), - commands={ - "run": [ - "echo run {application.name} on {system.name}", - f"echo build <{Base64OutputParser.TAG_NAME}>{app_metrics_b64.decode('utf-8')}", - ] - }, - reporting={ - "regex": { - "app_name": { - "pattern": r"run (.\S*) ", - "type": "str", - }, - "sys_name": { - "pattern": r"on (.\S*)", - "type": "str", - }, - } - }, - ) - report_file = Path(tmpdir) / "test_report.json" - param_val = "param=val1" - exit_code = MiddlewareExitCode.SUCCESS - - init_execution_test(monkeypatch, tmpdir, application_config, system_config) - - result = cli_runner.invoke( - run_cmd, - args=[ - "-n", - application_config["name"], - "-s", - system_config["name"], - "--report", - str(report_file), - "--param", - param_val, - ], - ) - assert result.exit_code == exit_code - assert report_file.is_file() - with open(report_file, "r", encoding="utf-8") as file: - report = json.load(file) - - assert report == { - "application": { - "metrics": {"0": {"app_metric": 3.14}}, - "name": "test_application", - "params": {"param": "val1", "p2": "the-right-choice"}, - }, - "system": { - "metrics": {"app_name": "test_application", "sys_name": "test_system"}, - "name": "test_system", - "params": {}, - }, - } - - -def init_execution_test( - monkeypatch: Any, - tmpdir: Any, - application_config: ApplicationConfig, - system_config: SystemConfig, - can_establish_connection: bool = True, - establish_conection_delay: float = 0, - remote_app_exit_code: int = 0, -) -> None: - """Init execution test.""" - application_name = application_config["name"] - system_name = system_config["name"] - - execute_cmd.params[0].type = click.Choice([application_name]) - execute_cmd.params[1].type = click.Choice([system_name]) - execute_cmd.params[2].type = click.Choice(["build", "run", "some_command"]) - - run_cmd.params[0].type = click.Choice([application_name]) - run_cmd.params[1].type = click.Choice([system_name]) - - if "config_location" not in application_config: - application_path = Path(tmpdir) / "application" - application_path.mkdir() - application_config["config_location"] = application_path - - # this file could be used as deploy parameter value or - # as deploy parameter in application configuration - sample_file = application_path / "sample_file" - sample_file.touch() - monkeypatch.setattr( - "aiet.backend.application.get_available_applications", - MagicMock(return_value=[Application(application_config)]), - ) - - ssh_protocol_mock = MagicMock(spec=SSHProtocol) - - def mock_establish_connection() -> bool: - """Mock establish connection function.""" - # give some time for the system to start - time.sleep(establish_conection_delay) - return can_establish_connection - - ssh_protocol_mock.establish_connection.side_effect = mock_establish_connection - ssh_protocol_mock.connection_details.return_value = ("localhost", 8022) - ssh_protocol_mock.run.return_value = ( - remote_app_exit_code, - bytearray(), - bytearray(), - ) - monkeypatch.setattr( - "aiet.backend.protocol.SSHProtocol", MagicMock(return_value=ssh_protocol_mock) - ) - - if "config_location" not in system_config: - system_path = Path(tmpdir) / "system" - system_path.mkdir() - system_config["config_location"] = system_path - monkeypatch.setattr( - "aiet.backend.system.get_available_systems", - MagicMock(return_value=[load_system(system_config)]), - ) - - monkeypatch.setattr("aiet.backend.execution.wait", MagicMock()) diff --git a/tests/aiet/test_cli_common.py b/tests/aiet/test_cli_common.py deleted file mode 100644 index d018e44..0000000 --- a/tests/aiet/test_cli_common.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Test for cli common module.""" -from typing import Any - -import pytest - -from aiet.cli.common import print_command_details -from aiet.cli.common import raise_exception_at_signal - - -def test_print_command_details(capsys: Any) -> None: - """Test print_command_details function.""" - command = { - "command_strings": ["echo test"], - "user_params": [ - {"name": "param_name", "description": "param_description"}, - { - "name": "param_name2", - "description": "param_description2", - "alias": "alias2", - }, - ], - } - print_command_details(command) - captured = capsys.readouterr() - assert "echo test" in captured.out - assert "param_name" in captured.out - assert "alias2" in captured.out - - -def test_raise_exception_at_signal() -> None: - """Test raise_exception_at_signal graceful shutdown.""" - with pytest.raises(Exception) as err: - raise_exception_at_signal(1, "") - - assert str(err.value) == "Middleware shutdown requested" diff --git a/tests/aiet/test_cli_system.py b/tests/aiet/test_cli_system.py deleted file mode 100644 index fd39f31..0000000 --- a/tests/aiet/test_cli_system.py +++ /dev/null @@ -1,240 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module for testing CLI system subcommand.""" -import json -from pathlib import Path -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Union -from unittest.mock import MagicMock - -import click -import pytest -from click.testing import CliRunner - -from aiet.backend.config import SystemConfig -from aiet.backend.system import load_system -from aiet.backend.system import System -from aiet.cli.system import details_cmd -from aiet.cli.system import install_cmd -from aiet.cli.system import list_cmd -from aiet.cli.system import remove_cmd -from aiet.cli.system import system_cmd - - -def test_system_cmd() -> None: - """Test system commands.""" - commands = ["list", "details", "install", "remove"] - assert all(command in system_cmd.commands for command in commands) - - -@pytest.mark.parametrize("format_", ["json", "cli"]) -def test_system_cmd_context(cli_runner: CliRunner, format_: str) -> None: - """Test setting command context parameters.""" - result = cli_runner.invoke(system_cmd, ["--format", format_]) - # command should fail if no subcommand provided - assert result.exit_code == 2 - - result = cli_runner.invoke(system_cmd, ["--format", format_, "list"]) - assert result.exit_code == 0 - - -@pytest.mark.parametrize( - "format_,expected_output", - [ - ("json", '{"type": "system", "available": ["system1", "system2"]}\n'), - ("cli", "Available systems:\n\nsystem1\nsystem2\n"), - ], -) -def test_list_cmd_with_format( - cli_runner: CliRunner, monkeypatch: Any, format_: str, expected_output: str -) -> None: - """Test available systems command with different formats output.""" - # Mock some systems - mock_system1 = MagicMock() - mock_system1.name = "system1" - mock_system2 = MagicMock() - mock_system2.name = "system2" - - # Monkey patch the call get_available_systems - mock_available_systems = MagicMock() - mock_available_systems.return_value = [mock_system1, mock_system2] - monkeypatch.setattr("aiet.cli.system.get_available_systems", mock_available_systems) - - obj = {"format": format_} - result = cli_runner.invoke(list_cmd, obj=obj) - assert result.output == expected_output - - -def get_test_system( - annotations: Optional[Dict[str, Union[str, List[str]]]] = None -) -> System: - """Return test system details.""" - config = SystemConfig( - name="system", - description="test", - data_transfer={ - "protocol": "ssh", - "username": "root", - "password": "root", - "hostname": "localhost", - "port": "8022", - }, - commands={ - "clean": ["clean"], - "build": ["build"], - "run": ["run"], - "post_run": ["post_run"], - }, - annotations=annotations or {}, - ) - - return load_system(config) - - -def get_details_cmd_json_output( - annotations: Optional[Dict[str, Union[str, List[str]]]] = None -) -> str: - """Test JSON output for details command.""" - ann_str = "" - if annotations is not None: - ann_str = '"annotations":{},'.format(json.dumps(annotations)) - - json_output = ( - """ -{ - "type": "system", - "name": "system", - "description": "test", - "data_transfer_protocol": "ssh", - "commands": { - "clean": - { - "command_strings": ["clean"], - "user_params": [] - }, - "build": - { - "command_strings": ["build"], - "user_params": [] - }, - "run": - { - "command_strings": ["run"], - "user_params": [] - }, - "post_run": - { - "command_strings": ["post_run"], - "user_params": [] - } - }, -""" - + ann_str - + """ - "available_application" : [] - } -""" - ) - return json.dumps(json.loads(json_output)) + "\n" - - -def get_details_cmd_console_output( - annotations: Optional[Dict[str, Union[str, List[str]]]] = None -) -> str: - """Test console output for details command.""" - ann_str = "" - if annotations: - val_str = "".join( - "\n\t{}: {}".format(ann_name, ann_value) - for ann_name, ann_value in annotations.items() - ) - ann_str = "\nAnnotations:{}".format(val_str) - return ( - 'System "system" details' - + "\nDescription: test" - + "\nData Transfer Protocol: ssh" - + "\nAvailable Applications: " - + ann_str - + "\n\nclean commands:" - + "\nCommands: ['clean']" - + "\n\nbuild commands:" - + "\nCommands: ['build']" - + "\n\nrun commands:" - + "\nCommands: ['run']" - + "\n\npost_run commands:" - + "\nCommands: ['post_run']" - + "\n" - ) - - -@pytest.mark.parametrize( - "format_,system,expected_output", - [ - ( - "json", - get_test_system(annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}), - get_details_cmd_json_output( - annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]} - ), - ), - ( - "cli", - get_test_system(annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}), - get_details_cmd_console_output( - annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]} - ), - ), - ( - "json", - get_test_system(annotations={}), - get_details_cmd_json_output(annotations={}), - ), - ( - "cli", - get_test_system(annotations={}), - get_details_cmd_console_output(annotations={}), - ), - ], -) -def test_details_cmd( - cli_runner: CliRunner, - monkeypatch: Any, - format_: str, - system: System, - expected_output: str, -) -> None: - """Test details command with different formats output.""" - mock_get_system = MagicMock() - mock_get_system.return_value = system - monkeypatch.setattr("aiet.cli.system.get_system", mock_get_system) - - args = ["--name", "system"] - obj = {"format": format_} - details_cmd.params[0].type = click.Choice(["system"]) - - result = cli_runner.invoke(details_cmd, args=args, obj=obj) - assert result.output == expected_output - - -def test_install_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: - """Test install system command.""" - mock_install_system = MagicMock() - monkeypatch.setattr("aiet.cli.system.install_system", mock_install_system) - - args = ["--source", "test"] - cli_runner.invoke(install_cmd, args=args) - mock_install_system.assert_called_once_with(Path("test")) - - -def test_remove_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None: - """Test remove system command.""" - mock_remove_system = MagicMock() - monkeypatch.setattr("aiet.cli.system.remove_system", mock_remove_system) - remove_cmd.params[0].type = click.Choice(["test"]) - - args = ["--directory_name", "test"] - cli_runner.invoke(remove_cmd, args=args) - mock_remove_system.assert_called_once_with("test") diff --git a/tests/aiet/test_cli_tool.py b/tests/aiet/test_cli_tool.py deleted file mode 100644 index 45d45c8..0000000 --- a/tests/aiet/test_cli_tool.py +++ /dev/null @@ -1,333 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=attribute-defined-outside-init,no-member,line-too-long,too-many-arguments,too-many-locals -"""Module for testing CLI tool subcommand.""" -import json -from pathlib import Path -from typing import Any -from typing import List -from typing import Optional -from typing import Sequence -from unittest.mock import MagicMock - -import click -import pytest -from click.testing import CliRunner -from click.testing import Result - -from aiet.backend.tool import get_unique_tool_names -from aiet.backend.tool import Tool -from aiet.cli.tool import details_cmd -from aiet.cli.tool import execute_cmd -from aiet.cli.tool import list_cmd -from aiet.cli.tool import tool_cmd - - -def test_tool_cmd() -> None: - """Test tool commands.""" - commands = ["list", "details", "execute"] - assert all(command in tool_cmd.commands for command in commands) - - -@pytest.mark.parametrize("format_", ["json", "cli"]) -def test_tool_cmd_context(cli_runner: CliRunner, format_: str) -> None: - """Test setting command context parameters.""" - result = cli_runner.invoke(tool_cmd, ["--format", format_]) - # command should fail if no subcommand provided - assert result.exit_code == 2 - - result = cli_runner.invoke(tool_cmd, ["--format", format_, "list"]) - assert result.exit_code == 0 - - -@pytest.mark.parametrize( - "format_, expected_output", - [ - ( - "json", - '{"type": "tool", "available": ["tool_1", "tool_2"]}\n', - ), - ("cli", "Available tools:\n\ntool_1\ntool_2\n"), - ], -) -def test_list_cmd( - cli_runner: CliRunner, - monkeypatch: Any, - format_: str, - expected_output: str, -) -> None: - """Test available tool commands.""" - # Mock some tools - mock_tool_1 = MagicMock(spec=Tool) - mock_tool_1.name = "tool_1" - mock_tool_2 = MagicMock(spec=Tool) - mock_tool_2.name = "tool_2" - - # Monkey patch the call get_available_tools - mock_available_tools = MagicMock() - mock_available_tools.return_value = [mock_tool_1, mock_tool_2] - - monkeypatch.setattr("aiet.backend.tool.get_available_tools", mock_available_tools) - - obj = {"format": format_} - args: Sequence[str] = [] - result = cli_runner.invoke(list_cmd, obj=obj, args=args) - assert result.output == expected_output - - -def get_details_cmd_json_output() -> List[dict]: - """Get JSON output for details command.""" - json_output = [ - { - "type": "tool", - "name": "tool_1", - "description": "This is tool 1", - "supported_systems": ["System 1"], - "commands": { - "clean": {"command_strings": ["echo 'clean'"], "user_params": []}, - "build": {"command_strings": ["echo 'build'"], "user_params": []}, - "run": {"command_strings": ["echo 'run'"], "user_params": []}, - "post_run": {"command_strings": ["echo 'post_run'"], "user_params": []}, - }, - } - ] - - return json_output - - -def get_details_cmd_console_output() -> str: - """Get console output for details command.""" - return ( - 'Tool "tool_1" details' - "\nDescription: This is tool 1" - "\n\nSupported systems: System 1" - "\n\nclean commands:" - "\nCommands: [\"echo 'clean'\"]" - "\n\nbuild commands:" - "\nCommands: [\"echo 'build'\"]" - "\n\nrun commands:\nCommands: [\"echo 'run'\"]" - "\n\npost_run commands:" - "\nCommands: [\"echo 'post_run'\"]" - "\n" - ) - - -@pytest.mark.parametrize( - [ - "tool_name", - "format_", - "expected_success", - "expected_output", - ], - [ - ("tool_1", "json", True, get_details_cmd_json_output()), - ("tool_1", "cli", True, get_details_cmd_console_output()), - ("non-existent tool", "json", False, None), - ("non-existent tool", "cli", False, None), - ], -) -def test_details_cmd( - cli_runner: CliRunner, - tool_name: str, - format_: str, - expected_success: bool, - expected_output: str, -) -> None: - """Test tool details command.""" - details_cmd.params[0].type = click.Choice(["tool_1", "tool_2", "vela"]) - result = cli_runner.invoke( - details_cmd, obj={"format": format_}, args=["--name", tool_name] - ) - success = result.exit_code == 0 - assert success == expected_success, result.output - if expected_success: - assert result.exception is None - output = json.loads(result.output) if format_ == "json" else result.output - assert output == expected_output - - -@pytest.mark.parametrize( - "system_name", - [ - "", - "Corstone-300: Cortex-M55+Ethos-U55", - "Corstone-300: Cortex-M55+Ethos-U65", - "Corstone-310: Cortex-M85+Ethos-U55", - ], -) -def test_details_cmd_vela(cli_runner: CliRunner, system_name: str) -> None: - """Test tool details command for Vela.""" - details_cmd.params[0].type = click.Choice(get_unique_tool_names()) - details_cmd.params[1].type = click.Choice([system_name]) - args = ["--name", "vela"] - if system_name: - args += ["--system", system_name] - result = cli_runner.invoke(details_cmd, obj={"format": "json"}, args=args) - success = result.exit_code == 0 - assert success, result.output - result_json = json.loads(result.output) - assert result_json - if system_name: - assert len(result_json) == 1 - tool = result_json[0] - assert len(tool["supported_systems"]) == 1 - assert system_name == tool["supported_systems"][0] - else: # no system specified => list details for all systems - assert len(result_json) == 3 - assert all(len(tool["supported_systems"]) == 1 for tool in result_json) - - -@pytest.fixture(scope="session") -def input_model_file(non_optimised_input_model_file: Path) -> Path: - """Provide the path to a quantized dummy model file in the test_resources_path.""" - return non_optimised_input_model_file - - -def execute_vela( - cli_runner: CliRunner, - tool_name: str = "vela", - system_name: Optional[str] = None, - input_model: Optional[Path] = None, - output_model: Optional[Path] = None, - mac: Optional[int] = None, - format_: str = "cli", -) -> Result: - """Run Vela with different parameters.""" - execute_cmd.params[0].type = click.Choice(get_unique_tool_names()) - execute_cmd.params[2].type = click.Choice([system_name or "dummy_system"]) - args = ["--name", tool_name] - if system_name is not None: - args += ["--system", system_name] - if input_model is not None: - args += ["--param", "input={}".format(input_model)] - if output_model is not None: - args += ["--param", "output={}".format(output_model)] - if mac is not None: - args += ["--param", "mac={}".format(mac)] - result = cli_runner.invoke( - execute_cmd, - args=args, - obj={"format": format_}, - ) - return result - - -@pytest.mark.parametrize("format_", ["cli, json"]) -@pytest.mark.parametrize( - ["tool_name", "system_name", "mac", "expected_success", "expected_output"], - [ - ("vela", "System 1", 32, False, None), # system not supported - ("vela", "NON-EXISTENT SYSTEM", 128, False, None), # system does not exist - ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 32, True, None), - ("NON-EXISTENT TOOL", "Corstone-300: Cortex-M55+Ethos-U55", 32, False, None), - ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 64, True, None), - ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 128, True, None), - ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 256, True, None), - ( - "vela", - "Corstone-300: Cortex-M55+Ethos-U55", - 512, - False, - None, - ), # mac not supported - ( - "vela", - "Corstone-300: Cortex-M55+Ethos-U65", - 32, - False, - None, - ), # mac not supported - ("vela", "Corstone-300: Cortex-M55+Ethos-U65", 256, True, None), - ("vela", "Corstone-300: Cortex-M55+Ethos-U65", 512, True, None), - ( - "vela", - None, - 512, - False, - "Error: Please specify the system for tool vela.", - ), # no system specified - ( - "NON-EXISTENT TOOL", - "Corstone-300: Cortex-M55+Ethos-U65", - 512, - False, - None, - ), # tool does not exist - ("vela", "Corstone-310: Cortex-M85+Ethos-U55", 128, True, None), - ], -) -def test_vela_run( - cli_runner: CliRunner, - format_: str, - input_model_file: Path, # pylint: disable=redefined-outer-name - tool_name: str, - system_name: Optional[str], - mac: int, - expected_success: bool, - expected_output: Optional[str], - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """Test the execution of the Vela command.""" - monkeypatch.chdir(tmp_path) - - output_file = Path("vela_output.tflite") - - result = execute_vela( - cli_runner, - tool_name=tool_name, - system_name=system_name, - input_model=input_model_file, - output_model=output_file, - mac=mac, - format_=format_, - ) - - success = result.exit_code == 0 - assert success == expected_success - if success: - # Check output file - output_file = output_file.resolve() - assert output_file.is_file() - if expected_output: - assert result.output.strip() == expected_output - - -@pytest.mark.parametrize("include_input_model", [True, False]) -@pytest.mark.parametrize("include_output_model", [True, False]) -@pytest.mark.parametrize("include_mac", [True, False]) -def test_vela_run_missing_params( - cli_runner: CliRunner, - input_model_file: Path, # pylint: disable=redefined-outer-name - include_input_model: bool, - include_output_model: bool, - include_mac: bool, - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - """Test the execution of the Vela command with missing user parameters.""" - monkeypatch.chdir(tmp_path) - - output_model_file = Path("output_model.tflite") - system_name = "Corstone-300: Cortex-M55+Ethos-U65" - mac = 256 - # input_model is a required parameters, but mac and output_model have default values. - expected_success = include_input_model - - result = execute_vela( - cli_runner, - tool_name="vela", - system_name=system_name, - input_model=input_model_file if include_input_model else None, - output_model=output_model_file if include_output_model else None, - mac=mac if include_mac else None, - ) - - success = result.exit_code == 0 - assert success == expected_success, ( - f"Success is {success}, but expected {expected_success}. " - f"Included params: [" - f"input_model={include_input_model}, " - f"output_model={include_output_model}, " - f"mac={include_mac}]" - ) diff --git a/tests/aiet/test_main.py b/tests/aiet/test_main.py deleted file mode 100644 index f2ebae2..0000000 --- a/tests/aiet/test_main.py +++ /dev/null @@ -1,16 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module for testing AIET main.py.""" -from typing import Any -from unittest.mock import MagicMock - -from aiet import main - - -def test_main(monkeypatch: Any) -> None: - """Test main entry point function.""" - with monkeypatch.context() as mock_context: - mock = MagicMock() - mock_context.setattr(main, "cli", mock) - main.main() - mock.assert_called_once() diff --git a/tests/aiet/test_resources/application_config.json b/tests/aiet/test_resources/application_config.json deleted file mode 100644 index 2dfcfec..0000000 --- a/tests/aiet/test_resources/application_config.json +++ /dev/null @@ -1,96 +0,0 @@ -[ - { - "name": "application_1", - "description": "application number one", - "supported_systems": [ - "system_1", - "system_2" - ], - "build_dir": "build_dir_11", - "commands": { - "clean": [ - "clean_cmd_11" - ], - "build": [ - "build_cmd_11" - ], - "run": [ - "run_cmd_11" - ], - "post_run": [ - "post_run_cmd_11" - ] - }, - "user_params": { - "run": [ - { - "name": "run_param_11", - "values": [], - "description": "run param number one" - } - ], - "build": [ - { - "name": "build_param_11", - "values": [], - "description": "build param number one" - }, - { - "name": "build_param_12", - "values": [], - "description": "build param number two" - }, - { - "name": "build_param_13", - "values": [ - "value_1" - ], - "description": "build param number three with some value" - } - ] - } - }, - { - "name": "application_2", - "description": "application number two", - "supported_systems": [ - "system_2" - ], - "build_dir": "build_dir_21", - "commands": { - "clean": [ - "clean_cmd_21" - ], - "build": [ - "build_cmd_21", - "build_cmd_22" - ], - "run": [ - "run_cmd_21" - ], - "post_run": [ - "post_run_cmd_21" - ] - }, - "user_params": { - "build": [ - { - "name": "build_param_21", - "values": [], - "description": "build param number one" - }, - { - "name": "build_param_22", - "values": [], - "description": "build param number two" - }, - { - "name": "build_param_23", - "values": [], - "description": "build param number three" - } - ], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/application_config.json.license b/tests/aiet/test_resources/application_config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/application_config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application1/aiet-config.json b/tests/aiet/test_resources/applications/application1/aiet-config.json deleted file mode 100644 index 97f0401..0000000 --- a/tests/aiet/test_resources/applications/application1/aiet-config.json +++ /dev/null @@ -1,30 +0,0 @@ -[ - { - "name": "application_1", - "description": "This is application 1", - "supported_systems": [ - { - "name": "System 1" - } - ], - "build_dir": "build", - "commands": { - "clean": [ - "echo 'clean'" - ], - "build": [ - "echo 'build'" - ], - "run": [ - "echo 'run'" - ], - "post_run": [ - "echo 'post_run'" - ] - }, - "user_params": { - "build": [], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/applications/application1/aiet-config.json.license b/tests/aiet/test_resources/applications/application1/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/applications/application1/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application2/aiet-config.json b/tests/aiet/test_resources/applications/application2/aiet-config.json deleted file mode 100644 index e9122d3..0000000 --- a/tests/aiet/test_resources/applications/application2/aiet-config.json +++ /dev/null @@ -1,30 +0,0 @@ -[ - { - "name": "application_2", - "description": "This is application 2", - "supported_systems": [ - { - "name": "System 2" - } - ], - "build_dir": "build", - "commands": { - "clean": [ - "echo 'clean'" - ], - "build": [ - "echo 'build'" - ], - "run": [ - "echo 'run'" - ], - "post_run": [ - "echo 'post_run'" - ] - }, - "user_params": { - "build": [], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/applications/application2/aiet-config.json.license b/tests/aiet/test_resources/applications/application2/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/applications/application2/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application3/readme.txt b/tests/aiet/test_resources/applications/application3/readme.txt deleted file mode 100644 index 8c72c05..0000000 --- a/tests/aiet/test_resources/applications/application3/readme.txt +++ /dev/null @@ -1,4 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -SPDX-License-Identifier: Apache-2.0 - -This application does not have json configuration file diff --git a/tests/aiet/test_resources/applications/application4/aiet-config.json b/tests/aiet/test_resources/applications/application4/aiet-config.json deleted file mode 100644 index 34dc780..0000000 --- a/tests/aiet/test_resources/applications/application4/aiet-config.json +++ /dev/null @@ -1,35 +0,0 @@ -[ - { - "name": "application_4", - "description": "This is application 4", - "build_dir": "build", - "supported_systems": [ - { - "name": "System 4" - } - ], - "commands": { - "build": [ - "cp ../hello_app.txt . # {user_params:0}" - ], - "run": [ - "{application.build_dir}/hello_app.txt" - ] - }, - "user_params": { - "build": [ - { - "name": "--app", - "description": "Sample command param", - "values": [ - "application1", - "application2", - "application3" - ], - "default_value": "application1" - } - ], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/applications/application4/aiet-config.json.license b/tests/aiet/test_resources/applications/application4/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/applications/application4/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/application4/hello_app.txt b/tests/aiet/test_resources/applications/application4/hello_app.txt deleted file mode 100644 index 2ec0d1d..0000000 --- a/tests/aiet/test_resources/applications/application4/hello_app.txt +++ /dev/null @@ -1,4 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -SPDX-License-Identifier: Apache-2.0 - -Hello from APP! diff --git a/tests/aiet/test_resources/applications/application5/aiet-config.json b/tests/aiet/test_resources/applications/application5/aiet-config.json deleted file mode 100644 index 5269409..0000000 --- a/tests/aiet/test_resources/applications/application5/aiet-config.json +++ /dev/null @@ -1,160 +0,0 @@ -[ - { - "name": "application_5", - "description": "This is application 5", - "build_dir": "default_build_dir", - "supported_systems": [ - { - "name": "System 1", - "lock": false - }, - { - "name": "System 2" - } - ], - "variables": { - "var1": "value1", - "var2": "value2" - }, - "lock": true, - "commands": { - "build": [ - "default build command" - ], - "run": [ - "default run command" - ] - }, - "user_params": { - "build": [], - "run": [] - } - }, - { - "name": "application_5A", - "description": "This is application 5A", - "supported_systems": [ - { - "name": "System 1", - "build_dir": "build_5A", - "variables": { - "var1": "new value1" - } - }, - { - "name": "System 2", - "variables": { - "var2": "new value2" - }, - "lock": true, - "commands": { - "run": [ - "run command on system 2" - ] - } - } - ], - "variables": { - "var1": "value1", - "var2": "value2" - }, - "build_dir": "build", - "commands": { - "build": [ - "default build command" - ], - "run": [ - "default run command" - ] - }, - "user_params": { - "build": [], - "run": [] - } - }, - { - "name": "application_5B", - "description": "This is application 5B", - "supported_systems": [ - { - "name": "System 1", - "build_dir": "build_5B", - "variables": { - "var1": "value for var1 System1", - "var2": "value for var2 System1" - }, - "user_params": { - "build": [ - { - "name": "--param_5B", - "description": "Sample command param", - "values": [ - "value1", - "value2", - "value3" - ], - "default_value": "value1", - "alias": "param1" - } - ] - } - }, - { - "name": "System 2", - "variables": { - "var1": "value for var1 System2", - "var2": "value for var2 System2" - }, - "commands": { - "build": [ - "build command on system 2 with {variables:var1} {user_params:param1}" - ], - "run": [ - "run command on system 2" - ] - }, - "user_params": { - "run": [] - } - } - ], - "build_dir": "build", - "commands": { - "build": [ - "default build command with {variables:var1}" - ], - "run": [ - "default run command with {variables:var2}" - ] - }, - "user_params": { - "build": [ - { - "name": "--param", - "description": "Sample command param", - "values": [ - "value1", - "value2", - "value3" - ], - "default_value": "value1", - "alias": "param1" - } - ], - "run": [], - "non_used_command": [ - { - "name": "--not-used", - "description": "Not used param anywhere", - "values": [ - "value1", - "value2", - "value3" - ], - "default_value": "value1", - "alias": "param1" - } - ] - } - } -] diff --git a/tests/aiet/test_resources/applications/application5/aiet-config.json.license b/tests/aiet/test_resources/applications/application5/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/applications/application5/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/applications/readme.txt b/tests/aiet/test_resources/applications/readme.txt deleted file mode 100644 index a1f8209..0000000 --- a/tests/aiet/test_resources/applications/readme.txt +++ /dev/null @@ -1,4 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -SPDX-License-Identifier: Apache-2.0 - -Dummy file for test purposes diff --git a/tests/aiet/test_resources/hello_world.json b/tests/aiet/test_resources/hello_world.json deleted file mode 100644 index 8a9a448..0000000 --- a/tests/aiet/test_resources/hello_world.json +++ /dev/null @@ -1,54 +0,0 @@ -[ - { - "name": "Hello world", - "description": "Dummy application that displays 'Hello world!'", - "supported_systems": [ - "Dummy System" - ], - "build_dir": "build", - "deploy_data": [ - [ - "src", - "/tmp/" - ], - [ - "README", - "/tmp/README.md" - ] - ], - "commands": { - "clean": [], - "build": [], - "run": [ - "echo 'Hello world!'", - "ls -l /tmp" - ], - "post_run": [] - }, - "user_params": { - "run": [ - { - "name": "--choice-param", - "values": [ - "dummy_value_1", - "dummy_value_2" - ], - "default_value": "dummy_value_1", - "description": "Choice param" - }, - { - "name": "--open-param", - "values": [], - "default_value": "dummy_value_4", - "description": "Open param" - }, - { - "name": "--enable-flag", - "default_value": "dummy_value_4", - "description": "Flag param" - } - ], - "build": [] - } - } -] diff --git a/tests/aiet/test_resources/hello_world.json.license b/tests/aiet/test_resources/hello_world.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/hello_world.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/scripts/test_backend_run b/tests/aiet/test_resources/scripts/test_backend_run deleted file mode 100755 index 548f577..0000000 --- a/tests/aiet/test_resources/scripts/test_backend_run +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 - -echo "Hello from script" ->&2 echo "Oops!" -sleep 100 diff --git a/tests/aiet/test_resources/scripts/test_backend_run_script.sh b/tests/aiet/test_resources/scripts/test_backend_run_script.sh deleted file mode 100644 index 548f577..0000000 --- a/tests/aiet/test_resources/scripts/test_backend_run_script.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 - -echo "Hello from script" ->&2 echo "Oops!" -sleep 100 diff --git a/tests/aiet/test_resources/systems/system1/aiet-config.json b/tests/aiet/test_resources/systems/system1/aiet-config.json deleted file mode 100644 index 4b5dd19..0000000 --- a/tests/aiet/test_resources/systems/system1/aiet-config.json +++ /dev/null @@ -1,35 +0,0 @@ -[ - { - "name": "System 1", - "description": "This is system 1", - "build_dir": "build", - "data_transfer": { - "protocol": "ssh", - "username": "root", - "password": "root", - "hostname": "localhost", - "port": "8021" - }, - "commands": { - "clean": [ - "echo 'clean'" - ], - "build": [ - "echo 'build'" - ], - "run": [ - "echo 'run'" - ], - "post_run": [ - "echo 'post_run'" - ], - "deploy": [ - "echo 'deploy'" - ] - }, - "user_params": { - "build": [], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/systems/system1/aiet-config.json.license b/tests/aiet/test_resources/systems/system1/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/systems/system1/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt b/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt deleted file mode 100644 index 487e9d8..0000000 --- a/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt +++ /dev/null @@ -1,2 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/systems/system2/aiet-config.json b/tests/aiet/test_resources/systems/system2/aiet-config.json deleted file mode 100644 index a9e0eb3..0000000 --- a/tests/aiet/test_resources/systems/system2/aiet-config.json +++ /dev/null @@ -1,32 +0,0 @@ -[ - { - "name": "System 2", - "description": "This is system 2", - "build_dir": "build", - "data_transfer": { - "protocol": "ssh", - "username": "root", - "password": "root", - "hostname": "localhost", - "port": "8021" - }, - "commands": { - "clean": [ - "echo 'clean'" - ], - "build": [ - "echo 'build'" - ], - "run": [ - "echo 'run'" - ], - "post_run": [ - "echo 'post_run'" - ] - }, - "user_params": { - "build": [], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/systems/system2/aiet-config.json.license b/tests/aiet/test_resources/systems/system2/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/systems/system2/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/systems/system3/readme.txt b/tests/aiet/test_resources/systems/system3/readme.txt deleted file mode 100644 index aba5a9c..0000000 --- a/tests/aiet/test_resources/systems/system3/readme.txt +++ /dev/null @@ -1,4 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -SPDX-License-Identifier: Apache-2.0 - -This system does not have the json configuration file diff --git a/tests/aiet/test_resources/systems/system4/aiet-config.json b/tests/aiet/test_resources/systems/system4/aiet-config.json deleted file mode 100644 index 295e00f..0000000 --- a/tests/aiet/test_resources/systems/system4/aiet-config.json +++ /dev/null @@ -1,19 +0,0 @@ -[ - { - "name": "System 4", - "description": "This is system 4", - "build_dir": "build", - "data_transfer": { - "protocol": "local" - }, - "commands": { - "run": [ - "echo {application.name}", - "cat {application.commands.run:0}" - ] - }, - "user_params": { - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/systems/system4/aiet-config.json.license b/tests/aiet/test_resources/systems/system4/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/systems/system4/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/tools/tool1/aiet-config.json b/tests/aiet/test_resources/tools/tool1/aiet-config.json deleted file mode 100644 index 067ef7e..0000000 --- a/tests/aiet/test_resources/tools/tool1/aiet-config.json +++ /dev/null @@ -1,30 +0,0 @@ -[ - { - "name": "tool_1", - "description": "This is tool 1", - "build_dir": "build", - "supported_systems": [ - { - "name": "System 1" - } - ], - "commands": { - "clean": [ - "echo 'clean'" - ], - "build": [ - "echo 'build'" - ], - "run": [ - "echo 'run'" - ], - "post_run": [ - "echo 'post_run'" - ] - }, - "user_params": { - "build": [], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/tools/tool1/aiet-config.json.license b/tests/aiet/test_resources/tools/tool1/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/tools/tool1/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/tools/tool2/aiet-config.json b/tests/aiet/test_resources/tools/tool2/aiet-config.json deleted file mode 100644 index 6eee9a6..0000000 --- a/tests/aiet/test_resources/tools/tool2/aiet-config.json +++ /dev/null @@ -1,26 +0,0 @@ -[ - { - "name": "tool_2", - "description": "This is tool 2 with no supported systems", - "build_dir": "build", - "supported_systems": [], - "commands": { - "clean": [ - "echo 'clean'" - ], - "build": [ - "echo 'build'" - ], - "run": [ - "echo 'run'" - ], - "post_run": [ - "echo 'post_run'" - ] - }, - "user_params": { - "build": [], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/tools/tool2/aiet-config.json.license b/tests/aiet/test_resources/tools/tool2/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/tools/tool2/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json deleted file mode 100644 index fe51488..0000000 --- a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json +++ /dev/null @@ -1 +0,0 @@ -[] diff --git a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json deleted file mode 100644 index ff1cf1a..0000000 --- a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json +++ /dev/null @@ -1,35 +0,0 @@ -[ - { - "name": "test_application", - "description": "This is test_application", - "build_dir": "build", - "supported_systems": [ - { - "name": "System 4" - } - ], - "commands": { - "build": [ - "cp ../hello_app.txt ." - ], - "run": [ - "{application.build_dir}/hello_app.txt" - ] - }, - "user_params": { - "build": [ - { - "name": "--app", - "description": "Sample command param", - "values": [ - "application1", - "application2", - "application3" - ], - "default_value": "application1" - } - ], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json deleted file mode 100644 index 724b31b..0000000 --- a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json +++ /dev/null @@ -1,2 +0,0 @@ -This is not valid json file -{ diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json deleted file mode 100644 index 1ebb29c..0000000 --- a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json +++ /dev/null @@ -1,30 +0,0 @@ -[ - { - "name": "test_application", - "description": "This is test_application", - "build_dir": "build", - "commands": { - "build": [ - "cp ../hello_app.txt ." - ], - "run": [ - "{application.build_dir}/hello_app.txt" - ] - }, - "user_params": { - "build": [ - { - "name": "--app", - "description": "Sample command param", - "values": [ - "application1", - "application2", - "application3" - ], - "default_value": "application1" - } - ], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json deleted file mode 100644 index 410d12d..0000000 --- a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json +++ /dev/null @@ -1,35 +0,0 @@ -[ - { - "name": "test_application", - "description": "This is test_application", - "build_dir": "build", - "supported_systems": [ - { - "anme": "System 4" - } - ], - "commands": { - "build": [ - "cp ../hello_app.txt ." - ], - "run": [ - "{application.build_dir}/hello_app.txt" - ] - }, - "user_params": { - "build": [ - { - "name": "--app", - "description": "Sample command param", - "values": [ - "application1", - "application2", - "application3" - ], - "default_value": "application1" - } - ], - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json deleted file mode 100644 index fe51488..0000000 --- a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json +++ /dev/null @@ -1 +0,0 @@ -[] diff --git a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json deleted file mode 100644 index 20142e9..0000000 --- a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json +++ /dev/null @@ -1,16 +0,0 @@ -[ - { - "name": "Test system", - "description": "This is a test system", - "build_dir": "build", - "data_transfer": { - "protocol": "local" - }, - "commands": { - "run": [] - }, - "user_params": { - "run": [] - } - } -] diff --git a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license deleted file mode 100644 index 9b83bfc..0000000 --- a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license +++ /dev/null @@ -1,3 +0,0 @@ -SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. - -SPDX-License-Identifier: Apache-2.0 diff --git a/tests/aiet/test_run_vela_script.py b/tests/aiet/test_run_vela_script.py deleted file mode 100644 index 971856e..0000000 --- a/tests/aiet/test_run_vela_script.py +++ /dev/null @@ -1,152 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=redefined-outer-name,no-self-use -"""Module for testing run_vela.py script.""" -from pathlib import Path -from typing import Any -from typing import List - -import pytest -from click.testing import CliRunner - -from aiet.cli.common import MiddlewareExitCode -from aiet.resources.tools.vela.check_model import get_model_from_file -from aiet.resources.tools.vela.check_model import is_vela_optimised -from aiet.resources.tools.vela.run_vela import run_vela - - -@pytest.fixture(scope="session") -def vela_config_path(test_tools_path: Path) -> Path: - """Return test systems path in a pytest fixture.""" - return test_tools_path / "vela" / "vela.ini" - - -@pytest.fixture( - params=[ - ["ethos-u65-256", "Ethos_U65_High_End", "U65_Shared_Sram"], - ["ethos-u55-32", "Ethos_U55_High_End_Embedded", "U55_Shared_Sram"], - ] -) -def ethos_config(request: Any) -> Any: - """Fixture to provide different configuration for Ethos-U optimization with Vela.""" - return request.param - - -# pylint: disable=too-many-arguments -def generate_args( - input_: Path, - output: Path, - cfg: Path, - acc_config: str, - system_config: str, - memory_mode: str, -) -> List[str]: - """Generate arguments that can be passed to script 'run_vela'.""" - return [ - "-i", - str(input_), - "-o", - str(output), - "--config", - str(cfg), - "--accelerator-config", - acc_config, - "--system-config", - system_config, - "--memory-mode", - memory_mode, - "--optimise", - "Performance", - ] - - -def check_run_vela( - cli_runner: CliRunner, args: List, expected_success: bool, output_file: Path -) -> None: - """Run Vela with the given arguments and check the result.""" - result = cli_runner.invoke(run_vela, args) - success = result.exit_code == MiddlewareExitCode.SUCCESS - assert success == expected_success - if success: - model = get_model_from_file(output_file) - assert is_vela_optimised(model) - - -def run_vela_script( - cli_runner: CliRunner, - input_model_file: Path, - output_model_file: Path, - vela_config: Path, - expected_success: bool, - acc_config: str, - system_config: str, - memory_mode: str, -) -> None: - """Run the command 'run_vela' on the command line.""" - args = generate_args( - input_model_file, - output_model_file, - vela_config, - acc_config, - system_config, - memory_mode, - ) - check_run_vela(cli_runner, args, expected_success, output_model_file) - - -class TestRunVelaCli: - """Test the command-line execution of the run_vela command.""" - - def test_non_optimised_model( - self, - cli_runner: CliRunner, - non_optimised_input_model_file: Path, - tmp_path: Path, - vela_config_path: Path, - ethos_config: List, - ) -> None: - """Verify Vela is run correctly on an unoptimised model.""" - run_vela_script( - cli_runner, - non_optimised_input_model_file, - tmp_path / "test.tflite", - vela_config_path, - True, - *ethos_config, - ) - - def test_optimised_model( - self, - cli_runner: CliRunner, - optimised_input_model_file: Path, - tmp_path: Path, - vela_config_path: Path, - ethos_config: List, - ) -> None: - """Verify Vela is run correctly on an already optimised model.""" - run_vela_script( - cli_runner, - optimised_input_model_file, - tmp_path / "test.tflite", - vela_config_path, - True, - *ethos_config, - ) - - def test_invalid_model( - self, - cli_runner: CliRunner, - invalid_input_model_file: Path, - tmp_path: Path, - vela_config_path: Path, - ethos_config: List, - ) -> None: - """Verify an error is raised when the input model is not valid.""" - run_vela_script( - cli_runner, - invalid_input_model_file, - tmp_path / "test.tflite", - vela_config_path, - False, - *ethos_config, - ) diff --git a/tests/aiet/test_utils_fs.py b/tests/aiet/test_utils_fs.py deleted file mode 100644 index 46d276e..0000000 --- a/tests/aiet/test_utils_fs.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-self-use -"""Module for testing fs.py.""" -from contextlib import ExitStack as does_not_raise -from pathlib import Path -from typing import Any -from typing import Union -from unittest.mock import MagicMock - -import pytest - -from aiet.utils.fs import get_resources -from aiet.utils.fs import read_file_as_bytearray -from aiet.utils.fs import read_file_as_string -from aiet.utils.fs import recreate_directory -from aiet.utils.fs import remove_directory -from aiet.utils.fs import remove_resource -from aiet.utils.fs import ResourceType -from aiet.utils.fs import valid_for_filename - - -@pytest.mark.parametrize( - "resource_name,expected_path", - [ - ("systems", does_not_raise()), - ("applications", does_not_raise()), - ("whaaat", pytest.raises(ResourceWarning)), - (None, pytest.raises(ResourceWarning)), - ], -) -def test_get_resources(resource_name: ResourceType, expected_path: Any) -> None: - """Test get_resources() with multiple parameters.""" - with expected_path: - resource_path = get_resources(resource_name) - assert resource_path.exists() - - -def test_remove_resource_wrong_directory( - monkeypatch: Any, test_applications_path: Path -) -> None: - """Test removing resource with wrong directory.""" - mock_get_resources = MagicMock(return_value=test_applications_path) - monkeypatch.setattr("aiet.utils.fs.get_resources", mock_get_resources) - - mock_shutil_rmtree = MagicMock() - monkeypatch.setattr("aiet.utils.fs.shutil.rmtree", mock_shutil_rmtree) - - with pytest.raises(Exception, match="Resource .* does not exist"): - remove_resource("unknown", "applications") - mock_shutil_rmtree.assert_not_called() - - with pytest.raises(Exception, match="Wrong resource .*"): - remove_resource("readme.txt", "applications") - mock_shutil_rmtree.assert_not_called() - - -def test_remove_resource(monkeypatch: Any, test_applications_path: Path) -> None: - """Test removing resource data.""" - mock_get_resources = MagicMock(return_value=test_applications_path) - monkeypatch.setattr("aiet.utils.fs.get_resources", mock_get_resources) - - mock_shutil_rmtree = MagicMock() - monkeypatch.setattr("aiet.utils.fs.shutil.rmtree", mock_shutil_rmtree) - - remove_resource("application1", "applications") - mock_shutil_rmtree.assert_called_once() - - -def test_remove_directory(tmpdir: Any) -> None: - """Test directory removal.""" - tmpdir_path = Path(tmpdir) - tmpfile = tmpdir_path / "temp.txt" - - for item in [None, tmpfile]: - with pytest.raises(Exception, match="No directory path provided"): - remove_directory(item) - - newdir = tmpdir_path / "newdir" - newdir.mkdir() - - assert newdir.is_dir() - remove_directory(newdir) - assert not newdir.exists() - - -def test_recreate_directory(tmpdir: Any) -> None: - """Test directory recreation.""" - with pytest.raises(Exception, match="No directory path provided"): - recreate_directory(None) - - tmpdir_path = Path(tmpdir) - tmpfile = tmpdir_path / "temp.txt" - tmpfile.touch() - with pytest.raises(Exception, match="Path .* does exist and it is not a directory"): - recreate_directory(tmpfile) - - newdir = tmpdir_path / "newdir" - newdir.mkdir() - newfile = newdir / "newfile" - newfile.touch() - assert list(newdir.iterdir()) == [newfile] - recreate_directory(newdir) - assert not list(newdir.iterdir()) - - newdir2 = tmpdir_path / "newdir2" - assert not newdir2.exists() - recreate_directory(newdir2) - assert newdir2.is_dir() - - -def write_to_file( - write_directory: Any, write_mode: str, write_text: Union[str, bytes] -) -> Path: - """Write some text to a temporary test file.""" - tmpdir_path = Path(write_directory) - tmpfile = tmpdir_path / "file_name.txt" - with open(tmpfile, write_mode) as file: # pylint: disable=unspecified-encoding - file.write(write_text) - return tmpfile - - -class TestReadFileAsString: - """Test read_file_as_string() function.""" - - def test_returns_text_from_valid_file(self, tmpdir: Any) -> None: - """Ensure the string written to a file read correctly.""" - file_path = write_to_file(tmpdir, "w", "hello") - assert read_file_as_string(file_path) == "hello" - - def test_output_is_empty_string_when_input_file_non_existent( - self, tmpdir: Any - ) -> None: - """Ensure empty string returned when reading from non-existent file.""" - file_path = Path(tmpdir / "non-existent.txt") - assert read_file_as_string(file_path) == "" - - -class TestReadFileAsByteArray: - """Test read_file_as_bytearray() function.""" - - def test_returns_bytes_from_valid_file(self, tmpdir: Any) -> None: - """Ensure the bytes written to a file read correctly.""" - file_path = write_to_file(tmpdir, "wb", b"hello bytes") - assert read_file_as_bytearray(file_path) == b"hello bytes" - - def test_output_is_empty_bytearray_when_input_file_non_existent( - self, tmpdir: Any - ) -> None: - """Ensure empty bytearray returned when reading from non-existent file.""" - file_path = Path(tmpdir / "non-existent.txt") - assert read_file_as_bytearray(file_path) == bytearray() - - -@pytest.mark.parametrize( - "value, replacement, expected_result", - [ - ["", "", ""], - ["123", "", "123"], - ["123", "_", "123"], - ["/some_folder/some_script.sh", "", "some_foldersome_script.sh"], - ["/some_folder/some_script.sh", "_", "_some_folder_some_script.sh"], - ["!;'some_name$%^!", "_", "___some_name____"], - ], -) -def test_valid_for_filename(value: str, replacement: str, expected_result: str) -> None: - """Test function valid_for_filename.""" - assert valid_for_filename(value, replacement) == expected_result diff --git a/tests/aiet/test_utils_helpers.py b/tests/aiet/test_utils_helpers.py deleted file mode 100644 index bbe03fc..0000000 --- a/tests/aiet/test_utils_helpers.py +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Module for testing helpers.py.""" -import logging -from typing import Any -from typing import List -from unittest.mock import call -from unittest.mock import MagicMock - -import pytest - -from aiet.utils.helpers import set_verbosity - - -@pytest.mark.parametrize( - "verbosity,expected_calls", - [(0, []), (1, [call(logging.INFO)]), (2, [call(logging.DEBUG)])], -) -def test_set_verbosity( - verbosity: int, expected_calls: List[Any], monkeypatch: Any -) -> None: - """Test set_verbosity() with different verbsosity levels.""" - with monkeypatch.context() as mock_context: - logging_mock = MagicMock() - mock_context.setattr(logging.getLogger(), "setLevel", logging_mock) - set_verbosity(None, None, verbosity) - logging_mock.assert_has_calls(expected_calls) diff --git a/tests/aiet/test_utils_proc.py b/tests/aiet/test_utils_proc.py deleted file mode 100644 index 9fb48dd..0000000 --- a/tests/aiet/test_utils_proc.py +++ /dev/null @@ -1,272 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=attribute-defined-outside-init,no-self-use,not-callable -"""Pytests for testing aiet/utils/proc.py.""" -from pathlib import Path -from typing import Any -from unittest import mock - -import psutil -import pytest -from sh import ErrorReturnCode - -from aiet.utils.proc import Command -from aiet.utils.proc import CommandFailedException -from aiet.utils.proc import CommandNotFound -from aiet.utils.proc import parse_command -from aiet.utils.proc import print_command_stdout -from aiet.utils.proc import run_and_wait -from aiet.utils.proc import save_process_info -from aiet.utils.proc import ShellCommand -from aiet.utils.proc import terminate_command -from aiet.utils.proc import terminate_external_process - - -class TestShellCommand: - """Sample class for collecting tests.""" - - def test_shellcommand_default_value(self) -> None: - """Test the instantiation of the class ShellCommand with no parameter.""" - shell_command = ShellCommand() - assert shell_command.base_log_path == "/tmp" - - @pytest.mark.parametrize( - "base_log_path,expected", [("/test", "/test"), ("/asd", "/asd")] - ) - def test_shellcommand_with_param(self, base_log_path: str, expected: str) -> None: - """Test init ShellCommand with different parameters.""" - shell_command = ShellCommand(base_log_path) - assert shell_command.base_log_path == expected - - def test_run_ls(self, monkeypatch: Any) -> None: - """Test a simple ls command.""" - mock_command = mock.MagicMock() - monkeypatch.setattr(Command, "bake", mock_command) - - mock_get_stdout_stderr_paths = mock.MagicMock() - mock_get_stdout_stderr_paths.return_value = ("/tmp/std.out", "/tmp/std.err") - monkeypatch.setattr( - ShellCommand, "get_stdout_stderr_paths", mock_get_stdout_stderr_paths - ) - - shell_command = ShellCommand() - shell_command.run("ls", "-l") - assert mock_command.mock_calls[0] == mock.call(("-l",)) - assert mock_command.mock_calls[1] == mock.call()( - _bg=True, _err="/tmp/std.err", _out="/tmp/std.out", _tee=True, _bg_exc=False - ) - - def test_run_command_not_found(self) -> None: - """Test whe the command doesn't exist.""" - shell_command = ShellCommand() - with pytest.raises(CommandNotFound): - shell_command.run("lsl", "-l") - - def test_get_stdout_stderr_paths_valid_path(self) -> None: - """Test the method to get files to store stdout and stderr.""" - valid_path = "/tmp" - shell_command = ShellCommand(valid_path) - out, err = shell_command.get_stdout_stderr_paths(valid_path, "cmd") - assert out.exists() and out.is_file() - assert err.exists() and err.is_file() - assert "cmd" in out.name - assert "cmd" in err.name - - def test_get_stdout_stderr_paths_not_invalid_path(self) -> None: - """Test the method to get output files with an invalid path.""" - invalid_path = "/invalid/foo/bar" - shell_command = ShellCommand(invalid_path) - with pytest.raises(FileNotFoundError): - shell_command.get_stdout_stderr_paths(invalid_path, "cmd") - - -@mock.patch("builtins.print") -def test_print_command_stdout_alive(mock_print: Any) -> None: - """Test the print command stdout with an alive (running) process.""" - mock_command = mock.MagicMock() - mock_command.is_alive.return_value = True - mock_command.next.side_effect = ["test1", "test2", StopIteration] - - print_command_stdout(mock_command) - - mock_command.assert_has_calls( - [mock.call.is_alive(), mock.call.next(), mock.call.next()] - ) - mock_print.assert_has_calls( - [mock.call("test1", end=""), mock.call("test2", end="")] - ) - - -@mock.patch("builtins.print") -def test_print_command_stdout_not_alive(mock_print: Any) -> None: - """Test the print command stdout with a not alive (exited) process.""" - mock_command = mock.MagicMock() - mock_command.is_alive.return_value = False - mock_command.stdout = "test" - - print_command_stdout(mock_command) - mock_command.assert_has_calls([mock.call.is_alive()]) - mock_print.assert_called_once_with("test") - - -def test_terminate_external_process_no_process(capsys: Any) -> None: - """Test that non existed process could be terminated.""" - mock_command = mock.MagicMock() - mock_command.terminate.side_effect = psutil.Error("Error!") - - terminate_external_process(mock_command) - captured = capsys.readouterr() - assert captured.out == "Unable to terminate process\n" - - -def test_terminate_external_process_case1() -> None: - """Test when process terminated immediately.""" - mock_command = mock.MagicMock() - mock_command.is_running.return_value = False - - terminate_external_process(mock_command) - mock_command.terminate.assert_called_once() - mock_command.is_running.assert_called_once() - - -def test_terminate_external_process_case2() -> None: - """Test when process termination takes time.""" - mock_command = mock.MagicMock() - mock_command.is_running.side_effect = [True, True, False] - - terminate_external_process(mock_command) - mock_command.terminate.assert_called_once() - assert mock_command.is_running.call_count == 3 - - -def test_terminate_external_process_case3() -> None: - """Test when process termination takes more time.""" - mock_command = mock.MagicMock() - mock_command.is_running.side_effect = [True, True, True] - - terminate_external_process( - mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1 - ) - assert mock_command.is_running.call_count == 3 - assert mock_command.terminate.call_count == 2 - - -def test_terminate_external_process_case4() -> None: - """Test when process termination takes more time.""" - mock_command = mock.MagicMock() - mock_command.is_running.side_effect = [True, True, False] - - terminate_external_process( - mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1 - ) - mock_command.terminate.assert_called_once() - assert mock_command.is_running.call_count == 3 - assert mock_command.terminate.call_count == 1 - - -def test_terminate_command_no_process() -> None: - """Test command termination when process does not exist.""" - mock_command = mock.MagicMock() - mock_command.process.signal_group.side_effect = ProcessLookupError() - - terminate_command(mock_command) - mock_command.process.signal_group.assert_called_once() - mock_command.is_alive.assert_not_called() - - -def test_terminate_command() -> None: - """Test command termination.""" - mock_command = mock.MagicMock() - mock_command.is_alive.return_value = False - - terminate_command(mock_command) - mock_command.process.signal_group.assert_called_once() - - -def test_terminate_command_case1() -> None: - """Test command termination when it takes time..""" - mock_command = mock.MagicMock() - mock_command.is_alive.side_effect = [True, True, False] - - terminate_command(mock_command, wait_period=0.1) - mock_command.process.signal_group.assert_called_once() - assert mock_command.is_alive.call_count == 3 - - -def test_terminate_command_case2() -> None: - """Test command termination when it takes much time..""" - mock_command = mock.MagicMock() - mock_command.is_alive.side_effect = [True, True, True] - - terminate_command(mock_command, number_of_attempts=3, wait_period=0.1) - assert mock_command.is_alive.call_count == 3 - assert mock_command.process.signal_group.call_count == 2 - - -class TestRunAndWait: - """Test run_and_wait function.""" - - @pytest.fixture(autouse=True) - def setup_method(self, monkeypatch: Any) -> None: - """Init test method.""" - self.execute_command_mock = mock.MagicMock() - monkeypatch.setattr( - "aiet.utils.proc.execute_command", self.execute_command_mock - ) - - self.terminate_command_mock = mock.MagicMock() - monkeypatch.setattr( - "aiet.utils.proc.terminate_command", self.terminate_command_mock - ) - - def test_if_execute_command_raises_exception(self) -> None: - """Test if execute_command fails.""" - self.execute_command_mock.side_effect = Exception("Error!") - with pytest.raises(Exception, match="Error!"): - run_and_wait("command", Path.cwd()) - - def test_if_command_finishes_with_error(self) -> None: - """Test if command finishes with error.""" - cmd_mock = mock.MagicMock() - self.execute_command_mock.return_value = cmd_mock - exit_code_mock = mock.PropertyMock( - side_effect=ErrorReturnCode("cmd", bytearray(), bytearray()) - ) - type(cmd_mock).exit_code = exit_code_mock - - with pytest.raises(CommandFailedException): - run_and_wait("command", Path.cwd()) - - @pytest.mark.parametrize("terminate_on_error, call_count", ((False, 0), (True, 1))) - def test_if_command_finishes_with_exception( - self, terminate_on_error: bool, call_count: int - ) -> None: - """Test if command finishes with error.""" - cmd_mock = mock.MagicMock() - self.execute_command_mock.return_value = cmd_mock - exit_code_mock = mock.PropertyMock(side_effect=Exception("Error!")) - type(cmd_mock).exit_code = exit_code_mock - - with pytest.raises(Exception, match="Error!"): - run_and_wait("command", Path.cwd(), terminate_on_error=terminate_on_error) - - assert self.terminate_command_mock.call_count == call_count - - -def test_save_process_info_no_process(monkeypatch: Any, tmpdir: Any) -> None: - """Test save_process_info function.""" - mock_process = mock.MagicMock() - monkeypatch.setattr("psutil.Process", mock.MagicMock(return_value=mock_process)) - mock_process.children.side_effect = psutil.NoSuchProcess(555) - - pid_file_path = Path(tmpdir) / "test.pid" - save_process_info(555, pid_file_path) - assert not pid_file_path.exists() - - -def test_parse_command() -> None: - """Test parse_command function.""" - assert parse_command("1.sh") == ["bash", "1.sh"] - assert parse_command("1.sh", shell="sh") == ["sh", "1.sh"] - assert parse_command("command") == ["command"] - assert parse_command("command 123 --param=1") == ["command", "123", "--param=1"] diff --git a/tests/mlia/conftest.py b/tests/mlia/conftest.py index f683fca..0b4b2aa 100644 --- a/tests/mlia/conftest.py +++ b/tests/mlia/conftest.py @@ -1,7 +1,10 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Pytest conf module.""" +import shutil +import tarfile from pathlib import Path +from typing import Any import pytest @@ -18,3 +21,91 @@ def fixture_test_resources_path() -> Path: def fixture_dummy_context(tmpdir: str) -> ExecutionContext: """Return dummy context fixture.""" return ExecutionContext(working_dir=tmpdir) + + +@pytest.fixture(scope="session") +def test_systems_path(test_resources_path: Path) -> Path: + """Return test systems path in a pytest fixture.""" + return test_resources_path / "backends" / "systems" + + +@pytest.fixture(scope="session") +def test_applications_path(test_resources_path: Path) -> Path: + """Return test applications path in a pytest fixture.""" + return test_resources_path / "backends" / "applications" + + +@pytest.fixture(scope="session") +def non_optimised_input_model_file(test_tflite_model: Path) -> Path: + """Provide the path to a quantized dummy model file.""" + return test_tflite_model + + +@pytest.fixture(scope="session") +def optimised_input_model_file(test_tflite_vela_model: Path) -> Path: + """Provide path to Vela-optimised dummy model file.""" + return test_tflite_vela_model + + +@pytest.fixture(scope="session") +def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path: + """Provide the path to an invalid dummy model file.""" + return test_tflite_invalid_model + + +@pytest.fixture(autouse=True) +def test_resources(monkeypatch: pytest.MonkeyPatch, test_resources_path: Path) -> Any: + """Force using test resources as middleware's repository.""" + + def get_test_resources() -> Path: + """Return path to the test resources.""" + return test_resources_path / "backends" + + monkeypatch.setattr("mlia.backend.fs.get_backend_resources", get_test_resources) + yield + + +def create_archive( + archive_name: str, source: Path, destination: Path, with_root_folder: bool = False +) -> None: + """Create archive from directory source.""" + with tarfile.open(destination / archive_name, mode="w:gz") as tar: + for item in source.iterdir(): + item_name = item.name + if with_root_folder: + item_name = f"{source.name}/{item_name}" + tar.add(item, item_name) + + +def process_directory(source: Path, destination: Path) -> None: + """Process resource directory.""" + destination.mkdir() + + for item in source.iterdir(): + if item.is_dir(): + create_archive(f"{item.name}.tar.gz", item, destination) + create_archive(f"{item.name}_dir.tar.gz", item, destination, True) + + +@pytest.fixture(scope="session", autouse=True) +def add_archives( + test_resources_path: Path, tmp_path_factory: pytest.TempPathFactory +) -> Any: + """Generate archives of the test resources.""" + tmp_path = tmp_path_factory.mktemp("archives") + + archives_path = tmp_path / "archives" + archives_path.mkdir() + + if (archives_path_link := test_resources_path / "archives").is_symlink(): + archives_path_link.unlink() + + archives_path_link.symlink_to(archives_path, target_is_directory=True) + + for item in ["applications", "systems"]: + process_directory(test_resources_path / "backends" / item, archives_path / item) + + yield + + archives_path_link.unlink() + shutil.rmtree(tmp_path) diff --git a/tests/mlia/test_backend_application.py b/tests/mlia/test_backend_application.py new file mode 100644 index 0000000..2cfb2ef --- /dev/null +++ b/tests/mlia/test_backend_application.py @@ -0,0 +1,460 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Tests for the application backend.""" +from collections import Counter +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import List +from unittest.mock import MagicMock + +import pytest + +from mlia.backend.application import Application +from mlia.backend.application import get_application +from mlia.backend.application import get_available_application_directory_names +from mlia.backend.application import get_available_applications +from mlia.backend.application import get_unique_application_names +from mlia.backend.application import install_application +from mlia.backend.application import load_applications +from mlia.backend.application import remove_application +from mlia.backend.common import Command +from mlia.backend.common import DataPaths +from mlia.backend.common import Param +from mlia.backend.common import UserParamConfig +from mlia.backend.config import ApplicationConfig +from mlia.backend.config import ExtendedApplicationConfig +from mlia.backend.config import NamedExecutionConfig + + +def test_get_available_application_directory_names() -> None: + """Test get_available_applicationss mocking get_resources.""" + directory_names = get_available_application_directory_names() + assert Counter(directory_names) == Counter( + [ + "application1", + "application2", + "application4", + "application5", + "application6", + ] + ) + + +def test_get_available_applications() -> None: + """Test get_available_applicationss mocking get_resources.""" + available_applications = get_available_applications() + + assert all(isinstance(s, Application) for s in available_applications) + assert all(s != 42 for s in available_applications) + assert len(available_applications) == 10 + # application_5 has multiply items with multiply supported systems + assert [str(s) for s in available_applications] == [ + "application_1", + "application_2", + "application_4", + "application_5", + "application_5", + "application_5A", + "application_5A", + "application_5B", + "application_5B", + "application_6", + ] + + +def test_get_unique_application_names() -> None: + """Test get_unique_application_names.""" + unique_names = get_unique_application_names() + + assert all(isinstance(s, str) for s in unique_names) + assert all(s for s in unique_names) + assert sorted(unique_names) == [ + "application_1", + "application_2", + "application_4", + "application_5", + "application_5A", + "application_5B", + "application_6", + ] + + +def test_get_application() -> None: + """Test get_application mocking get_resoures.""" + application = get_application("application_1") + if len(application) != 1: + pytest.fail("Unable to get application") + assert application[0].name == "application_1" + + application = get_application("unknown application") + assert len(application) == 0 + + +@pytest.mark.parametrize( + "source, call_count, expected_exception", + ( + ( + "archives/applications/application1.tar.gz", + 0, + pytest.raises( + Exception, match=r"Applications \[application_1\] are already installed" + ), + ), + ( + "various/applications/application_with_empty_config", + 0, + pytest.raises(Exception, match="No application definition found"), + ), + ( + "various/applications/application_with_wrong_config1", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ( + "various/applications/application_with_wrong_config2", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ( + "various/applications/application_with_wrong_config3", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ("various/applications/application_with_valid_config", 1, does_not_raise()), + ( + "archives/applications/application3.tar.gz", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ( + "backends/applications/application1", + 0, + pytest.raises( + Exception, match=r"Applications \[application_1\] are already installed" + ), + ), + ( + "backends/applications/application3", + 0, + pytest.raises(Exception, match="Unable to read application definition"), + ), + ), +) +def test_install_application( + monkeypatch: Any, + test_resources_path: Path, + source: str, + call_count: int, + expected_exception: Any, +) -> None: + """Test application install from archive.""" + mock_create_destination_and_install = MagicMock() + monkeypatch.setattr( + "mlia.backend.application.create_destination_and_install", + mock_create_destination_and_install, + ) + + with expected_exception: + install_application(test_resources_path / source) + assert mock_create_destination_and_install.call_count == call_count + + +def test_remove_application(monkeypatch: Any) -> None: + """Test application removal.""" + mock_remove_backend = MagicMock() + monkeypatch.setattr("mlia.backend.application.remove_backend", mock_remove_backend) + + remove_application("some_application_directory") + mock_remove_backend.assert_called_once() + + +def test_application_config_without_commands() -> None: + """Test application config without commands.""" + config = ApplicationConfig(name="application") + application = Application(config) + # pylint: disable=use-implicit-booleaness-not-comparison + assert application.commands == {} + + +class TestApplication: + """Test for application class methods.""" + + def test___eq__(self) -> None: + """Test overloaded __eq__ method.""" + config = ApplicationConfig( + # Application + supported_systems=["system1", "system2"], + build_dir="build_dir", + # inherited from Backend + name="name", + description="description", + commands={}, + ) + application1 = Application(config) + application2 = Application(config) # Identical + assert application1 == application2 + + application3 = Application(config) # changed + # Change one single attribute so not equal, but same Type + setattr(application3, "supported_systems", ["somewhere/else"]) + assert application1 != application3 + + # different Type + application4 = "Not the Application you are looking for" + assert application1 != application4 + + application5 = Application(config) + # supported systems could be in any order + setattr(application5, "supported_systems", ["system2", "system1"]) + assert application1 == application5 + + def test_can_run_on(self) -> None: + """Test Application can run on.""" + config = ApplicationConfig(name="application", supported_systems=["System-A"]) + + application = Application(config) + assert application.can_run_on("System-A") + assert not application.can_run_on("System-B") + + applications = get_application("application_1", "System 1") + assert len(applications) == 1 + assert applications[0].can_run_on("System 1") + + def test_get_deploy_data(self, tmp_path: Path) -> None: + """Test Application can run on.""" + src, dest = "src", "dest" + config = ApplicationConfig( + name="application", deploy_data=[(src, dest)], config_location=tmp_path + ) + src_path = tmp_path / src + src_path.mkdir() + application = Application(config) + assert application.get_deploy_data() == [DataPaths(src_path, dest)] + + def test_get_deploy_data_no_config_location(self) -> None: + """Test that getting deploy data fails if no config location provided.""" + with pytest.raises( + Exception, match="Unable to get application .* config location" + ): + Application(ApplicationConfig(name="application")).get_deploy_data() + + def test_unable_to_create_application_without_name(self) -> None: + """Test that it is not possible to create application without name.""" + with pytest.raises(Exception, match="Name is empty"): + Application(ApplicationConfig()) + + def test_application_config_without_commands(self) -> None: + """Test application config without commands.""" + config = ApplicationConfig(name="application") + application = Application(config) + # pylint: disable=use-implicit-booleaness-not-comparison + assert application.commands == {} + + @pytest.mark.parametrize( + "config, expected_params", + ( + ( + ApplicationConfig( + name="application", + commands={"command": ["cmd {user_params:0} {user_params:1}"]}, + user_params={ + "command": [ + UserParamConfig( + name="--param1", description="param1", alias="param1" + ), + UserParamConfig( + name="--param2", description="param2", alias="param2" + ), + ] + }, + ), + [Param("--param1", "param1"), Param("--param2", "param2")], + ), + ( + ApplicationConfig( + name="application", + commands={"command": ["cmd {user_params:param1} {user_params:1}"]}, + user_params={ + "command": [ + UserParamConfig( + name="--param1", description="param1", alias="param1" + ), + UserParamConfig( + name="--param2", description="param2", alias="param2" + ), + ] + }, + ), + [Param("--param1", "param1"), Param("--param2", "param2")], + ), + ( + ApplicationConfig( + name="application", + commands={"command": ["cmd {user_params:param1}"]}, + user_params={ + "command": [ + UserParamConfig( + name="--param1", description="param1", alias="param1" + ), + UserParamConfig( + name="--param2", description="param2", alias="param2" + ), + ] + }, + ), + [Param("--param1", "param1")], + ), + ), + ) + def test_remove_unused_params( + self, config: ApplicationConfig, expected_params: List[Param] + ) -> None: + """Test mod remove_unused_parameter.""" + application = Application(config) + application.remove_unused_params() + assert application.commands["command"].params == expected_params + + +@pytest.mark.parametrize( + "config, expected_error", + ( + ( + ExtendedApplicationConfig(name="application"), + pytest.raises(Exception, match="No supported systems definition provided"), + ), + ( + ExtendedApplicationConfig( + name="application", supported_systems=[NamedExecutionConfig(name="")] + ), + pytest.raises( + Exception, + match="Unable to read supported system definition, name is missed", + ), + ), + ( + ExtendedApplicationConfig( + name="application", + supported_systems=[ + NamedExecutionConfig( + name="system", + commands={"command": ["cmd"]}, + user_params={"command": [UserParamConfig(name="param")]}, + ) + ], + commands={"command": ["cmd {user_params:0}"]}, + user_params={"command": [UserParamConfig(name="param")]}, + ), + pytest.raises( + Exception, match="Default parameters for command .* should have aliases" + ), + ), + ( + ExtendedApplicationConfig( + name="application", + supported_systems=[ + NamedExecutionConfig( + name="system", + commands={"command": ["cmd"]}, + user_params={"command": [UserParamConfig(name="param")]}, + ) + ], + commands={"command": ["cmd {user_params:0}"]}, + user_params={"command": [UserParamConfig(name="param", alias="param")]}, + ), + pytest.raises( + Exception, match="system parameters for command .* should have aliases" + ), + ), + ), +) +def test_load_application_exceptional_cases( + config: ExtendedApplicationConfig, expected_error: Any +) -> None: + """Test exceptional cases for application load function.""" + with expected_error: + load_applications(config) + + +def test_load_application() -> None: + """Test application load function. + + The main purpose of this test is to test configuration for application + for different systems. All configuration should be correctly + overridden if needed. + """ + application_5 = get_application("application_5") + assert len(application_5) == 2 + + default_commands = { + "build": Command(["default build command"]), + "run": Command(["default run command"]), + } + default_variables = {"var1": "value1", "var2": "value2"} + + application_5_0 = application_5[0] + assert application_5_0.build_dir == "default_build_dir" + assert application_5_0.supported_systems == ["System 1"] + assert application_5_0.commands == default_commands + assert application_5_0.variables == default_variables + assert application_5_0.lock is False + + application_5_1 = application_5[1] + assert application_5_1.build_dir == application_5_0.build_dir + assert application_5_1.supported_systems == ["System 2"] + assert application_5_1.commands == application_5_1.commands + assert application_5_1.variables == default_variables + + application_5a = get_application("application_5A") + assert len(application_5a) == 2 + + application_5a_0 = application_5a[0] + assert application_5a_0.supported_systems == ["System 1"] + assert application_5a_0.build_dir == "build_5A" + assert application_5a_0.commands == default_commands + assert application_5a_0.variables == {"var1": "new value1", "var2": "value2"} + assert application_5a_0.lock is False + + application_5a_1 = application_5a[1] + assert application_5a_1.supported_systems == ["System 2"] + assert application_5a_1.build_dir == "build" + assert application_5a_1.commands == { + "build": Command(["default build command"]), + "run": Command(["run command on system 2"]), + } + assert application_5a_1.variables == {"var1": "value1", "var2": "new value2"} + assert application_5a_1.lock is True + + application_5b = get_application("application_5B") + assert len(application_5b) == 2 + + application_5b_0 = application_5b[0] + assert application_5b_0.build_dir == "build_5B" + assert application_5b_0.supported_systems == ["System 1"] + assert application_5b_0.commands == { + "build": Command(["default build command with value for var1 System1"], []), + "run": Command(["default run command with value for var2 System1"]), + } + assert "non_used_command" not in application_5b_0.commands + + application_5b_1 = application_5b[1] + assert application_5b_1.build_dir == "build" + assert application_5b_1.supported_systems == ["System 2"] + assert application_5b_1.commands == { + "build": Command( + [ + "build command on system 2 with value" + " for var1 System2 {user_params:param1}" + ], + [ + Param( + "--param", + "Sample command param", + ["value1", "value2", "value3"], + "value1", + ) + ], + ), + "run": Command(["run command on system 2"], []), + } diff --git a/tests/mlia/test_backend_common.py b/tests/mlia/test_backend_common.py new file mode 100644 index 0000000..82a985a --- /dev/null +++ b/tests/mlia/test_backend_common.py @@ -0,0 +1,486 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use,protected-access +"""Tests for the common backend module.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import cast +from typing import Dict +from typing import IO +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +from unittest.mock import MagicMock + +import pytest + +from mlia.backend.application import Application +from mlia.backend.common import Backend +from mlia.backend.common import BaseBackendConfig +from mlia.backend.common import Command +from mlia.backend.common import ConfigurationException +from mlia.backend.common import load_config +from mlia.backend.common import Param +from mlia.backend.common import parse_raw_parameter +from mlia.backend.common import remove_backend +from mlia.backend.config import ApplicationConfig +from mlia.backend.config import UserParamConfig +from mlia.backend.execution import ExecutionContext +from mlia.backend.execution import ParamResolver +from mlia.backend.system import System + + +@pytest.mark.parametrize( + "directory_name, expected_exception", + ( + ("some_dir", does_not_raise()), + (None, pytest.raises(Exception, match="No directory name provided")), + ), +) +def test_remove_backend( + monkeypatch: Any, directory_name: str, expected_exception: Any +) -> None: + """Test remove_backend function.""" + mock_remove_resource = MagicMock() + monkeypatch.setattr("mlia.backend.common.remove_resource", mock_remove_resource) + + with expected_exception: + remove_backend(directory_name, "applications") + + +@pytest.mark.parametrize( + "filename, expected_exception", + ( + ("application_config.json", does_not_raise()), + (None, pytest.raises(Exception, match="Unable to read config")), + ), +) +def test_load_config( + filename: str, expected_exception: Any, test_resources_path: Path, monkeypatch: Any +) -> None: + """Test load_config.""" + with expected_exception: + configs: List[Optional[Union[Path, IO[bytes]]]] = ( + [None] + if not filename + else [ + # Ignore pylint warning as 'with' can't be used inside of a + # generator expression. + # pylint: disable=consider-using-with + open(test_resources_path / filename, "rb"), + test_resources_path / filename, + ] + ) + for config in configs: + json_mock = MagicMock() + monkeypatch.setattr("mlia.backend.common.json.load", json_mock) + load_config(config) + json_mock.assert_called_once() + + +class TestBackend: + """Test Backend class.""" + + def test___repr__(self) -> None: + """Test the representation of Backend instance.""" + backend = Backend( + BaseBackendConfig(name="Testing name", description="Testing description") + ) + assert str(backend) == "Testing name" + + def test__eq__(self) -> None: + """Test equality method with different cases.""" + backend1 = Backend(BaseBackendConfig(name="name", description="description")) + backend1.commands = {"command": Command(["command"])} + + backend2 = Backend(BaseBackendConfig(name="name", description="description")) + backend2.commands = {"command": Command(["command"])} + + backend3 = Backend( + BaseBackendConfig( + name="Ben", description="This is not the Backend you are looking for" + ) + ) + backend3.commands = {"wave": Command(["wave hand"])} + + backend4 = "Foo" # checking not isinstance(backend4, Backend) + + assert backend1 == backend2 + assert backend1 != backend3 + assert backend1 != backend4 + + @pytest.mark.parametrize( + "parameter, valid", + [ + ("--choice-param dummy_value_1", True), + ("--choice-param wrong_value", False), + ("--open-param something", True), + ("--wrong-param value", False), + ], + ) + def test_validate_parameter( + self, parameter: str, valid: bool, test_resources_path: Path + ) -> None: + """Test validate_parameter.""" + config = cast( + List[ApplicationConfig], + load_config(test_resources_path / "hello_world.json"), + ) + # The application configuration is a list of configurations so we need + # only the first one + # Exercise the validate_parameter test using the Application classe which + # inherits from Backend. + application = Application(config[0]) + assert application.validate_parameter("run", parameter) == valid + + def test_validate_parameter_with_invalid_command( + self, test_resources_path: Path + ) -> None: + """Test validate_parameter with an invalid command_name.""" + config = cast( + List[ApplicationConfig], + load_config(test_resources_path / "hello_world.json"), + ) + application = Application(config[0]) + with pytest.raises(AttributeError) as err: + # command foo does not exist, so raise an error + application.validate_parameter("foo", "bar") + assert "Unknown command: 'foo'" in str(err.value) + + def test_build_command(self, monkeypatch: Any) -> None: + """Test command building.""" + config = { + "name": "test", + "commands": { + "build": ["build {user_params:0} {user_params:1}"], + "run": ["run {user_params:0}"], + "post_run": ["post_run {application_params:0} on {system_params:0}"], + "some_command": ["Command with {variables:var_A}"], + "empty_command": [""], + }, + "user_params": { + "build": [ + { + "name": "choice_param_0=", + "values": [1, 2, 3], + "default_value": 1, + }, + {"name": "choice_param_1", "values": [3, 4, 5], "default_value": 3}, + {"name": "choice_param_3", "values": [6, 7, 8]}, + ], + "run": [{"name": "flag_param_0"}], + }, + "variables": {"var_A": "value for variable A"}, + } + + monkeypatch.setattr("mlia.backend.system.ProtocolFactory", MagicMock()) + application, system = Application(config), System(config) # type: ignore + context = ExecutionContext( + app=application, + app_params=[], + system=system, + system_params=[], + custom_deploy_data=[], + ) + + param_resolver = ParamResolver(context) + + cmd = application.build_command( + "build", ["choice_param_0=2", "choice_param_1=4"], param_resolver + ) + assert cmd == ["build choice_param_0=2 choice_param_1 4"] + + cmd = application.build_command("build", ["choice_param_0=2"], param_resolver) + assert cmd == ["build choice_param_0=2 choice_param_1 3"] + + cmd = application.build_command( + "build", ["choice_param_0=2", "choice_param_3=7"], param_resolver + ) + assert cmd == ["build choice_param_0=2 choice_param_1 3"] + + with pytest.raises( + ConfigurationException, match="Command 'foo' could not be found." + ): + application.build_command("foo", [""], param_resolver) + + cmd = application.build_command("some_command", [], param_resolver) + assert cmd == ["Command with value for variable A"] + + cmd = application.build_command("empty_command", [], param_resolver) + assert cmd == [""] + + @pytest.mark.parametrize("class_", [Application, System]) + def test_build_command_unknown_variable(self, class_: type) -> None: + """Test that unable to construct backend with unknown variable.""" + with pytest.raises(Exception, match="Unknown variable var1"): + config = {"name": "test", "commands": {"run": ["run {variables:var1}"]}} + class_(config) + + @pytest.mark.parametrize( + "class_, config, expected_output", + [ + ( + Application, + { + "name": "test", + "commands": { + "build": ["build {user_params:0} {user_params:1}"], + "run": ["run {user_params:0}"], + }, + "user_params": { + "build": [ + { + "name": "choice_param_0=", + "values": ["a", "b", "c"], + "default_value": "a", + "alias": "param_1", + }, + { + "name": "choice_param_1", + "values": ["a", "b", "c"], + "default_value": "a", + "alias": "param_2", + }, + {"name": "choice_param_3", "values": ["a", "b", "c"]}, + ], + "run": [{"name": "flag_param_0"}], + }, + }, + [ + ( + "b", + Param( + name="choice_param_0=", + description="", + values=["a", "b", "c"], + default_value="a", + alias="param_1", + ), + ), + ( + "a", + Param( + name="choice_param_1", + description="", + values=["a", "b", "c"], + default_value="a", + alias="param_2", + ), + ), + ( + "c", + Param( + name="choice_param_3", + description="", + values=["a", "b", "c"], + ), + ), + ], + ), + (System, {"name": "test"}, []), + ], + ) + def test_resolved_parameters( + self, + monkeypatch: Any, + class_: type, + config: Dict, + expected_output: List[Tuple[Optional[str], Param]], + ) -> None: + """Test command building.""" + monkeypatch.setattr("mlia.backend.system.ProtocolFactory", MagicMock()) + backend = class_(config) + + params = backend.resolved_parameters( + "build", ["choice_param_0=b", "choice_param_3=c"] + ) + assert params == expected_output + + @pytest.mark.parametrize( + ["param_name", "user_param", "expected_value"], + [ + ( + "test_name", + "test_name=1234", + "1234", + ), # optional parameter using '=' + ( + "test_name", + "test_name 1234", + "1234", + ), # optional parameter using ' ' + ("test_name", "test_name", None), # flag + (None, "test_name=1234", "1234"), # positional parameter + ], + ) + def test_resolved_user_parameters( + self, param_name: str, user_param: str, expected_value: str + ) -> None: + """Test different variants to provide user parameters.""" + # A dummy config providing one backend config + config = { + "name": "test_backend", + "commands": { + "test": ["user_param:test_param"], + }, + "user_params": { + "test": [UserParamConfig(name=param_name, alias="test_name")], + }, + } + backend = Backend(cast(BaseBackendConfig, config)) + params = backend.resolved_parameters( + command_name="test", user_params=[user_param] + ) + assert len(params) == 1 + value, param = params[0] + assert param_name == param.name + assert expected_value == value + + @pytest.mark.parametrize( + "input_param,expected", + [ + ("--param=1", ("--param", "1")), + ("--param 1", ("--param", "1")), + ("--flag", ("--flag", None)), + ], + ) + def test__parse_raw_parameter( + self, input_param: str, expected: Tuple[str, Optional[str]] + ) -> None: + """Test internal method of parsing a single raw parameter.""" + assert parse_raw_parameter(input_param) == expected + + +class TestParam: + """Test Param class.""" + + def test__eq__(self) -> None: + """Test equality method with different cases.""" + param1 = Param(name="test", description="desc", values=["values"]) + param2 = Param(name="test", description="desc", values=["values"]) + param3 = Param(name="test1", description="desc", values=["values"]) + param4 = object() + + assert param1 == param2 + assert param1 != param3 + assert param1 != param4 + + def test_get_details(self) -> None: + """Test get_details() method.""" + param1 = Param(name="test", description="desc", values=["values"]) + assert param1.get_details() == { + "name": "test", + "values": ["values"], + "description": "desc", + } + + def test_invalid(self) -> None: + """Test invalid use cases for the Param class.""" + with pytest.raises( + ConfigurationException, + match="Either name, alias or both must be set to identify a parameter.", + ): + Param(name=None, description="desc", values=["values"]) + + +class TestCommand: + """Test Command class.""" + + def test_get_details(self) -> None: + """Test get_details() method.""" + param1 = Param(name="test", description="desc", values=["values"]) + command1 = Command(command_strings=["echo test"], params=[param1]) + assert command1.get_details() == { + "command_strings": ["echo test"], + "user_params": [ + {"name": "test", "values": ["values"], "description": "desc"} + ], + } + + def test__eq__(self) -> None: + """Test equality method with different cases.""" + param1 = Param("test", "desc", ["values"]) + param2 = Param("test1", "desc1", ["values1"]) + command1 = Command(command_strings=["echo test"], params=[param1]) + command2 = Command(command_strings=["echo test"], params=[param1]) + command3 = Command(command_strings=["echo test"]) + command4 = Command(command_strings=["echo test"], params=[param2]) + command5 = object() + + assert command1 == command2 + assert command1 != command3 + assert command1 != command4 + assert command1 != command5 + + @pytest.mark.parametrize( + "params, expected_error", + [ + [[], does_not_raise()], + [[Param("param", "param description", [])], does_not_raise()], + [ + [ + Param("param", "param description", [], None, "alias"), + Param("param", "param description", [], None), + ], + does_not_raise(), + ], + [ + [ + Param("param1", "param1 description", [], None, "alias1"), + Param("param2", "param2 description", [], None, "alias2"), + ], + does_not_raise(), + ], + [ + [ + Param("param", "param description", [], None, "alias"), + Param("param", "param description", [], None, "alias"), + ], + pytest.raises(ConfigurationException, match="Non unique aliases alias"), + ], + [ + [ + Param("alias", "param description", [], None, "alias1"), + Param("param", "param description", [], None, "alias"), + ], + pytest.raises( + ConfigurationException, + match="Aliases .* could not be used as parameter name", + ), + ], + [ + [ + Param("alias", "param description", [], None, "alias"), + Param("param1", "param1 description", [], None, "alias1"), + ], + does_not_raise(), + ], + [ + [ + Param("alias", "param description", [], None, "alias"), + Param("alias", "param1 description", [], None, "alias1"), + ], + pytest.raises( + ConfigurationException, + match="Aliases .* could not be used as parameter name", + ), + ], + [ + [ + Param("param1", "param1 description", [], None, "alias1"), + Param("param2", "param2 description", [], None, "alias1"), + Param("param3", "param3 description", [], None, "alias2"), + Param("param4", "param4 description", [], None, "alias2"), + ], + pytest.raises( + ConfigurationException, match="Non unique aliases alias1, alias2" + ), + ], + ], + ) + def test_validate_params(self, params: List[Param], expected_error: Any) -> None: + """Test command validation function.""" + with expected_error: + Command([], params) diff --git a/tests/mlia/test_backend_controller.py b/tests/mlia/test_backend_controller.py new file mode 100644 index 0000000..a047adf --- /dev/null +++ b/tests/mlia/test_backend_controller.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for system controller.""" +import csv +import os +import time +from pathlib import Path +from typing import Any + +import psutil +import pytest + +from mlia.backend.common import ConfigurationException +from mlia.backend.controller import SystemController +from mlia.backend.controller import SystemControllerSingleInstance +from mlia.backend.proc import ShellCommand + + +def get_system_controller(**kwargs: Any) -> SystemController: + """Get service controller.""" + single_instance = kwargs.get("single_instance", False) + if single_instance: + pid_file_path = kwargs.get("pid_file_path") + return SystemControllerSingleInstance(pid_file_path) + + return SystemController() + + +def test_service_controller() -> None: + """Test service controller functionality.""" + service_controller = get_system_controller() + + assert service_controller.get_output() == ("", "") + with pytest.raises(ConfigurationException, match="Wrong working directory"): + service_controller.start(["sleep 100"], Path("unknown")) + + service_controller.start(["sleep 100"], Path.cwd()) + assert service_controller.is_running() + + service_controller.stop(True) + assert not service_controller.is_running() + assert service_controller.get_output() == ("", "") + + service_controller.stop() + + with pytest.raises( + ConfigurationException, match="System should have only one command to run" + ): + service_controller.start(["sleep 100", "sleep 101"], Path.cwd()) + + with pytest.raises(ConfigurationException, match="No startup command provided"): + service_controller.start([""], Path.cwd()) + + +def test_service_controller_bad_configuration() -> None: + """Test service controller functionality for bad configuration.""" + with pytest.raises(Exception, match="No pid file path presented"): + service_controller = get_system_controller( + single_instance=True, pid_file_path=None + ) + service_controller.start(["sleep 100"], Path.cwd()) + + +def test_service_controller_writes_process_info_correctly(tmpdir: Any) -> None: + """Test that controller writes process info correctly.""" + pid_file = Path(tmpdir) / "test.pid" + + service_controller = get_system_controller( + single_instance=True, pid_file_path=Path(tmpdir) / "test.pid" + ) + + service_controller.start(["sleep 100"], Path.cwd()) + assert service_controller.is_running() + assert pid_file.is_file() + + with open(pid_file, "r", encoding="utf-8") as file: + csv_reader = csv.reader(file) + rows = list(csv_reader) + assert len(rows) == 1 + + name, *_ = rows[0] + assert name == "sleep" + + service_controller.stop() + assert pid_file.exists() + + +def test_service_controller_does_not_write_process_info_if_process_finishes( + tmpdir: Any, +) -> None: + """Test that controller does not write process info if process already finished.""" + pid_file = Path(tmpdir) / "test.pid" + service_controller = get_system_controller( + single_instance=True, pid_file_path=pid_file + ) + service_controller.is_running = lambda: False # type: ignore + service_controller.start(["echo hello"], Path.cwd()) + + assert not pid_file.exists() + + +def test_service_controller_searches_for_previous_instances_correctly( + tmpdir: Any, +) -> None: + """Test that controller searches for previous instances correctly.""" + pid_file = Path(tmpdir) / "test.pid" + command = ShellCommand().run("sleep", "100") + assert command.is_alive() + + pid = command.process.pid + process = psutil.Process(pid) + with open(pid_file, "w", encoding="utf-8") as file: + csv_writer = csv.writer(file) + csv_writer.writerow(("some_process", "some_program", "some_cwd", os.getpid())) + csv_writer.writerow((process.name(), process.exe(), process.cwd(), process.pid)) + csv_writer.writerow(("some_old_process", "not_running", "from_nowhere", 77777)) + + service_controller = get_system_controller( + single_instance=True, pid_file_path=pid_file + ) + service_controller.start(["sleep 100"], Path.cwd()) + # controller should stop this process as it is currently running and + # mentioned in pid file + assert not command.is_alive() + + service_controller.stop() + + +@pytest.mark.parametrize( + "executable", ["test_backend_run_script.sh", "test_backend_run"] +) +def test_service_controller_run_shell_script( + executable: str, test_resources_path: Path +) -> None: + """Test controller's ability to run shell scripts.""" + script_path = test_resources_path / "scripts" + + service_controller = get_system_controller() + + service_controller.start([executable], script_path) + + assert service_controller.is_running() + # give time for the command to produce output + time.sleep(2) + service_controller.stop(wait=True) + assert not service_controller.is_running() + stdout, stderr = service_controller.get_output() + assert stdout == "Hello from script\n" + assert stderr == "Oops!\n" + + +def test_service_controller_does_nothing_if_not_started(tmpdir: Any) -> None: + """Test that nothing happened if controller is not started.""" + service_controller = get_system_controller( + single_instance=True, pid_file_path=Path(tmpdir) / "test.pid" + ) + + assert not service_controller.is_running() + service_controller.stop() + assert not service_controller.is_running() diff --git a/tests/mlia/test_backend_execution.py b/tests/mlia/test_backend_execution.py new file mode 100644 index 0000000..9395352 --- /dev/null +++ b/tests/mlia/test_backend_execution.py @@ -0,0 +1,518 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Test backend execution module.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Dict +from unittest import mock +from unittest.mock import MagicMock + +import pytest +from sh import CommandNotFound + +from mlia.backend.application import Application +from mlia.backend.application import get_application +from mlia.backend.common import DataPaths +from mlia.backend.common import UserParamConfig +from mlia.backend.config import ApplicationConfig +from mlia.backend.config import LocalProtocolConfig +from mlia.backend.config import SystemConfig +from mlia.backend.execution import deploy_data +from mlia.backend.execution import execute_commands_locally +from mlia.backend.execution import ExecutionContext +from mlia.backend.execution import get_application_and_system +from mlia.backend.execution import get_application_by_name_and_system +from mlia.backend.execution import get_file_lock_path +from mlia.backend.execution import ParamResolver +from mlia.backend.execution import Reporter +from mlia.backend.execution import wait +from mlia.backend.output_parser import Base64OutputParser +from mlia.backend.output_parser import OutputParser +from mlia.backend.output_parser import RegexOutputParser +from mlia.backend.proc import CommandFailedException +from mlia.backend.system import get_system +from mlia.backend.system import load_system + + +def test_context_param_resolver(tmpdir: Any) -> None: + """Test parameter resolving.""" + system_config_location = Path(tmpdir) / "system" + system_config_location.mkdir() + + application_config_location = Path(tmpdir) / "application" + application_config_location.mkdir() + + ctx = ExecutionContext( + app=Application( + ApplicationConfig( + name="test_application", + description="Test application", + config_location=application_config_location, + build_dir="build-{application.name}-{system.name}", + commands={ + "run": [ + "run_command1 {user_params:0}", + "run_command2 {user_params:1}", + ] + }, + variables={"var_1": "value for var_1"}, + user_params={ + "run": [ + UserParamConfig( + name="--param1", + description="Param 1", + default_value="123", + alias="param_1", + ), + UserParamConfig( + name="--param2", description="Param 2", default_value="456" + ), + UserParamConfig( + name="--param3", description="Param 3", alias="param_3" + ), + UserParamConfig( + name="--param4=", + description="Param 4", + default_value="456", + alias="param_4", + ), + UserParamConfig( + description="Param 5", + default_value="789", + alias="param_5", + ), + ] + }, + ) + ), + app_params=["--param2=789"], + system=load_system( + SystemConfig( + name="test_system", + description="Test system", + config_location=system_config_location, + build_dir="build", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={ + "build": ["build_command1 {user_params:0}"], + "run": ["run_command {application.commands.run:1}"], + }, + variables={"var_1": "value for var_1"}, + user_params={ + "build": [ + UserParamConfig( + name="--param1", description="Param 1", default_value="aaa" + ), + UserParamConfig(name="--param2", description="Param 2"), + ] + }, + ) + ), + system_params=["--param1=bbb"], + custom_deploy_data=[], + ) + + param_resolver = ParamResolver(ctx) + expected_values = { + "application.name": "test_application", + "application.description": "Test application", + "application.config_dir": str(application_config_location), + "application.build_dir": "{}/build-test_application-test_system".format( + application_config_location + ), + "application.commands.run:0": "run_command1 --param1 123", + "application.commands.run.params:0": "123", + "application.commands.run.params:param_1": "123", + "application.commands.run:1": "run_command2 --param2 789", + "application.commands.run.params:1": "789", + "application.variables:var_1": "value for var_1", + "system.name": "test_system", + "system.description": "Test system", + "system.config_dir": str(system_config_location), + "system.commands.build:0": "build_command1 --param1 bbb", + "system.commands.run:0": "run_command run_command2 --param2 789", + "system.commands.build.params:0": "bbb", + "system.variables:var_1": "value for var_1", + } + + for param, value in expected_values.items(): + assert param_resolver(param) == value + + assert ctx.build_dir() == Path( + "{}/build-test_application-test_system".format(application_config_location) + ) + + expected_errors = { + "application.variables:var_2": pytest.raises( + Exception, match="Unknown variable var_2" + ), + "application.commands.clean:0": pytest.raises( + Exception, match="Command clean not found" + ), + "application.commands.run:2": pytest.raises( + Exception, match="Invalid index 2 for command run" + ), + "application.commands.run.params:5": pytest.raises( + Exception, match="Invalid parameter index 5 for command run" + ), + "application.commands.run.params:param_2": pytest.raises( + Exception, + match="No value for parameter with index or alias param_2 of command run", + ), + "UNKNOWN": pytest.raises( + Exception, match="Unable to resolve parameter UNKNOWN" + ), + "system.commands.build.params:1": pytest.raises( + Exception, + match="No value for parameter with index or alias 1 of command build", + ), + "system.commands.build:A": pytest.raises( + Exception, match="Bad command index A" + ), + "system.variables:var_2": pytest.raises( + Exception, match="Unknown variable var_2" + ), + } + for param, error in expected_errors.items(): + with error: + param_resolver(param) + + resolved_params = ctx.app.resolved_parameters("run", []) + expected_user_params = { + "user_params:0": "--param1 123", + "user_params:param_1": "--param1 123", + "user_params:2": "--param3", + "user_params:param_3": "--param3", + "user_params:3": "--param4=456", + "user_params:param_4": "--param4=456", + "user_params:param_5": "789", + } + for param, expected_value in expected_user_params.items(): + assert param_resolver(param, "run", resolved_params) == expected_value + + with pytest.raises( + Exception, match="Invalid index 5 for user params of command run" + ): + param_resolver("user_params:5", "run", resolved_params) + + with pytest.raises( + Exception, match="No user parameter for command 'run' with alias 'param_2'." + ): + param_resolver("user_params:param_2", "run", resolved_params) + + with pytest.raises(Exception, match="Unable to resolve user params"): + param_resolver("user_params:0", "", resolved_params) + + bad_ctx = ExecutionContext( + app=Application( + ApplicationConfig( + name="test_application", + config_location=application_config_location, + build_dir="build-{user_params:0}", + ) + ), + app_params=["--param2=789"], + system=load_system( + SystemConfig( + name="test_system", + description="Test system", + config_location=system_config_location, + build_dir="build-{system.commands.run.params:123}", + data_transfer=LocalProtocolConfig(protocol="local"), + ) + ), + system_params=["--param1=bbb"], + custom_deploy_data=[], + ) + param_resolver = ParamResolver(bad_ctx) + with pytest.raises(Exception, match="Unable to resolve user params"): + bad_ctx.build_dir() + + +# pylint: disable=too-many-arguments +@pytest.mark.parametrize( + "application_name, soft_lock, sys_lock, lock_dir, expected_error, expected_path", + ( + ( + "test_application", + True, + True, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_application_test_system.lock"), + ), + ( + "$$test_application$!:", + True, + True, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_application_test_system.lock"), + ), + ( + "test_application", + True, + True, + Path("unknown"), + pytest.raises( + Exception, match="Invalid directory unknown for lock files provided" + ), + None, + ), + ( + "test_application", + False, + True, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_system.lock"), + ), + ( + "test_application", + True, + False, + Path("/tmp"), + does_not_raise(), + Path("/tmp/middleware_test_application.lock"), + ), + ( + "test_application", + False, + False, + Path("/tmp"), + pytest.raises(Exception, match="No filename for lock provided"), + None, + ), + ), +) +def test_get_file_lock_path( + application_name: str, + soft_lock: bool, + sys_lock: bool, + lock_dir: Path, + expected_error: Any, + expected_path: Path, +) -> None: + """Test get_file_lock_path function.""" + with expected_error: + ctx = ExecutionContext( + app=Application(ApplicationConfig(name=application_name, lock=soft_lock)), + app_params=[], + system=load_system( + SystemConfig( + name="test_system", + lock=sys_lock, + data_transfer=LocalProtocolConfig(protocol="local"), + ) + ), + system_params=[], + custom_deploy_data=[], + ) + path = get_file_lock_path(ctx, lock_dir) + assert path == expected_path + + +def test_get_application_by_name_and_system(monkeypatch: Any) -> None: + """Test exceptional case for get_application_by_name_and_system.""" + monkeypatch.setattr( + "mlia.backend.execution.get_application", + MagicMock(return_value=[MagicMock(), MagicMock()]), + ) + + with pytest.raises( + ValueError, + match="Error during getting application test_application for the " + "system test_system", + ): + get_application_by_name_and_system("test_application", "test_system") + + +def test_get_application_and_system(monkeypatch: Any) -> None: + """Test exceptional case for get_application_and_system.""" + monkeypatch.setattr( + "mlia.backend.execution.get_system", MagicMock(return_value=None) + ) + + with pytest.raises(ValueError, match="System test_system is not found"): + get_application_and_system("test_application", "test_system") + + +def test_wait_function(monkeypatch: Any) -> None: + """Test wait function.""" + sleep_mock = MagicMock() + monkeypatch.setattr("time.sleep", sleep_mock) + wait(0.1) + sleep_mock.assert_called_once() + + +def test_deployment_execution_context() -> None: + """Test property 'is_deploy_needed' of the ExecutionContext.""" + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=get_system("System 1"), + system_params=[], + ) + assert not ctx.is_deploy_needed + deploy_data(ctx) # should be a NOP + + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=get_system("System 1"), + system_params=[], + custom_deploy_data=[DataPaths(Path("README.md"), ".")], + ) + assert ctx.is_deploy_needed + + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=None, + system_params=[], + ) + assert not ctx.is_deploy_needed + with pytest.raises(AssertionError): + deploy_data(ctx) + + +def test_reporter_execution_context(tmp_path: Path) -> None: + """Test ExecutionContext creates a reporter when a report file is provided.""" + # Configure regex parser for the system manually + sys = get_system("System 1") + assert sys is not None + sys.reporting = { + "regex": { + "simulated_time": {"pattern": "Simulated time.*: (.*)s", "type": "float"} + } + } + report_file_path = tmp_path / "test_report.json" + + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=sys, + system_params=[], + report_file=report_file_path, + ) + assert isinstance(ctx.reporter, Reporter) + assert len(ctx.reporter.parsers) == 2 + assert any(isinstance(parser, RegexOutputParser) for parser in ctx.reporter.parsers) + assert any( + isinstance(parser, Base64OutputParser) for parser in ctx.reporter.parsers + ) + + +class TestExecuteCommandsLocally: + """Test execute_commands_locally() function.""" + + @pytest.mark.parametrize( + "first_command, exception, expected_output", + ( + ( + "echo 'hello'", + None, + "Running: echo 'hello'\nhello\nRunning: echo 'goodbye'\ngoodbye\n", + ), + ( + "non-existent-command", + CommandNotFound, + "Running: non-existent-command\n", + ), + ("false", CommandFailedException, "Running: false\n"), + ), + ids=( + "runs_multiple_commands", + "stops_executing_on_non_existent_command", + "stops_executing_when_command_exits_with_error_code", + ), + ) + def test_execution( + self, + first_command: str, + exception: Any, + expected_output: str, + test_resources_path: Path, + capsys: Any, + ) -> None: + """Test expected behaviour of the function.""" + commands = [first_command, "echo 'goodbye'"] + cwd = test_resources_path + if exception is None: + execute_commands_locally(commands, cwd) + else: + with pytest.raises(exception): + execute_commands_locally(commands, cwd) + + captured = capsys.readouterr() + assert captured.out == expected_output + + def test_stops_executing_on_exception( + self, monkeypatch: Any, test_resources_path: Path + ) -> None: + """Ensure commands following an error-exit-code command don't run.""" + # Mock execute_command() function + execute_command_mock = mock.MagicMock() + monkeypatch.setattr("mlia.backend.proc.execute_command", execute_command_mock) + + # Mock Command object and assign as return value to execute_command() + cmd_mock = mock.MagicMock() + execute_command_mock.return_value = cmd_mock + + # Mock the terminate_command (speed up test) + terminate_command_mock = mock.MagicMock() + monkeypatch.setattr( + "mlia.backend.proc.terminate_command", terminate_command_mock + ) + + # Mock a thrown Exception and assign to Command().exit_code + exit_code_mock = mock.PropertyMock(side_effect=Exception("Exception.")) + type(cmd_mock).exit_code = exit_code_mock + + with pytest.raises(Exception, match="Exception."): + execute_commands_locally( + ["command_1", "command_2"], cwd=test_resources_path + ) + + # Assert only "command_1" was executed + assert execute_command_mock.call_count == 1 + + +def test_reporter(tmpdir: Any) -> None: + """Test class 'Reporter'.""" + ctx = ExecutionContext( + app=get_application("application_4")[0], + app_params=["--app=TestApp"], + system=get_system("System 4"), + system_params=[], + ) + assert ctx.system is not None + + class MockParser(OutputParser): + """Mock implementation of an output parser.""" + + def __init__(self, metrics: Dict[str, Any]) -> None: + """Set up the MockParser.""" + super().__init__(name="test") + self.metrics = metrics + + def __call__(self, output: bytearray) -> Dict[str, Any]: + """Return mock metrics (ignoring the given output).""" + return self.metrics + + metrics = {"Metric": 123, "AnotherMetric": 456} + reporter = Reporter( + parsers=[MockParser(metrics={key: val}) for key, val in metrics.items()], + ) + reporter.parse(bytearray()) + report = reporter.report(ctx) + assert report["system"]["name"] == ctx.system.name + assert report["system"]["params"] == {} + assert report["application"]["name"] == ctx.app.name + assert report["application"]["params"] == {"--app": "TestApp"} + assert report["test"]["metrics"] == metrics + report_file = Path(tmpdir) / "report.json" + reporter.save(report, report_file) + assert report_file.is_file() diff --git a/tests/mlia/test_backend_fs.py b/tests/mlia/test_backend_fs.py new file mode 100644 index 0000000..ff9c2ae --- /dev/null +++ b/tests/mlia/test_backend_fs.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Module for testing fs.py.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Union +from unittest.mock import MagicMock + +import pytest + +from mlia.backend.fs import get_backends_path +from mlia.backend.fs import read_file_as_bytearray +from mlia.backend.fs import read_file_as_string +from mlia.backend.fs import recreate_directory +from mlia.backend.fs import remove_directory +from mlia.backend.fs import remove_resource +from mlia.backend.fs import ResourceType +from mlia.backend.fs import valid_for_filename + + +@pytest.mark.parametrize( + "resource_name,expected_path", + [ + ("systems", does_not_raise()), + ("applications", does_not_raise()), + ("whaaat", pytest.raises(ResourceWarning)), + (None, pytest.raises(ResourceWarning)), + ], +) +def test_get_backends_path(resource_name: ResourceType, expected_path: Any) -> None: + """Test get_resources() with multiple parameters.""" + with expected_path: + resource_path = get_backends_path(resource_name) + assert resource_path.exists() + + +def test_remove_resource_wrong_directory( + monkeypatch: Any, test_applications_path: Path +) -> None: + """Test removing resource with wrong directory.""" + mock_get_resources = MagicMock(return_value=test_applications_path) + monkeypatch.setattr("mlia.backend.fs.get_backends_path", mock_get_resources) + + mock_shutil_rmtree = MagicMock() + monkeypatch.setattr("mlia.backend.fs.shutil.rmtree", mock_shutil_rmtree) + + with pytest.raises(Exception, match="Resource .* does not exist"): + remove_resource("unknown", "applications") + mock_shutil_rmtree.assert_not_called() + + with pytest.raises(Exception, match="Wrong resource .*"): + remove_resource("readme.txt", "applications") + mock_shutil_rmtree.assert_not_called() + + +def test_remove_resource(monkeypatch: Any, test_applications_path: Path) -> None: + """Test removing resource data.""" + mock_get_resources = MagicMock(return_value=test_applications_path) + monkeypatch.setattr("mlia.backend.fs.get_backends_path", mock_get_resources) + + mock_shutil_rmtree = MagicMock() + monkeypatch.setattr("mlia.backend.fs.shutil.rmtree", mock_shutil_rmtree) + + remove_resource("application1", "applications") + mock_shutil_rmtree.assert_called_once() + + +def test_remove_directory(tmpdir: Any) -> None: + """Test directory removal.""" + tmpdir_path = Path(tmpdir) + tmpfile = tmpdir_path / "temp.txt" + + for item in [None, tmpfile]: + with pytest.raises(Exception, match="No directory path provided"): + remove_directory(item) + + newdir = tmpdir_path / "newdir" + newdir.mkdir() + + assert newdir.is_dir() + remove_directory(newdir) + assert not newdir.exists() + + +def test_recreate_directory(tmpdir: Any) -> None: + """Test directory recreation.""" + with pytest.raises(Exception, match="No directory path provided"): + recreate_directory(None) + + tmpdir_path = Path(tmpdir) + tmpfile = tmpdir_path / "temp.txt" + tmpfile.touch() + with pytest.raises(Exception, match="Path .* does exist and it is not a directory"): + recreate_directory(tmpfile) + + newdir = tmpdir_path / "newdir" + newdir.mkdir() + newfile = newdir / "newfile" + newfile.touch() + assert list(newdir.iterdir()) == [newfile] + recreate_directory(newdir) + assert not list(newdir.iterdir()) + + newdir2 = tmpdir_path / "newdir2" + assert not newdir2.exists() + recreate_directory(newdir2) + assert newdir2.is_dir() + + +def write_to_file( + write_directory: Any, write_mode: str, write_text: Union[str, bytes] +) -> Path: + """Write some text to a temporary test file.""" + tmpdir_path = Path(write_directory) + tmpfile = tmpdir_path / "file_name.txt" + with open(tmpfile, write_mode) as file: # pylint: disable=unspecified-encoding + file.write(write_text) + return tmpfile + + +class TestReadFileAsString: + """Test read_file_as_string() function.""" + + def test_returns_text_from_valid_file(self, tmpdir: Any) -> None: + """Ensure the string written to a file read correctly.""" + file_path = write_to_file(tmpdir, "w", "hello") + assert read_file_as_string(file_path) == "hello" + + def test_output_is_empty_string_when_input_file_non_existent( + self, tmpdir: Any + ) -> None: + """Ensure empty string returned when reading from non-existent file.""" + file_path = Path(tmpdir / "non-existent.txt") + assert read_file_as_string(file_path) == "" + + +class TestReadFileAsByteArray: + """Test read_file_as_bytearray() function.""" + + def test_returns_bytes_from_valid_file(self, tmpdir: Any) -> None: + """Ensure the bytes written to a file read correctly.""" + file_path = write_to_file(tmpdir, "wb", b"hello bytes") + assert read_file_as_bytearray(file_path) == b"hello bytes" + + def test_output_is_empty_bytearray_when_input_file_non_existent( + self, tmpdir: Any + ) -> None: + """Ensure empty bytearray returned when reading from non-existent file.""" + file_path = Path(tmpdir / "non-existent.txt") + assert read_file_as_bytearray(file_path) == bytearray() + + +@pytest.mark.parametrize( + "value, replacement, expected_result", + [ + ["", "", ""], + ["123", "", "123"], + ["123", "_", "123"], + ["/some_folder/some_script.sh", "", "some_foldersome_script.sh"], + ["/some_folder/some_script.sh", "_", "_some_folder_some_script.sh"], + ["!;'some_name$%^!", "_", "___some_name____"], + ], +) +def test_valid_for_filename(value: str, replacement: str, expected_result: str) -> None: + """Test function valid_for_filename.""" + assert valid_for_filename(value, replacement) == expected_result diff --git a/tests/mlia/test_backend_manager.py b/tests/mlia/test_backend_manager.py new file mode 100644 index 0000000..c81366f --- /dev/null +++ b/tests/mlia/test_backend_manager.py @@ -0,0 +1,788 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for module backend/manager.""" +import os +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from unittest.mock import MagicMock +from unittest.mock import PropertyMock + +import pytest + +from mlia.backend.application import get_application +from mlia.backend.common import DataPaths +from mlia.backend.execution import ExecutionContext +from mlia.backend.execution import run_application +from mlia.backend.manager import BackendRunner +from mlia.backend.manager import DeviceInfo +from mlia.backend.manager import estimate_performance +from mlia.backend.manager import ExecutionParams +from mlia.backend.manager import GenericInferenceOutputParser +from mlia.backend.manager import GenericInferenceRunnerEthosU +from mlia.backend.manager import get_generic_runner +from mlia.backend.manager import get_system_name +from mlia.backend.manager import is_supported +from mlia.backend.manager import ModelInfo +from mlia.backend.manager import PerformanceMetrics +from mlia.backend.manager import supported_backends +from mlia.backend.system import get_system + + +@pytest.mark.parametrize( + "data, is_ready, result, missed_keys", + [ + ( + [], + False, + {}, + [ + "npu_active_cycles", + "npu_axi0_rd_data_beat_received", + "npu_axi0_wr_data_beat_written", + "npu_axi1_rd_data_beat_received", + "npu_idle_cycles", + "npu_total_cycles", + ], + ), + ( + ["sample text"], + False, + {}, + [ + "npu_active_cycles", + "npu_axi0_rd_data_beat_received", + "npu_axi0_wr_data_beat_written", + "npu_axi1_rd_data_beat_received", + "npu_idle_cycles", + "npu_total_cycles", + ], + ), + ( + [ + ["NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 123"], + False, + {"npu_axi0_rd_data_beat_received": 123}, + [ + "npu_active_cycles", + "npu_axi0_wr_data_beat_written", + "npu_axi1_rd_data_beat_received", + "npu_idle_cycles", + "npu_total_cycles", + ], + ] + ), + ( + [ + "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", + "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", + "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", + "NPU ACTIVE cycles: 4", + "NPU IDLE cycles: 5", + "NPU TOTAL cycles: 6", + ], + True, + { + "npu_axi0_rd_data_beat_received": 1, + "npu_axi0_wr_data_beat_written": 2, + "npu_axi1_rd_data_beat_received": 3, + "npu_active_cycles": 4, + "npu_idle_cycles": 5, + "npu_total_cycles": 6, + }, + [], + ), + ], +) +def test_generic_inference_output_parser( + data: List[str], is_ready: bool, result: Dict, missed_keys: List[str] +) -> None: + """Test generic runner output parser.""" + parser = GenericInferenceOutputParser() + + for line in data: + parser.feed(line) + + assert parser.is_ready() == is_ready + assert parser.result == result + assert parser.missed_keys() == missed_keys + + +class TestBackendRunner: + """Tests for BackendRunner class.""" + + @staticmethod + def _setup_backends( + monkeypatch: pytest.MonkeyPatch, + available_systems: Optional[List[str]] = None, + available_apps: Optional[List[str]] = None, + ) -> None: + """Set up backend metadata.""" + + def mock_system(system: str) -> MagicMock: + """Mock the System instance.""" + mock = MagicMock() + type(mock).name = PropertyMock(return_value=system) + return mock + + def mock_app(app: str) -> MagicMock: + """Mock the Application instance.""" + mock = MagicMock() + type(mock).name = PropertyMock(return_value=app) + mock.can_run_on.return_value = True + return mock + + system_mocks = [mock_system(name) for name in (available_systems or [])] + monkeypatch.setattr( + "mlia.backend.manager.get_available_systems", + MagicMock(return_value=system_mocks), + ) + + apps_mock = [mock_app(name) for name in (available_apps or [])] + monkeypatch.setattr( + "mlia.backend.manager.get_available_applications", + MagicMock(return_value=apps_mock), + ) + + @pytest.mark.parametrize( + "available_systems, system, installed", + [ + ([], "system1", False), + (["system1", "system2"], "system1", True), + ], + ) + def test_is_system_installed( + self, + available_systems: List, + system: str, + installed: bool, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method is_system_installed.""" + backend_runner = BackendRunner() + + self._setup_backends(monkeypatch, available_systems) + + assert backend_runner.is_system_installed(system) == installed + + @pytest.mark.parametrize( + "available_systems, systems", + [ + ([], []), + (["system1"], ["system1"]), + ], + ) + def test_installed_systems( + self, + available_systems: List[str], + systems: List[str], + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method installed_systems.""" + backend_runner = BackendRunner() + + self._setup_backends(monkeypatch, available_systems) + assert backend_runner.get_installed_systems() == systems + + @staticmethod + def test_install_system(monkeypatch: pytest.MonkeyPatch) -> None: + """Test system installation.""" + install_system_mock = MagicMock() + monkeypatch.setattr("mlia.backend.manager.install_system", install_system_mock) + + backend_runner = BackendRunner() + backend_runner.install_system(Path("test_system_path")) + + install_system_mock.assert_called_once_with(Path("test_system_path")) + + @pytest.mark.parametrize( + "available_systems, systems, expected_result", + [ + ([], [], False), + (["system1"], [], False), + (["system1"], ["system1"], True), + (["system1", "system2"], ["system1", "system3"], False), + (["system1", "system2"], ["system1", "system2"], True), + ], + ) + def test_systems_installed( + self, + available_systems: List[str], + systems: List[str], + expected_result: bool, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method systems_installed.""" + self._setup_backends(monkeypatch, available_systems) + + backend_runner = BackendRunner() + + assert backend_runner.systems_installed(systems) is expected_result + + @pytest.mark.parametrize( + "available_apps, applications, expected_result", + [ + ([], [], False), + (["app1"], [], False), + (["app1"], ["app1"], True), + (["app1", "app2"], ["app1", "app3"], False), + (["app1", "app2"], ["app1", "app2"], True), + ], + ) + def test_applications_installed( + self, + available_apps: List[str], + applications: List[str], + expected_result: bool, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method applications_installed.""" + self._setup_backends(monkeypatch, [], available_apps) + backend_runner = BackendRunner() + + assert backend_runner.applications_installed(applications) is expected_result + + @pytest.mark.parametrize( + "available_apps, applications", + [ + ([], []), + ( + ["application1", "application2"], + ["application1", "application2"], + ), + ], + ) + def test_get_installed_applications( + self, + available_apps: List[str], + applications: List[str], + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method get_installed_applications.""" + self._setup_backends(monkeypatch, [], available_apps) + + backend_runner = BackendRunner() + assert applications == backend_runner.get_installed_applications() + + @staticmethod + def test_install_application(monkeypatch: pytest.MonkeyPatch) -> None: + """Test application installation.""" + mock_install_application = MagicMock() + monkeypatch.setattr( + "mlia.backend.manager.install_application", mock_install_application + ) + + backend_runner = BackendRunner() + backend_runner.install_application(Path("test_application_path")) + mock_install_application.assert_called_once_with(Path("test_application_path")) + + @pytest.mark.parametrize( + "available_apps, application, installed", + [ + ([], "system1", False), + ( + ["application1", "application2"], + "application1", + True, + ), + ( + [], + "application1", + False, + ), + ], + ) + def test_is_application_installed( + self, + available_apps: List[str], + application: str, + installed: bool, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test method is_application_installed.""" + self._setup_backends(monkeypatch, [], available_apps) + + backend_runner = BackendRunner() + assert installed == backend_runner.is_application_installed( + application, "system1" + ) + + @staticmethod + @pytest.mark.parametrize( + "execution_params, expected_command", + [ + ( + ExecutionParams("application_4", "System 4", [], [], []), + ["application_4", [], "System 4", [], []], + ), + ( + ExecutionParams( + "application_6", + "System 6", + ["param1=value2"], + ["sys-param1=value2"], + [], + ), + [ + "application_6", + ["param1=value2"], + "System 6", + ["sys-param1=value2"], + [], + ], + ), + ], + ) + def test_run_application_local( + monkeypatch: pytest.MonkeyPatch, + execution_params: ExecutionParams, + expected_command: List[str], + ) -> None: + """Test method run_application with local systems.""" + run_app = MagicMock(wraps=run_application) + monkeypatch.setattr("mlia.backend.manager.run_application", run_app) + + backend_runner = BackendRunner() + backend_runner.run_application(execution_params) + + run_app.assert_called_once_with(*expected_command) + + @staticmethod + @pytest.mark.parametrize( + "execution_params, expected_command", + [ + ( + ExecutionParams( + "application_1", + "System 1", + [], + [], + ["source1.txt:dest1.txt", "source2.txt:dest2.txt"], + ), + [ + "application_1", + [], + "System 1", + [], + [ + DataPaths(Path("source1.txt"), "dest1.txt"), + DataPaths(Path("source2.txt"), "dest2.txt"), + ], + ], + ), + ], + ) + def test_run_application_connected( + monkeypatch: pytest.MonkeyPatch, + execution_params: ExecutionParams, + expected_command: List[str], + ) -> None: + """Test method run_application with connectable systems (SSH).""" + app = get_application(execution_params.application, execution_params.system)[0] + sys = get_system(execution_params.system) + + assert sys is not None + + connect_mock = MagicMock(return_value=True, name="connect_mock") + deploy_mock = MagicMock(return_value=True, name="deploy_mock") + run_mock = MagicMock( + return_value=(os.EX_OK, bytearray(), bytearray()), name="run_mock" + ) + sys.establish_connection = connect_mock # type: ignore + sys.deploy = deploy_mock # type: ignore + sys.run = run_mock # type: ignore + + monkeypatch.setattr( + "mlia.backend.execution.get_application_and_system", + MagicMock(return_value=(app, sys)), + ) + + run_app_mock = MagicMock(wraps=run_application) + monkeypatch.setattr("mlia.backend.manager.run_application", run_app_mock) + + backend_runner = BackendRunner() + backend_runner.run_application(execution_params) + + run_app_mock.assert_called_once_with(*expected_command) + + connect_mock.assert_called_once() + assert deploy_mock.call_count == 2 + + +@pytest.mark.parametrize( + "device, system, application, backend, expected_error", + [ + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U55", True), + ("Generic Inference Runner: Ethos-U55 SRAM", True), + "Corstone-300", + does_not_raise(), + ), + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U55", False), + ("Generic Inference Runner: Ethos-U55 SRAM", False), + "Corstone-300", + pytest.raises( + Exception, + match=r"System Corstone-300: Cortex-M55\+Ethos-U55 is not installed", + ), + ), + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U55", True), + ("Generic Inference Runner: Ethos-U55 SRAM", False), + "Corstone-300", + pytest.raises( + Exception, + match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM " + r"for the system Corstone-300: Cortex-M55\+Ethos-U55 is not installed", + ), + ), + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-310: Cortex-M85+Ethos-U55", True), + ("Generic Inference Runner: Ethos-U55 SRAM", True), + "Corstone-310", + does_not_raise(), + ), + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-310: Cortex-M85+Ethos-U55", False), + ("Generic Inference Runner: Ethos-U55 SRAM", False), + "Corstone-310", + pytest.raises( + Exception, + match=r"System Corstone-310: Cortex-M85\+Ethos-U55 is not installed", + ), + ), + ( + DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), + ("Corstone-310: Cortex-M85+Ethos-U55", True), + ("Generic Inference Runner: Ethos-U55 SRAM", False), + "Corstone-310", + pytest.raises( + Exception, + match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM " + r"for the system Corstone-310: Cortex-M85\+Ethos-U55 is not installed", + ), + ), + ( + DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U65", True), + ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", True), + "Corstone-300", + does_not_raise(), + ), + ( + DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U65", False), + ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False), + "Corstone-300", + pytest.raises( + Exception, + match=r"System Corstone-300: Cortex-M55\+Ethos-U65 is not installed", + ), + ), + ( + DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"), + ("Corstone-300: Cortex-M55+Ethos-U65", True), + ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False), + "Corstone-300", + pytest.raises( + Exception, + match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM " + r"for the system Corstone-300: Cortex-M55\+Ethos-U65 is not installed", + ), + ), + ( + DeviceInfo( + device_type="unknown_device", # type: ignore + mac=None, # type: ignore + memory_mode="Shared_Sram", + ), + ("some_system", False), + ("some_application", False), + "some backend", + pytest.raises(Exception, match="Unsupported device unknown_device"), + ), + ], +) +def test_estimate_performance( + device: DeviceInfo, + system: Tuple[str, bool], + application: Tuple[str, bool], + backend: str, + expected_error: Any, + test_tflite_model: Path, + backend_runner: MagicMock, +) -> None: + """Test getting performance estimations.""" + system_name, system_installed = system + application_name, application_installed = application + + backend_runner.is_system_installed.return_value = system_installed + backend_runner.is_application_installed.return_value = application_installed + + mock_context = create_mock_context( + [ + "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", + "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", + "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", + "NPU ACTIVE cycles: 4", + "NPU IDLE cycles: 5", + "NPU TOTAL cycles: 6", + ] + ) + + backend_runner.run_application.return_value = mock_context + + with expected_error: + perf_metrics = estimate_performance( + ModelInfo(test_tflite_model), device, backend + ) + + assert isinstance(perf_metrics, PerformanceMetrics) + assert perf_metrics == PerformanceMetrics( + npu_axi0_rd_data_beat_received=1, + npu_axi0_wr_data_beat_written=2, + npu_axi1_rd_data_beat_received=3, + npu_active_cycles=4, + npu_idle_cycles=5, + npu_total_cycles=6, + ) + + assert backend_runner.is_system_installed.called_once_with(system_name) + assert backend_runner.is_application_installed.called_once_with( + application_name, system_name + ) + + +@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) +def test_estimate_performance_insufficient_data( + backend_runner: MagicMock, test_tflite_model: Path, backend: str +) -> None: + """Test that performance could not be estimated when not all data presented.""" + backend_runner.is_system_installed.return_value = True + backend_runner.is_application_installed.return_value = True + + no_total_cycles_output = [ + "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", + "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", + "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", + "NPU ACTIVE cycles: 4", + "NPU IDLE cycles: 5", + ] + mock_context = create_mock_context(no_total_cycles_output) + + backend_runner.run_application.return_value = mock_context + + with pytest.raises( + Exception, match="Unable to get performance metrics, insufficient data" + ): + device = DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram") + estimate_performance(ModelInfo(test_tflite_model), device, backend) + + +@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) +def test_estimate_performance_invalid_output( + test_tflite_model: Path, backend_runner: MagicMock, backend: str +) -> None: + """Test estimation could not be done if inference produces unexpected output.""" + backend_runner.is_system_installed.return_value = True + backend_runner.is_application_installed.return_value = True + + mock_context = create_mock_context(["Something", "is", "wrong"]) + backend_runner.run_application.return_value = mock_context + + with pytest.raises(Exception, match="Unable to get performance metrics"): + estimate_performance( + ModelInfo(test_tflite_model), + DeviceInfo(device_type="ethos-u55", mac=256, memory_mode="Shared_Sram"), + backend=backend, + ) + + +def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock: + """Mock underlying process.""" + mock_process = MagicMock() + mock_process.poll.return_value = 0 + type(mock_process).stdout = PropertyMock(return_value=iter(stdout)) + type(mock_process).stderr = PropertyMock(return_value=iter(stderr)) + return mock_process + + +def create_mock_context(stdout: List[str]) -> ExecutionContext: + """Mock ExecutionContext.""" + ctx = ExecutionContext( + app=get_application("application_1")[0], + app_params=[], + system=get_system("System 1"), + system_params=[], + ) + ctx.stdout = bytearray("\n".join(stdout).encode("utf-8")) + return ctx + + +@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) +def test_get_generic_runner(backend: str) -> None: + """Test function get_generic_runner().""" + device_info = DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram") + + runner = get_generic_runner(device_info=device_info, backend=backend) + assert isinstance(runner, GenericInferenceRunnerEthosU) + + with pytest.raises(RuntimeError): + get_generic_runner(device_info=device_info, backend="UNKNOWN_BACKEND") + + +@pytest.mark.parametrize( + ("backend", "device_type"), + ( + ("Corstone-300", "ethos-u55"), + ("Corstone-300", "ethos-u65"), + ("Corstone-310", "ethos-u55"), + ), +) +def test_backend_support(backend: str, device_type: str) -> None: + """Test backend & device support.""" + assert is_supported(backend) + assert is_supported(backend, device_type) + + assert get_system_name(backend, device_type) + + assert backend in supported_backends() + + +class TestGenericInferenceRunnerEthosU: + """Test for the class GenericInferenceRunnerEthosU.""" + + @staticmethod + @pytest.mark.parametrize( + "device, backend, expected_system, expected_app", + [ + [ + DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), + "Corstone-300", + "Corstone-300: Cortex-M55+Ethos-U55", + "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + ], + [ + DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), + "Corstone-310", + "Corstone-310: Cortex-M85+Ethos-U55", + "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + ], + [ + DeviceInfo("ethos-u55", 256, memory_mode="Sram"), + "Corstone-310", + "Corstone-310: Cortex-M85+Ethos-U55", + "Generic Inference Runner: Ethos-U55 SRAM", + ], + [ + DeviceInfo("ethos-u55", 256, memory_mode="Sram"), + "Corstone-300", + "Corstone-300: Cortex-M55+Ethos-U55", + "Generic Inference Runner: Ethos-U55 SRAM", + ], + [ + DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"), + "Corstone-300", + "Corstone-300: Cortex-M55+Ethos-U65", + "Generic Inference Runner: Ethos-U55/65 Shared SRAM", + ], + [ + DeviceInfo("ethos-u65", 256, memory_mode="Dedicated_Sram"), + "Corstone-300", + "Corstone-300: Cortex-M55+Ethos-U65", + "Generic Inference Runner: Ethos-U65 Dedicated SRAM", + ], + ], + ) + def test_artifact_resolver( + device: DeviceInfo, backend: str, expected_system: str, expected_app: str + ) -> None: + """Test artifact resolving based on the provided parameters.""" + generic_runner = get_generic_runner(device, backend) + assert isinstance(generic_runner, GenericInferenceRunnerEthosU) + + assert generic_runner.system_name == expected_system + assert generic_runner.app_name == expected_app + + @staticmethod + def test_artifact_resolver_unsupported_backend() -> None: + """Test that it should be not possible to use unsupported backends.""" + with pytest.raises( + RuntimeError, match="Unsupported device ethos-u65 for backend test_backend" + ): + get_generic_runner( + DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"), "test_backend" + ) + + @staticmethod + def test_artifact_resolver_unsupported_memory_mode() -> None: + """Test that it should be not possible to use unsupported memory modes.""" + with pytest.raises( + RuntimeError, match="Unsupported memory mode test_memory_mode" + ): + get_generic_runner( + DeviceInfo( + "ethos-u65", + 256, + memory_mode="test_memory_mode", # type: ignore + ), + "Corstone-300", + ) + + @staticmethod + @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) + def test_inference_should_fail_if_system_not_installed( + backend_runner: MagicMock, test_tflite_model: Path, backend: str + ) -> None: + """Test that inference should fail if system is not installed.""" + backend_runner.is_system_installed.return_value = False + + generic_runner = get_generic_runner( + DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend + ) + with pytest.raises( + Exception, + match=r"System Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not installed", + ): + generic_runner.run(ModelInfo(test_tflite_model), []) + + @staticmethod + @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) + def test_inference_should_fail_is_apps_not_installed( + backend_runner: MagicMock, test_tflite_model: Path, backend: str + ) -> None: + """Test that inference should fail if apps are not installed.""" + backend_runner.is_system_installed.return_value = True + backend_runner.is_application_installed.return_value = False + + generic_runner = get_generic_runner( + DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend + ) + with pytest.raises( + Exception, + match="Application Generic Inference Runner: Ethos-U55/65 Shared SRAM" + r" for the system Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not " + r"installed", + ): + generic_runner.run(ModelInfo(test_tflite_model), []) + + +@pytest.fixture(name="backend_runner") +def fixture_backend_runner(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + """Mock backend runner.""" + backend_runner_mock = MagicMock(spec=BackendRunner) + monkeypatch.setattr( + "mlia.backend.manager.get_backend_runner", + MagicMock(return_value=backend_runner_mock), + ) + return backend_runner_mock diff --git a/tests/mlia/test_backend_output_parser.py b/tests/mlia/test_backend_output_parser.py new file mode 100644 index 0000000..d86aac8 --- /dev/null +++ b/tests/mlia/test_backend_output_parser.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the output parsing.""" +import base64 +import json +from typing import Any +from typing import Dict + +import pytest + +from mlia.backend.output_parser import Base64OutputParser +from mlia.backend.output_parser import OutputParser +from mlia.backend.output_parser import RegexOutputParser + + +OUTPUT_MATCH_ALL = bytearray( + """ +String1: My awesome string! +String2: STRINGS_ARE_GREAT!!! +Int: 12 +Float: 3.14 +""", + encoding="utf-8", +) + +OUTPUT_NO_MATCH = bytearray( + """ +This contains no matches... +Test1234567890!"£$%^&*()_+@~{}[]/.,<>?| +""", + encoding="utf-8", +) + +OUTPUT_PARTIAL_MATCH = bytearray( + "String1: My awesome string!", + encoding="utf-8", +) + +REGEX_CONFIG = { + "FirstString": {"pattern": r"String1.*: (.*)", "type": "str"}, + "SecondString": {"pattern": r"String2.*: (.*)!!!", "type": "str"}, + "IntegerValue": {"pattern": r"Int.*: (.*)", "type": "int"}, + "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"}, +} + +EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {} + +EXPECTED_METRICS_ALL = { + "FirstString": "My awesome string!", + "SecondString": "STRINGS_ARE_GREAT", + "IntegerValue": 12, + "FloatValue": 3.14, +} + +EXPECTED_METRICS_PARTIAL = { + "FirstString": "My awesome string!", +} + + +class TestRegexOutputParser: + """Collect tests for the RegexOutputParser.""" + + @staticmethod + @pytest.mark.parametrize( + ["output", "config", "expected_metrics"], + [ + (OUTPUT_MATCH_ALL, REGEX_CONFIG, EXPECTED_METRICS_ALL), + (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL), + (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL), + ( + OUTPUT_MATCH_ALL + OUTPUT_PARTIAL_MATCH, + REGEX_CONFIG, + EXPECTED_METRICS_ALL, + ), + (OUTPUT_NO_MATCH, REGEX_CONFIG, {}), + (OUTPUT_MATCH_ALL, EMPTY_REGEX_CONFIG, {}), + (bytearray(), EMPTY_REGEX_CONFIG, {}), + (bytearray(), REGEX_CONFIG, {}), + ], + ) + def test_parsing(output: bytearray, config: Dict, expected_metrics: Dict) -> None: + """ + Make sure the RegexOutputParser yields valid results. + + I.e. return an empty dict if either the input or the config is empty and + return the parsed metrics otherwise. + """ + parser = RegexOutputParser(name="Test", regex_config=config) + assert parser.name == "Test" + assert isinstance(parser, OutputParser) + res = parser(output) + assert res == expected_metrics + + @staticmethod + def test_unsupported_type() -> None: + """An unsupported type in the regex_config must raise an exception.""" + config = {"BrokenMetric": {"pattern": "(.*)", "type": "UNSUPPORTED_TYPE"}} + with pytest.raises(TypeError): + RegexOutputParser(name="Test", regex_config=config) + + @staticmethod + @pytest.mark.parametrize( + "config", + ( + {"TooManyGroups": {"pattern": r"(\w)(\d)", "type": "str"}}, + {"NoGroups": {"pattern": r"\W", "type": "str"}}, + ), + ) + def test_invalid_pattern(config: Dict) -> None: + """Exactly one capturing parenthesis is allowed in the regex pattern.""" + with pytest.raises(ValueError): + RegexOutputParser(name="Test", regex_config=config) + + +@pytest.mark.parametrize( + "expected_metrics", + [ + EXPECTED_METRICS_ALL, + EXPECTED_METRICS_PARTIAL, + ], +) +def test_base64_output_parser(expected_metrics: Dict) -> None: + """ + Make sure the Base64OutputParser yields valid results. + + I.e. return an empty dict if either the input or the config is empty and + return the parsed metrics otherwise. + """ + parser = Base64OutputParser(name="Test") + assert parser.name == "Test" + assert isinstance(parser, OutputParser) + + def create_base64_output(expected_metrics: Dict) -> bytearray: + json_str = json.dumps(expected_metrics, indent=4) + json_b64 = base64.b64encode(json_str.encode("utf-8")) + return ( + OUTPUT_MATCH_ALL # Should not be matched by the Base64OutputParser + + f"<{Base64OutputParser.TAG_NAME}>".encode("utf-8") + + bytearray(json_b64) + + f"".encode("utf-8") + + OUTPUT_NO_MATCH # Just to add some difficulty... + ) + + output = create_base64_output(expected_metrics) + res = parser(output) + assert len(res) == 1 + assert isinstance(res, dict) + for val in res.values(): + assert val == expected_metrics + + output = parser.filter_out_parsed_content(output) + assert output == (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH) diff --git a/tests/mlia/test_backend_proc.py b/tests/mlia/test_backend_proc.py new file mode 100644 index 0000000..9ca4788 --- /dev/null +++ b/tests/mlia/test_backend_proc.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=attribute-defined-outside-init,no-self-use,not-callable +"""Pytests for testing mlia/backend/proc.py.""" +from pathlib import Path +from typing import Any +from unittest import mock + +import psutil +import pytest +from sh import ErrorReturnCode + +from mlia.backend.proc import Command +from mlia.backend.proc import CommandFailedException +from mlia.backend.proc import CommandNotFound +from mlia.backend.proc import parse_command +from mlia.backend.proc import print_command_stdout +from mlia.backend.proc import run_and_wait +from mlia.backend.proc import save_process_info +from mlia.backend.proc import ShellCommand +from mlia.backend.proc import terminate_command +from mlia.backend.proc import terminate_external_process + + +class TestShellCommand: + """Sample class for collecting tests.""" + + def test_shellcommand_default_value(self) -> None: + """Test the instantiation of the class ShellCommand with no parameter.""" + shell_command = ShellCommand() + assert shell_command.base_log_path == "/tmp" + + @pytest.mark.parametrize( + "base_log_path,expected", [("/test", "/test"), ("/asd", "/asd")] + ) + def test_shellcommand_with_param(self, base_log_path: str, expected: str) -> None: + """Test init ShellCommand with different parameters.""" + shell_command = ShellCommand(base_log_path) + assert shell_command.base_log_path == expected + + def test_run_ls(self, monkeypatch: Any) -> None: + """Test a simple ls command.""" + mock_command = mock.MagicMock() + monkeypatch.setattr(Command, "bake", mock_command) + + mock_get_stdout_stderr_paths = mock.MagicMock() + mock_get_stdout_stderr_paths.return_value = ("/tmp/std.out", "/tmp/std.err") + monkeypatch.setattr( + ShellCommand, "get_stdout_stderr_paths", mock_get_stdout_stderr_paths + ) + + shell_command = ShellCommand() + shell_command.run("ls", "-l") + assert mock_command.mock_calls[0] == mock.call(("-l",)) + assert mock_command.mock_calls[1] == mock.call()( + _bg=True, _err="/tmp/std.err", _out="/tmp/std.out", _tee=True, _bg_exc=False + ) + + def test_run_command_not_found(self) -> None: + """Test whe the command doesn't exist.""" + shell_command = ShellCommand() + with pytest.raises(CommandNotFound): + shell_command.run("lsl", "-l") + + def test_get_stdout_stderr_paths_valid_path(self) -> None: + """Test the method to get files to store stdout and stderr.""" + valid_path = "/tmp" + shell_command = ShellCommand(valid_path) + out, err = shell_command.get_stdout_stderr_paths(valid_path, "cmd") + assert out.exists() and out.is_file() + assert err.exists() and err.is_file() + assert "cmd" in out.name + assert "cmd" in err.name + + def test_get_stdout_stderr_paths_not_invalid_path(self) -> None: + """Test the method to get output files with an invalid path.""" + invalid_path = "/invalid/foo/bar" + shell_command = ShellCommand(invalid_path) + with pytest.raises(FileNotFoundError): + shell_command.get_stdout_stderr_paths(invalid_path, "cmd") + + +@mock.patch("builtins.print") +def test_print_command_stdout_alive(mock_print: Any) -> None: + """Test the print command stdout with an alive (running) process.""" + mock_command = mock.MagicMock() + mock_command.is_alive.return_value = True + mock_command.next.side_effect = ["test1", "test2", StopIteration] + + print_command_stdout(mock_command) + + mock_command.assert_has_calls( + [mock.call.is_alive(), mock.call.next(), mock.call.next()] + ) + mock_print.assert_has_calls( + [mock.call("test1", end=""), mock.call("test2", end="")] + ) + + +@mock.patch("builtins.print") +def test_print_command_stdout_not_alive(mock_print: Any) -> None: + """Test the print command stdout with a not alive (exited) process.""" + mock_command = mock.MagicMock() + mock_command.is_alive.return_value = False + mock_command.stdout = "test" + + print_command_stdout(mock_command) + mock_command.assert_has_calls([mock.call.is_alive()]) + mock_print.assert_called_once_with("test") + + +def test_terminate_external_process_no_process(capsys: Any) -> None: + """Test that non existed process could be terminated.""" + mock_command = mock.MagicMock() + mock_command.terminate.side_effect = psutil.Error("Error!") + + terminate_external_process(mock_command) + captured = capsys.readouterr() + assert captured.out == "Unable to terminate process\n" + + +def test_terminate_external_process_case1() -> None: + """Test when process terminated immediately.""" + mock_command = mock.MagicMock() + mock_command.is_running.return_value = False + + terminate_external_process(mock_command) + mock_command.terminate.assert_called_once() + mock_command.is_running.assert_called_once() + + +def test_terminate_external_process_case2() -> None: + """Test when process termination takes time.""" + mock_command = mock.MagicMock() + mock_command.is_running.side_effect = [True, True, False] + + terminate_external_process(mock_command) + mock_command.terminate.assert_called_once() + assert mock_command.is_running.call_count == 3 + + +def test_terminate_external_process_case3() -> None: + """Test when process termination takes more time.""" + mock_command = mock.MagicMock() + mock_command.is_running.side_effect = [True, True, True] + + terminate_external_process( + mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1 + ) + assert mock_command.is_running.call_count == 3 + assert mock_command.terminate.call_count == 2 + + +def test_terminate_external_process_case4() -> None: + """Test when process termination takes more time.""" + mock_command = mock.MagicMock() + mock_command.is_running.side_effect = [True, True, False] + + terminate_external_process( + mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1 + ) + mock_command.terminate.assert_called_once() + assert mock_command.is_running.call_count == 3 + assert mock_command.terminate.call_count == 1 + + +def test_terminate_command_no_process() -> None: + """Test command termination when process does not exist.""" + mock_command = mock.MagicMock() + mock_command.process.signal_group.side_effect = ProcessLookupError() + + terminate_command(mock_command) + mock_command.process.signal_group.assert_called_once() + mock_command.is_alive.assert_not_called() + + +def test_terminate_command() -> None: + """Test command termination.""" + mock_command = mock.MagicMock() + mock_command.is_alive.return_value = False + + terminate_command(mock_command) + mock_command.process.signal_group.assert_called_once() + + +def test_terminate_command_case1() -> None: + """Test command termination when it takes time..""" + mock_command = mock.MagicMock() + mock_command.is_alive.side_effect = [True, True, False] + + terminate_command(mock_command, wait_period=0.1) + mock_command.process.signal_group.assert_called_once() + assert mock_command.is_alive.call_count == 3 + + +def test_terminate_command_case2() -> None: + """Test command termination when it takes much time..""" + mock_command = mock.MagicMock() + mock_command.is_alive.side_effect = [True, True, True] + + terminate_command(mock_command, number_of_attempts=3, wait_period=0.1) + assert mock_command.is_alive.call_count == 3 + assert mock_command.process.signal_group.call_count == 2 + + +class TestRunAndWait: + """Test run_and_wait function.""" + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch: Any) -> None: + """Init test method.""" + self.execute_command_mock = mock.MagicMock() + monkeypatch.setattr( + "mlia.backend.proc.execute_command", self.execute_command_mock + ) + + self.terminate_command_mock = mock.MagicMock() + monkeypatch.setattr( + "mlia.backend.proc.terminate_command", self.terminate_command_mock + ) + + def test_if_execute_command_raises_exception(self) -> None: + """Test if execute_command fails.""" + self.execute_command_mock.side_effect = Exception("Error!") + with pytest.raises(Exception, match="Error!"): + run_and_wait("command", Path.cwd()) + + def test_if_command_finishes_with_error(self) -> None: + """Test if command finishes with error.""" + cmd_mock = mock.MagicMock() + self.execute_command_mock.return_value = cmd_mock + exit_code_mock = mock.PropertyMock( + side_effect=ErrorReturnCode("cmd", bytearray(), bytearray()) + ) + type(cmd_mock).exit_code = exit_code_mock + + with pytest.raises(CommandFailedException): + run_and_wait("command", Path.cwd()) + + @pytest.mark.parametrize("terminate_on_error, call_count", ((False, 0), (True, 1))) + def test_if_command_finishes_with_exception( + self, terminate_on_error: bool, call_count: int + ) -> None: + """Test if command finishes with error.""" + cmd_mock = mock.MagicMock() + self.execute_command_mock.return_value = cmd_mock + exit_code_mock = mock.PropertyMock(side_effect=Exception("Error!")) + type(cmd_mock).exit_code = exit_code_mock + + with pytest.raises(Exception, match="Error!"): + run_and_wait("command", Path.cwd(), terminate_on_error=terminate_on_error) + + assert self.terminate_command_mock.call_count == call_count + + +def test_save_process_info_no_process(monkeypatch: Any, tmpdir: Any) -> None: + """Test save_process_info function.""" + mock_process = mock.MagicMock() + monkeypatch.setattr("psutil.Process", mock.MagicMock(return_value=mock_process)) + mock_process.children.side_effect = psutil.NoSuchProcess(555) + + pid_file_path = Path(tmpdir) / "test.pid" + save_process_info(555, pid_file_path) + assert not pid_file_path.exists() + + +def test_parse_command() -> None: + """Test parse_command function.""" + assert parse_command("1.sh") == ["bash", "1.sh"] + assert parse_command("1.sh", shell="sh") == ["sh", "1.sh"] + assert parse_command("command") == ["command"] + assert parse_command("command 123 --param=1") == ["command", "123", "--param=1"] diff --git a/tests/mlia/test_backend_protocol.py b/tests/mlia/test_backend_protocol.py new file mode 100644 index 0000000..35e9986 --- /dev/null +++ b/tests/mlia/test_backend_protocol.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use,attribute-defined-outside-init,protected-access +"""Tests for the protocol backend module.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import paramiko +import pytest + +from mlia.backend.common import ConfigurationException +from mlia.backend.config import LocalProtocolConfig +from mlia.backend.protocol import CustomSFTPClient +from mlia.backend.protocol import LocalProtocol +from mlia.backend.protocol import ProtocolFactory +from mlia.backend.protocol import SSHProtocol + + +class TestProtocolFactory: + """Test ProtocolFactory class.""" + + @pytest.mark.parametrize( + "config, expected_class, exception", + [ + ( + { + "protocol": "ssh", + "username": "user", + "password": "pass", + "hostname": "hostname", + "port": "22", + }, + SSHProtocol, + does_not_raise(), + ), + ({"protocol": "local"}, LocalProtocol, does_not_raise()), + ( + {"protocol": "something"}, + None, + pytest.raises(Exception, match="Protocol not supported"), + ), + (None, None, pytest.raises(Exception, match="No protocol config provided")), + ], + ) + def test_get_protocol( + self, config: Any, expected_class: type, exception: Any + ) -> None: + """Test get_protocol method.""" + factory = ProtocolFactory() + with exception: + protocol = factory.get_protocol(config) + assert isinstance(protocol, expected_class) + + +class TestLocalProtocol: + """Test local protocol.""" + + def test_local_protocol_run_command(self) -> None: + """Test local protocol run command.""" + config = LocalProtocolConfig(protocol="local") + protocol = LocalProtocol(config, cwd=Path("/tmp")) + ret, stdout, stderr = protocol.run("pwd") + assert ret == 0 + assert stdout.decode("utf-8").strip() == "/tmp" + assert stderr.decode("utf-8") == "" + + def test_local_protocol_run_wrong_cwd(self) -> None: + """Execution should fail if wrong working directory provided.""" + config = LocalProtocolConfig(protocol="local") + protocol = LocalProtocol(config, cwd=Path("unknown_directory")) + with pytest.raises( + ConfigurationException, match="Wrong working directory unknown_directory" + ): + protocol.run("pwd") + + +class TestSSHProtocol: + """Test SSH protocol.""" + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch: Any) -> None: + """Set up protocol mocks.""" + self.mock_ssh_client = MagicMock(spec=paramiko.client.SSHClient) + + self.mock_ssh_channel = ( + self.mock_ssh_client.get_transport.return_value.open_session.return_value + ) + self.mock_ssh_channel.mock_add_spec(spec=paramiko.channel.Channel) + self.mock_ssh_channel.exit_status_ready.side_effect = [False, True] + self.mock_ssh_channel.recv_exit_status.return_value = True + self.mock_ssh_channel.recv_ready.side_effect = [False, True] + self.mock_ssh_channel.recv_stderr_ready.side_effect = [False, True] + + monkeypatch.setattr( + "mlia.backend.protocol.paramiko.client.SSHClient", + MagicMock(return_value=self.mock_ssh_client), + ) + + self.mock_sftp_client = MagicMock(spec=CustomSFTPClient) + monkeypatch.setattr( + "mlia.backend.protocol.CustomSFTPClient.from_transport", + MagicMock(return_value=self.mock_sftp_client), + ) + + ssh_config = { + "protocol": "ssh", + "username": "user", + "password": "pass", + "hostname": "hostname", + "port": "22", + } + self.protocol = SSHProtocol(ssh_config) + + def test_unable_create_ssh_client(self, monkeypatch: Any) -> None: + """Test that command should fail if unable to create ssh client instance.""" + monkeypatch.setattr( + "mlia.backend.protocol.paramiko.client.SSHClient", + MagicMock(side_effect=OSError("Error!")), + ) + + with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"): + self.protocol.run("command_example", retry=False) + + def test_ssh_protocol_run_command(self) -> None: + """Test that command run via ssh successfully.""" + self.protocol.run("command_example") + self.mock_ssh_channel.exec_command.assert_called_once() + + def test_ssh_protocol_run_command_connect_failed(self) -> None: + """Test that if connection is not possible then correct exception is raised.""" + self.mock_ssh_client.connect.side_effect = OSError("Unable to connect") + self.mock_ssh_client.close.side_effect = Exception("Error!") + + with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"): + self.protocol.run("command_example", retry=False) + + def test_ssh_protocol_run_command_bad_transport(self) -> None: + """Test that command should fail if unable to get transport.""" + self.mock_ssh_client.get_transport.return_value = None + + with pytest.raises(Exception, match="Unable to get transport"): + self.protocol.run("command_example", retry=False) + + def test_ssh_protocol_deploy_command_file( + self, test_applications_path: Path + ) -> None: + """Test that files could be deployed over ssh.""" + file_for_deploy = test_applications_path / "readme.txt" + dest = "/tmp/dest" + + self.protocol.deploy(file_for_deploy, dest) + self.mock_sftp_client.put.assert_called_once_with(str(file_for_deploy), dest) + + def test_ssh_protocol_deploy_command_unknown_file(self) -> None: + """Test that deploy will fail if file does not exist.""" + with pytest.raises(Exception, match="Deploy error: file type not supported"): + self.protocol.deploy(Path("unknown_file"), "/tmp/dest") + + def test_ssh_protocol_deploy_command_bad_transport(self) -> None: + """Test that deploy should fail if unable to get transport.""" + self.mock_ssh_client.get_transport.return_value = None + + with pytest.raises(Exception, match="Unable to get transport"): + self.protocol.deploy(Path("some_file"), "/tmp/dest") + + def test_ssh_protocol_deploy_command_directory( + self, test_resources_path: Path + ) -> None: + """Test that directory could be deployed over ssh.""" + directory_for_deploy = test_resources_path / "scripts" + dest = "/tmp/dest" + + self.protocol.deploy(directory_for_deploy, dest) + self.mock_sftp_client.put_dir.assert_called_once_with( + directory_for_deploy, dest + ) + + @pytest.mark.parametrize("establish_connection", (True, False)) + def test_ssh_protocol_close(self, establish_connection: bool) -> None: + """Test protocol close operation.""" + if establish_connection: + self.protocol.establish_connection() + self.protocol.close() + + call_count = 1 if establish_connection else 0 + assert self.mock_ssh_channel.exec_command.call_count == call_count + + def test_connection_details(self) -> None: + """Test getting connection details.""" + assert self.protocol.connection_details() == ("hostname", 22) + + +class TestCustomSFTPClient: + """Test CustomSFTPClient class.""" + + @pytest.fixture(autouse=True) + def setup_method(self, monkeypatch: Any) -> None: + """Set up mocks for CustomSFTPClient instance.""" + self.mock_mkdir = MagicMock() + self.mock_put = MagicMock() + monkeypatch.setattr( + "mlia.backend.protocol.paramiko.SFTPClient.__init__", + MagicMock(return_value=None), + ) + monkeypatch.setattr( + "mlia.backend.protocol.paramiko.SFTPClient.mkdir", self.mock_mkdir + ) + monkeypatch.setattr( + "mlia.backend.protocol.paramiko.SFTPClient.put", self.mock_put + ) + + self.sftp_client = CustomSFTPClient(MagicMock()) + + def test_put_dir(self, test_systems_path: Path) -> None: + """Test deploying directory to remote host.""" + directory_for_deploy = test_systems_path / "system1" + + self.sftp_client.put_dir(directory_for_deploy, "/tmp/dest") + assert self.mock_put.call_count == 3 + assert self.mock_mkdir.call_count == 3 + + def test_mkdir(self) -> None: + """Test creating directory on remote host.""" + self.mock_mkdir.side_effect = IOError("Cannot create directory") + + self.sftp_client._mkdir("new_directory", ignore_existing=True) + + with pytest.raises(IOError, match="Cannot create directory"): + self.sftp_client._mkdir("new_directory", ignore_existing=False) diff --git a/tests/mlia/test_backend_source.py b/tests/mlia/test_backend_source.py new file mode 100644 index 0000000..84a6a77 --- /dev/null +++ b/tests/mlia/test_backend_source.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=no-self-use +"""Tests for the source backend module.""" +from collections import Counter +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from mlia.backend.common import ConfigurationException +from mlia.backend.source import create_destination_and_install +from mlia.backend.source import DirectorySource +from mlia.backend.source import get_source +from mlia.backend.source import TarArchiveSource + + +def test_create_destination_and_install(test_systems_path: Path, tmpdir: Any) -> None: + """Test create_destination_and_install function.""" + system_directory = test_systems_path / "system1" + + dir_source = DirectorySource(system_directory) + resources = Path(tmpdir) + create_destination_and_install(dir_source, resources) + assert (resources / "system1").is_dir() + + +@patch("mlia.backend.source.DirectorySource.create_destination", return_value=False) +def test_create_destination_and_install_if_dest_creation_not_required( + mock_ds_create_destination: Any, tmpdir: Any +) -> None: + """Test create_destination_and_install function.""" + dir_source = DirectorySource(Path("unknown")) + resources = Path(tmpdir) + with pytest.raises(Exception): + create_destination_and_install(dir_source, resources) + + mock_ds_create_destination.assert_called_once() + + +def test_create_destination_and_install_if_installation_fails(tmpdir: Any) -> None: + """Test create_destination_and_install function if installation fails.""" + dir_source = DirectorySource(Path("unknown")) + resources = Path(tmpdir) + with pytest.raises(Exception, match="Directory .* does not exist"): + create_destination_and_install(dir_source, resources) + assert not (resources / "unknown").exists() + assert resources.exists() + + +def test_create_destination_and_install_if_name_is_empty() -> None: + """Test create_destination_and_install function fails if source name is empty.""" + source = MagicMock() + source.create_destination.return_value = True + source.name.return_value = None + + with pytest.raises(Exception, match="Unable to get source name"): + create_destination_and_install(source, Path("some_path")) + + source.install_into.assert_not_called() + + +@pytest.mark.parametrize( + "source_path, expected_class, expected_error", + [ + ( + Path("backends/applications/application1/"), + DirectorySource, + does_not_raise(), + ), + ( + Path("archives/applications/application1.tar.gz"), + TarArchiveSource, + does_not_raise(), + ), + ( + Path("doesnt/exist"), + None, + pytest.raises( + ConfigurationException, match="Unable to read .*doesnt/exist" + ), + ), + ], +) +def test_get_source( + source_path: Path, + expected_class: Any, + expected_error: Any, + test_resources_path: Path, +) -> None: + """Test get_source function.""" + with expected_error: + full_source_path = test_resources_path / source_path + source = get_source(full_source_path) + assert isinstance(source, expected_class) + + +class TestDirectorySource: + """Test DirectorySource class.""" + + @pytest.mark.parametrize( + "directory, name", + [ + (Path("/some/path/some_system"), "some_system"), + (Path("some_system"), "some_system"), + ], + ) + def test_name(self, directory: Path, name: str) -> None: + """Test getting source name.""" + assert DirectorySource(directory).name() == name + + def test_install_into(self, test_systems_path: Path, tmpdir: Any) -> None: + """Test install directory into destination.""" + system_directory = test_systems_path / "system1" + + dir_source = DirectorySource(system_directory) + with pytest.raises(Exception, match="Wrong destination .*"): + dir_source.install_into(Path("unknown_destination")) + + tmpdir_path = Path(tmpdir) + dir_source.install_into(tmpdir_path) + source_files = [f.name for f in system_directory.iterdir()] + dest_files = [f.name for f in tmpdir_path.iterdir()] + assert Counter(source_files) == Counter(dest_files) + + def test_install_into_unknown_source_directory(self, tmpdir: Any) -> None: + """Test install system from unknown directory.""" + with pytest.raises(Exception, match="Directory .* does not exist"): + DirectorySource(Path("unknown_directory")).install_into(Path(tmpdir)) + + +class TestTarArchiveSource: + """Test TarArchiveSource class.""" + + @pytest.mark.parametrize( + "archive, name", + [ + (Path("some_archive.tgz"), "some_archive"), + (Path("some_archive.tar.gz"), "some_archive"), + (Path("some_archive"), "some_archive"), + ("archives/systems/system1.tar.gz", "system1"), + ("archives/systems/system1_dir.tar.gz", "system1"), + ], + ) + def test_name(self, test_resources_path: Path, archive: Path, name: str) -> None: + """Test getting source name.""" + assert TarArchiveSource(test_resources_path / archive).name() == name + + def test_install_into(self, test_resources_path: Path, tmpdir: Any) -> None: + """Test install archive into destination.""" + system_archive = test_resources_path / "archives/systems/system1.tar.gz" + + tar_source = TarArchiveSource(system_archive) + with pytest.raises(Exception, match="Wrong destination .*"): + tar_source.install_into(Path("unknown_destination")) + + tmpdir_path = Path(tmpdir) + tar_source.install_into(tmpdir_path) + source_files = [ + "aiet-config.json.license", + "aiet-config.json", + "system_artifact", + ] + dest_files = [f.name for f in tmpdir_path.iterdir()] + assert Counter(source_files) == Counter(dest_files) + + def test_install_into_unknown_source_archive(self, tmpdir: Any) -> None: + """Test install unknown source archive.""" + with pytest.raises(Exception, match="File .* does not exist"): + TarArchiveSource(Path("unknown.tar.gz")).install_into(Path(tmpdir)) + + def test_install_into_unsupported_source_archive(self, tmpdir: Any) -> None: + """Test install unsupported file type.""" + plain_text_file = Path(tmpdir) / "test_file" + plain_text_file.write_text("Not a system config") + + with pytest.raises(Exception, match="Unsupported archive type .*"): + TarArchiveSource(plain_text_file).install_into(Path(tmpdir)) + + def test_lazy_property_init(self, test_resources_path: Path) -> None: + """Test that class properties initialized correctly.""" + system_archive = test_resources_path / "archives/systems/system1.tar.gz" + + tar_source = TarArchiveSource(system_archive) + assert tar_source.name() == "system1" + assert tar_source.config() is not None + assert tar_source.create_destination() + + tar_source = TarArchiveSource(system_archive) + assert tar_source.config() is not None + assert tar_source.create_destination() + assert tar_source.name() == "system1" + + def test_create_destination_property(self, test_resources_path: Path) -> None: + """Test create_destination property filled correctly for different archives.""" + system_archive1 = test_resources_path / "archives/systems/system1.tar.gz" + system_archive2 = test_resources_path / "archives/systems/system1_dir.tar.gz" + + assert TarArchiveSource(system_archive1).create_destination() + assert not TarArchiveSource(system_archive2).create_destination() diff --git a/tests/mlia/test_backend_system.py b/tests/mlia/test_backend_system.py new file mode 100644 index 0000000..21187ff --- /dev/null +++ b/tests/mlia/test_backend_system.py @@ -0,0 +1,541 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for system backend.""" +from contextlib import ExitStack as does_not_raise +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from unittest.mock import MagicMock + +import pytest + +from mlia.backend.common import Command +from mlia.backend.common import ConfigurationException +from mlia.backend.common import Param +from mlia.backend.common import UserParamConfig +from mlia.backend.config import LocalProtocolConfig +from mlia.backend.config import ProtocolConfig +from mlia.backend.config import SSHConfig +from mlia.backend.config import SystemConfig +from mlia.backend.controller import SystemController +from mlia.backend.controller import SystemControllerSingleInstance +from mlia.backend.protocol import LocalProtocol +from mlia.backend.protocol import SSHProtocol +from mlia.backend.protocol import SupportsClose +from mlia.backend.protocol import SupportsDeploy +from mlia.backend.system import ControlledSystem +from mlia.backend.system import get_available_systems +from mlia.backend.system import get_controller +from mlia.backend.system import get_system +from mlia.backend.system import install_system +from mlia.backend.system import load_system +from mlia.backend.system import remove_system +from mlia.backend.system import StandaloneSystem +from mlia.backend.system import System + + +def dummy_resolver( + values: Optional[Dict[str, str]] = None +) -> Callable[[str, str, List[Tuple[Optional[str], Param]]], str]: + """Return dummy parameter resolver implementation.""" + # pylint: disable=unused-argument + def resolver( + param: str, cmd: str, param_values: List[Tuple[Optional[str], Param]] + ) -> str: + """Implement dummy parameter resolver.""" + return values.get(param, "") if values else "" + + return resolver + + +def test_get_available_systems() -> None: + """Test get_available_systems mocking get_resources.""" + available_systems = get_available_systems() + assert all(isinstance(s, System) for s in available_systems) + assert len(available_systems) == 4 + assert [str(s) for s in available_systems] == [ + "System 1", + "System 2", + "System 4", + "System 6", + ] + + +def test_get_system() -> None: + """Test get_system.""" + system1 = get_system("System 1") + assert isinstance(system1, ControlledSystem) + assert system1.connectable is True + assert system1.connection_details() == ("localhost", 8021) + assert system1.name == "System 1" + + system2 = get_system("System 2") + # check that comparison with object of another type returns false + assert system1 != 42 + assert system1 != system2 + + system = get_system("Unknown system") + assert system is None + + +@pytest.mark.parametrize( + "source, call_count, exception_type", + ( + ( + "archives/systems/system1.tar.gz", + 0, + pytest.raises(Exception, match="Systems .* are already installed"), + ), + ( + "archives/systems/system3.tar.gz", + 0, + pytest.raises(Exception, match="Unable to read system definition"), + ), + ( + "backends/systems/system1", + 0, + pytest.raises(Exception, match="Systems .* are already installed"), + ), + ( + "backends/systems/system3", + 0, + pytest.raises(Exception, match="Unable to read system definition"), + ), + ("unknown_path", 0, pytest.raises(Exception, match="Unable to read")), + ( + "various/systems/system_with_empty_config", + 0, + pytest.raises(Exception, match="No system definition found"), + ), + ("various/systems/system_with_valid_config", 1, does_not_raise()), + ), +) +def test_install_system( + monkeypatch: Any, + test_resources_path: Path, + source: str, + call_count: int, + exception_type: Any, +) -> None: + """Test system installation from archive.""" + mock_create_destination_and_install = MagicMock() + monkeypatch.setattr( + "mlia.backend.system.create_destination_and_install", + mock_create_destination_and_install, + ) + + with exception_type: + install_system(test_resources_path / source) + + assert mock_create_destination_and_install.call_count == call_count + + +def test_remove_system(monkeypatch: Any) -> None: + """Test system removal.""" + mock_remove_backend = MagicMock() + monkeypatch.setattr("mlia.backend.system.remove_backend", mock_remove_backend) + remove_system("some_system_dir") + mock_remove_backend.assert_called_once() + + +def test_system(monkeypatch: Any) -> None: + """Test the System class.""" + config = SystemConfig(name="System 1") + monkeypatch.setattr("mlia.backend.system.ProtocolFactory", MagicMock()) + system = System(config) + assert str(system) == "System 1" + assert system.name == "System 1" + + +def test_system_with_empty_parameter_name() -> None: + """Test that configuration fails if parameter name is empty.""" + bad_config = SystemConfig( + name="System 1", + commands={"run": ["run"]}, + user_params={"run": [{"name": "", "values": ["1", "2", "3"]}]}, + ) + with pytest.raises(Exception, match="Parameter has an empty 'name' attribute."): + System(bad_config) + + +def test_system_standalone_run() -> None: + """Test run operation for standalone system.""" + system = get_system("System 4") + assert isinstance(system, StandaloneSystem) + + with pytest.raises( + ConfigurationException, match="System .* does not support connections" + ): + system.connection_details() + + with pytest.raises( + ConfigurationException, match="System .* does not support connections" + ): + system.establish_connection() + + assert system.connectable is False + + system.run("echo 'application run'") + + +@pytest.mark.parametrize( + "system_name, expected_value", [("System 1", True), ("System 4", False)] +) +def test_system_supports_deploy(system_name: str, expected_value: bool) -> None: + """Test system property supports_deploy.""" + system = get_system(system_name) + if system is None: + pytest.fail("Unable to get system {}".format(system_name)) + assert system.supports_deploy == expected_value + + +@pytest.mark.parametrize( + "mock_protocol", + [ + MagicMock(spec=SSHProtocol), + MagicMock( + spec=SSHProtocol, + **{"close.side_effect": ValueError("Unable to close protocol")} + ), + MagicMock(spec=LocalProtocol), + ], +) +def test_system_start_and_stop(monkeypatch: Any, mock_protocol: MagicMock) -> None: + """Test system start, run commands and stop.""" + monkeypatch.setattr( + "mlia.backend.system.ProtocolFactory.get_protocol", + MagicMock(return_value=mock_protocol), + ) + + system = get_system("System 1") + if system is None: + pytest.fail("Unable to get system") + assert isinstance(system, ControlledSystem) + + with pytest.raises(Exception, match="System has not been started"): + system.stop() + + assert not system.is_running() + assert system.get_output() == ("", "") + system.start(["sleep 10"], False) + assert system.is_running() + system.stop(wait=True) + assert not system.is_running() + assert system.get_output() == ("", "") + + if isinstance(mock_protocol, SupportsClose): + mock_protocol.close.assert_called_once() + + if isinstance(mock_protocol, SSHProtocol): + system.establish_connection() + + +def test_system_start_no_config_location() -> None: + """Test that system without config location could not start.""" + system = load_system( + SystemConfig( + name="test", + data_transfer=SSHConfig( + protocol="ssh", + username="user", + password="user", + hostname="localhost", + port="123", + ), + ) + ) + + assert isinstance(system, ControlledSystem) + with pytest.raises( + ConfigurationException, match="System test has wrong config location" + ): + system.start(["sleep 100"]) + + +@pytest.mark.parametrize( + "config, expected_class, expected_error", + [ + ( + SystemConfig( + name="test", + data_transfer=SSHConfig( + protocol="ssh", + username="user", + password="user", + hostname="localhost", + port="123", + ), + ), + ControlledSystem, + does_not_raise(), + ), + ( + SystemConfig( + name="test", data_transfer=LocalProtocolConfig(protocol="local") + ), + StandaloneSystem, + does_not_raise(), + ), + ( + SystemConfig( + name="test", + data_transfer=ProtocolConfig(protocol="cool_protocol"), # type: ignore + ), + None, + pytest.raises( + Exception, match="Unsupported execution type for protocol cool_protocol" + ), + ), + ], +) +def test_load_system( + config: SystemConfig, expected_class: type, expected_error: Any +) -> None: + """Test load_system function.""" + if not expected_class: + with expected_error: + load_system(config) + else: + system = load_system(config) + assert isinstance(system, expected_class) + + +def test_load_system_populate_shared_params() -> None: + """Test shared parameters population.""" + with pytest.raises(Exception, match="All shared parameters should have aliases"): + load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + ) + ] + }, + ) + ) + + with pytest.raises( + Exception, match="All parameters for command run should have aliases" + ): + load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + alias="shared_param1", + ) + ], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + ) + ], + }, + ) + ) + system0 = load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"run": ["run_command"]}, + user_params={ + "shared": [], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + alias="run_param1", + ) + ], + }, + ) + ) + assert len(system0.commands) == 1 + run_command1 = system0.commands["run"] + assert run_command1 == Command( + ["run_command"], + [ + Param( + "--run_param1", + "Run specific parameter", + ["1", "2", "3"], + "2", + "run_param1", + ) + ], + ) + + system1 = load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + alias="shared_param1", + ) + ], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + alias="run_param1", + ) + ], + }, + ) + ) + assert len(system1.commands) == 2 + build_command1 = system1.commands["build"] + assert build_command1 == Command( + [], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ) + ], + ) + + run_command1 = system1.commands["run"] + assert run_command1 == Command( + [], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ), + Param( + "--run_param1", + "Run specific parameter", + ["1", "2", "3"], + "2", + "run_param1", + ), + ], + ) + + system2 = load_system( + SystemConfig( + name="test_system", + data_transfer=LocalProtocolConfig(protocol="local"), + commands={"build": ["build_command"]}, + user_params={ + "shared": [ + UserParamConfig( + name="--shared_param1", + description="Shared parameter", + values=["1", "2", "3"], + default_value="1", + alias="shared_param1", + ) + ], + "run": [ + UserParamConfig( + name="--run_param1", + description="Run specific parameter", + values=["1", "2", "3"], + default_value="2", + alias="run_param1", + ) + ], + }, + ) + ) + assert len(system2.commands) == 2 + build_command2 = system2.commands["build"] + assert build_command2 == Command( + ["build_command"], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ) + ], + ) + + run_command2 = system1.commands["run"] + assert run_command2 == Command( + [], + [ + Param( + "--shared_param1", + "Shared parameter", + ["1", "2", "3"], + "1", + "shared_param1", + ), + Param( + "--run_param1", + "Run specific parameter", + ["1", "2", "3"], + "2", + "run_param1", + ), + ], + ) + + +@pytest.mark.parametrize( + "mock_protocol, expected_call_count", + [(MagicMock(spec=SupportsDeploy), 1), (MagicMock(), 0)], +) +def test_system_deploy_data( + monkeypatch: Any, mock_protocol: MagicMock, expected_call_count: int +) -> None: + """Test deploy data functionality.""" + monkeypatch.setattr( + "mlia.backend.system.ProtocolFactory.get_protocol", + MagicMock(return_value=mock_protocol), + ) + + system = ControlledSystem(SystemConfig(name="test")) + system.deploy(Path("some_file"), "some_dest") + + assert mock_protocol.deploy.call_count == expected_call_count + + +@pytest.mark.parametrize( + "single_instance, controller_class", + ((False, SystemController), (True, SystemControllerSingleInstance)), +) +def test_get_controller(single_instance: bool, controller_class: type) -> None: + """Test function get_controller.""" + controller = get_controller(single_instance) + assert isinstance(controller, controller_class) diff --git a/tests/mlia/test_cli_logging.py b/tests/mlia/test_cli_logging.py index 7c5f299..3f59cb6 100644 --- a/tests/mlia/test_cli_logging.py +++ b/tests/mlia/test_cli_logging.py @@ -32,7 +32,7 @@ def teardown_function() -> None: ( None, True, - """mlia.tools.aiet_wrapper - aiet debug + """mlia.backend.manager - backends debug cli info mlia.cli - cli debug """, @@ -41,11 +41,11 @@ mlia.cli - cli debug ( "logs", True, - """mlia.tools.aiet_wrapper - aiet debug + """mlia.backend.manager - backends debug cli info mlia.cli - cli debug """, - """mlia.tools.aiet_wrapper - DEBUG - aiet debug + """mlia.backend.manager - DEBUG - backends debug mlia.cli - DEBUG - cli debug """, ), @@ -64,8 +64,8 @@ def test_setup_logging( setup_logging(logs_dir_path, verbose) - aiet_logger = logging.getLogger("mlia.tools.aiet_wrapper") - aiet_logger.debug("aiet debug") + backend_logger = logging.getLogger("mlia.backend.manager") + backend_logger.debug("backends debug") cli_logger = logging.getLogger("mlia.cli") cli_logger.info("cli info") diff --git a/tests/mlia/test_devices_ethosu_performance.py b/tests/mlia/test_devices_ethosu_performance.py index e27efa0..b3e5298 100644 --- a/tests/mlia/test_devices_ethosu_performance.py +++ b/tests/mlia/test_devices_ethosu_performance.py @@ -23,6 +23,6 @@ def test_memory_usage_conversion() -> None: def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None: """Mock performance estimation.""" monkeypatch.setattr( - "mlia.tools.aiet_wrapper.estimate_performance", + "mlia.backend.manager.estimate_performance", MagicMock(return_value=MagicMock()), ) diff --git a/tests/mlia/test_resources/application_config.json b/tests/mlia/test_resources/application_config.json new file mode 100644 index 0000000..2dfcfec --- /dev/null +++ b/tests/mlia/test_resources/application_config.json @@ -0,0 +1,96 @@ +[ + { + "name": "application_1", + "description": "application number one", + "supported_systems": [ + "system_1", + "system_2" + ], + "build_dir": "build_dir_11", + "commands": { + "clean": [ + "clean_cmd_11" + ], + "build": [ + "build_cmd_11" + ], + "run": [ + "run_cmd_11" + ], + "post_run": [ + "post_run_cmd_11" + ] + }, + "user_params": { + "run": [ + { + "name": "run_param_11", + "values": [], + "description": "run param number one" + } + ], + "build": [ + { + "name": "build_param_11", + "values": [], + "description": "build param number one" + }, + { + "name": "build_param_12", + "values": [], + "description": "build param number two" + }, + { + "name": "build_param_13", + "values": [ + "value_1" + ], + "description": "build param number three with some value" + } + ] + } + }, + { + "name": "application_2", + "description": "application number two", + "supported_systems": [ + "system_2" + ], + "build_dir": "build_dir_21", + "commands": { + "clean": [ + "clean_cmd_21" + ], + "build": [ + "build_cmd_21", + "build_cmd_22" + ], + "run": [ + "run_cmd_21" + ], + "post_run": [ + "post_run_cmd_21" + ] + }, + "user_params": { + "build": [ + { + "name": "build_param_21", + "values": [], + "description": "build param number one" + }, + { + "name": "build_param_22", + "values": [], + "description": "build param number two" + }, + { + "name": "build_param_23", + "values": [], + "description": "build param number three" + } + ], + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/application_config.json.license b/tests/mlia/test_resources/application_config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/application_config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/backends/applications/application1/aiet-config.json b/tests/mlia/test_resources/backends/applications/application1/aiet-config.json new file mode 100644 index 0000000..97f0401 --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application1/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "application_1", + "description": "This is application 1", + "supported_systems": [ + { + "name": "System 1" + } + ], + "build_dir": "build", + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/backends/applications/application1/aiet-config.json.license b/tests/mlia/test_resources/backends/applications/application1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/backends/applications/application2/aiet-config.json b/tests/mlia/test_resources/backends/applications/application2/aiet-config.json new file mode 100644 index 0000000..e9122d3 --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application2/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "application_2", + "description": "This is application 2", + "supported_systems": [ + { + "name": "System 2" + } + ], + "build_dir": "build", + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/backends/applications/application2/aiet-config.json.license b/tests/mlia/test_resources/backends/applications/application2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/backends/applications/application3/readme.txt b/tests/mlia/test_resources/backends/applications/application3/readme.txt new file mode 100644 index 0000000..8c72c05 --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application3/readme.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +This application does not have json configuration file diff --git a/tests/mlia/test_resources/backends/applications/application4/aiet-config.json b/tests/mlia/test_resources/backends/applications/application4/aiet-config.json new file mode 100644 index 0000000..ffb5746 --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application4/aiet-config.json @@ -0,0 +1,36 @@ +[ + { + "name": "application_4", + "description": "This is application 4", + "build_dir": "build", + "supported_systems": [ + { + "name": "System 4" + } + ], + "commands": { + "build": [ + "cp ../hello_app.txt .", + "echo '{user_params:0}' > params.txt" + ], + "run": [ + "cat {application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/backends/applications/application4/aiet-config.json.license b/tests/mlia/test_resources/backends/applications/application4/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application4/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/backends/applications/application4/hello_app.txt b/tests/mlia/test_resources/backends/applications/application4/hello_app.txt new file mode 100644 index 0000000..2ec0d1d --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application4/hello_app.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +Hello from APP! diff --git a/tests/mlia/test_resources/backends/applications/application5/aiet-config.json b/tests/mlia/test_resources/backends/applications/application5/aiet-config.json new file mode 100644 index 0000000..5269409 --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application5/aiet-config.json @@ -0,0 +1,160 @@ +[ + { + "name": "application_5", + "description": "This is application 5", + "build_dir": "default_build_dir", + "supported_systems": [ + { + "name": "System 1", + "lock": false + }, + { + "name": "System 2" + } + ], + "variables": { + "var1": "value1", + "var2": "value2" + }, + "lock": true, + "commands": { + "build": [ + "default build command" + ], + "run": [ + "default run command" + ] + }, + "user_params": { + "build": [], + "run": [] + } + }, + { + "name": "application_5A", + "description": "This is application 5A", + "supported_systems": [ + { + "name": "System 1", + "build_dir": "build_5A", + "variables": { + "var1": "new value1" + } + }, + { + "name": "System 2", + "variables": { + "var2": "new value2" + }, + "lock": true, + "commands": { + "run": [ + "run command on system 2" + ] + } + } + ], + "variables": { + "var1": "value1", + "var2": "value2" + }, + "build_dir": "build", + "commands": { + "build": [ + "default build command" + ], + "run": [ + "default run command" + ] + }, + "user_params": { + "build": [], + "run": [] + } + }, + { + "name": "application_5B", + "description": "This is application 5B", + "supported_systems": [ + { + "name": "System 1", + "build_dir": "build_5B", + "variables": { + "var1": "value for var1 System1", + "var2": "value for var2 System1" + }, + "user_params": { + "build": [ + { + "name": "--param_5B", + "description": "Sample command param", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value1", + "alias": "param1" + } + ] + } + }, + { + "name": "System 2", + "variables": { + "var1": "value for var1 System2", + "var2": "value for var2 System2" + }, + "commands": { + "build": [ + "build command on system 2 with {variables:var1} {user_params:param1}" + ], + "run": [ + "run command on system 2" + ] + }, + "user_params": { + "run": [] + } + } + ], + "build_dir": "build", + "commands": { + "build": [ + "default build command with {variables:var1}" + ], + "run": [ + "default run command with {variables:var2}" + ] + }, + "user_params": { + "build": [ + { + "name": "--param", + "description": "Sample command param", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value1", + "alias": "param1" + } + ], + "run": [], + "non_used_command": [ + { + "name": "--not-used", + "description": "Not used param anywhere", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value1", + "alias": "param1" + } + ] + } + } +] diff --git a/tests/mlia/test_resources/backends/applications/application5/aiet-config.json.license b/tests/mlia/test_resources/backends/applications/application5/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application5/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/backends/applications/application6/aiet-config.json b/tests/mlia/test_resources/backends/applications/application6/aiet-config.json new file mode 100644 index 0000000..56ad807 --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application6/aiet-config.json @@ -0,0 +1,42 @@ +[ + { + "name": "application_6", + "description": "This is application 6", + "supported_systems": [ + { + "name": "System 6" + } + ], + "build_dir": "build", + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run {user_params:param1}'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [ + { + "name": "--param1", + "description": "Test parameter", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value3", + "alias": "param1" + } + ] + } + } +] diff --git a/tests/mlia/test_resources/backends/applications/application6/aiet-config.json.license b/tests/mlia/test_resources/backends/applications/application6/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/application6/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/backends/applications/readme.txt b/tests/mlia/test_resources/backends/applications/readme.txt new file mode 100644 index 0000000..a1f8209 --- /dev/null +++ b/tests/mlia/test_resources/backends/applications/readme.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +Dummy file for test purposes diff --git a/tests/mlia/test_resources/backends/systems/system1/aiet-config.json b/tests/mlia/test_resources/backends/systems/system1/aiet-config.json new file mode 100644 index 0000000..4b5dd19 --- /dev/null +++ b/tests/mlia/test_resources/backends/systems/system1/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "System 1", + "description": "This is system 1", + "build_dir": "build", + "data_transfer": { + "protocol": "ssh", + "username": "root", + "password": "root", + "hostname": "localhost", + "port": "8021" + }, + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ], + "deploy": [ + "echo 'deploy'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/backends/systems/system1/aiet-config.json.license b/tests/mlia/test_resources/backends/systems/system1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/backends/systems/system1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/backends/systems/system1/system_artifact/dummy.txt b/tests/mlia/test_resources/backends/systems/system1/system_artifact/dummy.txt new file mode 100644 index 0000000..487e9d8 --- /dev/null +++ b/tests/mlia/test_resources/backends/systems/system1/system_artifact/dummy.txt @@ -0,0 +1,2 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/backends/systems/system2/aiet-config.json b/tests/mlia/test_resources/backends/systems/system2/aiet-config.json new file mode 100644 index 0000000..a9e0eb3 --- /dev/null +++ b/tests/mlia/test_resources/backends/systems/system2/aiet-config.json @@ -0,0 +1,32 @@ +[ + { + "name": "System 2", + "description": "This is system 2", + "build_dir": "build", + "data_transfer": { + "protocol": "ssh", + "username": "root", + "password": "root", + "hostname": "localhost", + "port": "8021" + }, + "commands": { + "clean": [ + "echo 'clean'" + ], + "build": [ + "echo 'build'" + ], + "run": [ + "echo 'run'" + ], + "post_run": [ + "echo 'post_run'" + ] + }, + "user_params": { + "build": [], + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/backends/systems/system2/aiet-config.json.license b/tests/mlia/test_resources/backends/systems/system2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/backends/systems/system2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/backends/systems/system3/readme.txt b/tests/mlia/test_resources/backends/systems/system3/readme.txt new file mode 100644 index 0000000..aba5a9c --- /dev/null +++ b/tests/mlia/test_resources/backends/systems/system3/readme.txt @@ -0,0 +1,4 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +SPDX-License-Identifier: Apache-2.0 + +This system does not have the json configuration file diff --git a/tests/mlia/test_resources/backends/systems/system4/aiet-config.json b/tests/mlia/test_resources/backends/systems/system4/aiet-config.json new file mode 100644 index 0000000..7b13160 --- /dev/null +++ b/tests/mlia/test_resources/backends/systems/system4/aiet-config.json @@ -0,0 +1,19 @@ +[ + { + "name": "System 4", + "description": "This is system 4", + "build_dir": "build", + "data_transfer": { + "protocol": "local" + }, + "commands": { + "run": [ + "echo {application.name}", + "{application.commands.run:0}" + ] + }, + "user_params": { + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/backends/systems/system4/aiet-config.json.license b/tests/mlia/test_resources/backends/systems/system4/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/backends/systems/system4/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/backends/systems/system6/aiet-config.json b/tests/mlia/test_resources/backends/systems/system6/aiet-config.json new file mode 100644 index 0000000..4242f64 --- /dev/null +++ b/tests/mlia/test_resources/backends/systems/system6/aiet-config.json @@ -0,0 +1,34 @@ +[ + { + "name": "System 6", + "description": "This is system 6", + "build_dir": "build", + "data_transfer": { + "protocol": "local" + }, + "variables": { + "var1": "{user_params:sys-param1}" + }, + "commands": { + "run": [ + "echo {application.name}", + "{application.commands.run:0}" + ] + }, + "user_params": { + "run": [ + { + "name": "--sys-param1", + "description": "Test parameter", + "values": [ + "value1", + "value2", + "value3" + ], + "default_value": "value1", + "alias": "sys-param1" + } + ] + } + } +] diff --git a/tests/mlia/test_resources/backends/systems/system6/aiet-config.json.license b/tests/mlia/test_resources/backends/systems/system6/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/backends/systems/system6/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/hello_world.json b/tests/mlia/test_resources/hello_world.json new file mode 100644 index 0000000..8a9a448 --- /dev/null +++ b/tests/mlia/test_resources/hello_world.json @@ -0,0 +1,54 @@ +[ + { + "name": "Hello world", + "description": "Dummy application that displays 'Hello world!'", + "supported_systems": [ + "Dummy System" + ], + "build_dir": "build", + "deploy_data": [ + [ + "src", + "/tmp/" + ], + [ + "README", + "/tmp/README.md" + ] + ], + "commands": { + "clean": [], + "build": [], + "run": [ + "echo 'Hello world!'", + "ls -l /tmp" + ], + "post_run": [] + }, + "user_params": { + "run": [ + { + "name": "--choice-param", + "values": [ + "dummy_value_1", + "dummy_value_2" + ], + "default_value": "dummy_value_1", + "description": "Choice param" + }, + { + "name": "--open-param", + "values": [], + "default_value": "dummy_value_4", + "description": "Open param" + }, + { + "name": "--enable-flag", + "default_value": "dummy_value_4", + "description": "Flag param" + } + ], + "build": [] + } + } +] diff --git a/tests/mlia/test_resources/hello_world.json.license b/tests/mlia/test_resources/hello_world.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/hello_world.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/scripts/test_backend_run b/tests/mlia/test_resources/scripts/test_backend_run new file mode 100755 index 0000000..548f577 --- /dev/null +++ b/tests/mlia/test_resources/scripts/test_backend_run @@ -0,0 +1,8 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +echo "Hello from script" +>&2 echo "Oops!" +sleep 100 diff --git a/tests/mlia/test_resources/scripts/test_backend_run_script.sh b/tests/mlia/test_resources/scripts/test_backend_run_script.sh new file mode 100644 index 0000000..548f577 --- /dev/null +++ b/tests/mlia/test_resources/scripts/test_backend_run_script.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 + +echo "Hello from script" +>&2 echo "Oops!" +sleep 100 diff --git a/tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json b/tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json @@ -0,0 +1 @@ +[] diff --git a/tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json.license b/tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/various/applications/application_with_empty_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json b/tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json new file mode 100644 index 0000000..ff1cf1a --- /dev/null +++ b/tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "test_application", + "description": "This is test_application", + "build_dir": "build", + "supported_systems": [ + { + "name": "System 4" + } + ], + "commands": { + "build": [ + "cp ../hello_app.txt ." + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json.license b/tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/various/applications/application_with_valid_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json b/tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json new file mode 100644 index 0000000..724b31b --- /dev/null +++ b/tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json @@ -0,0 +1,2 @@ +This is not valid json file +{ diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license b/tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json b/tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json new file mode 100644 index 0000000..1ebb29c --- /dev/null +++ b/tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json @@ -0,0 +1,30 @@ +[ + { + "name": "test_application", + "description": "This is test_application", + "build_dir": "build", + "commands": { + "build": [ + "cp ../hello_app.txt ." + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license b/tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json b/tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json new file mode 100644 index 0000000..410d12d --- /dev/null +++ b/tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json @@ -0,0 +1,35 @@ +[ + { + "name": "test_application", + "description": "This is test_application", + "build_dir": "build", + "supported_systems": [ + { + "anme": "System 4" + } + ], + "commands": { + "build": [ + "cp ../hello_app.txt ." + ], + "run": [ + "{application.build_dir}/hello_app.txt" + ] + }, + "user_params": { + "build": [ + { + "name": "--app", + "description": "Sample command param", + "values": [ + "application1", + "application2", + "application3" + ], + "default_value": "application1" + } + ], + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license b/tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json b/tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json @@ -0,0 +1 @@ +[] diff --git a/tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json.license b/tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/various/systems/system_with_empty_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json b/tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json new file mode 100644 index 0000000..20142e9 --- /dev/null +++ b/tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json @@ -0,0 +1,16 @@ +[ + { + "name": "Test system", + "description": "This is a test system", + "build_dir": "build", + "data_transfer": { + "protocol": "local" + }, + "commands": { + "run": [] + }, + "user_params": { + "run": [] + } + } +] diff --git a/tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json.license b/tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json.license new file mode 100644 index 0000000..9b83bfc --- /dev/null +++ b/tests/mlia/test_resources/various/systems/system_with_valid_config/aiet-config.json.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. + +SPDX-License-Identifier: Apache-2.0 diff --git a/tests/mlia/test_tools_aiet_wrapper.py b/tests/mlia/test_tools_aiet_wrapper.py deleted file mode 100644 index ab55b71..0000000 --- a/tests/mlia/test_tools_aiet_wrapper.py +++ /dev/null @@ -1,760 +0,0 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. -# SPDX-License-Identifier: Apache-2.0 -"""Tests for module tools/aiet_wrapper.""" -from contextlib import ExitStack as does_not_raise -from pathlib import Path -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from unittest.mock import MagicMock -from unittest.mock import PropertyMock - -import pytest - -from mlia.tools.aiet_wrapper import AIETRunner -from mlia.tools.aiet_wrapper import DeviceInfo -from mlia.tools.aiet_wrapper import estimate_performance -from mlia.tools.aiet_wrapper import ExecutionParams -from mlia.tools.aiet_wrapper import GenericInferenceOutputParser -from mlia.tools.aiet_wrapper import GenericInferenceRunnerEthosU -from mlia.tools.aiet_wrapper import get_aiet_runner -from mlia.tools.aiet_wrapper import get_generic_runner -from mlia.tools.aiet_wrapper import get_system_name -from mlia.tools.aiet_wrapper import is_supported -from mlia.tools.aiet_wrapper import ModelInfo -from mlia.tools.aiet_wrapper import PerformanceMetrics -from mlia.tools.aiet_wrapper import supported_backends -from mlia.utils.proc import RunningCommand - - -@pytest.mark.parametrize( - "data, is_ready, result, missed_keys", - [ - ( - [], - False, - {}, - [ - "npu_active_cycles", - "npu_axi0_rd_data_beat_received", - "npu_axi0_wr_data_beat_written", - "npu_axi1_rd_data_beat_received", - "npu_idle_cycles", - "npu_total_cycles", - ], - ), - ( - ["sample text"], - False, - {}, - [ - "npu_active_cycles", - "npu_axi0_rd_data_beat_received", - "npu_axi0_wr_data_beat_written", - "npu_axi1_rd_data_beat_received", - "npu_idle_cycles", - "npu_total_cycles", - ], - ), - ( - [ - ["NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 123"], - False, - {"npu_axi0_rd_data_beat_received": 123}, - [ - "npu_active_cycles", - "npu_axi0_wr_data_beat_written", - "npu_axi1_rd_data_beat_received", - "npu_idle_cycles", - "npu_total_cycles", - ], - ] - ), - ( - [ - "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", - "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", - "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", - "NPU ACTIVE cycles: 4", - "NPU IDLE cycles: 5", - "NPU TOTAL cycles: 6", - ], - True, - { - "npu_axi0_rd_data_beat_received": 1, - "npu_axi0_wr_data_beat_written": 2, - "npu_axi1_rd_data_beat_received": 3, - "npu_active_cycles": 4, - "npu_idle_cycles": 5, - "npu_total_cycles": 6, - }, - [], - ), - ], -) -def test_generic_inference_output_parser( - data: List[str], is_ready: bool, result: Dict, missed_keys: List[str] -) -> None: - """Test generic runner output parser.""" - parser = GenericInferenceOutputParser() - - for line in data: - parser.feed(line) - - assert parser.is_ready() == is_ready - assert parser.result == result - assert parser.missed_keys() == missed_keys - - -class TestAIETRunner: - """Tests for AIETRunner class.""" - - @staticmethod - def _setup_aiet( - monkeypatch: pytest.MonkeyPatch, - available_systems: Optional[List[str]] = None, - available_apps: Optional[List[str]] = None, - ) -> None: - """Set up AIET metadata.""" - - def mock_system(system: str) -> MagicMock: - """Mock the System instance.""" - mock = MagicMock() - type(mock).name = PropertyMock(return_value=system) - return mock - - def mock_app(app: str) -> MagicMock: - """Mock the Application instance.""" - mock = MagicMock() - type(mock).name = PropertyMock(return_value=app) - mock.can_run_on.return_value = True - return mock - - system_mocks = [mock_system(name) for name in (available_systems or [])] - monkeypatch.setattr( - "mlia.tools.aiet_wrapper.get_available_systems", - MagicMock(return_value=system_mocks), - ) - - apps_mock = [mock_app(name) for name in (available_apps or [])] - monkeypatch.setattr( - "mlia.tools.aiet_wrapper.get_available_applications", - MagicMock(return_value=apps_mock), - ) - - @pytest.mark.parametrize( - "available_systems, system, installed", - [ - ([], "system1", False), - (["system1", "system2"], "system1", True), - ], - ) - def test_is_system_installed( - self, - available_systems: List, - system: str, - installed: bool, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Test method is_system_installed.""" - mock_executor = MagicMock() - aiet_runner = AIETRunner(mock_executor) - - self._setup_aiet(monkeypatch, available_systems) - - assert aiet_runner.is_system_installed(system) == installed - mock_executor.assert_not_called() - - @pytest.mark.parametrize( - "available_systems, systems", - [ - ([], []), - (["system1"], ["system1"]), - ], - ) - def test_installed_systems( - self, - available_systems: List[str], - systems: List[str], - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Test method installed_systems.""" - mock_executor = MagicMock() - aiet_runner = AIETRunner(mock_executor) - - self._setup_aiet(monkeypatch, available_systems) - assert aiet_runner.get_installed_systems() == systems - - mock_executor.assert_not_called() - - @staticmethod - def test_install_system(monkeypatch: pytest.MonkeyPatch) -> None: - """Test system installation.""" - install_system_mock = MagicMock() - monkeypatch.setattr( - "mlia.tools.aiet_wrapper.install_system", install_system_mock - ) - - mock_executor = MagicMock() - aiet_runner = AIETRunner(mock_executor) - aiet_runner.install_system(Path("test_system_path")) - - install_system_mock.assert_called_once_with(Path("test_system_path")) - mock_executor.assert_not_called() - - @pytest.mark.parametrize( - "available_systems, systems, expected_result", - [ - ([], [], False), - (["system1"], [], False), - (["system1"], ["system1"], True), - (["system1", "system2"], ["system1", "system3"], False), - (["system1", "system2"], ["system1", "system2"], True), - ], - ) - def test_systems_installed( - self, - available_systems: List[str], - systems: List[str], - expected_result: bool, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Test method systems_installed.""" - self._setup_aiet(monkeypatch, available_systems) - - mock_executor = MagicMock() - aiet_runner = AIETRunner(mock_executor) - - assert aiet_runner.systems_installed(systems) is expected_result - - mock_executor.assert_not_called() - - @pytest.mark.parametrize( - "available_apps, applications, expected_result", - [ - ([], [], False), - (["app1"], [], False), - (["app1"], ["app1"], True), - (["app1", "app2"], ["app1", "app3"], False), - (["app1", "app2"], ["app1", "app2"], True), - ], - ) - def test_applications_installed( - self, - available_apps: List[str], - applications: List[str], - expected_result: bool, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Test method applications_installed.""" - self._setup_aiet(monkeypatch, [], available_apps) - mock_executor = MagicMock() - aiet_runner = AIETRunner(mock_executor) - - assert aiet_runner.applications_installed(applications) is expected_result - mock_executor.assert_not_called() - - @pytest.mark.parametrize( - "available_apps, applications", - [ - ([], []), - ( - ["application1", "application2"], - ["application1", "application2"], - ), - ], - ) - def test_get_installed_applications( - self, - available_apps: List[str], - applications: List[str], - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Test method get_installed_applications.""" - mock_executor = MagicMock() - self._setup_aiet(monkeypatch, [], available_apps) - - aiet_runner = AIETRunner(mock_executor) - assert applications == aiet_runner.get_installed_applications() - - mock_executor.assert_not_called() - - @staticmethod - def test_install_application(monkeypatch: pytest.MonkeyPatch) -> None: - """Test application installation.""" - mock_install_application = MagicMock() - monkeypatch.setattr( - "mlia.tools.aiet_wrapper.install_application", mock_install_application - ) - - mock_executor = MagicMock() - - aiet_runner = AIETRunner(mock_executor) - aiet_runner.install_application(Path("test_application_path")) - mock_install_application.assert_called_once_with(Path("test_application_path")) - - mock_executor.assert_not_called() - - @pytest.mark.parametrize( - "available_apps, application, installed", - [ - ([], "system1", False), - ( - ["application1", "application2"], - "application1", - True, - ), - ( - [], - "application1", - False, - ), - ], - ) - def test_is_application_installed( - self, - available_apps: List[str], - application: str, - installed: bool, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Test method is_application_installed.""" - self._setup_aiet(monkeypatch, [], available_apps) - - mock_executor = MagicMock() - aiet_runner = AIETRunner(mock_executor) - assert installed == aiet_runner.is_application_installed(application, "system1") - - mock_executor.assert_not_called() - - @staticmethod - @pytest.mark.parametrize( - "execution_params, expected_command", - [ - ( - ExecutionParams("application1", "system1", [], [], []), - ["aiet", "application", "run", "-n", "application1", "-s", "system1"], - ), - ( - ExecutionParams( - "application1", - "system1", - ["input_file=123.txt", "size=777"], - ["param1=456", "param2=789"], - ["source1.txt:dest1.txt", "source2.txt:dest2.txt"], - ), - [ - "aiet", - "application", - "run", - "-n", - "application1", - "-s", - "system1", - "-p", - "input_file=123.txt", - "-p", - "size=777", - "--system-param", - "param1=456", - "--system-param", - "param2=789", - "--deploy", - "source1.txt:dest1.txt", - "--deploy", - "source2.txt:dest2.txt", - ], - ), - ], - ) - def test_run_application( - execution_params: ExecutionParams, expected_command: List[str] - ) -> None: - """Test method run_application.""" - mock_executor = MagicMock() - mock_running_command = MagicMock() - mock_executor.submit.return_value = mock_running_command - - aiet_runner = AIETRunner(mock_executor) - aiet_runner.run_application(execution_params) - - mock_executor.submit.assert_called_once_with(expected_command) - - -@pytest.mark.parametrize( - "device, system, application, backend, expected_error", - [ - ( - DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), - ("Corstone-300: Cortex-M55+Ethos-U55", True), - ("Generic Inference Runner: Ethos-U55 SRAM", True), - "Corstone-300", - does_not_raise(), - ), - ( - DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), - ("Corstone-300: Cortex-M55+Ethos-U55", False), - ("Generic Inference Runner: Ethos-U55 SRAM", False), - "Corstone-300", - pytest.raises( - Exception, - match=r"System Corstone-300: Cortex-M55\+Ethos-U55 is not installed", - ), - ), - ( - DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), - ("Corstone-300: Cortex-M55+Ethos-U55", True), - ("Generic Inference Runner: Ethos-U55 SRAM", False), - "Corstone-300", - pytest.raises( - Exception, - match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM " - r"for the system Corstone-300: Cortex-M55\+Ethos-U55 is not installed", - ), - ), - ( - DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), - ("Corstone-310: Cortex-M85+Ethos-U55", True), - ("Generic Inference Runner: Ethos-U55 SRAM", True), - "Corstone-310", - does_not_raise(), - ), - ( - DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), - ("Corstone-310: Cortex-M85+Ethos-U55", False), - ("Generic Inference Runner: Ethos-U55 SRAM", False), - "Corstone-310", - pytest.raises( - Exception, - match=r"System Corstone-310: Cortex-M85\+Ethos-U55 is not installed", - ), - ), - ( - DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"), - ("Corstone-310: Cortex-M85+Ethos-U55", True), - ("Generic Inference Runner: Ethos-U55 SRAM", False), - "Corstone-310", - pytest.raises( - Exception, - match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM " - r"for the system Corstone-310: Cortex-M85\+Ethos-U55 is not installed", - ), - ), - ( - DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"), - ("Corstone-300: Cortex-M55+Ethos-U65", True), - ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", True), - "Corstone-300", - does_not_raise(), - ), - ( - DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"), - ("Corstone-300: Cortex-M55+Ethos-U65", False), - ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False), - "Corstone-300", - pytest.raises( - Exception, - match=r"System Corstone-300: Cortex-M55\+Ethos-U65 is not installed", - ), - ), - ( - DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"), - ("Corstone-300: Cortex-M55+Ethos-U65", True), - ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False), - "Corstone-300", - pytest.raises( - Exception, - match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM " - r"for the system Corstone-300: Cortex-M55\+Ethos-U65 is not installed", - ), - ), - ( - DeviceInfo( - device_type="unknown_device", # type: ignore - mac=None, # type: ignore - memory_mode="Shared_Sram", - ), - ("some_system", False), - ("some_application", False), - "some backend", - pytest.raises(Exception, match="Unsupported device unknown_device"), - ), - ], -) -def test_estimate_performance( - device: DeviceInfo, - system: Tuple[str, bool], - application: Tuple[str, bool], - backend: str, - expected_error: Any, - test_tflite_model: Path, - aiet_runner: MagicMock, -) -> None: - """Test getting performance estimations.""" - system_name, system_installed = system - application_name, application_installed = application - - aiet_runner.is_system_installed.return_value = system_installed - aiet_runner.is_application_installed.return_value = application_installed - - mock_process = create_mock_process( - [ - "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", - "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", - "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", - "NPU ACTIVE cycles: 4", - "NPU IDLE cycles: 5", - "NPU TOTAL cycles: 6", - ], - [], - ) - - mock_generic_inference_run = RunningCommand(mock_process) - aiet_runner.run_application.return_value = mock_generic_inference_run - - with expected_error: - perf_metrics = estimate_performance( - ModelInfo(test_tflite_model), device, backend - ) - - assert isinstance(perf_metrics, PerformanceMetrics) - assert perf_metrics == PerformanceMetrics( - npu_axi0_rd_data_beat_received=1, - npu_axi0_wr_data_beat_written=2, - npu_axi1_rd_data_beat_received=3, - npu_active_cycles=4, - npu_idle_cycles=5, - npu_total_cycles=6, - ) - - assert aiet_runner.is_system_installed.called_once_with(system_name) - assert aiet_runner.is_application_installed.called_once_with( - application_name, system_name - ) - - -@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) -def test_estimate_performance_insufficient_data( - aiet_runner: MagicMock, test_tflite_model: Path, backend: str -) -> None: - """Test that performance could not be estimated when not all data presented.""" - aiet_runner.is_system_installed.return_value = True - aiet_runner.is_application_installed.return_value = True - - no_total_cycles_output = [ - "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1", - "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2", - "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3", - "NPU ACTIVE cycles: 4", - "NPU IDLE cycles: 5", - ] - mock_process = create_mock_process( - no_total_cycles_output, - [], - ) - - mock_generic_inference_run = RunningCommand(mock_process) - aiet_runner.run_application.return_value = mock_generic_inference_run - - with pytest.raises( - Exception, match="Unable to get performance metrics, insufficient data" - ): - device = DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram") - estimate_performance(ModelInfo(test_tflite_model), device, backend) - - -@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) -def test_estimate_performance_invalid_output( - test_tflite_model: Path, aiet_runner: MagicMock, backend: str -) -> None: - """Test estimation could not be done if inference produces unexpected output.""" - aiet_runner.is_system_installed.return_value = True - aiet_runner.is_application_installed.return_value = True - - mock_process = create_mock_process( - ["Something", "is", "wrong"], ["What a nice error!"] - ) - aiet_runner.run_application.return_value = RunningCommand(mock_process) - - with pytest.raises(Exception, match="Unable to get performance metrics"): - estimate_performance( - ModelInfo(test_tflite_model), - DeviceInfo(device_type="ethos-u55", mac=256, memory_mode="Shared_Sram"), - backend=backend, - ) - - -def test_get_aiet_runner() -> None: - """Test getting aiet runner.""" - aiet_runner = get_aiet_runner() - assert isinstance(aiet_runner, AIETRunner) - - -def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock: - """Mock underlying process.""" - mock_process = MagicMock() - mock_process.poll.return_value = 0 - type(mock_process).stdout = PropertyMock(return_value=iter(stdout)) - type(mock_process).stderr = PropertyMock(return_value=iter(stderr)) - return mock_process - - -@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) -def test_get_generic_runner(backend: str) -> None: - """Test function get_generic_runner().""" - device_info = DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram") - - runner = get_generic_runner(device_info=device_info, backend=backend) - assert isinstance(runner, GenericInferenceRunnerEthosU) - - with pytest.raises(RuntimeError): - get_generic_runner(device_info=device_info, backend="UNKNOWN_BACKEND") - - -@pytest.mark.parametrize( - ("backend", "device_type"), - ( - ("Corstone-300", "ethos-u55"), - ("Corstone-300", "ethos-u65"), - ("Corstone-310", "ethos-u55"), - ), -) -def test_aiet_backend_support(backend: str, device_type: str) -> None: - """Test AIET backend & device support.""" - assert is_supported(backend) - assert is_supported(backend, device_type) - - assert get_system_name(backend, device_type) - - assert backend in supported_backends() - - -class TestGenericInferenceRunnerEthosU: - """Test for the class GenericInferenceRunnerEthosU.""" - - @staticmethod - @pytest.mark.parametrize( - "device, backend, expected_system, expected_app", - [ - [ - DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), - "Corstone-300", - "Corstone-300: Cortex-M55+Ethos-U55", - "Generic Inference Runner: Ethos-U55/65 Shared SRAM", - ], - [ - DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), - "Corstone-310", - "Corstone-310: Cortex-M85+Ethos-U55", - "Generic Inference Runner: Ethos-U55/65 Shared SRAM", - ], - [ - DeviceInfo("ethos-u55", 256, memory_mode="Sram"), - "Corstone-310", - "Corstone-310: Cortex-M85+Ethos-U55", - "Generic Inference Runner: Ethos-U55 SRAM", - ], - [ - DeviceInfo("ethos-u55", 256, memory_mode="Sram"), - "Corstone-300", - "Corstone-300: Cortex-M55+Ethos-U55", - "Generic Inference Runner: Ethos-U55 SRAM", - ], - [ - DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"), - "Corstone-300", - "Corstone-300: Cortex-M55+Ethos-U65", - "Generic Inference Runner: Ethos-U55/65 Shared SRAM", - ], - [ - DeviceInfo("ethos-u65", 256, memory_mode="Dedicated_Sram"), - "Corstone-300", - "Corstone-300: Cortex-M55+Ethos-U65", - "Generic Inference Runner: Ethos-U65 Dedicated SRAM", - ], - ], - ) - def test_artifact_resolver( - device: DeviceInfo, backend: str, expected_system: str, expected_app: str - ) -> None: - """Test artifact resolving based on the provided parameters.""" - generic_runner = get_generic_runner(device, backend) - assert isinstance(generic_runner, GenericInferenceRunnerEthosU) - - assert generic_runner.system_name == expected_system - assert generic_runner.app_name == expected_app - - @staticmethod - def test_artifact_resolver_unsupported_backend() -> None: - """Test that it should be not possible to use unsupported backends.""" - with pytest.raises( - RuntimeError, match="Unsupported device ethos-u65 for backend test_backend" - ): - get_generic_runner( - DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"), "test_backend" - ) - - @staticmethod - def test_artifact_resolver_unsupported_memory_mode() -> None: - """Test that it should be not possible to use unsupported memory modes.""" - with pytest.raises( - RuntimeError, match="Unsupported memory mode test_memory_mode" - ): - get_generic_runner( - DeviceInfo( - "ethos-u65", - 256, - memory_mode="test_memory_mode", # type: ignore - ), - "Corstone-300", - ) - - @staticmethod - @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) - def test_inference_should_fail_if_system_not_installed( - aiet_runner: MagicMock, test_tflite_model: Path, backend: str - ) -> None: - """Test that inference should fail if system is not installed.""" - aiet_runner.is_system_installed.return_value = False - - generic_runner = get_generic_runner( - DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend - ) - with pytest.raises( - Exception, - match=r"System Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not installed", - ): - generic_runner.run(ModelInfo(test_tflite_model), []) - - @staticmethod - @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310")) - def test_inference_should_fail_is_apps_not_installed( - aiet_runner: MagicMock, test_tflite_model: Path, backend: str - ) -> None: - """Test that inference should fail if apps are not installed.""" - aiet_runner.is_system_installed.return_value = True - aiet_runner.is_application_installed.return_value = False - - generic_runner = get_generic_runner( - DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend - ) - with pytest.raises( - Exception, - match="Application Generic Inference Runner: Ethos-U55/65 Shared SRAM" - r" for the system Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not " - r"installed", - ): - generic_runner.run(ModelInfo(test_tflite_model), []) - - -@pytest.fixture(name="aiet_runner") -def fixture_aiet_runner(monkeypatch: pytest.MonkeyPatch) -> MagicMock: - """Mock AIET runner.""" - aiet_runner_mock = MagicMock(spec=AIETRunner) - monkeypatch.setattr( - "mlia.tools.aiet_wrapper.get_aiet_runner", - MagicMock(return_value=aiet_runner_mock), - ) - return aiet_runner_mock diff --git a/tests/mlia/test_tools_metadata_corstone.py b/tests/mlia/test_tools_metadata_corstone.py index 2ce3610..017d0c7 100644 --- a/tests/mlia/test_tools_metadata_corstone.py +++ b/tests/mlia/test_tools_metadata_corstone.py @@ -9,13 +9,13 @@ from unittest.mock import MagicMock import pytest -from mlia.tools.aiet_wrapper import AIETRunner +from mlia.backend.manager import BackendRunner from mlia.tools.metadata.common import DownloadAndInstall from mlia.tools.metadata.common import InstallFromPath -from mlia.tools.metadata.corstone import AIETBasedInstallation -from mlia.tools.metadata.corstone import AIETMetadata from mlia.tools.metadata.corstone import BackendInfo +from mlia.tools.metadata.corstone import BackendInstallation from mlia.tools.metadata.corstone import BackendInstaller +from mlia.tools.metadata.corstone import BackendMetadata from mlia.tools.metadata.corstone import CompoundPathChecker from mlia.tools.metadata.corstone import Corstone300Installer from mlia.tools.metadata.corstone import get_corstone_installations @@ -40,8 +40,8 @@ def fixture_test_mlia_resources( return mlia_resources -def get_aiet_based_installation( # pylint: disable=too-many-arguments - aiet_runner_mock: MagicMock = MagicMock(), +def get_backend_installation( # pylint: disable=too-many-arguments + backend_runner_mock: MagicMock = MagicMock(), name: str = "test_name", description: str = "test_description", download_artifact: Optional[MagicMock] = None, @@ -50,11 +50,11 @@ def get_aiet_based_installation( # pylint: disable=too-many-arguments system_config: Optional[str] = None, backend_installer: BackendInstaller = MagicMock(), supported_platforms: Optional[List[str]] = None, -) -> AIETBasedInstallation: - """Get AIET based installation.""" - return AIETBasedInstallation( - aiet_runner=aiet_runner_mock, - metadata=AIETMetadata( +) -> BackendInstallation: + """Get backend installation.""" + return BackendInstallation( + backend_runner=backend_runner_mock, + metadata=BackendMetadata( name=name, description=description, system_config=system_config or "", @@ -90,10 +90,10 @@ def test_could_be_installed_depends_on_platform( monkeypatch.setattr( "mlia.tools.metadata.corstone.all_paths_valid", MagicMock(return_value=True) ) - aiet_runner_mock = MagicMock(spec=AIETRunner) + backend_runner_mock = MagicMock(spec=BackendRunner) - installation = get_aiet_based_installation( - aiet_runner_mock, + installation = get_backend_installation( + backend_runner_mock, supported_platforms=supported_platforms, ) assert installation.could_be_installed == expected_result @@ -103,53 +103,53 @@ def test_get_corstone_installations() -> None: """Test function get_corstone_installation.""" installs = get_corstone_installations() assert len(installs) == 2 - assert all(isinstance(install, AIETBasedInstallation) for install in installs) + assert all(isinstance(install, BackendInstallation) for install in installs) -def test_aiet_based_installation_metadata_resolving() -> None: - """Test AIET based installation metadata resolving.""" - aiet_runner_mock = MagicMock(spec=AIETRunner) - installation = get_aiet_based_installation(aiet_runner_mock) +def test_backend_installation_metadata_resolving() -> None: + """Test backend installation metadata resolving.""" + backend_runner_mock = MagicMock(spec=BackendRunner) + installation = get_backend_installation(backend_runner_mock) assert installation.name == "test_name" assert installation.description == "test_description" - aiet_runner_mock.all_installed.return_value = False + backend_runner_mock.all_installed.return_value = False assert installation.already_installed is False assert installation.could_be_installed is True -def test_aiet_based_installation_supported_install_types(tmp_path: Path) -> None: +def test_backend_installation_supported_install_types(tmp_path: Path) -> None: """Test supported installation types.""" - installation_no_download_artifact = get_aiet_based_installation() + installation_no_download_artifact = get_backend_installation() assert installation_no_download_artifact.supports(DownloadAndInstall()) is False - installation_with_download_artifact = get_aiet_based_installation( + installation_with_download_artifact = get_backend_installation( download_artifact=MagicMock() ) assert installation_with_download_artifact.supports(DownloadAndInstall()) is True path_checker_mock = MagicMock(return_value=BackendInfo(tmp_path)) - installation_can_install_from_dir = get_aiet_based_installation( + installation_can_install_from_dir = get_backend_installation( path_checker=path_checker_mock ) assert installation_can_install_from_dir.supports(InstallFromPath(tmp_path)) is True - any_installation = get_aiet_based_installation() + any_installation = get_backend_installation() assert any_installation.supports("unknown_install_type") is False # type: ignore -def test_aiet_based_installation_install_wrong_type() -> None: +def test_backend_installation_install_wrong_type() -> None: """Test that operation should fail if wrong install type provided.""" with pytest.raises(Exception, match="Unable to install wrong_install_type"): - aiet_runner_mock = MagicMock(spec=AIETRunner) - installation = get_aiet_based_installation(aiet_runner_mock) + backend_runner_mock = MagicMock(spec=BackendRunner) + installation = get_backend_installation(backend_runner_mock) installation.install("wrong_install_type") # type: ignore -def test_aiet_based_installation_install_from_path( +def test_backend_installation_install_from_path( tmp_path: Path, test_mlia_resources: Path ) -> None: """Test installation from the path.""" @@ -164,9 +164,9 @@ def test_aiet_based_installation_install_from_path( path_checker_mock = MagicMock(return_value=BackendInfo(dist_dir)) - aiet_runner_mock = MagicMock(spec=AIETRunner) - installation = get_aiet_based_installation( - aiet_runner_mock=aiet_runner_mock, + backend_runner_mock = MagicMock(spec=BackendRunner) + installation = get_backend_installation( + backend_runner_mock=backend_runner_mock, path_checker=path_checker_mock, apps_resources=[sample_app.name], system_config="example_config.json", @@ -175,12 +175,12 @@ def test_aiet_based_installation_install_from_path( assert installation.supports(InstallFromPath(dist_dir)) is True installation.install(InstallFromPath(dist_dir)) - aiet_runner_mock.install_system.assert_called_once() - aiet_runner_mock.install_application.assert_called_once_with(sample_app) + backend_runner_mock.install_system.assert_called_once() + backend_runner_mock.install_application.assert_called_once_with(sample_app) @pytest.mark.parametrize("copy_source", [True, False]) -def test_aiet_based_installation_install_from_static_path( +def test_backend_installation_install_from_static_path( tmp_path: Path, test_mlia_resources: Path, copy_source: bool ) -> None: """Test installation from the predefined path.""" @@ -204,7 +204,7 @@ def test_aiet_based_installation_install_from_static_path( nested_file = predefined_location_dir / "nested_file.txt" nested_file.touch() - aiet_runner_mock = MagicMock(spec=AIETRunner) + backend_runner_mock = MagicMock(spec=BackendRunner) def check_install_dir(install_dir: Path) -> None: """Check content of the install dir.""" @@ -220,10 +220,10 @@ def test_aiet_based_installation_install_from_static_path( assert install_dir / "custom_config.json" in files - aiet_runner_mock.install_system.side_effect = check_install_dir + backend_runner_mock.install_system.side_effect = check_install_dir - installation = get_aiet_based_installation( - aiet_runner_mock=aiet_runner_mock, + installation = get_backend_installation( + backend_runner_mock=backend_runner_mock, path_checker=StaticPathChecker( predefined_location, ["file.txt"], @@ -237,8 +237,8 @@ def test_aiet_based_installation_install_from_static_path( assert installation.supports(InstallFromPath(predefined_location)) is True installation.install(InstallFromPath(predefined_location)) - aiet_runner_mock.install_system.assert_called_once() - aiet_runner_mock.install_application.assert_called_once_with(sample_app) + backend_runner_mock.install_system.assert_called_once() + backend_runner_mock.install_application.assert_called_once_with(sample_app) def create_sample_fvp_archive(tmp_path: Path) -> Path: @@ -259,7 +259,7 @@ def create_sample_fvp_archive(tmp_path: Path) -> Path: return fvp_archive -def test_aiet_based_installation_download_and_install( +def test_backend_installation_download_and_install( test_mlia_resources: Path, tmp_path: Path ) -> None: """Test downloading and installation process.""" @@ -277,9 +277,9 @@ def test_aiet_based_installation_download_and_install( """Sample installer.""" return dist_dir - aiet_runner_mock = MagicMock(spec=AIETRunner) - installation = get_aiet_based_installation( - aiet_runner_mock, + backend_runner_mock = MagicMock(spec=BackendRunner) + installation = get_backend_installation( + backend_runner_mock, download_artifact=download_artifact_mock, backend_installer=installer, path_checker=path_checker, @@ -288,7 +288,7 @@ def test_aiet_based_installation_download_and_install( installation.install(DownloadAndInstall()) - aiet_runner_mock.install_system.assert_called_once() + backend_runner_mock.install_system.assert_called_once() @pytest.mark.parametrize( -- cgit v1.2.1