aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDiego Russo <diego.russo@arm.com>2022-05-30 13:34:14 +0100
committerDiego Russo <diego.russo@arm.com>2022-05-30 13:34:14 +0100
commit0efca3cadbad5517a59884576ddb90cfe7ac30f8 (patch)
treeabed6cb6fbf3c439fc8d947f505b6a53d5daeb1e
parent0777092695c143c3a54680b5748287d40c914c35 (diff)
downloadmlia-0efca3cadbad5517a59884576ddb90cfe7ac30f8.tar.gz
Add MLIA codebase0.3.0-rc.1
Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd
-rw-r--r--.gitignore15
-rw-r--r--LICENSES/Apache-2.0.txt177
-rw-r--r--LICENSES/BSD-3-Clause.txt26
-rw-r--r--LICENSES/CC-PDDC.txt28
-rw-r--r--LICENSES/MIT.txt21
-rw-r--r--MANIFEST.in5
-rw-r--r--README.md312
-rw-r--r--pyproject.toml78
-rw-r--r--setup.cfg79
-rw-r--r--setup.py7
-rw-r--r--src/aiet/__init__.py7
-rw-r--r--src/aiet/backend/__init__.py3
-rw-r--r--src/aiet/backend/application.py187
-rw-r--r--src/aiet/backend/common.py532
-rw-r--r--src/aiet/backend/config.py107
-rw-r--r--src/aiet/backend/controller.py134
-rw-r--r--src/aiet/backend/execution.py859
-rw-r--r--src/aiet/backend/output_parser.py176
-rw-r--r--src/aiet/backend/protocol.py325
-rw-r--r--src/aiet/backend/source.py209
-rw-r--r--src/aiet/backend/system.py289
-rw-r--r--src/aiet/backend/tool.py109
-rw-r--r--src/aiet/cli/__init__.py28
-rw-r--r--src/aiet/cli/application.py362
-rw-r--r--src/aiet/cli/common.py173
-rw-r--r--src/aiet/cli/completion.py72
-rw-r--r--src/aiet/cli/system.py122
-rw-r--r--src/aiet/cli/tool.py143
-rw-r--r--src/aiet/main.py13
-rw-r--r--src/aiet/resources/applications/.gitignore6
-rw-r--r--src/aiet/resources/systems/.gitignore6
-rw-r--r--src/aiet/resources/tools/vela/aiet-config.json73
-rw-r--r--src/aiet/resources/tools/vela/aiet-config.json.license3
-rw-r--r--src/aiet/resources/tools/vela/check_model.py75
-rw-r--r--src/aiet/resources/tools/vela/run_vela.py65
-rw-r--r--src/aiet/resources/tools/vela/vela.ini53
-rw-r--r--src/aiet/utils/__init__.py3
-rw-r--r--src/aiet/utils/fs.py116
-rw-r--r--src/aiet/utils/helpers.py17
-rw-r--r--src/aiet/utils/proc.py283
-rw-r--r--src/mlia/__init__.py22
-rw-r--r--src/mlia/api.py162
-rw-r--r--src/mlia/cli/__init__.py3
-rw-r--r--src/mlia/cli/commands.py276
-rw-r--r--src/mlia/cli/common.py38
-rw-r--r--src/mlia/cli/config.py64
-rw-r--r--src/mlia/cli/helpers.py116
-rw-r--r--src/mlia/cli/logging.py117
-rw-r--r--src/mlia/cli/main.py280
-rw-r--r--src/mlia/cli/options.py280
-rw-r--r--src/mlia/core/__init__.py21
-rw-r--r--src/mlia/core/_typing.py12
-rw-r--r--src/mlia/core/advice_generation.py106
-rw-r--r--src/mlia/core/advisor.py21
-rw-r--r--src/mlia/core/common.py47
-rw-r--r--src/mlia/core/context.py218
-rw-r--r--src/mlia/core/data_analysis.py70
-rw-r--r--src/mlia/core/data_collection.py37
-rw-r--r--src/mlia/core/errors.py18
-rw-r--r--src/mlia/core/events.py455
-rw-r--r--src/mlia/core/helpers.py38
-rw-r--r--src/mlia/core/mixins.py54
-rw-r--r--src/mlia/core/performance.py47
-rw-r--r--src/mlia/core/reporting.py762
-rw-r--r--src/mlia/core/workflow.py216
-rw-r--r--src/mlia/devices/__init__.py3
-rw-r--r--src/mlia/devices/config.py11
-rw-r--r--src/mlia/devices/ethosu/__init__.py3
-rw-r--r--src/mlia/devices/ethosu/advice_generation.py209
-rw-r--r--src/mlia/devices/ethosu/advisor.py151
-rw-r--r--src/mlia/devices/ethosu/config.py89
-rw-r--r--src/mlia/devices/ethosu/data_analysis.py154
-rw-r--r--src/mlia/devices/ethosu/data_collection.py188
-rw-r--r--src/mlia/devices/ethosu/events.py24
-rw-r--r--src/mlia/devices/ethosu/handlers.py146
-rw-r--r--src/mlia/devices/ethosu/operators.py14
-rw-r--r--src/mlia/devices/ethosu/performance.py257
-rw-r--r--src/mlia/devices/ethosu/reporters.py398
-rw-r--r--src/mlia/nn/__init__.py3
-rw-r--r--src/mlia/nn/tensorflow/__init__.py3
-rw-r--r--src/mlia/nn/tensorflow/config.py134
-rw-r--r--src/mlia/nn/tensorflow/optimizations/__init__.py3
-rw-r--r--src/mlia/nn/tensorflow/optimizations/clustering.py109
-rw-r--r--src/mlia/nn/tensorflow/optimizations/common.py29
-rw-r--r--src/mlia/nn/tensorflow/optimizations/pruning.py168
-rw-r--r--src/mlia/nn/tensorflow/optimizations/select.py179
-rw-r--r--src/mlia/nn/tensorflow/tflite_metrics.py296
-rw-r--r--src/mlia/nn/tensorflow/utils.py149
-rw-r--r--src/mlia/resources/aiet/applications/APPLICATIONS.txt6
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json18
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axfbin0 -> 426496 bytes
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license31
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json15
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axfbin0 -> 426544 bytes
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license31
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json15
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axfbin0 -> 2524028 bytes
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license31
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json15
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axfbin0 -> 426488 bytes
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license31
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json15
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axfbin0 -> 426536 bytes
-rw-r--r--src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license31
-rw-r--r--src/mlia/resources/aiet/systems/SYSTEMS.txt10
-rw-r--r--src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json80
-rw-r--r--src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/systems/corstone-300/aiet-config.json80
-rw-r--r--src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json42
-rw-r--r--src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license3
-rw-r--r--src/mlia/resources/aiet/systems/corstone-310/aiet-config.json42
-rw-r--r--src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license3
-rw-r--r--src/mlia/resources/profiles.json20
-rw-r--r--src/mlia/resources/profiles.json.license3
-rw-r--r--src/mlia/resources/vela/vela.ini75
-rw-r--r--src/mlia/tools/__init__.py3
-rw-r--r--src/mlia/tools/aiet_wrapper.py435
-rw-r--r--src/mlia/tools/metadata/__init__.py3
-rw-r--r--src/mlia/tools/metadata/common.py290
-rw-r--r--src/mlia/tools/metadata/corstone.py402
-rw-r--r--src/mlia/tools/vela_wrapper.py500
-rw-r--r--src/mlia/utils/__init__.py3
-rw-r--r--src/mlia/utils/console.py97
-rw-r--r--src/mlia/utils/download.py89
-rw-r--r--src/mlia/utils/filesystem.py124
-rw-r--r--src/mlia/utils/logging.py120
-rw-r--r--src/mlia/utils/misc.py9
-rw-r--r--src/mlia/utils/proc.py164
-rw-r--r--src/mlia/utils/types.py37
-rw-r--r--tests/__init__.py3
-rw-r--r--tests/aiet/__init__.py3
-rw-r--r--tests/aiet/conftest.py139
-rw-r--r--tests/aiet/test_backend_application.py452
-rw-r--r--tests/aiet/test_backend_common.py486
-rw-r--r--tests/aiet/test_backend_controller.py160
-rw-r--r--tests/aiet/test_backend_execution.py526
-rw-r--r--tests/aiet/test_backend_output_parser.py152
-rw-r--r--tests/aiet/test_backend_protocol.py231
-rw-r--r--tests/aiet/test_backend_source.py199
-rw-r--r--tests/aiet/test_backend_system.py536
-rw-r--r--tests/aiet/test_backend_tool.py60
-rw-r--r--tests/aiet/test_check_model.py162
-rw-r--r--tests/aiet/test_cli.py37
-rw-r--r--tests/aiet/test_cli_application.py1153
-rw-r--r--tests/aiet/test_cli_common.py37
-rw-r--r--tests/aiet/test_cli_system.py240
-rw-r--r--tests/aiet/test_cli_tool.py333
-rw-r--r--tests/aiet/test_main.py16
-rw-r--r--tests/aiet/test_resources/application_config.json96
-rw-r--r--tests/aiet/test_resources/application_config.json.license3
-rw-r--r--tests/aiet/test_resources/applications/application1/aiet-config.json30
-rw-r--r--tests/aiet/test_resources/applications/application1/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/applications/application2/aiet-config.json30
-rw-r--r--tests/aiet/test_resources/applications/application2/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/applications/application3/readme.txt4
-rw-r--r--tests/aiet/test_resources/applications/application4/aiet-config.json35
-rw-r--r--tests/aiet/test_resources/applications/application4/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/applications/application4/hello_app.txt4
-rw-r--r--tests/aiet/test_resources/applications/application5/aiet-config.json160
-rw-r--r--tests/aiet/test_resources/applications/application5/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/applications/readme.txt4
-rw-r--r--tests/aiet/test_resources/hello_world.json54
-rw-r--r--tests/aiet/test_resources/hello_world.json.license3
-rwxr-xr-xtests/aiet/test_resources/scripts/test_backend_run8
-rw-r--r--tests/aiet/test_resources/scripts/test_backend_run_script.sh8
-rw-r--r--tests/aiet/test_resources/systems/system1/aiet-config.json35
-rw-r--r--tests/aiet/test_resources/systems/system1/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt2
-rw-r--r--tests/aiet/test_resources/systems/system2/aiet-config.json32
-rw-r--r--tests/aiet/test_resources/systems/system2/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/systems/system3/readme.txt4
-rw-r--r--tests/aiet/test_resources/systems/system4/aiet-config.json19
-rw-r--r--tests/aiet/test_resources/systems/system4/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/tools/tool1/aiet-config.json30
-rw-r--r--tests/aiet/test_resources/tools/tool1/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/tools/tool2/aiet-config.json26
-rw-r--r--tests/aiet/test_resources/tools/tool2/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json1
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json35
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json2
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json30
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json35
-rw-r--r--tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json1
-rw-r--r--tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license3
-rw-r--r--tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json16
-rw-r--r--tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license3
-rw-r--r--tests/aiet/test_run_vela_script.py152
-rw-r--r--tests/aiet/test_utils_fs.py168
-rw-r--r--tests/aiet/test_utils_helpers.py27
-rw-r--r--tests/aiet/test_utils_proc.py272
-rw-r--r--tests/conftest.py95
-rw-r--r--tests/mlia/__init__.py3
-rw-r--r--tests/mlia/conftest.py20
-rw-r--r--tests/mlia/test_api.py96
-rw-r--r--tests/mlia/test_cli_commands.py204
-rw-r--r--tests/mlia/test_cli_config.py49
-rw-r--r--tests/mlia/test_cli_helpers.py165
-rw-r--r--tests/mlia/test_cli_logging.py104
-rw-r--r--tests/mlia/test_cli_main.py357
-rw-r--r--tests/mlia/test_cli_options.py186
-rw-r--r--tests/mlia/test_core_advice_generation.py71
-rw-r--r--tests/mlia/test_core_advisor.py40
-rw-r--r--tests/mlia/test_core_context.py62
-rw-r--r--tests/mlia/test_core_data_analysis.py31
-rw-r--r--tests/mlia/test_core_events.py155
-rw-r--r--tests/mlia/test_core_helpers.py17
-rw-r--r--tests/mlia/test_core_mixins.py99
-rw-r--r--tests/mlia/test_core_performance.py29
-rw-r--r--tests/mlia/test_core_reporting.py413
-rw-r--r--tests/mlia/test_core_workflow.py164
-rw-r--r--tests/mlia/test_devices_ethosu_advice_generation.py483
-rw-r--r--tests/mlia/test_devices_ethosu_advisor.py9
-rw-r--r--tests/mlia/test_devices_ethosu_config.py124
-rw-r--r--tests/mlia/test_devices_ethosu_data_analysis.py147
-rw-r--r--tests/mlia/test_devices_ethosu_data_collection.py151
-rw-r--r--tests/mlia/test_devices_ethosu_performance.py28
-rw-r--r--tests/mlia/test_devices_ethosu_reporters.py434
-rw-r--r--tests/mlia/test_nn_tensorflow_config.py72
-rw-r--r--tests/mlia/test_nn_tensorflow_optimizations_clustering.py131
-rw-r--r--tests/mlia/test_nn_tensorflow_optimizations_pruning.py117
-rw-r--r--tests/mlia/test_nn_tensorflow_optimizations_select.py240
-rw-r--r--tests/mlia/test_nn_tensorflow_tflite_metrics.py137
-rw-r--r--tests/mlia/test_nn_tensorflow_utils.py81
-rw-r--r--tests/mlia/test_resources/vela/sample_vela.ini47
-rw-r--r--tests/mlia/test_tools_aiet_wrapper.py760
-rw-r--r--tests/mlia/test_tools_metadata_common.py196
-rw-r--r--tests/mlia/test_tools_metadata_corstone.py419
-rw-r--r--tests/mlia/test_tools_vela_wrapper.py285
-rw-r--r--tests/mlia/test_utils_console.py100
-rw-r--r--tests/mlia/test_utils_download.py147
-rw-r--r--tests/mlia/test_utils_filesystem.py166
-rw-r--r--tests/mlia/test_utils_logging.py63
-rw-r--r--tests/mlia/test_utils_misc.py25
-rw-r--r--tests/mlia/test_utils_proc.py149
-rw-r--r--tests/mlia/test_utils_types.py77
-rw-r--r--tests/mlia/utils/__init__.py3
-rw-r--r--tests/mlia/utils/common.py32
-rw-r--r--tests/mlia/utils/logging.py13
249 files changed, 27687 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e1557d9
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,15 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+*.pyc
+*~
+\.coverage
+\.eggs
+build
+dist
+src/*.egg-info
+.vscode
+venv
+e2e_config
+mlia_output
+report
+.ipynb_checkpoints
diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt
new file mode 100644
index 0000000..f433b1a
--- /dev/null
+++ b/LICENSES/Apache-2.0.txt
@@ -0,0 +1,177 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
diff --git a/LICENSES/BSD-3-Clause.txt b/LICENSES/BSD-3-Clause.txt
new file mode 100644
index 0000000..e7674f7
--- /dev/null
+++ b/LICENSES/BSD-3-Clause.txt
@@ -0,0 +1,26 @@
+Copyright <YEAR> <COPYRIGHT HOLDER>
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/LICENSES/CC-PDDC.txt b/LICENSES/CC-PDDC.txt
new file mode 100644
index 0000000..b4ecb41
--- /dev/null
+++ b/LICENSES/CC-PDDC.txt
@@ -0,0 +1,28 @@
+Creative Commons Public Domain Dedication and Certification
+
+The person or persons who have associated work with this document (the
+"Dedicator" or "Certifier") hereby either (a) certifies that, to the best of
+his knowledge, the work of authorship identified is in the public domain of the
+country from which the work is published, or (b) hereby dedicates whatever
+copyright the dedicators holds in the work of authorship identified below (the
+"Work") to the public domain. A certifier, moreover, dedicates any copyright
+interest he may have in the associated work, and for these purposes, is
+described as a "dedicator" below.
+
+A certifier has taken reasonable steps to verify the copyright status of this
+work. Certifier recognizes that his good faith efforts may not shield him from
+liability if in fact the work certified is not in the public domain.
+
+Dedicator makes this dedication for the benefit of the public at large and to
+the detriment of the Dedicator's heirs and successors. Dedicator intends this
+dedication to be an overt act of relinquishment in perpetuity of all present
+and future rights under copyright law, whether vested or contingent, in the
+Work. Dedicator understands that such relinquishment of all rights includes
+the relinquishment of all rights to enforce (by lawsuit or otherwise) those
+copyrights in the Work.
+
+Dedicator recognizes that, once placed in the public domain, the Work may be
+freely reproduced, distributed, transmitted, used, modified, built upon, or
+otherwise exploited by anyone for any purpose, commercial or non-commercial,
+and in any way, including by methods that have not yet been invented or
+conceived.
diff --git a/LICENSES/MIT.txt b/LICENSES/MIT.txt
new file mode 100644
index 0000000..8aa2645
--- /dev/null
+++ b/LICENSES/MIT.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) [year] [fullname]
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..0f012d2
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,5 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+include README.md
+graft src/mlia/resources
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..dd1fd05
--- /dev/null
+++ b/README.md
@@ -0,0 +1,312 @@
+<!---
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+--->
+# ML Inference Advisor
+
+## Introduction
+
+This tool is used to help AI developers design and optimize neural network
+models for efficient inference on Arm® targets by enabling performance analysis
+and providing actionable advice early in the model development cycle. The final
+advice can cover the operator list, performance analysis and suggestions for
+model inference run on certain hardware before/after applying model optimization
+(e.g. pruning, clustering, etc.).
+
+## Prerequisites and dependencies
+
+It is recommended to use virtual environments for MLIA installation, and a
+typical setup for MLIA requires:
+
+* Ubuntu® 20.04.03 LTS (other OSs may work, the ML Inference Advisor has been
+ tested on this one specifically)
+* Python® >= 3.8
+* Ethos™-U Vela dependencies (Linux® only)
+ * For more details, please refer to the
+ [prerequisites of Vela](https://pypi.org/project/ethos-u-vela/)
+
+## Backend installation
+
+### Generic case using Corstone™-300 as an example
+
+The ML Inference Advisor is designed to support multiple performance
+estimators (backends) that could generate performance analysis for individual
+types of hardware. In this guide, we use the backend for
+Ethos™-U (Corstone™-300) as an example.
+
+The install command can automatically download the necessary components and
+dependencies, install them and configure them properly.
+
+The usage is:
+
+```bash
+mlia backend install --help
+```
+
+and the result looks like:
+
+positional arguments:
+
+* name: Name of the backend to install
+
+optional arguments:
+
+* -h/--help: Show this help message and exit
+* --path PATH: Path to the installed backend
+* --download: Download and install a backend
+* --noninteractive: Non interactive mode with automatic confirmation of every action
+
+Some examples of the installation process are:
+
+```bash
+# reply 'y' or 'n' when prompted to download and install a Corstone-300
+mlia backend install --download
+# for downloading and installing a specific backend
+mlia backend install Corstone-300 --download
+# for installing backend from the path of your downloaded backend
+mlia backend install --path your_local_path_for_the_installed_backend
+```
+
+Please note: Corstone™-300 used in the example above is available only
+on the Linux® platform.
+
+After a successful installation of the backend(s), start using mlia in your
+virtual environment. Please note: backends cannot be removed once installed.
+Consider creating new environment and reinstall backends when needed.
+
+### Using Corstone™-310
+
+For instructions on installing Corstone™-310, please refer to
+<https://github.com/ARM-software/open-iot-sdk>
+
+## Usage
+
+After the initial setup, you can start the program by opening your terminal and
+typing the following command:
+
+```bash
+mlia [command] [arguments]
+```
+
+where [command] is to be substituted by one of the supported options, discussed in
+the next section.
+
+To get a list of all available options, use:
+
+```bash
+mlia --help
+```
+
+To get help on a specific command, use:
+
+```bash
+mlia [command] --help
+```
+
+Choices of commands: you can use ["operators"](#operators-ops) command for the
+model's operator list, run the specified optimizations using
+["model optimization"](#model-optimization-opt) command, or measure the
+performance of inference on hardware using ["performance"](#performance-perf)
+command. In the end, you can use ["all tests"](#all-tests-all) command to
+have a full report.
+
+Most commands accept the name of the target profile name as input parameter.
+There are a number of predefined profiles with following attributes:
+
+```
++--------------------------------------------------------------------+
+| Profile name | MAC | System config | Memory mode |
++=====================================================================
+| ethos-u55-256 | 256 | Ethos_U55_High_End_Embedded | Shared_Sram |
++---------------------------------------------------------------------
+| ethos-u55-128 | 128 | Ethos_U55_High_End_Embedded | Shared_Sram |
++---------------------------------------------------------------------
+| ethos-u65-512 | 512 | Ethos_U65_High_End | Dedicated_Sram |
++--------------------------------------------------------------------+
+```
+
+### **Operators** (ops)
+
+#### *Description*
+
+Prints the model's operator list.
+
+#### *Arguments*
+
+##### Optional arguments
+
+* -h/--help: Show the general help document and exit.
+* --supported-ops-report: Generate the SUPPORTED_OPS.md file in the current working
+ directory and exit.
+
+##### Target profile options
+
+* --target-profile: Target profile that will set the target options such as
+ target, mac value, memory mode, etc ...
+ * default: ethos-u55-256
+ * options:
+ * ethos-u55-256
+ * ethos-u55-128
+ * ethos-u65-512
+
+##### TFLite model options
+
+* model: Input model in TFLite format [required].
+
+##### Output options
+
+* --output: Name of the file where the report will be saved.
+ The report is also displayed the standard output, as plain text.
+ Valid file extensions (formats) are {.txt,.json,.csv},
+ anything else will be formatted as plain text.
+
+### **Performance** (perf)
+
+#### *Description*
+
+Prints the model's performance statistics.
+
+#### *Arguments*
+
+##### optional arguments
+
+* -h/--help: Show the general help document and exit.
+* --supported-ops-report: Generate the SUPPORTED_OPS.md file in the current
+ working directory and exit.
+
+##### Target profile options
+
+* --target-profile: Target profile that will set the target options such as
+ target, mac value, memory mode, etc ...
+ * default: ethos-u55-256
+ * options:
+ * ethos-u55-256
+ * ethos-u55-128
+ * ethos-u65-512
+
+##### TFLite model options
+
+* model: Input model in TFLite format [required].
+
+##### Output options
+
+* --output: Name of the file where the report will be saved.
+ The report is also displayed the standard output, as plain text.
+ Valid file extensions (formats) are {.txt,.json,.csv},
+ anything else will be formatted as plain text.
+
+##### Debug options
+
+* --verbose: Produce verbose output (for debugging purposes).
+
+### **Model optimization** (opt)
+
+#### *Description*
+
+Shows the performance improvements after applying optimizations to the model.
+
+#### *Arguments*
+
+##### optional arguments
+
+* -h/--help: Show the general help document and exit.
+* --supported-ops-report: Generate the SUPPORTED_OPS.md file in the current
+ working directory and exit.
+
+##### Target profile options
+
+* --target-profile: Target profile that will set the target options such as
+ target, mac value, memory mode, etc ...
+ * default: ethos-u55-256
+ * options:
+ * ethos-u55-256
+ * ethos-u55-128
+ * ethos-u65-512
+
+##### Keras™ model options
+
+* model: Input model in Keras™ (.h5 or SavedModel) format [required].
+
+##### optimization options
+
+* --optimization-type: Type of optimization to apply to the model [required].
+ * options:
+ * pruning
+ * clustering
+* --optimization-target: Target for optimization (for pruning this is sparsity
+ between (0,1), for clustering this is the number of clusters
+ (positive integer)) [required].
+* --layers-to-optimize: Name of the layers to optimize (separated by space).
+ Example: conv1 conv2 conv3
+ * default: every layer
+
+##### Debug options
+
+* --verbose: Produce verbose output (for debugging purposes).
+
+### **All tests** (all)
+
+#### *Description*
+
+Generates a full report on the input model's operator list,
+runs the specified optimizations and lists the performance improvements.
+
+#### *Arguments*
+
+##### Optional arguments
+
+* -h/--help: show this help message and exit
+
+##### Target profile options
+
+* --target-profile: Target profile that will set the target options such as
+ target, mac value, memory mode, etc ...
+ * default: ethos-u55-256
+ * options:
+ * ethos-u55-256
+ * ethos-u55-128
+ * ethos-u65-512
+
+##### Keras™ model options
+
+* model: Input model in Keras™ (.h5 or SavedModel) format [required].
+
+##### Optimization options
+
+* --optimization-type: List of the optimization types separated by comma
+ * default: pruning, clustering
+* --optimization-target: List of the optimization targets separated by comma,
+ (for pruning this is sparsity between (0,1), for clustering this is the
+ number of clusters (positive integer))
+ * default: 0.5, 32
+
+##### Output options
+
+* --output: Name of the file where the report will be saved.
+ The report is also displayed the standard output, as plain text.
+ Valid file extensions (formats) are {.txt,.json,.csv},
+ anything else will be formatted as plain text.
+
+##### Debug options
+
+* --verbose: Produce verbose output (for debugging purposes).
+
+## Resources
+
+Additional useful information:
+
+* [Corstone™-300](https://developer.arm.com/Processors/Corstone-300)
+
+## License
+
+ML Inference Advisor is licensed under [Apache License 2.0](LICENSE.txt).
+
+## Trademarks and Copyrights
+
+Arm®, Ethos™-U, Cortex®-M, Corstone™ are registered trademarks or trademarks
+of Arm® Limited (or its subsidiaries) in the U.S. and/or elsewhere.
+TensorFlow™ is a trademark of Google® LLC.
+Keras™ is a trademark by François Chollet.
+Linux® is the registered trademark of Linus Torvalds in the U.S. and elsewhere.
+Python® is a registered trademark of the PSF.
+Ubuntu® is a registered trademark of Canonical.
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..05363d8
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,78 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright (c) 2012-2022 Jukka Lehtosalo and contributors
+# SPDX-FileCopyrightText: Copyright (c) 2015-2022 Dropbox, Inc.
+# SPDX-License-Identifier: Apache-2.0 AND MIT
+
+
+[build-system]
+requires = [
+ "setuptools>=42",
+ "wheel",
+ "setuptools_scm[toml]>=6.2"
+]
+build-backend = "setuptools.build_meta"
+
+# Enable setuptools_scm
+[tool.setuptools_scm]
+
+[tool.pytest.ini_options]
+testpaths = "tests"
+markers = [
+ "e2e", # e2e tests
+ "install", # installation tests
+ "command", # command tests
+ "model_gen" # model generation tests
+]
+junit_logging = "all"
+
+[tool.pylint.messages_control]
+min-similarity-lines = 10
+min-public-methods = 1
+max-line-length = 88
+max-args = 8
+max-attributes=10
+
+# Provide basic compatibility with black
+disable = [
+ "wrong-import-order",
+ "consider-using-f-string" # C0209
+]
+
+enable = [
+ "dangerous-default-value", # W0102
+ # black will reflow code lines, but won't touch comments, error on those
+ "line-too-long" # C0301
+]
+
+[tool.mypy]
+# Suppresses error messages about imports that cannot be resolved
+ignore_missing_imports = true
+# Shows a warning when encountering any code inferred to be unreachable or redundant after performing type analysis
+warn_unreachable = true
+# Shows errors for missing return statements on some execution paths
+warn_no_return = true
+# Shows a warning when returning a value with type Any from a function declared with a non- Any return type
+warn_return_any = true
+# Warns about unneeded # type: ignore comments
+warn_unused_ignores = true
+# Warns about casting an expression to its inferred type
+warn_redundant_casts = true
+# Disallows calling functions without type annotations from functions with type annotations
+disallow_untyped_calls = true
+# Disallows defining functions without type annotations or with incomplete type annotations
+disallow_untyped_defs = true
+# Disallows defining functions with incomplete type annotations
+disallow_incomplete_defs = true
+# Reports an error whenever a function with type annotations is decorated with a decorator without annotations
+disallow_untyped_decorators = true
+# Type-checks the interior of functions without type annotations
+check_untyped_defs = true
+
+[[tool.mypy.overrides]]
+module = [
+ "pkg_resources",
+ "paramiko",
+ "requests",
+ "filelock"
+]
+ignore_missing_imports = true
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..5a9202a
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,79 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright (c) 2020 Troy Comi
+# SPDX-License-Identifier: Apache-2.0 AND MIT
+
+[metadata]
+name = mlia
+description = ML Inference Advisor
+long_description = file: README.md
+url = https://git.mlplatform.org/ml/mlia.git
+author = Arm Ltd
+author_email = mlia@arm.com
+license = Apache License 2.0
+license_files = LICENSES/*.txt
+classifiers =
+ Development Status :: 4 - Beta
+ License :: OSI Approved :: Apache Software License
+ Intended Audience :: Developers
+ Operating System :: POSIX :: Linux
+ Programming Language :: Python :: 3
+ Programming Language :: Python :: 3.8
+ Topic :: Scientific/Engineering :: Artificial Intelligence
+keywords = ml, arm, ethos-u, tflite
+
+[options]
+include_package_data = True
+python_requires = >=3.8
+package_dir =
+ = src
+packages = find:
+install_requires =
+ tensorflow~=2.7.1
+ tensorflow-model-optimization~=0.7.2
+ ethos-u-vela~=3.3.0
+ requests
+ rich
+ click
+ sh
+ paramiko
+ filelock
+ psutil
+ cloup>=0.12.0
+
+[options.packages.find]
+where = src
+
+[options.entry_points]
+console_scripts =
+ mlia=mlia.cli.main:main
+ aiet=aiet.main:main
+ run_vela=aiet.resources.tools.vela.run_vela:main
+
+[options.extras_require]
+dev =
+ pytest==7.1.1
+ pytest-cov==3.0.0
+ mypy==0.942
+ pylint==2.13.7
+
+[flake8]
+# ignored errors
+# E501 line too long
+# W503 line break before binary operator
+ignore = E501, W503
+max-complexity = 18
+select = B,C,E,F,W,T4
+
+[blocklint]
+# Do not allow any non-inclusive language
+max_issue_threshold=1
+# Blocklist: Words to lint in any context, with possibly special characters
+# between, case insensitive
+blocklist=master,slave,blacklist,whitelist
+# Word list: Words to lint as whole words, with possibly special characters
+# between, case insensitive
+wordlist=he,she,him,her,his,hers
+# Exact list: Words to lint as whole words exactly as entered
+# exactlist=
+# Files that should not be checked by blocklint.
+skip_files=LICENSES/CC-PDDC.txt
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..4707ea5
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module to setup the python package."""
+from setuptools import setup
+
+
+setup()
diff --git a/src/aiet/__init__.py b/src/aiet/__init__.py
new file mode 100644
index 0000000..49304b1
--- /dev/null
+++ b/src/aiet/__init__.py
@@ -0,0 +1,7 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Init of aiet."""
+import pkg_resources
+
+
+__version__ = pkg_resources.get_distribution("mlia").version
diff --git a/src/aiet/backend/__init__.py b/src/aiet/backend/__init__.py
new file mode 100644
index 0000000..3d60372
--- /dev/null
+++ b/src/aiet/backend/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Backend module."""
diff --git a/src/aiet/backend/application.py b/src/aiet/backend/application.py
new file mode 100644
index 0000000..f6ef815
--- /dev/null
+++ b/src/aiet/backend/application.py
@@ -0,0 +1,187 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Application backend module."""
+import re
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from aiet.backend.common import Backend
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import DataPaths
+from aiet.backend.common import get_backend_configs
+from aiet.backend.common import get_backend_directories
+from aiet.backend.common import load_application_or_tool_configs
+from aiet.backend.common import load_config
+from aiet.backend.common import remove_backend
+from aiet.backend.config import ApplicationConfig
+from aiet.backend.config import ExtendedApplicationConfig
+from aiet.backend.source import create_destination_and_install
+from aiet.backend.source import get_source
+from aiet.utils.fs import get_resources
+
+
+def get_available_application_directory_names() -> List[str]:
+ """Return a list of directory names for all available applications."""
+ return [entry.name for entry in get_backend_directories("applications")]
+
+
+def get_available_applications() -> List["Application"]:
+ """Return a list with all available applications."""
+ available_applications = []
+ for config_json in get_backend_configs("applications"):
+ config_entries = cast(List[ExtendedApplicationConfig], load_config(config_json))
+ for config_entry in config_entries:
+ config_entry["config_location"] = config_json.parent.absolute()
+ applications = load_applications(config_entry)
+ available_applications += applications
+
+ return sorted(available_applications, key=lambda application: application.name)
+
+
+def get_application(
+ application_name: str, system_name: Optional[str] = None
+) -> List["Application"]:
+ """Return a list of application instances with provided name."""
+ return [
+ application
+ for application in get_available_applications()
+ if application.name == application_name
+ and (not system_name or application.can_run_on(system_name))
+ ]
+
+
+def install_application(source_path: Path) -> None:
+ """Install application."""
+ try:
+ source = get_source(source_path)
+ config = cast(List[ExtendedApplicationConfig], source.config())
+ applications_to_install = [
+ s for entry in config for s in load_applications(entry)
+ ]
+ except Exception as error:
+ raise ConfigurationException("Unable to read application definition") from error
+
+ if not applications_to_install:
+ raise ConfigurationException("No application definition found")
+
+ available_applications = get_available_applications()
+ already_installed = [
+ s for s in applications_to_install if s in available_applications
+ ]
+ if already_installed:
+ names = {application.name for application in already_installed}
+ raise ConfigurationException(
+ "Applications [{}] are already installed".format(",".join(names))
+ )
+
+ create_destination_and_install(source, get_resources("applications"))
+
+
+def remove_application(directory_name: str) -> None:
+ """Remove application directory."""
+ remove_backend(directory_name, "applications")
+
+
+def get_unique_application_names(system_name: Optional[str] = None) -> List[str]:
+ """Extract a list of unique application names of all application available."""
+ return list(
+ set(
+ application.name
+ for application in get_available_applications()
+ if not system_name or application.can_run_on(system_name)
+ )
+ )
+
+
+class Application(Backend):
+ """Class for representing a single application component."""
+
+ def __init__(self, config: ApplicationConfig) -> None:
+ """Construct a Application instance from a dict."""
+ super().__init__(config)
+
+ self.supported_systems = config.get("supported_systems", [])
+ self.deploy_data = config.get("deploy_data", [])
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Application):
+ return False
+
+ return (
+ super().__eq__(other)
+ and self.name == other.name
+ and set(self.supported_systems) == set(other.supported_systems)
+ )
+
+ def can_run_on(self, system_name: str) -> bool:
+ """Check if the application can run on the system passed as argument."""
+ return system_name in self.supported_systems
+
+ def get_deploy_data(self) -> List[DataPaths]:
+ """Validate and return data specified in the config file."""
+ if self.config_location is None:
+ raise ConfigurationException(
+ "Unable to get application {} config location".format(self.name)
+ )
+
+ deploy_data = []
+ for item in self.deploy_data:
+ src, dst = item
+ src_full_path = self.config_location / src
+ assert src_full_path.exists(), "{} does not exists".format(src_full_path)
+ deploy_data.append(DataPaths(src_full_path, dst))
+ return deploy_data
+
+ def get_details(self) -> Dict[str, Any]:
+ """Return dictionary with information about the Application instance."""
+ output = {
+ "type": "application",
+ "name": self.name,
+ "description": self.description,
+ "supported_systems": self.supported_systems,
+ "commands": self._get_command_details(),
+ }
+
+ return output
+
+ def remove_unused_params(self) -> None:
+ """Remove unused params in commands.
+
+ After merging default and system related configuration application
+ could have parameters that are not being used in commands. They
+ should be removed.
+ """
+ for command in self.commands.values():
+ indexes_or_aliases = [
+ m
+ for cmd_str in command.command_strings
+ for m in re.findall(r"{user_params:(?P<index_or_alias>\w+)}", cmd_str)
+ ]
+
+ only_aliases = all(not item.isnumeric() for item in indexes_or_aliases)
+ if only_aliases:
+ used_params = [
+ param
+ for param in command.params
+ if param.alias in indexes_or_aliases
+ ]
+ command.params = used_params
+
+
+def load_applications(config: ExtendedApplicationConfig) -> List[Application]:
+ """Load application.
+
+ Application configuration could contain different parameters/commands for different
+ supported systems. For each supported system this function will return separate
+ Application instance with appropriate configuration.
+ """
+ configs = load_application_or_tool_configs(config, ApplicationConfig)
+ applications = [Application(cfg) for cfg in configs]
+ for application in applications:
+ application.remove_unused_params()
+ return applications
diff --git a/src/aiet/backend/common.py b/src/aiet/backend/common.py
new file mode 100644
index 0000000..b887ee7
--- /dev/null
+++ b/src/aiet/backend/common.py
@@ -0,0 +1,532 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contain all common functions for the backends."""
+import json
+import logging
+import re
+from abc import ABC
+from collections import Counter
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import Final
+from typing import IO
+from typing import Iterable
+from typing import List
+from typing import Match
+from typing import NamedTuple
+from typing import Optional
+from typing import Pattern
+from typing import Tuple
+from typing import Type
+from typing import Union
+
+from aiet.backend.config import BackendConfig
+from aiet.backend.config import BaseBackendConfig
+from aiet.backend.config import NamedExecutionConfig
+from aiet.backend.config import UserParamConfig
+from aiet.backend.config import UserParamsConfig
+from aiet.utils.fs import get_resources
+from aiet.utils.fs import remove_resource
+from aiet.utils.fs import ResourceType
+
+
+AIET_CONFIG_FILE: Final[str] = "aiet-config.json"
+
+
+class ConfigurationException(Exception):
+ """Configuration exception."""
+
+
+def get_backend_config(dir_path: Path) -> Path:
+ """Get path to backendir configuration file."""
+ return dir_path / AIET_CONFIG_FILE
+
+
+def get_backend_configs(resource_type: ResourceType) -> Iterable[Path]:
+ """Get path to the backend configs for provided resource_type."""
+ return (
+ get_backend_config(entry) for entry in get_backend_directories(resource_type)
+ )
+
+
+def get_backend_directories(resource_type: ResourceType) -> Iterable[Path]:
+ """Get path to the backend directories for provided resource_type."""
+ return (
+ entry
+ for entry in get_resources(resource_type).iterdir()
+ if is_backend_directory(entry)
+ )
+
+
+def is_backend_directory(dir_path: Path) -> bool:
+ """Check if path is backend's configuration directory."""
+ return dir_path.is_dir() and get_backend_config(dir_path).is_file()
+
+
+def remove_backend(directory_name: str, resource_type: ResourceType) -> None:
+ """Remove backend with provided type and directory_name."""
+ if not directory_name:
+ raise Exception("No directory name provided")
+
+ remove_resource(directory_name, resource_type)
+
+
+def load_config(config: Union[None, Path, IO[bytes]]) -> BackendConfig:
+ """Return a loaded json file."""
+ if config is None:
+ raise Exception("Unable to read config")
+
+ if isinstance(config, Path):
+ with config.open() as json_file:
+ return cast(BackendConfig, json.load(json_file))
+
+ return cast(BackendConfig, json.load(config))
+
+
+def parse_raw_parameter(parameter: str) -> Tuple[str, Optional[str]]:
+ """Split the parameter string in name and optional value.
+
+ It manages the following cases:
+ --param=1 -> --param, 1
+ --param 1 -> --param, 1
+ --flag -> --flag, None
+ """
+ data = re.split(" |=", parameter)
+ if len(data) == 1:
+ param_name = data[0]
+ param_value = None
+ else:
+ param_name = " ".join(data[0:-1])
+ param_value = data[-1]
+ return param_name, param_value
+
+
+class DataPaths(NamedTuple):
+ """DataPaths class."""
+
+ src: Path
+ dst: str
+
+
+class Backend(ABC):
+ """Backend class."""
+
+ # pylint: disable=too-many-instance-attributes
+
+ def __init__(self, config: BaseBackendConfig):
+ """Initialize backend."""
+ name = config.get("name")
+ if not name:
+ raise ConfigurationException("Name is empty")
+
+ self.name = name
+ self.description = config.get("description", "")
+ self.config_location = config.get("config_location")
+ self.variables = config.get("variables", {})
+ self.build_dir = config.get("build_dir")
+ self.lock = config.get("lock", False)
+ if self.build_dir:
+ self.build_dir = self._substitute_variables(self.build_dir)
+ self.annotations = config.get("annotations", {})
+
+ self._parse_commands_and_params(config)
+
+ def validate_parameter(self, command_name: str, parameter: str) -> bool:
+ """Validate the parameter string against the application configuration.
+
+ We take the parameter string, extract the parameter name/value and
+ check them against the current configuration.
+ """
+ param_name, param_value = parse_raw_parameter(parameter)
+ valid_param_name = valid_param_value = False
+
+ command = self.commands.get(command_name)
+ if not command:
+ raise AttributeError("Unknown command: '{}'".format(command_name))
+
+ # Iterate over all available parameters until we have a match.
+ for param in command.params:
+ if self._same_parameter(param_name, param):
+ valid_param_name = True
+ # This is a non-empty list
+ if param.values:
+ # We check if the value is allowed in the configuration
+ valid_param_value = param_value in param.values
+ else:
+ # In this case we don't validate the value and accept
+ # whatever we have set.
+ valid_param_value = True
+ break
+
+ return valid_param_name and valid_param_value
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Backend):
+ return False
+
+ return (
+ self.name == other.name
+ and self.description == other.description
+ and self.commands == other.commands
+ )
+
+ def __repr__(self) -> str:
+ """Represent the Backend instance by its name."""
+ return self.name
+
+ def _parse_commands_and_params(self, config: BaseBackendConfig) -> None:
+ """Parse commands and user parameters."""
+ self.commands: Dict[str, Command] = {}
+
+ commands = config.get("commands")
+ if commands:
+ params = config.get("user_params")
+
+ for command_name in commands.keys():
+ command_params = self._parse_params(params, command_name)
+ command_strings = [
+ self._substitute_variables(cmd)
+ for cmd in commands.get(command_name, [])
+ ]
+ self.commands[command_name] = Command(command_strings, command_params)
+
+ def _substitute_variables(self, str_val: str) -> str:
+ """Substitute variables in string.
+
+ Variables is being substituted at backend's creation stage because
+ they could contain references to other params which will be
+ resolved later.
+ """
+ if not str_val:
+ return str_val
+
+ var_pattern: Final[Pattern] = re.compile(r"{variables:(?P<var_name>\w+)}")
+
+ def var_value(match: Match) -> str:
+ var_name = match["var_name"]
+ if var_name not in self.variables:
+ raise ConfigurationException("Unknown variable {}".format(var_name))
+
+ return self.variables[var_name]
+
+ return var_pattern.sub(var_value, str_val) # type: ignore
+
+ @classmethod
+ def _parse_params(
+ cls, params: Optional[UserParamsConfig], command: str
+ ) -> List["Param"]:
+ if not params:
+ return []
+
+ return [cls._parse_param(p) for p in params.get(command, [])]
+
+ @classmethod
+ def _parse_param(cls, param: UserParamConfig) -> "Param":
+ """Parse a single parameter."""
+ name = param.get("name")
+ if name is not None and not name:
+ raise ConfigurationException("Parameter has an empty 'name' attribute.")
+ values = param.get("values", None)
+ default_value = param.get("default_value", None)
+ description = param.get("description", "")
+ alias = param.get("alias")
+
+ return Param(
+ name=name,
+ description=description,
+ values=values,
+ default_value=default_value,
+ alias=alias,
+ )
+
+ def _get_command_details(self) -> Dict:
+ command_details = {
+ command_name: command.get_details()
+ for command_name, command in self.commands.items()
+ }
+ return command_details
+
+ def _get_user_param_value(
+ self, user_params: List[str], param: "Param"
+ ) -> Optional[str]:
+ """Get the user-specified value of a parameter."""
+ for user_param in user_params:
+ user_param_name, user_param_value = parse_raw_parameter(user_param)
+ if user_param_name == param.name:
+ warn_message = (
+ "The direct use of parameter name is deprecated"
+ " and might be removed in the future.\n"
+ f"Please use alias '{param.alias}' instead of "
+ "'{user_param_name}' to provide the parameter."
+ )
+ logging.warning(warn_message)
+
+ if self._same_parameter(user_param_name, param):
+ return user_param_value
+
+ return None
+
+ @staticmethod
+ def _same_parameter(user_param_name_or_alias: str, param: "Param") -> bool:
+ """Compare user parameter name with param name or alias."""
+ # Strip the "=" sign in the param_name. This is needed just for
+ # comparison with the parameters passed by the user.
+ # The equal sign needs to be honoured when re-building the
+ # parameter back.
+ param_name = None if not param.name else param.name.rstrip("=")
+ return user_param_name_or_alias in [param_name, param.alias]
+
+ def resolved_parameters(
+ self, command_name: str, user_params: List[str]
+ ) -> List[Tuple[Optional[str], "Param"]]:
+ """Return list of parameters with values."""
+ result: List[Tuple[Optional[str], "Param"]] = []
+ command = self.commands.get(command_name)
+ if not command:
+ return result
+
+ for param in command.params:
+ value = self._get_user_param_value(user_params, param)
+ if not value:
+ value = param.default_value
+ result.append((value, param))
+
+ return result
+
+ def build_command(
+ self,
+ command_name: str,
+ user_params: List[str],
+ param_resolver: Callable[[str, str, List[Tuple[Optional[str], "Param"]]], str],
+ ) -> List[str]:
+ """
+ Return a list of executable command strings.
+
+ Given a command and associated parameters, returns a list of executable command
+ strings.
+ """
+ command = self.commands.get(command_name)
+ if not command:
+ raise ConfigurationException(
+ "Command '{}' could not be found.".format(command_name)
+ )
+
+ commands_to_run = []
+
+ params_values = self.resolved_parameters(command_name, user_params)
+ for cmd_str in command.command_strings:
+ cmd_str = resolve_all_parameters(
+ cmd_str, param_resolver, command_name, params_values
+ )
+ commands_to_run.append(cmd_str)
+
+ return commands_to_run
+
+
+class Param:
+ """Class for representing a generic application parameter."""
+
+ def __init__( # pylint: disable=too-many-arguments
+ self,
+ name: Optional[str],
+ description: str,
+ values: Optional[List[str]] = None,
+ default_value: Optional[str] = None,
+ alias: Optional[str] = None,
+ ) -> None:
+ """Construct a Param instance."""
+ if not name and not alias:
+ raise ConfigurationException(
+ "Either name, alias or both must be set to identify a parameter."
+ )
+ self.name = name
+ self.values = values
+ self.description = description
+ self.default_value = default_value
+ self.alias = alias
+
+ def get_details(self) -> Dict:
+ """Return a dictionary with all relevant information of a Param."""
+ return {key: value for key, value in self.__dict__.items() if value}
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Param):
+ return False
+
+ return (
+ self.name == other.name
+ and self.values == other.values
+ and self.default_value == other.default_value
+ and self.description == other.description
+ )
+
+
+class Command:
+ """Class for representing a command."""
+
+ def __init__(
+ self, command_strings: List[str], params: Optional[List[Param]] = None
+ ) -> None:
+ """Construct a Command instance."""
+ self.command_strings = command_strings
+
+ if params:
+ self.params = params
+ else:
+ self.params = []
+
+ self._validate()
+
+ def _validate(self) -> None:
+ """Validate command."""
+ if not self.params:
+ return
+
+ aliases = [param.alias for param in self.params if param.alias is not None]
+ repeated_aliases = [
+ alias for alias, count in Counter(aliases).items() if count > 1
+ ]
+
+ if repeated_aliases:
+ raise ConfigurationException(
+ "Non unique aliases {}".format(", ".join(repeated_aliases))
+ )
+
+ both_name_and_alias = [
+ param.name
+ for param in self.params
+ if param.name in aliases and param.name != param.alias
+ ]
+ if both_name_and_alias:
+ raise ConfigurationException(
+ "Aliases {} could not be used as parameter name".format(
+ ", ".join(both_name_and_alias)
+ )
+ )
+
+ def get_details(self) -> Dict:
+ """Return a dictionary with all relevant information of a Command."""
+ output = {
+ "command_strings": self.command_strings,
+ "user_params": [param.get_details() for param in self.params],
+ }
+ return output
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Command):
+ return False
+
+ return (
+ self.command_strings == other.command_strings
+ and self.params == other.params
+ )
+
+
+def resolve_all_parameters(
+ str_val: str,
+ param_resolver: Callable[[str, str, List[Tuple[Optional[str], Param]]], str],
+ command_name: Optional[str] = None,
+ params_values: Optional[List[Tuple[Optional[str], Param]]] = None,
+) -> str:
+ """Resolve all parameters in the string."""
+ if not str_val:
+ return str_val
+
+ param_pattern: Final[Pattern] = re.compile(r"{(?P<param_name>[\w.:]+)}")
+ while param_pattern.findall(str_val):
+ str_val = param_pattern.sub(
+ lambda m: param_resolver(
+ m["param_name"], command_name or "", params_values or []
+ ),
+ str_val,
+ )
+ return str_val
+
+
+def load_application_or_tool_configs(
+ config: Any,
+ config_type: Type[Any],
+ is_system_required: bool = True,
+) -> Any:
+ """Get one config for each system supported by the application/tool.
+
+ The configuration could contain different parameters/commands for different
+ supported systems. For each supported system this function will return separate
+ config with appropriate configuration.
+ """
+ merged_configs = []
+ supported_systems: Optional[List[NamedExecutionConfig]] = config.get(
+ "supported_systems"
+ )
+ if not supported_systems:
+ if is_system_required:
+ raise ConfigurationException("No supported systems definition provided")
+ # Create an empty system to be used in the parsing below
+ supported_systems = [cast(NamedExecutionConfig, {})]
+
+ default_user_params = config.get("user_params", {})
+
+ def merge_config(system: NamedExecutionConfig) -> Any:
+ system_name = system.get("name")
+ if not system_name and is_system_required:
+ raise ConfigurationException(
+ "Unable to read supported system definition, name is missed"
+ )
+
+ merged_config = config_type(**config)
+ merged_config["supported_systems"] = [system_name] if system_name else []
+ # merge default configuration and specific to the system
+ merged_config["commands"] = {
+ **config.get("commands", {}),
+ **system.get("commands", {}),
+ }
+
+ params = {}
+ tool_user_params = system.get("user_params", {})
+ command_names = tool_user_params.keys() | default_user_params.keys()
+ for command_name in command_names:
+ if command_name not in merged_config["commands"]:
+ continue
+
+ params_default = default_user_params.get(command_name, [])
+ params_tool = tool_user_params.get(command_name, [])
+ if not params_default or not params_tool:
+ params[command_name] = params_tool or params_default
+ if params_default and params_tool:
+ if any(not p.get("alias") for p in params_default):
+ raise ConfigurationException(
+ "Default parameters for command {} should have aliases".format(
+ command_name
+ )
+ )
+ if any(not p.get("alias") for p in params_tool):
+ raise ConfigurationException(
+ "{} parameters for command {} should have aliases".format(
+ system_name, command_name
+ )
+ )
+
+ merged_by_alias = {
+ **{p.get("alias"): p for p in params_default},
+ **{p.get("alias"): p for p in params_tool},
+ }
+ params[command_name] = list(merged_by_alias.values())
+
+ merged_config["user_params"] = params
+ merged_config["build_dir"] = system.get("build_dir", config.get("build_dir"))
+ merged_config["lock"] = system.get("lock", config.get("lock", False))
+ merged_config["variables"] = {
+ **config.get("variables", {}),
+ **system.get("variables", {}),
+ }
+ return merged_config
+
+ merged_configs = [merge_config(system) for system in supported_systems]
+
+ return merged_configs
diff --git a/src/aiet/backend/config.py b/src/aiet/backend/config.py
new file mode 100644
index 0000000..dd42012
--- /dev/null
+++ b/src/aiet/backend/config.py
@@ -0,0 +1,107 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contain definition of backend configuration."""
+from pathlib import Path
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+from typing import TypedDict
+from typing import Union
+
+
+class UserParamConfig(TypedDict, total=False):
+ """User parameter configuration."""
+
+ name: Optional[str]
+ default_value: str
+ values: List[str]
+ description: str
+ alias: str
+
+
+UserParamsConfig = Dict[str, List[UserParamConfig]]
+
+
+class ExecutionConfig(TypedDict, total=False):
+ """Execution configuration."""
+
+ commands: Dict[str, List[str]]
+ user_params: UserParamsConfig
+ build_dir: str
+ variables: Dict[str, str]
+ lock: bool
+
+
+class NamedExecutionConfig(ExecutionConfig):
+ """Execution configuration with name."""
+
+ name: str
+
+
+class BaseBackendConfig(ExecutionConfig, total=False):
+ """Base backend configuration."""
+
+ name: str
+ description: str
+ config_location: Path
+ annotations: Dict[str, Union[str, List[str]]]
+
+
+class ApplicationConfig(BaseBackendConfig, total=False):
+ """Application configuration."""
+
+ supported_systems: List[str]
+ deploy_data: List[Tuple[str, str]]
+
+
+class ExtendedApplicationConfig(BaseBackendConfig, total=False):
+ """Extended application configuration."""
+
+ supported_systems: List[NamedExecutionConfig]
+ deploy_data: List[Tuple[str, str]]
+
+
+class ProtocolConfig(TypedDict, total=False):
+ """Protocol config."""
+
+ protocol: Literal["local", "ssh"]
+
+
+class SSHConfig(ProtocolConfig, total=False):
+ """SSH configuration."""
+
+ username: str
+ password: str
+ hostname: str
+ port: str
+
+
+class LocalProtocolConfig(ProtocolConfig, total=False):
+ """Local protocol config."""
+
+
+class SystemConfig(BaseBackendConfig, total=False):
+ """System configuration."""
+
+ data_transfer: Union[SSHConfig, LocalProtocolConfig]
+ reporting: Dict[str, Dict]
+
+
+class ToolConfig(BaseBackendConfig, total=False):
+ """Tool configuration."""
+
+ supported_systems: List[str]
+
+
+class ExtendedToolConfig(BaseBackendConfig, total=False):
+ """Extended tool configuration."""
+
+ supported_systems: List[NamedExecutionConfig]
+
+
+BackendItemConfig = Union[ApplicationConfig, SystemConfig, ToolConfig]
+BackendConfig = Union[
+ List[ExtendedApplicationConfig], List[SystemConfig], List[ToolConfig]
+]
diff --git a/src/aiet/backend/controller.py b/src/aiet/backend/controller.py
new file mode 100644
index 0000000..2650902
--- /dev/null
+++ b/src/aiet/backend/controller.py
@@ -0,0 +1,134 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Controller backend module."""
+import time
+from pathlib import Path
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import psutil
+import sh
+
+from aiet.backend.common import ConfigurationException
+from aiet.utils.fs import read_file_as_string
+from aiet.utils.proc import execute_command
+from aiet.utils.proc import get_stdout_stderr_paths
+from aiet.utils.proc import read_process_info
+from aiet.utils.proc import save_process_info
+from aiet.utils.proc import terminate_command
+from aiet.utils.proc import terminate_external_process
+
+
+class SystemController:
+ """System controller class."""
+
+ def __init__(self) -> None:
+ """Create new instance of service controller."""
+ self.cmd: Optional[sh.RunningCommand] = None
+ self.out_path: Optional[Path] = None
+ self.err_path: Optional[Path] = None
+
+ def before_start(self) -> None:
+ """Run actions before system start."""
+
+ def after_start(self) -> None:
+ """Run actions after system start."""
+
+ def start(self, commands: List[str], cwd: Path) -> None:
+ """Start system."""
+ if not isinstance(cwd, Path) or not cwd.is_dir():
+ raise ConfigurationException("Wrong working directory {}".format(cwd))
+
+ if len(commands) != 1:
+ raise ConfigurationException("System should have only one command to run")
+
+ startup_command = commands[0]
+ if not startup_command:
+ raise ConfigurationException("No startup command provided")
+
+ self.before_start()
+
+ self.out_path, self.err_path = get_stdout_stderr_paths(startup_command)
+
+ self.cmd = execute_command(
+ startup_command,
+ cwd,
+ bg=True,
+ out=str(self.out_path),
+ err=str(self.err_path),
+ )
+
+ self.after_start()
+
+ def stop(
+ self, wait: bool = False, wait_period: float = 0.5, number_of_attempts: int = 20
+ ) -> None:
+ """Stop system."""
+ if self.cmd is not None and self.is_running():
+ terminate_command(self.cmd, wait, wait_period, number_of_attempts)
+
+ def is_running(self) -> bool:
+ """Check if underlying process is running."""
+ return self.cmd is not None and self.cmd.is_alive()
+
+ def get_output(self) -> Tuple[str, str]:
+ """Return application output."""
+ if self.cmd is None or self.out_path is None or self.err_path is None:
+ return ("", "")
+
+ return (read_file_as_string(self.out_path), read_file_as_string(self.err_path))
+
+
+class SystemControllerSingleInstance(SystemController):
+ """System controller with support of system's single instance."""
+
+ def __init__(self, pid_file_path: Optional[Path] = None) -> None:
+ """Create new instance of the service controller."""
+ super().__init__()
+ self.pid_file_path = pid_file_path
+
+ def before_start(self) -> None:
+ """Run actions before system start."""
+ self._check_if_previous_instance_is_running()
+
+ def after_start(self) -> None:
+ """Run actions after system start."""
+ self._save_process_info()
+
+ def _check_if_previous_instance_is_running(self) -> None:
+ """Check if another instance of the system is running."""
+ process_info = read_process_info(self._pid_file())
+
+ for item in process_info:
+ try:
+ process = psutil.Process(item.pid)
+ same_process = (
+ process.name() == item.name
+ and process.exe() == item.executable
+ and process.cwd() == item.cwd
+ )
+ if same_process:
+ print(
+ "Stopping previous instance of the system [{}]".format(item.pid)
+ )
+ terminate_external_process(process)
+ except psutil.NoSuchProcess:
+ pass
+
+ def _save_process_info(self, wait_period: float = 2) -> None:
+ """Save information about system's processes."""
+ if self.cmd is None or not self.is_running():
+ return
+
+ # give some time for the system to start
+ time.sleep(wait_period)
+
+ save_process_info(self.cmd.process.pid, self._pid_file())
+
+ def _pid_file(self) -> Path:
+ """Return path to file which is used for saving process info."""
+ if not self.pid_file_path:
+ raise Exception("No pid file path presented")
+
+ return self.pid_file_path
diff --git a/src/aiet/backend/execution.py b/src/aiet/backend/execution.py
new file mode 100644
index 0000000..1653ee2
--- /dev/null
+++ b/src/aiet/backend/execution.py
@@ -0,0 +1,859 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Application execution module."""
+import itertools
+import json
+import random
+import re
+import string
+import sys
+import time
+import warnings
+from collections import defaultdict
+from contextlib import contextmanager
+from contextlib import ExitStack
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import ContextManager
+from typing import Dict
+from typing import Generator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TypedDict
+from typing import Union
+
+from filelock import FileLock
+from filelock import Timeout
+
+from aiet.backend.application import Application
+from aiet.backend.application import get_application
+from aiet.backend.common import Backend
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import DataPaths
+from aiet.backend.common import Param
+from aiet.backend.common import parse_raw_parameter
+from aiet.backend.common import resolve_all_parameters
+from aiet.backend.output_parser import Base64OutputParser
+from aiet.backend.output_parser import OutputParser
+from aiet.backend.output_parser import RegexOutputParser
+from aiet.backend.system import ControlledSystem
+from aiet.backend.system import get_system
+from aiet.backend.system import StandaloneSystem
+from aiet.backend.system import System
+from aiet.backend.tool import get_tool
+from aiet.backend.tool import Tool
+from aiet.utils.fs import recreate_directory
+from aiet.utils.fs import remove_directory
+from aiet.utils.fs import valid_for_filename
+from aiet.utils.proc import run_and_wait
+
+
+class AnotherInstanceIsRunningException(Exception):
+ """Concurrent execution error."""
+
+
+class ConnectionException(Exception):
+ """Connection exception."""
+
+
+class ExecutionParams(TypedDict, total=False):
+ """Execution parameters."""
+
+ disable_locking: bool
+ unique_build_dir: bool
+
+
+class ExecutionContext:
+ """Command execution context."""
+
+ # pylint: disable=too-many-arguments,too-many-instance-attributes
+ def __init__(
+ self,
+ app: Union[Application, Tool],
+ app_params: List[str],
+ system: Optional[System],
+ system_params: List[str],
+ custom_deploy_data: Optional[List[DataPaths]] = None,
+ execution_params: Optional[ExecutionParams] = None,
+ report_file: Optional[Path] = None,
+ ):
+ """Init execution context."""
+ self.app = app
+ self.app_params = app_params
+ self.custom_deploy_data = custom_deploy_data or []
+ self.system = system
+ self.system_params = system_params
+ self.execution_params = execution_params or ExecutionParams()
+ self.report_file = report_file
+
+ self.reporter: Optional[Reporter]
+ if self.report_file:
+ # Create reporter with output parsers
+ parsers: List[OutputParser] = []
+ if system and system.reporting:
+ # Add RegexOutputParser, if it is configured in the system
+ parsers.append(RegexOutputParser("system", system.reporting["regex"]))
+ # Add Base64 parser for applications
+ parsers.append(Base64OutputParser("application"))
+ self.reporter = Reporter(parsers=parsers)
+ else:
+ self.reporter = None # No reporter needed.
+
+ self.param_resolver = ParamResolver(self)
+ self._resolved_build_dir: Optional[Path] = None
+
+ @property
+ def is_deploy_needed(self) -> bool:
+ """Check if application requires data deployment."""
+ if isinstance(self.app, Application):
+ return (
+ len(self.app.get_deploy_data()) > 0 or len(self.custom_deploy_data) > 0
+ )
+ return False
+
+ @property
+ def is_locking_required(self) -> bool:
+ """Return true if any form of locking required."""
+ return not self._disable_locking() and (
+ self.app.lock or (self.system is not None and self.system.lock)
+ )
+
+ @property
+ def is_build_required(self) -> bool:
+ """Return true if application build required."""
+ return "build" in self.app.commands
+
+ @property
+ def is_unique_build_dir_required(self) -> bool:
+ """Return true if unique build dir required."""
+ return self.execution_params.get("unique_build_dir", False)
+
+ def build_dir(self) -> Path:
+ """Return resolved application build dir."""
+ if self._resolved_build_dir is not None:
+ return self._resolved_build_dir
+
+ if (
+ not isinstance(self.app.config_location, Path)
+ or not self.app.config_location.is_dir()
+ ):
+ raise ConfigurationException(
+ "Application {} has wrong config location".format(self.app.name)
+ )
+
+ _build_dir = self.app.build_dir
+ if _build_dir:
+ _build_dir = resolve_all_parameters(_build_dir, self.param_resolver)
+
+ if not _build_dir:
+ raise ConfigurationException(
+ "No build directory defined for the app {}".format(self.app.name)
+ )
+
+ if self.is_unique_build_dir_required:
+ random_suffix = "".join(
+ random.choices(string.ascii_lowercase + string.digits, k=7)
+ )
+ _build_dir = "{}_{}".format(_build_dir, random_suffix)
+
+ self._resolved_build_dir = self.app.config_location / _build_dir
+ return self._resolved_build_dir
+
+ def _disable_locking(self) -> bool:
+ """Return true if locking should be disabled."""
+ return self.execution_params.get("disable_locking", False)
+
+
+class ParamResolver:
+ """Parameter resolver."""
+
+ def __init__(self, context: ExecutionContext):
+ """Init parameter resolver."""
+ self.ctx = context
+
+ @staticmethod
+ def resolve_user_params(
+ cmd_name: Optional[str],
+ index_or_alias: str,
+ resolved_params: Optional[List[Tuple[Optional[str], Param]]],
+ ) -> str:
+ """Resolve user params."""
+ if not cmd_name or resolved_params is None:
+ raise ConfigurationException("Unable to resolve user params")
+
+ param_value: Optional[str] = None
+ param: Optional[Param] = None
+
+ if index_or_alias.isnumeric():
+ i = int(index_or_alias)
+ if i not in range(len(resolved_params)):
+ raise ConfigurationException(
+ "Invalid index {} for user params of command {}".format(i, cmd_name)
+ )
+ param_value, param = resolved_params[i]
+ else:
+ for val, par in resolved_params:
+ if par.alias == index_or_alias:
+ param_value, param = val, par
+ break
+
+ if param is None:
+ raise ConfigurationException(
+ "No user parameter for command '{}' with alias '{}'.".format(
+ cmd_name, index_or_alias
+ )
+ )
+
+ if param_value:
+ # We need to handle to cases of parameters here:
+ # 1) Optional parameters (non-positional with a name and value)
+ # 2) Positional parameters (value only, no name needed)
+ # Default to empty strings for positional arguments
+ param_name = ""
+ separator = ""
+ if param.name is not None:
+ # A valid param name means we have an optional/non-positional argument:
+ # The separator is an empty string in case the param_name
+ # has an equal sign as we have to honour it.
+ # If the parameter doesn't end with an equal sign then a
+ # space character is injected to split the parameter name
+ # and its value
+ param_name = param.name
+ separator = "" if param.name.endswith("=") else " "
+
+ return "{param_name}{separator}{param_value}".format(
+ param_name=param_name,
+ separator=separator,
+ param_value=param_value,
+ )
+
+ if param.name is None:
+ raise ConfigurationException(
+ "Missing user parameter with alias '{}' for command '{}'.".format(
+ index_or_alias, cmd_name
+ )
+ )
+
+ return param.name # flag: just return the parameter name
+
+ def resolve_commands_and_params(
+ self, backend_type: str, cmd_name: str, return_params: bool, index_or_alias: str
+ ) -> str:
+ """Resolve command or command's param value."""
+ if backend_type == "system":
+ backend = cast(Backend, self.ctx.system)
+ backend_params = self.ctx.system_params
+ else: # Application or Tool backend
+ backend = cast(Backend, self.ctx.app)
+ backend_params = self.ctx.app_params
+
+ if cmd_name not in backend.commands:
+ raise ConfigurationException("Command {} not found".format(cmd_name))
+
+ if return_params:
+ params = backend.resolved_parameters(cmd_name, backend_params)
+ if index_or_alias.isnumeric():
+ i = int(index_or_alias)
+ if i not in range(len(params)):
+ raise ConfigurationException(
+ "Invalid parameter index {} for command {}".format(i, cmd_name)
+ )
+
+ param_value = params[i][0]
+ else:
+ param_value = None
+ for value, param in params:
+ if param.alias == index_or_alias:
+ param_value = value
+ break
+
+ if not param_value:
+ raise ConfigurationException(
+ (
+ "No value for parameter with index or alias {} of command {}"
+ ).format(index_or_alias, cmd_name)
+ )
+ return param_value
+
+ if not index_or_alias.isnumeric():
+ raise ConfigurationException("Bad command index {}".format(index_or_alias))
+
+ i = int(index_or_alias)
+ commands = backend.build_command(cmd_name, backend_params, self.param_resolver)
+ if i not in range(len(commands)):
+ raise ConfigurationException(
+ "Invalid index {} for command {}".format(i, cmd_name)
+ )
+
+ return commands[i]
+
+ def resolve_variables(self, backend_type: str, var_name: str) -> str:
+ """Resolve variable value."""
+ if backend_type == "system":
+ backend = cast(Backend, self.ctx.system)
+ else: # Application or Tool backend
+ backend = cast(Backend, self.ctx.app)
+
+ if var_name not in backend.variables:
+ raise ConfigurationException("Unknown variable {}".format(var_name))
+
+ return backend.variables[var_name]
+
+ def param_matcher(
+ self,
+ param_name: str,
+ cmd_name: Optional[str],
+ resolved_params: Optional[List[Tuple[Optional[str], Param]]],
+ ) -> str:
+ """Regexp to resolve a param from the param_name."""
+ # this pattern supports parameter names like "application.commands.run:0" and
+ # "system.commands.run.params:0"
+ # Note: 'software' is included for backward compatibility.
+ commands_and_params_match = re.match(
+ r"(?P<type>application|software|tool|system)[.]commands[.]"
+ r"(?P<name>\w+)"
+ r"(?P<params>[.]params|)[:]"
+ r"(?P<index_or_alias>\w+)",
+ param_name,
+ )
+
+ if commands_and_params_match:
+ backend_type, cmd_name, return_params, index_or_alias = (
+ commands_and_params_match["type"],
+ commands_and_params_match["name"],
+ commands_and_params_match["params"],
+ commands_and_params_match["index_or_alias"],
+ )
+ return self.resolve_commands_and_params(
+ backend_type, cmd_name, bool(return_params), index_or_alias
+ )
+
+ # Note: 'software' is included for backward compatibility.
+ variables_match = re.match(
+ r"(?P<type>application|software|tool|system)[.]variables:(?P<var_name>\w+)",
+ param_name,
+ )
+ if variables_match:
+ backend_type, var_name = (
+ variables_match["type"],
+ variables_match["var_name"],
+ )
+ return self.resolve_variables(backend_type, var_name)
+
+ user_params_match = re.match(r"user_params:(?P<index_or_alias>\w+)", param_name)
+ if user_params_match:
+ index_or_alias = user_params_match["index_or_alias"]
+ return self.resolve_user_params(cmd_name, index_or_alias, resolved_params)
+
+ raise ConfigurationException(
+ "Unable to resolve parameter {}".format(param_name)
+ )
+
+ def param_resolver(
+ self,
+ param_name: str,
+ cmd_name: Optional[str] = None,
+ resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None,
+ ) -> str:
+ """Resolve parameter value based on current execution context."""
+ # Note: 'software.*' is included for backward compatibility.
+ resolved_param = None
+ if param_name in ["application.name", "tool.name", "software.name"]:
+ resolved_param = self.ctx.app.name
+ elif param_name in [
+ "application.description",
+ "tool.description",
+ "software.description",
+ ]:
+ resolved_param = self.ctx.app.description
+ elif self.ctx.app.config_location and (
+ param_name
+ in ["application.config_dir", "tool.config_dir", "software.config_dir"]
+ ):
+ resolved_param = str(self.ctx.app.config_location.absolute())
+ elif self.ctx.app.build_dir and (
+ param_name
+ in ["application.build_dir", "tool.build_dir", "software.build_dir"]
+ ):
+ resolved_param = str(self.ctx.build_dir().absolute())
+ elif self.ctx.system is not None:
+ if param_name == "system.name":
+ resolved_param = self.ctx.system.name
+ elif param_name == "system.description":
+ resolved_param = self.ctx.system.description
+ elif param_name == "system.config_dir" and self.ctx.system.config_location:
+ resolved_param = str(self.ctx.system.config_location.absolute())
+
+ if not resolved_param:
+ resolved_param = self.param_matcher(param_name, cmd_name, resolved_params)
+ return resolved_param
+
+ def __call__(
+ self,
+ param_name: str,
+ cmd_name: Optional[str] = None,
+ resolved_params: Optional[List[Tuple[Optional[str], Param]]] = None,
+ ) -> str:
+ """Resolve provided parameter."""
+ return self.param_resolver(param_name, cmd_name, resolved_params)
+
+
+class Reporter:
+ """Report metrics from the simulation output."""
+
+ def __init__(self, parsers: Optional[List[OutputParser]] = None) -> None:
+ """Create an empty reporter (i.e. no parsers registered)."""
+ self.parsers: List[OutputParser] = parsers if parsers is not None else []
+ self._report: Dict[str, Any] = defaultdict(lambda: defaultdict(dict))
+
+ def parse(self, output: bytearray) -> None:
+ """Parse output and append parsed metrics to internal report dict."""
+ for parser in self.parsers:
+ # Merge metrics from different parsers (do not overwrite)
+ self._report[parser.name]["metrics"].update(parser(output))
+
+ def get_filtered_output(self, output: bytearray) -> bytearray:
+ """Filter the output according to each parser."""
+ for parser in self.parsers:
+ output = parser.filter_out_parsed_content(output)
+ return output
+
+ def report(self, ctx: ExecutionContext) -> Dict[str, Any]:
+ """Add static simulation info to parsed data and return the report."""
+ report: Dict[str, Any] = defaultdict(dict)
+ # Add static simulation info
+ report.update(self._static_info(ctx))
+ # Add metrics parsed from the output
+ for key, val in self._report.items():
+ report[key].update(val)
+ return report
+
+ @staticmethod
+ def save(report: Dict[str, Any], report_file: Path) -> None:
+ """Save the report to a JSON file."""
+ with open(report_file, "w", encoding="utf-8") as file:
+ json.dump(report, file, indent=4)
+
+ @staticmethod
+ def _compute_all_params(cli_params: List[str], backend: Backend) -> Dict[str, str]:
+ """
+ Build a dict of all parameters, {name:value}.
+
+ Param values taken from command line if specified, defaults otherwise.
+ """
+ # map of params passed from the cli ["p1=v1","p2=v2"] -> {"p1":"v1", "p2":"v2"}
+ app_params_map = dict(parse_raw_parameter(expr) for expr in cli_params)
+
+ # a map of params declared in the application, with values taken from the CLI,
+ # defaults otherwise
+ all_params = {
+ (p.alias or p.name): app_params_map.get(
+ cast(str, p.name), cast(str, p.default_value)
+ )
+ for cmd in backend.commands.values()
+ for p in cmd.params
+ }
+ return cast(Dict[str, str], all_params)
+
+ @staticmethod
+ def _static_info(ctx: ExecutionContext) -> Dict[str, Any]:
+ """Extract static simulation information from the context."""
+ if ctx.system is None:
+ raise ValueError("No system available to report.")
+
+ info = {
+ "system": {
+ "name": ctx.system.name,
+ "params": Reporter._compute_all_params(ctx.system_params, ctx.system),
+ },
+ "application": {
+ "name": ctx.app.name,
+ "params": Reporter._compute_all_params(ctx.app_params, ctx.app),
+ },
+ }
+ return info
+
+
+def validate_parameters(
+ backend: Backend, command_names: List[str], params: List[str]
+) -> None:
+ """Check parameters passed to backend."""
+ for param in params:
+ acceptable = any(
+ backend.validate_parameter(command_name, param)
+ for command_name in command_names
+ if command_name in backend.commands
+ )
+
+ if not acceptable:
+ backend_type = "System" if isinstance(backend, System) else "Application"
+ raise ValueError(
+ "{} parameter '{}' not valid for command '{}'".format(
+ backend_type, param, " or ".join(command_names)
+ )
+ )
+
+
+def get_application_by_name_and_system(
+ application_name: str, system_name: str
+) -> Application:
+ """Get application."""
+ applications = get_application(application_name, system_name)
+ if not applications:
+ raise ValueError(
+ "Application '{}' doesn't support the system '{}'".format(
+ application_name, system_name
+ )
+ )
+
+ if len(applications) != 1:
+ raise ValueError(
+ "Error during getting application {} for the system {}".format(
+ application_name, system_name
+ )
+ )
+
+ return applications[0]
+
+
+def get_application_and_system(
+ application_name: str, system_name: str
+) -> Tuple[Application, System]:
+ """Return application and system by provided names."""
+ system = get_system(system_name)
+ if not system:
+ raise ValueError("System {} is not found".format(system_name))
+
+ application = get_application_by_name_and_system(application_name, system_name)
+
+ return application, system
+
+
+def execute_application_command( # pylint: disable=too-many-arguments
+ command_name: str,
+ application_name: str,
+ application_params: List[str],
+ system_name: str,
+ system_params: List[str],
+ custom_deploy_data: List[DataPaths],
+) -> None:
+ """Execute application command.
+
+ .. deprecated:: 21.12
+ """
+ warnings.warn(
+ "Use 'run_application()' instead. Use of 'execute_application_command()' is "
+ "deprecated and might be removed in a future release.",
+ DeprecationWarning,
+ )
+
+ if command_name not in ["build", "run"]:
+ raise ConfigurationException("Unsupported command {}".format(command_name))
+
+ application, system = get_application_and_system(application_name, system_name)
+ validate_parameters(application, [command_name], application_params)
+ validate_parameters(system, [command_name], system_params)
+
+ ctx = ExecutionContext(
+ app=application,
+ app_params=application_params,
+ system=system,
+ system_params=system_params,
+ custom_deploy_data=custom_deploy_data,
+ )
+
+ if command_name == "run":
+ execute_application_command_run(ctx)
+ else:
+ execute_application_command_build(ctx)
+
+
+# pylint: disable=too-many-arguments
+def run_application(
+ application_name: str,
+ application_params: List[str],
+ system_name: str,
+ system_params: List[str],
+ custom_deploy_data: List[DataPaths],
+ report_file: Optional[Path] = None,
+) -> None:
+ """Run application on the provided system."""
+ application, system = get_application_and_system(application_name, system_name)
+ validate_parameters(application, ["build", "run"], application_params)
+ validate_parameters(system, ["build", "run"], system_params)
+
+ execution_params = ExecutionParams()
+ if isinstance(system, StandaloneSystem):
+ execution_params["disable_locking"] = True
+ execution_params["unique_build_dir"] = True
+
+ ctx = ExecutionContext(
+ app=application,
+ app_params=application_params,
+ system=system,
+ system_params=system_params,
+ custom_deploy_data=custom_deploy_data,
+ execution_params=execution_params,
+ report_file=report_file,
+ )
+
+ with build_dir_manager(ctx):
+ if ctx.is_build_required:
+ execute_application_command_build(ctx)
+
+ execute_application_command_run(ctx)
+
+
+def execute_application_command_build(ctx: ExecutionContext) -> None:
+ """Execute application command 'build'."""
+ with ExitStack() as context_stack:
+ for manager in get_context_managers("build", ctx):
+ context_stack.enter_context(manager(ctx))
+
+ build_dir = ctx.build_dir()
+ recreate_directory(build_dir)
+
+ build_commands = ctx.app.build_command(
+ "build", ctx.app_params, ctx.param_resolver
+ )
+ execute_commands_locally(build_commands, build_dir)
+
+
+def execute_commands_locally(commands: List[str], cwd: Path) -> None:
+ """Execute list of commands locally."""
+ for command in commands:
+ print("Running: {}".format(command))
+ run_and_wait(
+ command, cwd, terminate_on_error=True, out=sys.stdout, err=sys.stderr
+ )
+
+
+def execute_application_command_run(ctx: ExecutionContext) -> None:
+ """Execute application command."""
+ assert ctx.system is not None, "System must be provided."
+ if ctx.is_deploy_needed and not ctx.system.supports_deploy:
+ raise ConfigurationException(
+ "System {} does not support data deploy".format(ctx.system.name)
+ )
+
+ with ExitStack() as context_stack:
+ for manager in get_context_managers("run", ctx):
+ context_stack.enter_context(manager(ctx))
+
+ print("Generating commands to execute")
+ commands_to_run = build_run_commands(ctx)
+
+ if ctx.system.connectable:
+ establish_connection(ctx)
+
+ if ctx.system.supports_deploy:
+ deploy_data(ctx)
+
+ for command in commands_to_run:
+ print("Running: {}".format(command))
+ exit_code, std_output, std_err = ctx.system.run(command)
+
+ if exit_code != 0:
+ print("Application exited with exit code {}".format(exit_code))
+
+ if ctx.reporter:
+ ctx.reporter.parse(std_output)
+ std_output = ctx.reporter.get_filtered_output(std_output)
+
+ print(std_output.decode("utf8"), end="")
+ print(std_err.decode("utf8"), end="")
+
+ if ctx.reporter:
+ report = ctx.reporter.report(ctx)
+ ctx.reporter.save(report, cast(Path, ctx.report_file))
+
+
+def establish_connection(
+ ctx: ExecutionContext, retries: int = 90, interval: float = 15.0
+) -> None:
+ """Establish connection with the system."""
+ assert ctx.system is not None, "System is required."
+ host, port = ctx.system.connection_details()
+ print(
+ "Trying to establish connection with '{}:{}' - "
+ "{} retries every {} seconds ".format(host, port, retries, interval),
+ end="",
+ )
+
+ try:
+ for _ in range(retries):
+ print(".", end="", flush=True)
+
+ if ctx.system.establish_connection():
+ break
+
+ if isinstance(ctx.system, ControlledSystem) and not ctx.system.is_running():
+ print(
+ "\n\n---------- {} execution failed ----------".format(
+ ctx.system.name
+ )
+ )
+ stdout, stderr = ctx.system.get_output()
+ print(stdout)
+ print(stderr)
+
+ raise Exception("System is not running")
+
+ wait(interval)
+ else:
+ raise ConnectionException("Couldn't connect to '{}:{}'.".format(host, port))
+ finally:
+ print()
+
+
+def wait(interval: float) -> None:
+ """Wait for a period of time."""
+ time.sleep(interval)
+
+
+def deploy_data(ctx: ExecutionContext) -> None:
+ """Deploy data to the system."""
+ if isinstance(ctx.app, Application):
+ # Only application can deploy data (tools can not)
+ assert ctx.system is not None, "System is required."
+ for item in itertools.chain(ctx.app.get_deploy_data(), ctx.custom_deploy_data):
+ print("Deploying {} onto {}".format(item.src, item.dst))
+ ctx.system.deploy(item.src, item.dst)
+
+
+def build_run_commands(ctx: ExecutionContext) -> List[str]:
+ """Build commands to run application."""
+ if isinstance(ctx.system, StandaloneSystem):
+ return ctx.system.build_command("run", ctx.system_params, ctx.param_resolver)
+
+ return ctx.app.build_command("run", ctx.app_params, ctx.param_resolver)
+
+
+@contextmanager
+def controlled_system_manager(ctx: ExecutionContext) -> Generator[None, None, None]:
+ """Context manager used for system initialisation before run."""
+ system = cast(ControlledSystem, ctx.system)
+ commands = system.build_command("run", ctx.system_params, ctx.param_resolver)
+ pid_file_path: Optional[Path] = None
+ if ctx.is_locking_required:
+ file_lock_path = get_file_lock_path(ctx)
+ pid_file_path = file_lock_path.parent / "{}.pid".format(file_lock_path.stem)
+
+ system.start(commands, ctx.is_locking_required, pid_file_path)
+ try:
+ yield
+ finally:
+ print("Shutting down sequence...")
+ print("Stopping {}... (It could take few seconds)".format(system.name))
+ system.stop(wait=True)
+ print("{} stopped successfully.".format(system.name))
+
+
+@contextmanager
+def lock_execution_manager(ctx: ExecutionContext) -> Generator[None, None, None]:
+ """Lock execution manager."""
+ file_lock_path = get_file_lock_path(ctx)
+ file_lock = FileLock(str(file_lock_path))
+
+ try:
+ file_lock.acquire(timeout=1)
+ except Timeout as error:
+ raise AnotherInstanceIsRunningException() from error
+
+ try:
+ yield
+ finally:
+ file_lock.release()
+
+
+def get_file_lock_path(ctx: ExecutionContext, lock_dir: Path = Path("/tmp")) -> Path:
+ """Get file lock path."""
+ lock_modules = []
+ if ctx.app.lock:
+ lock_modules.append(ctx.app.name)
+ if ctx.system is not None and ctx.system.lock:
+ lock_modules.append(ctx.system.name)
+ lock_filename = ""
+ if lock_modules:
+ lock_filename = "_".join(["middleware"] + lock_modules) + ".lock"
+
+ if lock_filename:
+ lock_filename = resolve_all_parameters(lock_filename, ctx.param_resolver)
+ lock_filename = valid_for_filename(lock_filename)
+
+ if not lock_filename:
+ raise ConfigurationException("No filename for lock provided")
+
+ if not isinstance(lock_dir, Path) or not lock_dir.is_dir():
+ raise ConfigurationException(
+ "Invalid directory {} for lock files provided".format(lock_dir)
+ )
+
+ return lock_dir / lock_filename
+
+
+@contextmanager
+def build_dir_manager(ctx: ExecutionContext) -> Generator[None, None, None]:
+ """Build directory manager."""
+ try:
+ yield
+ finally:
+ if (
+ ctx.is_build_required
+ and ctx.is_unique_build_dir_required
+ and ctx.build_dir().is_dir()
+ ):
+ remove_directory(ctx.build_dir())
+
+
+def get_context_managers(
+ command_name: str, ctx: ExecutionContext
+) -> Sequence[Callable[[ExecutionContext], ContextManager[None]]]:
+ """Get context manager for the system."""
+ managers = []
+
+ if ctx.is_locking_required:
+ managers.append(lock_execution_manager)
+
+ if command_name == "run":
+ if isinstance(ctx.system, ControlledSystem):
+ managers.append(controlled_system_manager)
+
+ return managers
+
+
+def get_tool_by_system(tool_name: str, system_name: Optional[str]) -> Tool:
+ """Return tool (optionally by provided system name."""
+ tools = get_tool(tool_name, system_name)
+ if not tools:
+ raise ConfigurationException(
+ "Tool '{}' not found or doesn't support the system '{}'".format(
+ tool_name, system_name
+ )
+ )
+ if len(tools) != 1:
+ raise ConfigurationException(
+ "Please specify the system for tool {}.".format(tool_name)
+ )
+ tool = tools[0]
+
+ return tool
+
+
+def execute_tool_command(
+ tool_name: str,
+ tool_params: List[str],
+ system_name: Optional[str] = None,
+) -> None:
+ """Execute the tool command locally calling the 'run' command."""
+ tool = get_tool_by_system(tool_name, system_name)
+ ctx = ExecutionContext(
+ app=tool, app_params=tool_params, system=None, system_params=[]
+ )
+ commands = tool.build_command("run", tool_params, ctx.param_resolver)
+
+ execute_commands_locally(commands, Path.cwd())
diff --git a/src/aiet/backend/output_parser.py b/src/aiet/backend/output_parser.py
new file mode 100644
index 0000000..111772a
--- /dev/null
+++ b/src/aiet/backend/output_parser.py
@@ -0,0 +1,176 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Definition of output parsers (including base class OutputParser)."""
+import base64
+import json
+import re
+from abc import ABC
+from abc import abstractmethod
+from typing import Any
+from typing import Dict
+from typing import Union
+
+
+class OutputParser(ABC):
+ """Abstract base class for output parsers."""
+
+ def __init__(self, name: str) -> None:
+ """Set up the name of the parser."""
+ super().__init__()
+ self.name = name
+
+ @abstractmethod
+ def __call__(self, output: bytearray) -> Dict[str, Any]:
+ """Parse the output and return a map of names to metrics."""
+ return {}
+
+ # pylint: disable=no-self-use
+ def filter_out_parsed_content(self, output: bytearray) -> bytearray:
+ """
+ Filter out the parsed content from the output.
+
+ Does nothing by default. Can be overridden in subclasses.
+ """
+ return output
+
+
+class RegexOutputParser(OutputParser):
+ """Parser of standard output data using regular expressions."""
+
+ _TYPE_MAP = {"str": str, "float": float, "int": int}
+
+ def __init__(
+ self,
+ name: str,
+ regex_config: Dict[str, Dict[str, str]],
+ ) -> None:
+ """
+ Set up the parser with the regular expressions.
+
+ The regex_config is mapping from a name to a dict with keys 'pattern'
+ and 'type':
+ - The 'pattern' holds the regular expression that must contain exactly
+ one capturing parenthesis
+ - The 'type' can be one of ['str', 'float', 'int'].
+
+ Example:
+ ```
+ {"Metric1": {"pattern": ".*= *(.*)", "type": "str"}}
+ ```
+
+ The different regular expressions from the config are combined using
+ non-capturing parenthesis, i.e. regular expressions must not overlap
+ if more than one match per line is expected.
+ """
+ super().__init__(name)
+
+ self._verify_config(regex_config)
+ self._regex_cfg = regex_config
+
+ # Compile regular expression to match in the output
+ self._regex = re.compile(
+ "|".join("(?:{0})".format(x["pattern"]) for x in self._regex_cfg.values())
+ )
+
+ def __call__(self, output: bytearray) -> Dict[str, Union[str, float, int]]:
+ """
+ Parse the output and return a map of names to metrics.
+
+ Example:
+ Assuming a regex_config as used as example in `__init__()` and the
+ following output:
+ ```
+ Simulation finished:
+ SIMULATION_STATUS = SUCCESS
+ Simulation DONE
+ ```
+ Then calling the parser should return the following dict:
+ ```
+ {
+ "Metric1": "SUCCESS"
+ }
+ ```
+ """
+ metrics = {}
+ output_str = output.decode("utf-8")
+ results = self._regex.findall(output_str)
+ for line_result in results:
+ for idx, (name, cfg) in enumerate(self._regex_cfg.items()):
+ # The result(s) returned by findall() are either a single string
+ # or a tuple (depending on the number of groups etc.)
+ result = (
+ line_result if isinstance(line_result, str) else line_result[idx]
+ )
+ if result:
+ mapped_result = self._TYPE_MAP[cfg["type"]](result)
+ metrics[name] = mapped_result
+ return metrics
+
+ def _verify_config(self, regex_config: Dict[str, Dict[str, str]]) -> None:
+ """Make sure we have a valid regex_config.
+
+ I.e.
+ - Exactly one capturing parenthesis per pattern
+ - Correct types
+ """
+ for name, cfg in regex_config.items():
+ # Check that there is one capturing group defined in the pattern.
+ regex = re.compile(cfg["pattern"])
+ if regex.groups != 1:
+ raise ValueError(
+ f"Pattern for metric '{name}' must have exactly one "
+ f"capturing parenthesis, but it has {regex.groups}."
+ )
+ # Check if type is supported
+ if not cfg["type"] in self._TYPE_MAP:
+ raise TypeError(
+ f"Type '{cfg['type']}' for metric '{name}' is not "
+ f"supported. Choose from: {list(self._TYPE_MAP.keys())}."
+ )
+
+
+class Base64OutputParser(OutputParser):
+ """
+ Parser to extract base64-encoded JSON from tagged standard output.
+
+ Example of the tagged output:
+ ```
+ # Encoded JSON: {"test": 1234}
+ <metrics>eyJ0ZXN0IjogMTIzNH0</metrics>
+ ```
+ """
+
+ TAG_NAME = "metrics"
+
+ def __init__(self, name: str) -> None:
+ """Set up the regular expression to extract tagged strings."""
+ super().__init__(name)
+ self._regex = re.compile(rf"<{self.TAG_NAME}>(.*)</{self.TAG_NAME}>")
+
+ def __call__(self, output: bytearray) -> Dict[str, Any]:
+ """
+ Parse the output and return a map of index (as string) to decoded JSON.
+
+ Example:
+ Using the tagged output from the class docs the parser should return
+ the following dict:
+ ```
+ {
+ "0": {"test": 1234}
+ }
+ ```
+ """
+ metrics = {}
+ output_str = output.decode("utf-8")
+ results = self._regex.findall(output_str)
+ for idx, result_base64 in enumerate(results):
+ result_json = base64.b64decode(result_base64, validate=True)
+ result = json.loads(result_json)
+ metrics[str(idx)] = result
+
+ return metrics
+
+ def filter_out_parsed_content(self, output: bytearray) -> bytearray:
+ """Filter out base64-encoded content from the output."""
+ output_str = self._regex.sub("", output.decode("utf-8"))
+ return bytearray(output_str.encode("utf-8"))
diff --git a/src/aiet/backend/protocol.py b/src/aiet/backend/protocol.py
new file mode 100644
index 0000000..c621436
--- /dev/null
+++ b/src/aiet/backend/protocol.py
@@ -0,0 +1,325 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contain protocol related classes and functions."""
+from abc import ABC
+from abc import abstractmethod
+from contextlib import closing
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import Iterable
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import paramiko
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.config import SSHConfig
+from aiet.utils.proc import run_and_wait
+
+
+# Redirect all paramiko thread exceptions to a file otherwise these will be
+# printed to stderr.
+paramiko.util.log_to_file("/tmp/main_paramiko_log.txt", level=paramiko.common.INFO)
+
+
+class SSHConnectionException(Exception):
+ """SSH connection exception."""
+
+
+class SupportsClose(ABC):
+ """Class indicates support of close operation."""
+
+ @abstractmethod
+ def close(self) -> None:
+ """Close protocol session."""
+
+
+class SupportsDeploy(ABC):
+ """Class indicates support of deploy operation."""
+
+ @abstractmethod
+ def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
+ """Abstract method for deploy data."""
+
+
+class SupportsConnection(ABC):
+ """Class indicates that protocol uses network connections."""
+
+ @abstractmethod
+ def establish_connection(self) -> bool:
+ """Establish connection with underlying system."""
+
+ @abstractmethod
+ def connection_details(self) -> Tuple[str, int]:
+ """Return connection details (host, port)."""
+
+
+class Protocol(ABC):
+ """Abstract class for representing the protocol."""
+
+ def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None:
+ """Initialize the class using a dict."""
+ self.__dict__.update(iterable, **kwargs)
+ self._validate()
+
+ @abstractmethod
+ def _validate(self) -> None:
+ """Abstract method for config data validation."""
+
+ @abstractmethod
+ def run(
+ self, command: str, retry: bool = False
+ ) -> Tuple[int, bytearray, bytearray]:
+ """
+ Abstract method for running commands.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+
+
+class CustomSFTPClient(paramiko.SFTPClient):
+ """Class for creating a custom sftp client."""
+
+ def put_dir(self, source: Path, target: str) -> None:
+ """Upload the source directory to the target path.
+
+ The target directory needs to exists and the last directory of the
+ source will be created under the target with all its content.
+ """
+ # Create the target directory
+ self._mkdir(target, ignore_existing=True)
+ # Create the last directory in the source on the target
+ self._mkdir("{}/{}".format(target, source.name), ignore_existing=True)
+ # Go through the whole content of source
+ for item in sorted(source.glob("**/*")):
+ relative_path = item.relative_to(source.parent)
+ remote_target = target / relative_path
+ if item.is_file():
+ self.put(str(item), str(remote_target))
+ else:
+ self._mkdir(str(remote_target), ignore_existing=True)
+
+ def _mkdir(self, path: str, mode: int = 511, ignore_existing: bool = False) -> None:
+ """Extend mkdir functionality.
+
+ This version adds an option to not fail if the folder exists.
+ """
+ try:
+ super().mkdir(path, mode)
+ except IOError as error:
+ if ignore_existing:
+ pass
+ else:
+ raise error
+
+
+class LocalProtocol(Protocol):
+ """Class for local protocol."""
+
+ protocol: str
+ cwd: Path
+
+ def run(
+ self, command: str, retry: bool = False
+ ) -> Tuple[int, bytearray, bytearray]:
+ """
+ Run command locally.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+ if not isinstance(self.cwd, Path) or not self.cwd.is_dir():
+ raise ConfigurationException("Wrong working directory {}".format(self.cwd))
+
+ stdout = bytearray()
+ stderr = bytearray()
+
+ return run_and_wait(
+ command, self.cwd, terminate_on_error=True, out=stdout, err=stderr
+ )
+
+ def _validate(self) -> None:
+ """Validate protocol configuration."""
+ assert hasattr(self, "protocol") and self.protocol == "local"
+ assert hasattr(self, "cwd")
+
+
+class SSHProtocol(Protocol, SupportsClose, SupportsDeploy, SupportsConnection):
+ """Class for SSH protocol."""
+
+ protocol: str
+ username: str
+ password: str
+ hostname: str
+ port: int
+
+ def __init__(self, iterable: Iterable = (), **kwargs: Any) -> None:
+ """Initialize the class using a dict."""
+ super().__init__(iterable, **kwargs)
+ # Internal state to store if the system is connectable. It will be set
+ # to true at the first connection instance
+ self.client: Optional[paramiko.client.SSHClient] = None
+ self.port = int(self.port)
+
+ def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]:
+ """
+ Run command over SSH.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+ transport = self._get_transport()
+ with closing(transport.open_session()) as channel:
+ # Enable shell's .profile settings and execute command
+ channel.exec_command("bash -l -c '{}'".format(command))
+ exit_status = -1
+ stdout = bytearray()
+ stderr = bytearray()
+ while True:
+ if channel.exit_status_ready():
+ exit_status = channel.recv_exit_status()
+ # Call it one last time to read any leftover in the channel
+ self._recv_stdout_err(channel, stdout, stderr)
+ break
+ self._recv_stdout_err(channel, stdout, stderr)
+
+ return exit_status, stdout, stderr
+
+ def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
+ """Deploy src to remote dst over SSH.
+
+ src and dst should be path to a file or directory.
+ """
+ transport = self._get_transport()
+ sftp = cast(CustomSFTPClient, CustomSFTPClient.from_transport(transport))
+
+ with closing(sftp):
+ if src.is_dir():
+ sftp.put_dir(src, dst)
+ elif src.is_file():
+ sftp.put(str(src), dst)
+ else:
+ raise Exception("Deploy error: file type not supported")
+
+ # After the deployment of files, sync the remote filesystem to flush
+ # buffers to hard disk
+ self.run("sync")
+
+ def close(self) -> None:
+ """Close protocol session."""
+ if self.client is not None:
+ print("Try syncing remote file system...")
+ # Before stopping the system, we try to run sync to make sure all
+ # data are flushed on disk.
+ self.run("sync", retry=False)
+ self._close_client(self.client)
+
+ def establish_connection(self) -> bool:
+ """Establish connection with underlying system."""
+ if self.client is not None:
+ return True
+
+ self.client = self._connect()
+ return self.client is not None
+
+ def _get_transport(self) -> paramiko.transport.Transport:
+ """Get transport."""
+ self.establish_connection()
+
+ if self.client is None:
+ raise SSHConnectionException(
+ "Couldn't connect to '{}:{}'.".format(self.hostname, self.port)
+ )
+
+ transport = self.client.get_transport()
+ if not transport:
+ raise Exception("Unable to get transport")
+
+ return transport
+
+ def connection_details(self) -> Tuple[str, int]:
+ """Return connection details of underlying system."""
+ return (self.hostname, self.port)
+
+ def _connect(self) -> Optional[paramiko.client.SSHClient]:
+ """Try to establish connection."""
+ client: Optional[paramiko.client.SSHClient] = None
+ try:
+ client = paramiko.client.SSHClient()
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ client.connect(
+ self.hostname,
+ self.port,
+ self.username,
+ self.password,
+ # next parameters should be set to False to disable authentication
+ # using ssh keys
+ allow_agent=False,
+ look_for_keys=False,
+ )
+ return client
+ except (
+ # OSError raised on first attempt to connect when running inside Docker
+ OSError,
+ paramiko.ssh_exception.NoValidConnectionsError,
+ paramiko.ssh_exception.SSHException,
+ ):
+ # even if connection is not established socket could be still
+ # open, it should be closed
+ self._close_client(client)
+
+ return None
+
+ @staticmethod
+ def _close_client(client: Optional[paramiko.client.SSHClient]) -> None:
+ """Close ssh client."""
+ try:
+ if client is not None:
+ client.close()
+ except Exception: # pylint: disable=broad-except
+ pass
+
+ @classmethod
+ def _recv_stdout_err(
+ cls, channel: paramiko.channel.Channel, stdout: bytearray, stderr: bytearray
+ ) -> None:
+ """Read from channel to stdout/stder."""
+ chunk_size = 512
+ if channel.recv_ready():
+ stdout_chunk = channel.recv(chunk_size)
+ stdout.extend(stdout_chunk)
+ if channel.recv_stderr_ready():
+ stderr_chunk = channel.recv_stderr(chunk_size)
+ stderr.extend(stderr_chunk)
+
+ def _validate(self) -> None:
+ """Check if there are all the info for establishing the connection."""
+ assert hasattr(self, "protocol") and self.protocol == "ssh"
+ assert hasattr(self, "username")
+ assert hasattr(self, "password")
+ assert hasattr(self, "hostname")
+ assert hasattr(self, "port")
+
+
+class ProtocolFactory:
+ """Factory class to return the appropriate Protocol class."""
+
+ @staticmethod
+ def get_protocol(
+ config: Optional[Union[SSHConfig, LocalProtocolConfig]],
+ **kwargs: Union[str, Path, None]
+ ) -> Union[SSHProtocol, LocalProtocol]:
+ """Return the right protocol instance based on the config."""
+ if not config:
+ raise ValueError("No protocol config provided")
+
+ protocol = config["protocol"]
+ if protocol == "ssh":
+ return SSHProtocol(config)
+
+ if protocol == "local":
+ cwd = kwargs.get("cwd")
+ return LocalProtocol(config, cwd=cwd)
+
+ raise ValueError("Protocol not supported: '{}'".format(protocol))
diff --git a/src/aiet/backend/source.py b/src/aiet/backend/source.py
new file mode 100644
index 0000000..dec175a
--- /dev/null
+++ b/src/aiet/backend/source.py
@@ -0,0 +1,209 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Contain source related classes and functions."""
+import os
+import shutil
+import tarfile
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from tarfile import TarFile
+from typing import Optional
+from typing import Union
+
+from aiet.backend.common import AIET_CONFIG_FILE
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import get_backend_config
+from aiet.backend.common import is_backend_directory
+from aiet.backend.common import load_config
+from aiet.backend.config import BackendConfig
+from aiet.utils.fs import copy_directory_content
+
+
+class Source(ABC):
+ """Source class."""
+
+ @abstractmethod
+ def name(self) -> Optional[str]:
+ """Get source name."""
+
+ @abstractmethod
+ def config(self) -> Optional[BackendConfig]:
+ """Get configuration file content."""
+
+ @abstractmethod
+ def install_into(self, destination: Path) -> None:
+ """Install source into destination directory."""
+
+ @abstractmethod
+ def create_destination(self) -> bool:
+ """Return True if destination folder should be created before installation."""
+
+
+class DirectorySource(Source):
+ """DirectorySource class."""
+
+ def __init__(self, directory_path: Path) -> None:
+ """Create the DirectorySource instance."""
+ assert isinstance(directory_path, Path)
+ self.directory_path = directory_path
+
+ def name(self) -> str:
+ """Return name of source."""
+ return self.directory_path.name
+
+ def config(self) -> Optional[BackendConfig]:
+ """Return configuration file content."""
+ if not is_backend_directory(self.directory_path):
+ raise ConfigurationException("No configuration file found")
+
+ config_file = get_backend_config(self.directory_path)
+ return load_config(config_file)
+
+ def install_into(self, destination: Path) -> None:
+ """Install source into destination directory."""
+ if not destination.is_dir():
+ raise ConfigurationException("Wrong destination {}".format(destination))
+
+ if not self.directory_path.is_dir():
+ raise ConfigurationException(
+ "Directory {} does not exist".format(self.directory_path)
+ )
+
+ copy_directory_content(self.directory_path, destination)
+
+ def create_destination(self) -> bool:
+ """Return True if destination folder should be created before installation."""
+ return True
+
+
+class TarArchiveSource(Source):
+ """TarArchiveSource class."""
+
+ def __init__(self, archive_path: Path) -> None:
+ """Create the TarArchiveSource class."""
+ assert isinstance(archive_path, Path)
+ self.archive_path = archive_path
+ self._config: Optional[BackendConfig] = None
+ self._has_top_level_folder: Optional[bool] = None
+ self._name: Optional[str] = None
+
+ def _read_archive_content(self) -> None:
+ """Read various information about archive."""
+ # get source name from archive name (everything without extensions)
+ extensions = "".join(self.archive_path.suffixes)
+ self._name = self.archive_path.name.rstrip(extensions)
+
+ if not self.archive_path.exists():
+ return
+
+ with self._open(self.archive_path) as archive:
+ try:
+ config_entry = archive.getmember(AIET_CONFIG_FILE)
+ self._has_top_level_folder = False
+ except KeyError as error_no_config:
+ try:
+ archive_entries = archive.getnames()
+ entries_common_prefix = os.path.commonprefix(archive_entries)
+ top_level_dir = entries_common_prefix.rstrip("/")
+
+ if not top_level_dir:
+ raise RuntimeError(
+ "Archive has no top level directory"
+ ) from error_no_config
+
+ config_path = "{}/{}".format(top_level_dir, AIET_CONFIG_FILE)
+
+ config_entry = archive.getmember(config_path)
+ self._has_top_level_folder = True
+ self._name = top_level_dir
+ except (KeyError, RuntimeError) as error_no_root_dir_or_config:
+ raise ConfigurationException(
+ "No configuration file found"
+ ) from error_no_root_dir_or_config
+
+ content = archive.extractfile(config_entry)
+ self._config = load_config(content)
+
+ def config(self) -> Optional[BackendConfig]:
+ """Return configuration file content."""
+ if self._config is None:
+ self._read_archive_content()
+
+ return self._config
+
+ def name(self) -> Optional[str]:
+ """Return name of the source."""
+ if self._name is None:
+ self._read_archive_content()
+
+ return self._name
+
+ def create_destination(self) -> bool:
+ """Return True if destination folder must be created before installation."""
+ if self._has_top_level_folder is None:
+ self._read_archive_content()
+
+ return not self._has_top_level_folder
+
+ def install_into(self, destination: Path) -> None:
+ """Install source into destination directory."""
+ if not destination.is_dir():
+ raise ConfigurationException("Wrong destination {}".format(destination))
+
+ with self._open(self.archive_path) as archive:
+ archive.extractall(destination)
+
+ def _open(self, archive_path: Path) -> TarFile:
+ """Open archive file."""
+ if not archive_path.is_file():
+ raise ConfigurationException("File {} does not exist".format(archive_path))
+
+ if archive_path.name.endswith("tar.gz") or archive_path.name.endswith("tgz"):
+ mode = "r:gz"
+ else:
+ raise ConfigurationException(
+ "Unsupported archive type {}".format(archive_path)
+ )
+
+ # The returned TarFile object can be used as a context manager (using
+ # 'with') by the calling instance.
+ return tarfile.open( # pylint: disable=consider-using-with
+ self.archive_path, mode=mode
+ )
+
+
+def get_source(source_path: Path) -> Union[TarArchiveSource, DirectorySource]:
+ """Return appropriate source instance based on provided source path."""
+ if source_path.is_file():
+ return TarArchiveSource(source_path)
+
+ if source_path.is_dir():
+ return DirectorySource(source_path)
+
+ raise ConfigurationException("Unable to read {}".format(source_path))
+
+
+def create_destination_and_install(source: Source, resource_path: Path) -> None:
+ """Create destination directory and install source.
+
+ This function is used for actual installation of system/backend New
+ directory will be created inside :resource_path: if needed If for example
+ archive contains top level folder then no need to create new directory
+ """
+ destination = resource_path
+ create_destination = source.create_destination()
+
+ if create_destination:
+ name = source.name()
+ if not name:
+ raise ConfigurationException("Unable to get source name")
+
+ destination = resource_path / name
+ destination.mkdir()
+ try:
+ source.install_into(destination)
+ except Exception as error:
+ if create_destination:
+ shutil.rmtree(destination)
+ raise error
diff --git a/src/aiet/backend/system.py b/src/aiet/backend/system.py
new file mode 100644
index 0000000..48f1bb1
--- /dev/null
+++ b/src/aiet/backend/system.py
@@ -0,0 +1,289 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""System backend module."""
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+from aiet.backend.common import Backend
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import get_backend_configs
+from aiet.backend.common import get_backend_directories
+from aiet.backend.common import load_config
+from aiet.backend.common import remove_backend
+from aiet.backend.config import SystemConfig
+from aiet.backend.controller import SystemController
+from aiet.backend.controller import SystemControllerSingleInstance
+from aiet.backend.protocol import ProtocolFactory
+from aiet.backend.protocol import SupportsClose
+from aiet.backend.protocol import SupportsConnection
+from aiet.backend.protocol import SupportsDeploy
+from aiet.backend.source import create_destination_and_install
+from aiet.backend.source import get_source
+from aiet.utils.fs import get_resources
+
+
+def get_available_systems_directory_names() -> List[str]:
+ """Return a list of directory names for all avialable systems."""
+ return [entry.name for entry in get_backend_directories("systems")]
+
+
+def get_available_systems() -> List["System"]:
+ """Return a list with all available systems."""
+ available_systems = []
+ for config_json in get_backend_configs("systems"):
+ config_entries = cast(List[SystemConfig], (load_config(config_json)))
+ for config_entry in config_entries:
+ config_entry["config_location"] = config_json.parent.absolute()
+ system = load_system(config_entry)
+ available_systems.append(system)
+
+ return sorted(available_systems, key=lambda system: system.name)
+
+
+def get_system(system_name: str) -> Optional["System"]:
+ """Return a system instance with the same name passed as argument."""
+ available_systems = get_available_systems()
+ for system in available_systems:
+ if system_name == system.name:
+ return system
+ return None
+
+
+def install_system(source_path: Path) -> None:
+ """Install new system."""
+ try:
+ source = get_source(source_path)
+ config = cast(List[SystemConfig], source.config())
+ systems_to_install = [load_system(entry) for entry in config]
+ except Exception as error:
+ raise ConfigurationException("Unable to read system definition") from error
+
+ if not systems_to_install:
+ raise ConfigurationException("No system definition found")
+
+ available_systems = get_available_systems()
+ already_installed = [s for s in systems_to_install if s in available_systems]
+ if already_installed:
+ names = [system.name for system in already_installed]
+ raise ConfigurationException(
+ "Systems [{}] are already installed".format(",".join(names))
+ )
+
+ create_destination_and_install(source, get_resources("systems"))
+
+
+def remove_system(directory_name: str) -> None:
+ """Remove system."""
+ remove_backend(directory_name, "systems")
+
+
+class System(Backend):
+ """System class."""
+
+ def __init__(self, config: SystemConfig) -> None:
+ """Construct the System class using the dictionary passed."""
+ super().__init__(config)
+
+ self._setup_data_transfer(config)
+ self._setup_reporting(config)
+
+ def _setup_data_transfer(self, config: SystemConfig) -> None:
+ data_transfer_config = config.get("data_transfer")
+ protocol = ProtocolFactory().get_protocol(
+ data_transfer_config, cwd=self.config_location
+ )
+ self.protocol = protocol
+
+ def _setup_reporting(self, config: SystemConfig) -> None:
+ self.reporting = config.get("reporting")
+
+ def run(self, command: str, retry: bool = True) -> Tuple[int, bytearray, bytearray]:
+ """
+ Run command on the system.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+ return self.protocol.run(command, retry)
+
+ def deploy(self, src: Path, dst: str, retry: bool = True) -> None:
+ """Deploy files to the system."""
+ if isinstance(self.protocol, SupportsDeploy):
+ self.protocol.deploy(src, dst, retry)
+
+ @property
+ def supports_deploy(self) -> bool:
+ """Check if protocol supports deploy operation."""
+ return isinstance(self.protocol, SupportsDeploy)
+
+ @property
+ def connectable(self) -> bool:
+ """Check if protocol supports connection."""
+ return isinstance(self.protocol, SupportsConnection)
+
+ def establish_connection(self) -> bool:
+ """Establish connection with the system."""
+ if not isinstance(self.protocol, SupportsConnection):
+ raise ConfigurationException(
+ "System {} does not support connections".format(self.name)
+ )
+
+ return self.protocol.establish_connection()
+
+ def connection_details(self) -> Tuple[str, int]:
+ """Return connection details."""
+ if not isinstance(self.protocol, SupportsConnection):
+ raise ConfigurationException(
+ "System {} does not support connections".format(self.name)
+ )
+
+ return self.protocol.connection_details()
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, System):
+ return False
+
+ return super().__eq__(other) and self.name == other.name
+
+ def get_details(self) -> Dict[str, Any]:
+ """Return a dictionary with all relevant information of a System."""
+ output = {
+ "type": "system",
+ "name": self.name,
+ "description": self.description,
+ "data_transfer_protocol": self.protocol.protocol,
+ "commands": self._get_command_details(),
+ "annotations": self.annotations,
+ }
+
+ return output
+
+
+class StandaloneSystem(System):
+ """StandaloneSystem class."""
+
+
+def get_controller(
+ single_instance: bool, pid_file_path: Optional[Path] = None
+) -> SystemController:
+ """Get system controller."""
+ if single_instance:
+ return SystemControllerSingleInstance(pid_file_path)
+
+ return SystemController()
+
+
+class ControlledSystem(System):
+ """ControlledSystem class."""
+
+ def __init__(self, config: SystemConfig):
+ """Construct the ControlledSystem class using the dictionary passed."""
+ super().__init__(config)
+ self.controller: Optional[SystemController] = None
+
+ def start(
+ self,
+ commands: List[str],
+ single_instance: bool = True,
+ pid_file_path: Optional[Path] = None,
+ ) -> None:
+ """Launch the system."""
+ if (
+ not isinstance(self.config_location, Path)
+ or not self.config_location.is_dir()
+ ):
+ raise ConfigurationException(
+ "System {} has wrong config location".format(self.name)
+ )
+
+ self.controller = get_controller(single_instance, pid_file_path)
+ self.controller.start(commands, self.config_location)
+
+ def is_running(self) -> bool:
+ """Check if system is running."""
+ if not self.controller:
+ return False
+
+ return self.controller.is_running()
+
+ def get_output(self) -> Tuple[str, str]:
+ """Return system output."""
+ if not self.controller:
+ return "", ""
+
+ return self.controller.get_output()
+
+ def stop(self, wait: bool = False) -> None:
+ """Stop the system."""
+ if not self.controller:
+ raise Exception("System has not been started")
+
+ if isinstance(self.protocol, SupportsClose):
+ try:
+ self.protocol.close()
+ except Exception as error: # pylint: disable=broad-except
+ print(error)
+ self.controller.stop(wait)
+
+
+def load_system(config: SystemConfig) -> Union[StandaloneSystem, ControlledSystem]:
+ """Load system based on it's execution type."""
+ data_transfer = config.get("data_transfer", {})
+ protocol = data_transfer.get("protocol")
+ populate_shared_params(config)
+
+ if protocol == "ssh":
+ return ControlledSystem(config)
+
+ if protocol == "local":
+ return StandaloneSystem(config)
+
+ raise ConfigurationException(
+ "Unsupported execution type for protocol {}".format(protocol)
+ )
+
+
+def populate_shared_params(config: SystemConfig) -> None:
+ """Populate command parameters with shared parameters."""
+ user_params = config.get("user_params")
+ if not user_params or "shared" not in user_params:
+ return
+
+ shared_user_params = user_params["shared"]
+ if not shared_user_params:
+ return
+
+ only_aliases = all(p.get("alias") for p in shared_user_params)
+ if not only_aliases:
+ raise ConfigurationException("All shared parameters should have aliases")
+
+ commands = config.get("commands", {})
+ for cmd_name in ["build", "run"]:
+ command = commands.get(cmd_name)
+ if command is None:
+ commands[cmd_name] = []
+ cmd_user_params = user_params.get(cmd_name)
+ if not cmd_user_params:
+ cmd_user_params = shared_user_params
+ else:
+ only_aliases = all(p.get("alias") for p in cmd_user_params)
+ if not only_aliases:
+ raise ConfigurationException(
+ "All parameters for command {} should have aliases".format(cmd_name)
+ )
+ merged_by_alias = {
+ **{p.get("alias"): p for p in shared_user_params},
+ **{p.get("alias"): p for p in cmd_user_params},
+ }
+ cmd_user_params = list(merged_by_alias.values())
+
+ user_params[cmd_name] = cmd_user_params
+
+ config["commands"] = commands
+ del user_params["shared"]
diff --git a/src/aiet/backend/tool.py b/src/aiet/backend/tool.py
new file mode 100644
index 0000000..d643665
--- /dev/null
+++ b/src/aiet/backend/tool.py
@@ -0,0 +1,109 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tool backend module."""
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from aiet.backend.common import Backend
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import get_backend_configs
+from aiet.backend.common import get_backend_directories
+from aiet.backend.common import load_application_or_tool_configs
+from aiet.backend.common import load_config
+from aiet.backend.config import ExtendedToolConfig
+from aiet.backend.config import ToolConfig
+
+
+def get_available_tool_directory_names() -> List[str]:
+ """Return a list of directory names for all available tools."""
+ return [entry.name for entry in get_backend_directories("tools")]
+
+
+def get_available_tools() -> List["Tool"]:
+ """Return a list with all available tools."""
+ available_tools = []
+ for config_json in get_backend_configs("tools"):
+ config_entries = cast(List[ExtendedToolConfig], load_config(config_json))
+ for config_entry in config_entries:
+ config_entry["config_location"] = config_json.parent.absolute()
+ tools = load_tools(config_entry)
+ available_tools += tools
+
+ return sorted(available_tools, key=lambda tool: tool.name)
+
+
+def get_tool(tool_name: str, system_name: Optional[str] = None) -> List["Tool"]:
+ """Return a tool instance with the same name passed as argument."""
+ return [
+ tool
+ for tool in get_available_tools()
+ if tool.name == tool_name and (not system_name or tool.can_run_on(system_name))
+ ]
+
+
+def get_unique_tool_names(system_name: Optional[str] = None) -> List[str]:
+ """Extract a list of unique tool names of all tools available."""
+ return list(
+ set(
+ tool.name
+ for tool in get_available_tools()
+ if not system_name or tool.can_run_on(system_name)
+ )
+ )
+
+
+class Tool(Backend):
+ """Class for representing a single tool component."""
+
+ def __init__(self, config: ToolConfig) -> None:
+ """Construct a Tool instance from a dict."""
+ super().__init__(config)
+
+ self.supported_systems = config.get("supported_systems", [])
+
+ if "run" not in self.commands:
+ raise ConfigurationException("A Tool must have a 'run' command.")
+
+ def __eq__(self, other: object) -> bool:
+ """Overload operator ==."""
+ if not isinstance(other, Tool):
+ return False
+
+ return (
+ super().__eq__(other)
+ and self.name == other.name
+ and set(self.supported_systems) == set(other.supported_systems)
+ )
+
+ def can_run_on(self, system_name: str) -> bool:
+ """Check if the tool can run on the system passed as argument."""
+ return system_name in self.supported_systems
+
+ def get_details(self) -> Dict[str, Any]:
+ """Return dictionary with all relevant information of the Tool instance."""
+ output = {
+ "type": "tool",
+ "name": self.name,
+ "description": self.description,
+ "supported_systems": self.supported_systems,
+ "commands": self._get_command_details(),
+ }
+
+ return output
+
+
+def load_tools(config: ExtendedToolConfig) -> List[Tool]:
+ """Load tool.
+
+ Tool configuration could contain different parameters/commands for different
+ supported systems. For each supported system this function will return separate
+ Tool instance with appropriate configuration.
+ """
+ configs = load_application_or_tool_configs(
+ config, ToolConfig, is_system_required=False
+ )
+ tools = [Tool(cfg) for cfg in configs]
+ return tools
diff --git a/src/aiet/cli/__init__.py b/src/aiet/cli/__init__.py
new file mode 100644
index 0000000..bcd17c3
--- /dev/null
+++ b/src/aiet/cli/__init__.py
@@ -0,0 +1,28 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module to mange the CLI interface."""
+import click
+
+from aiet import __version__
+from aiet.cli.application import application_cmd
+from aiet.cli.completion import completion_cmd
+from aiet.cli.system import system_cmd
+from aiet.cli.tool import tool_cmd
+from aiet.utils.helpers import set_verbosity
+
+
+@click.group()
+@click.version_option(__version__)
+@click.option(
+ "-v", "--verbose", default=0, count=True, callback=set_verbosity, expose_value=False
+)
+@click.pass_context
+def cli(ctx: click.Context) -> None: # pylint: disable=unused-argument
+ """AIET: AI Evaluation Toolkit."""
+ # Unused arguments must be present here in definition to pass click context.
+
+
+cli.add_command(application_cmd)
+cli.add_command(system_cmd)
+cli.add_command(tool_cmd)
+cli.add_command(completion_cmd)
diff --git a/src/aiet/cli/application.py b/src/aiet/cli/application.py
new file mode 100644
index 0000000..59b652d
--- /dev/null
+++ b/src/aiet/cli/application.py
@@ -0,0 +1,362 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright (c) 2021, Gianluca Gippetto. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause
+"""Module to manage the CLI interface of applications."""
+import json
+import logging
+import re
+from pathlib import Path
+from typing import Any
+from typing import IO
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import click
+import cloup
+
+from aiet.backend.application import get_application
+from aiet.backend.application import get_available_application_directory_names
+from aiet.backend.application import get_unique_application_names
+from aiet.backend.application import install_application
+from aiet.backend.application import remove_application
+from aiet.backend.common import DataPaths
+from aiet.backend.execution import execute_application_command
+from aiet.backend.execution import run_application
+from aiet.backend.system import get_available_systems
+from aiet.cli.common import get_format
+from aiet.cli.common import middleware_exception_handler
+from aiet.cli.common import middleware_signal_handler
+from aiet.cli.common import print_command_details
+from aiet.cli.common import set_format
+
+
+@click.group(name="application")
+@click.option(
+ "-f",
+ "--format",
+ "format_",
+ type=click.Choice(["cli", "json"]),
+ default="cli",
+ show_default=True,
+)
+@click.pass_context
+def application_cmd(ctx: click.Context, format_: str) -> None:
+ """Sub command to manage applications."""
+ set_format(ctx, format_)
+
+
+@application_cmd.command(name="list")
+@click.pass_context
+@click.option(
+ "-s",
+ "--system",
+ "system_name",
+ type=click.Choice([s.name for s in get_available_systems()]),
+ required=False,
+)
+def list_cmd(ctx: click.Context, system_name: str) -> None:
+ """List all available applications."""
+ unique_application_names = get_unique_application_names(system_name)
+ unique_application_names.sort()
+ if get_format(ctx) == "json":
+ data = {"type": "application", "available": unique_application_names}
+ print(json.dumps(data))
+ else:
+ print("Available applications:\n")
+ print(*unique_application_names, sep="\n")
+
+
+@application_cmd.command(name="details")
+@click.option(
+ "-n",
+ "--name",
+ "application_name",
+ type=click.Choice(get_unique_application_names()),
+ required=True,
+)
+@click.option(
+ "-s",
+ "--system",
+ "system_name",
+ type=click.Choice([s.name for s in get_available_systems()]),
+ required=False,
+)
+@click.pass_context
+def details_cmd(ctx: click.Context, application_name: str, system_name: str) -> None:
+ """Details of a specific application."""
+ applications = get_application(application_name, system_name)
+ if not applications:
+ raise click.UsageError(
+ "Application '{}' doesn't support the system '{}'".format(
+ application_name, system_name
+ )
+ )
+
+ if get_format(ctx) == "json":
+ applications_details = [s.get_details() for s in applications]
+ print(json.dumps(applications_details))
+ else:
+ for application in applications:
+ application_details = application.get_details()
+ application_details_template = (
+ 'Application "{name}" details\nDescription: {description}'
+ )
+
+ print(
+ application_details_template.format(
+ name=application_details["name"],
+ description=application_details["description"],
+ )
+ )
+
+ print(
+ "\nSupported systems: {}".format(
+ ", ".join(application_details["supported_systems"])
+ )
+ )
+
+ command_details = application_details["commands"]
+
+ for command, details in command_details.items():
+ print("\n{} commands:".format(command))
+ print_command_details(details)
+
+
+# pylint: disable=too-many-arguments
+@application_cmd.command(name="execute")
+@click.option(
+ "-n",
+ "--name",
+ "application_name",
+ type=click.Choice(get_unique_application_names()),
+ required=True,
+)
+@click.option(
+ "-s",
+ "--system",
+ "system_name",
+ type=click.Choice([s.name for s in get_available_systems()]),
+ required=True,
+)
+@click.option(
+ "-c",
+ "--command",
+ "command_name",
+ type=click.Choice(["build", "run"]),
+ required=True,
+)
+@click.option("-p", "--param", "application_params", multiple=True)
+@click.option("--system-param", "system_params", multiple=True)
+@click.option("-d", "--deploy", "deploy_params", multiple=True)
+@middleware_signal_handler
+@middleware_exception_handler
+def execute_cmd(
+ application_name: str,
+ system_name: str,
+ command_name: str,
+ application_params: List[str],
+ system_params: List[str],
+ deploy_params: List[str],
+) -> None:
+ """Execute application commands. DEPRECATED! Use 'aiet application run' instead."""
+ logging.warning(
+ "Please use 'aiet application run' instead. Use of 'aiet application "
+ "execute' is deprecated and might be removed in a future release."
+ )
+
+ custom_deploy_data = get_custom_deploy_data(command_name, deploy_params)
+
+ execute_application_command(
+ command_name,
+ application_name,
+ application_params,
+ system_name,
+ system_params,
+ custom_deploy_data,
+ )
+
+
+@cloup.command(name="run")
+@cloup.option(
+ "-n",
+ "--name",
+ "application_name",
+ type=click.Choice(get_unique_application_names()),
+)
+@cloup.option(
+ "-s",
+ "--system",
+ "system_name",
+ type=click.Choice([s.name for s in get_available_systems()]),
+)
+@cloup.option("-p", "--param", "application_params", multiple=True)
+@cloup.option("--system-param", "system_params", multiple=True)
+@cloup.option("-d", "--deploy", "deploy_params", multiple=True)
+@click.option(
+ "-r",
+ "--report",
+ "report_file",
+ type=Path,
+ help="Create a report file in JSON format containing metrics parsed from "
+ "the simulation output as specified in the aiet-config.json.",
+)
+@cloup.option(
+ "--config",
+ "config_file",
+ type=click.File("r"),
+ help="Read options from a config file rather than from the command line. "
+ "The config file is a json file.",
+)
+@cloup.constraint(
+ cloup.constraints.If(
+ cloup.constraints.conditions.Not(
+ cloup.constraints.conditions.IsSet("config_file")
+ ),
+ then=cloup.constraints.require_all,
+ ),
+ ["system_name", "application_name"],
+)
+@cloup.constraint(
+ cloup.constraints.If("config_file", then=cloup.constraints.accept_none),
+ [
+ "system_name",
+ "application_name",
+ "application_params",
+ "system_params",
+ "deploy_params",
+ ],
+)
+@middleware_signal_handler
+@middleware_exception_handler
+def run_cmd(
+ application_name: str,
+ system_name: str,
+ application_params: List[str],
+ system_params: List[str],
+ deploy_params: List[str],
+ report_file: Optional[Path],
+ config_file: Optional[IO[str]],
+) -> None:
+ """Execute application commands."""
+ if config_file:
+ payload_data = json.load(config_file)
+ (
+ system_name,
+ application_name,
+ application_params,
+ system_params,
+ deploy_params,
+ report_file,
+ ) = parse_payload_run_config(payload_data)
+
+ custom_deploy_data = get_custom_deploy_data("run", deploy_params)
+
+ run_application(
+ application_name,
+ application_params,
+ system_name,
+ system_params,
+ custom_deploy_data,
+ report_file,
+ )
+
+
+application_cmd.add_command(run_cmd)
+
+
+def parse_payload_run_config(
+ payload_data: dict,
+) -> Tuple[str, str, List[str], List[str], List[str], Optional[Path]]:
+ """Parse the payload into a tuple."""
+ system_id = payload_data.get("id")
+ arguments: Optional[Any] = payload_data.get("arguments")
+
+ if not isinstance(system_id, str):
+ raise click.ClickException("invalid payload json: no system 'id'")
+ if not isinstance(arguments, dict):
+ raise click.ClickException("invalid payload json: no arguments object")
+
+ application_name = arguments.pop("application", None)
+ if not isinstance(application_name, str):
+ raise click.ClickException("invalid payload json: no application_id")
+
+ report_path = arguments.pop("report_path", None)
+
+ application_params = []
+ system_params = []
+ deploy_params = []
+
+ for (param_key, value) in arguments.items():
+ (par, _) = re.subn("^application/", "", param_key)
+ (par, found_sys_param) = re.subn("^system/", "", par)
+ (par, found_deploy_param) = re.subn("^deploy/", "", par)
+
+ param_expr = par + "=" + value
+ if found_sys_param:
+ system_params.append(param_expr)
+ elif found_deploy_param:
+ deploy_params.append(par)
+ else:
+ application_params.append(param_expr)
+
+ return (
+ system_id,
+ application_name,
+ application_params,
+ system_params,
+ deploy_params,
+ report_path,
+ )
+
+
+def get_custom_deploy_data(
+ command_name: str, deploy_params: List[str]
+) -> List[DataPaths]:
+ """Get custom deploy data information."""
+ custom_deploy_data: List[DataPaths] = []
+ if not deploy_params:
+ return custom_deploy_data
+
+ for param in deploy_params:
+ parts = param.split(":")
+ if not len(parts) == 2 or any(not part.strip() for part in parts):
+ raise click.ClickException(
+ "Invalid deploy parameter '{}' for command {}".format(
+ param, command_name
+ )
+ )
+ data_path = DataPaths(Path(parts[0]), parts[1])
+ if not data_path.src.exists():
+ raise click.ClickException("Path {} does not exist".format(data_path.src))
+ custom_deploy_data.append(data_path)
+
+ return custom_deploy_data
+
+
+@application_cmd.command(name="install")
+@click.option(
+ "-s",
+ "--source",
+ "source",
+ required=True,
+ help="Path to the directory or archive with application definition",
+)
+def install_cmd(source: str) -> None:
+ """Install new application."""
+ source_path = Path(source)
+ install_application(source_path)
+
+
+@application_cmd.command(name="remove")
+@click.option(
+ "-d",
+ "--directory_name",
+ "directory_name",
+ type=click.Choice(get_available_application_directory_names()),
+ required=True,
+ help="Name of the directory with application",
+)
+def remove_cmd(directory_name: str) -> None:
+ """Remove application."""
+ remove_application(directory_name)
diff --git a/src/aiet/cli/common.py b/src/aiet/cli/common.py
new file mode 100644
index 0000000..1d157b6
--- /dev/null
+++ b/src/aiet/cli/common.py
@@ -0,0 +1,173 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common functions for cli module."""
+import enum
+import logging
+from functools import wraps
+from signal import SIG_IGN
+from signal import SIGINT
+from signal import signal as signal_handler
+from signal import SIGTERM
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+
+from click import ClickException
+from click import Context
+from click import UsageError
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.execution import AnotherInstanceIsRunningException
+from aiet.backend.execution import ConnectionException
+from aiet.backend.protocol import SSHConnectionException
+from aiet.utils.proc import CommandFailedException
+
+
+class MiddlewareExitCode(enum.IntEnum):
+ """Middleware exit codes."""
+
+ SUCCESS = 0
+ # exit codes 1 and 2 are used by click
+ SHUTDOWN_REQUESTED = 3
+ BACKEND_ERROR = 4
+ CONCURRENT_ERROR = 5
+ CONNECTION_ERROR = 6
+ CONFIGURATION_ERROR = 7
+ MODEL_OPTIMISED_ERROR = 8
+ INVALID_TFLITE_FILE_ERROR = 9
+
+
+class CustomClickException(ClickException):
+ """Custom click exception."""
+
+ def show(self, file: Any = None) -> None:
+ """Override show method."""
+ super().show(file)
+
+ logging.debug("Execution failed with following exception: ", exc_info=self)
+
+
+class MiddlewareShutdownException(CustomClickException):
+ """Exception indicates that user requested middleware shutdown."""
+
+ exit_code = int(MiddlewareExitCode.SHUTDOWN_REQUESTED)
+
+
+class BackendException(CustomClickException):
+ """Exception indicates that command failed."""
+
+ exit_code = int(MiddlewareExitCode.BACKEND_ERROR)
+
+
+class ConcurrentErrorException(CustomClickException):
+ """Exception indicates concurrent execution error."""
+
+ exit_code = int(MiddlewareExitCode.CONCURRENT_ERROR)
+
+
+class BackendConnectionException(CustomClickException):
+ """Exception indicates that connection could not be established."""
+
+ exit_code = int(MiddlewareExitCode.CONNECTION_ERROR)
+
+
+class BackendConfigurationException(CustomClickException):
+ """Exception indicates some configuration issue."""
+
+ exit_code = int(MiddlewareExitCode.CONFIGURATION_ERROR)
+
+
+class ModelOptimisedException(CustomClickException):
+ """Exception indicates input file has previously been Vela optimised."""
+
+ exit_code = int(MiddlewareExitCode.MODEL_OPTIMISED_ERROR)
+
+
+class InvalidTFLiteFileError(CustomClickException):
+ """Exception indicates input TFLite file is misformatted."""
+
+ exit_code = int(MiddlewareExitCode.INVALID_TFLITE_FILE_ERROR)
+
+
+def print_command_details(command: Dict) -> None:
+ """Print command details including parameters."""
+ command_strings = command["command_strings"]
+ print("Commands: {}".format(command_strings))
+ user_params = command["user_params"]
+ for i, param in enumerate(user_params, 1):
+ print("User parameter #{}".format(i))
+ print("\tName: {}".format(param.get("name", "-")))
+ print("\tDescription: {}".format(param["description"]))
+ print("\tPossible values: {}".format(param.get("values", "-")))
+ print("\tDefault value: {}".format(param.get("default_value", "-")))
+ print("\tAlias: {}".format(param.get("alias", "-")))
+
+
+def raise_exception_at_signal(
+ signum: int, frame: Any # pylint: disable=unused-argument
+) -> None:
+ """Handle signals."""
+ # Disable both SIGINT and SIGTERM signals. Further SIGINT and SIGTERM
+ # signals will be ignored as we allow a graceful shutdown.
+ # Unused arguments must be present here in definition as used in signal handler
+ # callback
+
+ signal_handler(SIGINT, SIG_IGN)
+ signal_handler(SIGTERM, SIG_IGN)
+ raise MiddlewareShutdownException("Middleware shutdown requested")
+
+
+def middleware_exception_handler(func: Callable) -> Callable:
+ """Handle backend exceptions decorator."""
+
+ @wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ try:
+ return func(*args, **kwargs)
+ except (MiddlewareShutdownException, UsageError, ClickException) as error:
+ # click should take care of these exceptions
+ raise error
+ except ValueError as error:
+ raise ClickException(str(error)) from error
+ except AnotherInstanceIsRunningException as error:
+ raise ConcurrentErrorException(
+ "Another instance of the system is running"
+ ) from error
+ except (SSHConnectionException, ConnectionException) as error:
+ raise BackendConnectionException(str(error)) from error
+ except ConfigurationException as error:
+ raise BackendConfigurationException(str(error)) from error
+ except (CommandFailedException, Exception) as error:
+ raise BackendException(
+ "Execution failed. Please check output for the details."
+ ) from error
+
+ return wrapper
+
+
+def middleware_signal_handler(func: Callable) -> Callable:
+ """Handle signals decorator."""
+
+ @wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ # Set up signal handlers for SIGINT (ctrl-c) and SIGTERM (kill command)
+ # The handler ignores further signals and it raises an exception
+ signal_handler(SIGINT, raise_exception_at_signal)
+ signal_handler(SIGTERM, raise_exception_at_signal)
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def set_format(ctx: Context, format_: str) -> None:
+ """Save format in click context."""
+ ctx_obj = ctx.ensure_object(dict)
+ ctx_obj["format"] = format_
+
+
+def get_format(ctx: Context) -> str:
+ """Get format from click context."""
+ ctx_obj = cast(Dict[str, str], ctx.ensure_object(dict))
+ return ctx_obj["format"]
diff --git a/src/aiet/cli/completion.py b/src/aiet/cli/completion.py
new file mode 100644
index 0000000..71f054f
--- /dev/null
+++ b/src/aiet/cli/completion.py
@@ -0,0 +1,72 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Add auto completion to different shells with these helpers.
+
+See: https://click.palletsprojects.com/en/8.0.x/shell-completion/
+"""
+import click
+
+
+def _get_package_name() -> str:
+ return __name__.split(".", maxsplit=1)[0]
+
+
+# aiet completion bash
+@click.group(name="completion")
+def completion_cmd() -> None:
+ """Enable auto completion for your shell."""
+
+
+@completion_cmd.command(name="bash")
+def bash_cmd() -> None:
+ """
+ Enable auto completion for bash.
+
+ Use this command to activate completion in the current bash:
+
+ eval "`aiet completion bash`"
+
+ Use this command to add auto completion to bash globally, if you have aiet
+ installed globally (requires starting a new shell afterwards):
+
+ aiet completion bash >> ~/.bashrc
+ """
+ package_name = _get_package_name()
+ print(f'eval "$(_{package_name.upper()}_COMPLETE=bash_source {package_name})"')
+
+
+@completion_cmd.command(name="zsh")
+def zsh_cmd() -> None:
+ """
+ Enable auto completion for zsh.
+
+ Use this command to activate completion in the current zsh:
+
+ eval "`aiet completion zsh`"
+
+ Use this command to add auto completion to zsh globally, if you have aiet
+ installed globally (requires starting a new shell afterwards):
+
+ aiet completion zsh >> ~/.zshrc
+ """
+ package_name = _get_package_name()
+ print(f'eval "$(_{package_name.upper()}_COMPLETE=zsh_source {package_name})"')
+
+
+@completion_cmd.command(name="fish")
+def fish_cmd() -> None:
+ """
+ Enable auto completion for fish.
+
+ Use this command to activate completion in the current fish:
+
+ eval "`aiet completion fish`"
+
+ Use this command to add auto completion to fish globally, if you have aiet
+ installed globally (requires starting a new shell afterwards):
+
+ aiet completion fish >> ~/.config/fish/completions/aiet.fish
+ """
+ package_name = _get_package_name()
+ print(f'eval "(env _{package_name.upper()}_COMPLETE=fish_source {package_name})"')
diff --git a/src/aiet/cli/system.py b/src/aiet/cli/system.py
new file mode 100644
index 0000000..f1f7637
--- /dev/null
+++ b/src/aiet/cli/system.py
@@ -0,0 +1,122 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module to manage the CLI interface of systems."""
+import json
+from pathlib import Path
+from typing import cast
+
+import click
+
+from aiet.backend.application import get_available_applications
+from aiet.backend.system import get_available_systems
+from aiet.backend.system import get_available_systems_directory_names
+from aiet.backend.system import get_system
+from aiet.backend.system import install_system
+from aiet.backend.system import remove_system
+from aiet.backend.system import System
+from aiet.cli.common import get_format
+from aiet.cli.common import print_command_details
+from aiet.cli.common import set_format
+
+
+@click.group(name="system")
+@click.option(
+ "-f",
+ "--format",
+ "format_",
+ type=click.Choice(["cli", "json"]),
+ default="cli",
+ show_default=True,
+)
+@click.pass_context
+def system_cmd(ctx: click.Context, format_: str) -> None:
+ """Sub command to manage systems."""
+ set_format(ctx, format_)
+
+
+@system_cmd.command(name="list")
+@click.pass_context
+def list_cmd(ctx: click.Context) -> None:
+ """List all available systems."""
+ available_systems = get_available_systems()
+ system_names = [system.name for system in available_systems]
+ if get_format(ctx) == "json":
+ data = {"type": "system", "available": system_names}
+ print(json.dumps(data))
+ else:
+ print("Available systems:\n")
+ print(*system_names, sep="\n")
+
+
+@system_cmd.command(name="details")
+@click.option(
+ "-n",
+ "--name",
+ "system_name",
+ type=click.Choice([s.name for s in get_available_systems()]),
+ required=True,
+)
+@click.pass_context
+def details_cmd(ctx: click.Context, system_name: str) -> None:
+ """Details of a specific system."""
+ system = cast(System, get_system(system_name))
+ applications = [
+ s.name for s in get_available_applications() if s.can_run_on(system.name)
+ ]
+ system_details = system.get_details()
+ if get_format(ctx) == "json":
+ system_details["available_application"] = applications
+ print(json.dumps(system_details))
+ else:
+ system_details_template = (
+ 'System "{name}" details\n'
+ "Description: {description}\n"
+ "Data Transfer Protocol: {protocol}\n"
+ "Available Applications: {available_application}"
+ )
+ print(
+ system_details_template.format(
+ name=system_details["name"],
+ description=system_details["description"],
+ protocol=system_details["data_transfer_protocol"],
+ available_application=", ".join(applications),
+ )
+ )
+
+ if system_details["annotations"]:
+ print("Annotations:")
+ for ann_name, ann_value in system_details["annotations"].items():
+ print("\t{}: {}".format(ann_name, ann_value))
+
+ command_details = system_details["commands"]
+ for command, details in command_details.items():
+ print("\n{} commands:".format(command))
+ print_command_details(details)
+
+
+@system_cmd.command(name="install")
+@click.option(
+ "-s",
+ "--source",
+ "source",
+ required=True,
+ help="Path to the directory or archive with system definition",
+)
+def install_cmd(source: str) -> None:
+ """Install new system."""
+ source_path = Path(source)
+ install_system(source_path)
+
+
+@system_cmd.command(name="remove")
+@click.option(
+ "-d",
+ "--directory_name",
+ "directory_name",
+ type=click.Choice(get_available_systems_directory_names()),
+ required=True,
+ help="Name of the directory with system",
+)
+def remove_cmd(directory_name: str) -> None:
+ """Remove system by given name."""
+ remove_system(directory_name)
diff --git a/src/aiet/cli/tool.py b/src/aiet/cli/tool.py
new file mode 100644
index 0000000..2c80821
--- /dev/null
+++ b/src/aiet/cli/tool.py
@@ -0,0 +1,143 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module to manage the CLI interface of tools."""
+import json
+from typing import Any
+from typing import List
+from typing import Optional
+
+import click
+
+from aiet.backend.execution import execute_tool_command
+from aiet.backend.tool import get_tool
+from aiet.backend.tool import get_unique_tool_names
+from aiet.cli.common import get_format
+from aiet.cli.common import middleware_exception_handler
+from aiet.cli.common import middleware_signal_handler
+from aiet.cli.common import print_command_details
+from aiet.cli.common import set_format
+
+
+@click.group(name="tool")
+@click.option(
+ "-f",
+ "--format",
+ "format_",
+ type=click.Choice(["cli", "json"]),
+ default="cli",
+ show_default=True,
+)
+@click.pass_context
+def tool_cmd(ctx: click.Context, format_: str) -> None:
+ """Sub command to manage tools."""
+ set_format(ctx, format_)
+
+
+@tool_cmd.command(name="list")
+@click.pass_context
+def list_cmd(ctx: click.Context) -> None:
+ """List all available tools."""
+ # raise NotImplementedError("TODO")
+ tool_names = get_unique_tool_names()
+ tool_names.sort()
+ if get_format(ctx) == "json":
+ data = {"type": "tool", "available": tool_names}
+ print(json.dumps(data))
+ else:
+ print("Available tools:\n")
+ print(*tool_names, sep="\n")
+
+
+def validate_system(
+ ctx: click.Context,
+ _: click.Parameter, # param is not used
+ value: Any,
+) -> Any:
+ """Validate provided system name depending on the the tool name."""
+ tool_name = ctx.params["tool_name"]
+ tools = get_tool(tool_name, value)
+ if not tools:
+ supported_systems = [tool.supported_systems[0] for tool in get_tool(tool_name)]
+ raise click.BadParameter(
+ message="'{}' is not one of {}.".format(
+ value,
+ ", ".join("'{}'".format(system) for system in supported_systems),
+ ),
+ ctx=ctx,
+ )
+ return value
+
+
+@tool_cmd.command(name="details")
+@click.option(
+ "-n",
+ "--name",
+ "tool_name",
+ type=click.Choice(get_unique_tool_names()),
+ required=True,
+)
+@click.option(
+ "-s",
+ "--system",
+ "system_name",
+ callback=validate_system,
+ required=False,
+)
+@click.pass_context
+@middleware_signal_handler
+@middleware_exception_handler
+def details_cmd(ctx: click.Context, tool_name: str, system_name: Optional[str]) -> None:
+ """Details of a specific tool."""
+ tools = get_tool(tool_name, system_name)
+ if get_format(ctx) == "json":
+ tools_details = [s.get_details() for s in tools]
+ print(json.dumps(tools_details))
+ else:
+ for tool in tools:
+ tool_details = tool.get_details()
+ tool_details_template = 'Tool "{name}" details\nDescription: {description}'
+
+ print(
+ tool_details_template.format(
+ name=tool_details["name"],
+ description=tool_details["description"],
+ )
+ )
+
+ print(
+ "\nSupported systems: {}".format(
+ ", ".join(tool_details["supported_systems"])
+ )
+ )
+
+ command_details = tool_details["commands"]
+
+ for command, details in command_details.items():
+ print("\n{} commands:".format(command))
+ print_command_details(details)
+
+
+# pylint: disable=too-many-arguments
+@tool_cmd.command(name="execute")
+@click.option(
+ "-n",
+ "--name",
+ "tool_name",
+ type=click.Choice(get_unique_tool_names()),
+ required=True,
+)
+@click.option("-p", "--param", "tool_params", multiple=True)
+@click.option(
+ "-s",
+ "--system",
+ "system_name",
+ callback=validate_system,
+ required=False,
+)
+@middleware_signal_handler
+@middleware_exception_handler
+def execute_cmd(
+ tool_name: str, tool_params: List[str], system_name: Optional[str]
+) -> None:
+ """Execute tool commands."""
+ execute_tool_command(tool_name, tool_params, system_name)
diff --git a/src/aiet/main.py b/src/aiet/main.py
new file mode 100644
index 0000000..6898ad9
--- /dev/null
+++ b/src/aiet/main.py
@@ -0,0 +1,13 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Entry point module of AIET."""
+from aiet.cli import cli
+
+
+def main() -> None:
+ """Entry point of aiet application."""
+ cli() # pylint: disable=no-value-for-parameter
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/aiet/resources/applications/.gitignore b/src/aiet/resources/applications/.gitignore
new file mode 100644
index 0000000..0226166
--- /dev/null
+++ b/src/aiet/resources/applications/.gitignore
@@ -0,0 +1,6 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# Ignore everything in this directory
+*
+# Except this file
+!.gitignore
diff --git a/src/aiet/resources/systems/.gitignore b/src/aiet/resources/systems/.gitignore
new file mode 100644
index 0000000..0226166
--- /dev/null
+++ b/src/aiet/resources/systems/.gitignore
@@ -0,0 +1,6 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# Ignore everything in this directory
+*
+# Except this file
+!.gitignore
diff --git a/src/aiet/resources/tools/vela/aiet-config.json b/src/aiet/resources/tools/vela/aiet-config.json
new file mode 100644
index 0000000..c12f291
--- /dev/null
+++ b/src/aiet/resources/tools/vela/aiet-config.json
@@ -0,0 +1,73 @@
+[
+ {
+ "name": "vela",
+ "description": "Neural network model compiler for Arm Ethos-U NPUs",
+ "supported_systems": [
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U55"
+ },
+ {
+ "name": "Corstone-310: Cortex-M85+Ethos-U55"
+ },
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U65",
+ "variables": {
+ "accelerator_config_prefix": "ethos-u65",
+ "system_config": "Ethos_U65_High_End",
+ "shared_sram": "U65_Shared_Sram"
+ },
+ "user_params": {
+ "run": [
+ {
+ "description": "MACs per cycle",
+ "values": [
+ "256",
+ "512"
+ ],
+ "default_value": "512",
+ "alias": "mac"
+ }
+ ]
+ }
+ }
+ ],
+ "variables": {
+ "accelerator_config_prefix": "ethos-u55",
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "shared_sram": "U55_Shared_Sram"
+ },
+ "commands": {
+ "run": [
+ "run_vela {user_params:input} {user_params:output} --config {tool.config_dir}/vela.ini --accelerator-config {variables:accelerator_config_prefix}-{user_params:mac} --system-config {variables:system_config} --memory-mode {variables:shared_sram} --optimise Performance"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "description": "MACs per cycle",
+ "values": [
+ "32",
+ "64",
+ "128",
+ "256"
+ ],
+ "default_value": "128",
+ "alias": "mac"
+ },
+ {
+ "name": "--input-model",
+ "description": "Path to the TFLite model",
+ "values": [],
+ "alias": "input"
+ },
+ {
+ "name": "--output-model",
+ "description": "Path to the output model file of the vela-optimisation step. The vela output is saved in the parent directory.",
+ "values": [],
+ "default_value": "output_model.tflite",
+ "alias": "output"
+ }
+ ]
+ }
+ }
+]
diff --git a/src/aiet/resources/tools/vela/aiet-config.json.license b/src/aiet/resources/tools/vela/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/aiet/resources/tools/vela/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/aiet/resources/tools/vela/check_model.py b/src/aiet/resources/tools/vela/check_model.py
new file mode 100644
index 0000000..7c700b1
--- /dev/null
+++ b/src/aiet/resources/tools/vela/check_model.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Check if a TFLite model file is Vela-optimised."""
+import struct
+from pathlib import Path
+
+from ethosu.vela.tflite.Model import Model
+
+from aiet.cli.common import InvalidTFLiteFileError
+from aiet.cli.common import ModelOptimisedException
+from aiet.utils.fs import read_file_as_bytearray
+
+
+def get_model_from_file(input_model_file: Path) -> Model:
+ """Generate Model instance from TFLite file using flatc generated code."""
+ buffer = read_file_as_bytearray(input_model_file)
+ try:
+ model = Model.GetRootAsModel(buffer, 0)
+ except (TypeError, RuntimeError, struct.error) as tflite_error:
+ raise InvalidTFLiteFileError(
+ f"Error reading in model from {input_model_file}."
+ ) from tflite_error
+ return model
+
+
+def is_vela_optimised(tflite_model: Model) -> bool:
+ """Return True if 'ethos-u' custom operator found in the Model."""
+ operators = get_operators_from_model(tflite_model)
+
+ custom_codes = get_custom_codes_from_operators(operators)
+
+ return check_custom_codes_for_ethosu(custom_codes)
+
+
+def get_operators_from_model(tflite_model: Model) -> list:
+ """Return list of the unique operator codes used in the Model."""
+ return [
+ tflite_model.OperatorCodes(index)
+ for index in range(tflite_model.OperatorCodesLength())
+ ]
+
+
+def get_custom_codes_from_operators(operators: list) -> list:
+ """Return list of each operator's CustomCode() strings, if they exist."""
+ return [
+ operator.CustomCode()
+ for operator in operators
+ if operator.CustomCode() is not None
+ ]
+
+
+def check_custom_codes_for_ethosu(custom_codes: list) -> bool:
+ """Check for existence of ethos-u string in the custom codes."""
+ return any(
+ custom_code_name.decode("utf-8") == "ethos-u"
+ for custom_code_name in custom_codes
+ )
+
+
+def check_model(tflite_file_name: str) -> None:
+ """Raise an exception if model in given file is Vela optimised."""
+ tflite_path = Path(tflite_file_name)
+
+ tflite_model = get_model_from_file(tflite_path)
+
+ if is_vela_optimised(tflite_model):
+ raise ModelOptimisedException(
+ f"TFLite model in {tflite_file_name} is already "
+ f"vela optimised ('ethos-u' custom op detected)."
+ )
+
+ print(
+ f"TFLite model in {tflite_file_name} is not vela optimised "
+ f"('ethos-u' custom op not detected)."
+ )
diff --git a/src/aiet/resources/tools/vela/run_vela.py b/src/aiet/resources/tools/vela/run_vela.py
new file mode 100644
index 0000000..2c1b0be
--- /dev/null
+++ b/src/aiet/resources/tools/vela/run_vela.py
@@ -0,0 +1,65 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Wrapper to only run Vela when the input is not already optimised."""
+import shutil
+import subprocess
+from pathlib import Path
+from typing import Tuple
+
+import click
+
+from aiet.cli.common import ModelOptimisedException
+from aiet.resources.tools.vela.check_model import check_model
+
+
+def vela_output_model_path(input_model: str, output_dir: str) -> Path:
+ """Construct the path to the Vela output file."""
+ in_path = Path(input_model)
+ tflite_vela = Path(output_dir) / f"{in_path.stem}_vela{in_path.suffix}"
+ return tflite_vela
+
+
+def execute_vela(vela_args: Tuple, output_dir: Path, input_model: str) -> None:
+ """Execute vela as external call."""
+ cmd = ["vela"] + list(vela_args)
+ cmd += ["--output-dir", str(output_dir)] # Re-add parsed out_dir to arguments
+ cmd += [input_model]
+ subprocess.run(cmd, check=True)
+
+
+@click.command(context_settings=dict(ignore_unknown_options=True))
+@click.option(
+ "--input-model",
+ "-i",
+ type=click.Path(exists=True, file_okay=True, readable=True),
+ required=True,
+)
+@click.option("--output-model", "-o", type=click.Path(), required=True)
+# Collect the remaining arguments to be directly forwarded to Vela
+@click.argument("vela-args", nargs=-1, type=click.UNPROCESSED)
+def run_vela(input_model: str, output_model: str, vela_args: Tuple) -> None:
+ """Check input, run Vela (if needed) and copy optimised file to destination."""
+ output_dir = Path(output_model).parent
+ try:
+ check_model(input_model) # raises an exception if already Vela-optimised
+ execute_vela(vela_args, output_dir, input_model)
+ print("Vela optimisation complete.")
+ src_model = vela_output_model_path(input_model, str(output_dir))
+ except ModelOptimisedException as ex:
+ # Input already optimized: copy input file to destination path and return
+ print(f"Input already vela-optimised.\n{ex}")
+ src_model = Path(input_model)
+ except subprocess.CalledProcessError as ex:
+ print(ex)
+ raise SystemExit(ex.returncode) from ex
+
+ try:
+ shutil.copyfile(src_model, output_model)
+ except (shutil.SameFileError, OSError) as ex:
+ print(ex)
+ raise SystemExit(ex.errno) from ex
+
+
+def main() -> None:
+ """Entry point of check_model application."""
+ run_vela() # pylint: disable=no-value-for-parameter
diff --git a/src/aiet/resources/tools/vela/vela.ini b/src/aiet/resources/tools/vela/vela.ini
new file mode 100644
index 0000000..5996553
--- /dev/null
+++ b/src/aiet/resources/tools/vela/vela.ini
@@ -0,0 +1,53 @@
+; SPDX-FileCopyrightText: Copyright 2021-2022, Arm Limited and/or its affiliates.
+; SPDX-License-Identifier: Apache-2.0
+
+; -----------------------------------------------------------------------------
+; Vela configuration file
+
+; -----------------------------------------------------------------------------
+; System Configuration
+
+; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s)
+[System_Config.Ethos_U55_High_End_Embedded]
+core_clock=500e6
+axi0_port=Sram
+axi1_port=OffChipFlash
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+OffChipFlash_clock_scale=0.125
+OffChipFlash_burst_length=128
+OffChipFlash_read_latency=64
+OffChipFlash_write_latency=64
+
+; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s)
+[System_Config.Ethos_U65_High_End]
+core_clock=1e9
+axi0_port=Sram
+axi1_port=Dram
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+Dram_clock_scale=0.234375
+Dram_burst_length=128
+Dram_read_latency=500
+Dram_write_latency=250
+
+; -----------------------------------------------------------------------------
+; Memory Mode
+
+; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software
+; The non-SRAM memory is assumed to be read-only
+[Memory_Mode.U55_Shared_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+arena_cache_size=4194304
+
+[Memory_Mode.U65_Shared_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+arena_cache_size=2097152
diff --git a/src/aiet/utils/__init__.py b/src/aiet/utils/__init__.py
new file mode 100644
index 0000000..fc7ef7c
--- /dev/null
+++ b/src/aiet/utils/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""This module contains all utils shared across aiet project."""
diff --git a/src/aiet/utils/fs.py b/src/aiet/utils/fs.py
new file mode 100644
index 0000000..ea99a69
--- /dev/null
+++ b/src/aiet/utils/fs.py
@@ -0,0 +1,116 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module to host all file system related functions."""
+import importlib.resources as pkg_resources
+import re
+import shutil
+from pathlib import Path
+from typing import Any
+from typing import Literal
+from typing import Optional
+
+ResourceType = Literal["applications", "systems", "tools"]
+
+
+def get_aiet_resources() -> Path:
+ """Get resources folder path."""
+ with pkg_resources.path("aiet", "__init__.py") as init_path:
+ project_root = init_path.parent
+ return project_root / "resources"
+
+
+def get_resources(name: ResourceType) -> Path:
+ """Return the absolute path of the specified resource.
+
+ It uses importlib to return resources packaged with MANIFEST.in.
+ """
+ if not name:
+ raise ResourceWarning("Resource name is not provided")
+
+ resource_path = get_aiet_resources() / name
+ if resource_path.is_dir():
+ return resource_path
+
+ raise ResourceWarning("Resource '{}' not found.".format(name))
+
+
+def copy_directory_content(source: Path, destination: Path) -> None:
+ """Copy content of the source directory into destination directory."""
+ for item in source.iterdir():
+ src = source / item.name
+ dest = destination / item.name
+
+ if src.is_dir():
+ shutil.copytree(src, dest)
+ else:
+ shutil.copy2(src, dest)
+
+
+def remove_resource(resource_directory: str, resource_type: ResourceType) -> None:
+ """Remove resource data."""
+ resources = get_resources(resource_type)
+
+ resource_location = resources / resource_directory
+ if not resource_location.exists():
+ raise Exception("Resource {} does not exist".format(resource_directory))
+
+ if not resource_location.is_dir():
+ raise Exception("Wrong resource {}".format(resource_directory))
+
+ shutil.rmtree(resource_location)
+
+
+def remove_directory(directory_path: Optional[Path]) -> None:
+ """Remove directory."""
+ if not directory_path or not directory_path.is_dir():
+ raise Exception("No directory path provided")
+
+ shutil.rmtree(directory_path)
+
+
+def recreate_directory(directory_path: Optional[Path]) -> None:
+ """Recreate directory."""
+ if not directory_path:
+ raise Exception("No directory path provided")
+
+ if directory_path.exists() and not directory_path.is_dir():
+ raise Exception(
+ "Path {} does exist and it is not a directory".format(str(directory_path))
+ )
+
+ if directory_path.is_dir():
+ remove_directory(directory_path)
+
+ directory_path.mkdir()
+
+
+def read_file(file_path: Path, mode: Optional[str] = None) -> Any:
+ """Read file as string or bytearray."""
+ if file_path.is_file():
+ if mode is not None:
+ # Ignore pylint warning because mode can be 'binary' as well which
+ # is not compatible with specifying encodings.
+ with open(file_path, mode) as file: # pylint: disable=unspecified-encoding
+ return file.read()
+ else:
+ with open(file_path, encoding="utf-8") as file:
+ return file.read()
+
+ if mode == "rb":
+ return b""
+ return ""
+
+
+def read_file_as_string(file_path: Path) -> str:
+ """Read file as string."""
+ return str(read_file(file_path))
+
+
+def read_file_as_bytearray(file_path: Path) -> bytearray:
+ """Read a file as bytearray."""
+ return bytearray(read_file(file_path, mode="rb"))
+
+
+def valid_for_filename(value: str, replacement: str = "") -> str:
+ """Replace non alpha numeric characters."""
+ return re.sub(r"[^\w.]", replacement, value, flags=re.ASCII)
diff --git a/src/aiet/utils/helpers.py b/src/aiet/utils/helpers.py
new file mode 100644
index 0000000..6d3cd22
--- /dev/null
+++ b/src/aiet/utils/helpers.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Helpers functions."""
+import logging
+from typing import Any
+
+
+def set_verbosity(
+ ctx: Any, option: Any, verbosity: Any # pylint: disable=unused-argument
+) -> None:
+ """Set the logging level according to the verbosity."""
+ # Unused arguments must be present here in definition as these are required in
+ # function definition when set as a callback
+ if verbosity == 1:
+ logging.getLogger().setLevel(logging.INFO)
+ elif verbosity > 1:
+ logging.getLogger().setLevel(logging.DEBUG)
diff --git a/src/aiet/utils/proc.py b/src/aiet/utils/proc.py
new file mode 100644
index 0000000..b6f4357
--- /dev/null
+++ b/src/aiet/utils/proc.py
@@ -0,0 +1,283 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Processes module.
+
+This module contains all classes and functions for dealing with Linux
+processes.
+"""
+import csv
+import datetime
+import logging
+import shlex
+import signal
+import time
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+from typing import Tuple
+
+import psutil
+from sh import Command
+from sh import CommandNotFound
+from sh import ErrorReturnCode
+from sh import RunningCommand
+
+from aiet.utils.fs import valid_for_filename
+
+
+class CommandFailedException(Exception):
+ """Exception for failed command execution."""
+
+
+class ShellCommand:
+ """Wrapper class for shell commands."""
+
+ def __init__(self, base_log_path: str = "/tmp") -> None:
+ """Initialise the class.
+
+ base_log_path: it is the base directory where logs will be stored
+ """
+ self.base_log_path = base_log_path
+
+ def run(
+ self,
+ cmd: str,
+ *args: str,
+ _cwd: Optional[Path] = None,
+ _tee: bool = True,
+ _bg: bool = True,
+ _out: Any = None,
+ _err: Any = None,
+ _search_paths: Optional[List[Path]] = None
+ ) -> RunningCommand:
+ """Run the shell command with the given arguments.
+
+ There are special arguments that modify the behaviour of the process.
+ _cwd: current working directory
+ _tee: it redirects the stdout both to console and file
+ _bg: if True, it runs the process in background and the command is not
+ blocking.
+ _out: use this object for stdout redirect,
+ _err: use this object for stderr redirect,
+ _search_paths: If presented used for searching executable
+ """
+ try:
+ kwargs = {}
+ if _cwd:
+ kwargs["_cwd"] = str(_cwd)
+ command = Command(cmd, _search_paths).bake(args, **kwargs)
+ except CommandNotFound as error:
+ logging.error("Command '%s' not found", error.args[0])
+ raise error
+
+ out, err = _out, _err
+ if not _out and not _err:
+ out, err = [
+ str(item)
+ for item in self.get_stdout_stderr_paths(self.base_log_path, cmd)
+ ]
+
+ return command(_out=out, _err=err, _tee=_tee, _bg=_bg, _bg_exc=False)
+
+ @classmethod
+ def get_stdout_stderr_paths(cls, base_log_path: str, cmd: str) -> Tuple[Path, Path]:
+ """Construct and returns the paths of stdout/stderr files."""
+ timestamp = datetime.datetime.now().timestamp()
+ base_path = Path(base_log_path)
+ base = base_path / "{}_{}".format(valid_for_filename(cmd, "_"), timestamp)
+ stdout = base.with_suffix(".out")
+ stderr = base.with_suffix(".err")
+ try:
+ stdout.touch()
+ stderr.touch()
+ except FileNotFoundError as error:
+ logging.error("File not found: %s", error.filename)
+ raise error
+ return stdout, stderr
+
+
+def parse_command(command: str, shell: str = "bash") -> List[str]:
+ """Parse command."""
+ cmd, *args = shlex.split(command, posix=True)
+
+ if is_shell_script(cmd):
+ args = [cmd] + args
+ cmd = shell
+
+ return [cmd] + args
+
+
+def get_stdout_stderr_paths(
+ command: str, base_log_path: str = "/tmp"
+) -> Tuple[Path, Path]:
+ """Construct and returns the paths of stdout/stderr files."""
+ cmd, *_ = parse_command(command)
+
+ return ShellCommand.get_stdout_stderr_paths(base_log_path, cmd)
+
+
+def execute_command( # pylint: disable=invalid-name
+ command: str,
+ cwd: Path,
+ bg: bool = False,
+ shell: str = "bash",
+ out: Any = None,
+ err: Any = None,
+) -> RunningCommand:
+ """Execute shell command."""
+ cmd, *args = parse_command(command, shell)
+
+ search_paths = None
+ if cmd != shell and (cwd / cmd).is_file():
+ search_paths = [cwd]
+
+ return ShellCommand().run(
+ cmd, *args, _cwd=cwd, _bg=bg, _search_paths=search_paths, _out=out, _err=err
+ )
+
+
+def is_shell_script(cmd: str) -> bool:
+ """Check if command is shell script."""
+ return cmd.endswith(".sh")
+
+
+def run_and_wait(
+ command: str,
+ cwd: Path,
+ terminate_on_error: bool = False,
+ out: Any = None,
+ err: Any = None,
+) -> Tuple[int, bytearray, bytearray]:
+ """
+ Run command and wait while it is executing.
+
+ Returns a tuple: (exit_code, stdout, stderr)
+ """
+ running_cmd: Optional[RunningCommand] = None
+ try:
+ running_cmd = execute_command(command, cwd, bg=True, out=out, err=err)
+ return running_cmd.exit_code, running_cmd.stdout, running_cmd.stderr
+ except ErrorReturnCode as cmd_failed:
+ raise CommandFailedException() from cmd_failed
+ except Exception as error:
+ is_running = running_cmd is not None and running_cmd.is_alive()
+ if terminate_on_error and is_running:
+ print("Terminating ...")
+ terminate_command(running_cmd)
+
+ raise error
+
+
+def terminate_command(
+ running_cmd: RunningCommand,
+ wait: bool = True,
+ wait_period: float = 0.5,
+ number_of_attempts: int = 20,
+) -> None:
+ """Terminate running command."""
+ try:
+ running_cmd.process.signal_group(signal.SIGINT)
+ if wait:
+ for _ in range(number_of_attempts):
+ time.sleep(wait_period)
+ if not running_cmd.is_alive():
+ return
+ print(
+ "Unable to terminate process {}. Sending SIGTERM...".format(
+ running_cmd.process.pid
+ )
+ )
+ running_cmd.process.signal_group(signal.SIGTERM)
+ except ProcessLookupError:
+ pass
+
+
+def terminate_external_process(
+ process: psutil.Process,
+ wait_period: float = 0.5,
+ number_of_attempts: int = 20,
+ wait_for_termination: float = 5.0,
+) -> None:
+ """Terminate external process."""
+ try:
+ process.terminate()
+ for _ in range(number_of_attempts):
+ if not process.is_running():
+ return
+ time.sleep(wait_period)
+
+ if process.is_running():
+ process.terminate()
+ time.sleep(wait_for_termination)
+ except psutil.Error:
+ print("Unable to terminate process")
+
+
+class ProcessInfo(NamedTuple):
+ """Process information."""
+
+ name: str
+ executable: str
+ cwd: str
+ pid: int
+
+
+def save_process_info(pid: int, pid_file: Path) -> None:
+ """Save process information to file."""
+ try:
+ parent = psutil.Process(pid)
+ children = parent.children(recursive=True)
+ family = [parent] + children
+
+ with open(pid_file, "w", encoding="utf-8") as file:
+ csv_writer = csv.writer(file)
+ for member in family:
+ process_info = ProcessInfo(
+ name=member.name(),
+ executable=member.exe(),
+ cwd=member.cwd(),
+ pid=member.pid,
+ )
+ csv_writer.writerow(process_info)
+ except psutil.NoSuchProcess:
+ # if process does not exist or finishes before
+ # function call then nothing could be saved
+ # just ignore this exception and exit
+ pass
+
+
+def read_process_info(pid_file: Path) -> List[ProcessInfo]:
+ """Read information about previous system processes."""
+ if not pid_file.is_file():
+ return []
+
+ result = []
+ with open(pid_file, encoding="utf-8") as file:
+ csv_reader = csv.reader(file)
+ for row in csv_reader:
+ name, executable, cwd, pid = row
+ result.append(
+ ProcessInfo(name=name, executable=executable, cwd=cwd, pid=int(pid))
+ )
+
+ return result
+
+
+def print_command_stdout(command: RunningCommand) -> None:
+ """Print the stdout of a command.
+
+ The command has 2 states: running and done.
+ If the command is running, the output is taken by the running process.
+ If the command has ended its execution, the stdout is taken from stdout
+ property
+ """
+ if command.is_alive():
+ while True:
+ try:
+ print(command.next(), end="")
+ except StopIteration:
+ break
+ else:
+ print(command.stdout)
diff --git a/src/mlia/__init__.py b/src/mlia/__init__.py
new file mode 100644
index 0000000..ed9ae87
--- /dev/null
+++ b/src/mlia/__init__.py
@@ -0,0 +1,22 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Init of MLIA."""
+import logging
+import os
+
+import pkg_resources
+
+# redirect warnings to logging
+logging.captureWarnings(True)
+
+
+# as TensorFlow tries to configure root logger
+# it should be configured before importing TensorFlow
+root_logger = logging.getLogger()
+root_logger.addHandler(logging.NullHandler())
+
+
+# disable TensorFlow warning messages
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+
+__version__ = pkg_resources.get_distribution("mlia").version
diff --git a/src/mlia/api.py b/src/mlia/api.py
new file mode 100644
index 0000000..53ea4c8
--- /dev/null
+++ b/src/mlia/api.py
@@ -0,0 +1,162 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for the API functions."""
+import logging
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Union
+
+from mlia.core._typing import PathOrFileLike
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.common import AdviceCategory
+from mlia.core.context import ExecutionContext
+from mlia.core.events import EventHandler
+from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
+from mlia.devices.ethosu.handlers import EthosUEventHandler
+
+
+logger = logging.getLogger(__name__)
+
+
+_DEFAULT_OPTIMIZATION_TARGETS = [
+ {
+ "optimization_type": "pruning",
+ "optimization_target": 0.5,
+ "layers_to_optimize": None,
+ },
+ {
+ "optimization_type": "clustering",
+ "optimization_target": 32,
+ "layers_to_optimize": None,
+ },
+]
+
+
+def get_advice(
+ target_profile: str,
+ model: Union[Path, str],
+ category: Literal["all", "operators", "performance", "optimization"] = "all",
+ optimization_targets: Optional[List[Dict[str, Any]]] = None,
+ working_dir: Union[str, Path] = "mlia_output",
+ output: Optional[PathOrFileLike] = None,
+ context: Optional[ExecutionContext] = None,
+ backends: Optional[List[str]] = None,
+) -> None:
+ """Get the advice.
+
+ This function represents an entry point to the library API.
+
+ Based on provided parameters it will collect and analyze the data
+ and produce the advice.
+
+ :param target_profile: target profile identifier
+ :param model: path to the NN model
+ :param category: category of the advice. MLIA supports four categories:
+ "all", "operators", "performance", "optimization". If not provided
+ category "all" is used by default.
+ :param optimization_targets: optional model optimization targets that
+ could be used for generating advice in categories
+ "all" and "optimization."
+ :param working_dir: path to the directory that will be used for storing
+ intermediate files during execution (e.g. converted models)
+ :param output: path to the report file. If provided MLIA will save
+ report in this location. Format of the report automatically
+ detected based on file extension.
+ :param context: optional parameter which represents execution context,
+ could be used for advanced use cases
+ :param backends: A list of backends that should be used for the given
+ target. Default settings will be used if None.
+
+
+ Examples:
+ NB: Before launching MLIA, the logging functionality should be configured!
+
+ Getting the advice for the provided target profile and the model
+
+ >>> get_advice("ethos-u55-256", "path/to/the/model")
+
+ Getting the advice for the category "performance" and save result report in file
+ "report.json"
+
+ >>> get_advice("ethos-u55-256", "path/to/the/model", "performance",
+ output="report.json")
+
+ """
+ advice_category = AdviceCategory.from_string(category)
+ config_parameters = _get_config_parameters(
+ model, target_profile, backends, optimization_targets
+ )
+ event_handlers = _get_event_handlers(output)
+
+ if context is not None:
+ if context.advice_category is None:
+ context.advice_category = advice_category
+
+ if context.config_parameters is None:
+ context.config_parameters = config_parameters
+
+ if context.event_handlers is None:
+ context.event_handlers = event_handlers
+
+ if context is None:
+ context = ExecutionContext(
+ advice_category=advice_category,
+ working_dir=working_dir,
+ config_parameters=config_parameters,
+ event_handlers=event_handlers,
+ )
+
+ advisor = _get_advisor(target_profile)
+ advisor.run(context)
+
+
+def _get_advisor(target: Optional[str]) -> InferenceAdvisor:
+ """Find appropriate advisor for the target."""
+ if not target:
+ raise Exception("Target is not provided")
+
+ return EthosUInferenceAdvisor()
+
+
+def _get_config_parameters(
+ model: Union[Path, str],
+ target_profile: str,
+ backends: Optional[List[str]],
+ optimization_targets: Optional[List[Dict[str, Any]]],
+) -> Dict[str, Any]:
+ """Get configuration parameters for the advisor."""
+ advisor_parameters: Dict[str, Any] = {
+ "ethos_u_inference_advisor": {
+ "model": model,
+ "device": {
+ "target_profile": target_profile,
+ },
+ },
+ }
+ # Specifying backends is optional (default is used)
+ if backends is not None:
+ advisor_parameters["ethos_u_inference_advisor"]["backends"] = backends
+
+ if not optimization_targets:
+ optimization_targets = _DEFAULT_OPTIMIZATION_TARGETS
+
+ advisor_parameters.update(
+ {
+ "ethos_u_model_optimizations": {
+ "optimizations": [
+ optimization_targets,
+ ],
+ },
+ }
+ )
+
+ return advisor_parameters
+
+
+def _get_event_handlers(output: Optional[PathOrFileLike]) -> List[EventHandler]:
+ """Return list of the event handlers."""
+ return [EthosUEventHandler(output)]
diff --git a/src/mlia/cli/__init__.py b/src/mlia/cli/__init__.py
new file mode 100644
index 0000000..f50778e
--- /dev/null
+++ b/src/mlia/cli/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI module."""
diff --git a/src/mlia/cli/commands.py b/src/mlia/cli/commands.py
new file mode 100644
index 0000000..45c7c32
--- /dev/null
+++ b/src/mlia/cli/commands.py
@@ -0,0 +1,276 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI commands module.
+
+This module contains functions which implement main app
+functionality.
+
+Before running them from scripts 'logging' module should
+be configured. Function 'setup_logging' from module
+'mli.cli.logging' could be used for that, e.g.
+
+>>> from mlia.api import ExecutionContext
+>>> from mlia.cli.logging import setup_logging
+>>> setup_logging(verbose=True)
+>>> import mlia.cli.commands as mlia
+>>> mlia.all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "path/to/model")
+"""
+import logging
+from pathlib import Path
+from typing import cast
+from typing import List
+from typing import Optional
+
+from mlia.api import ExecutionContext
+from mlia.api import get_advice
+from mlia.api import PathOrFileLike
+from mlia.cli.config import get_installation_manager
+from mlia.cli.options import parse_optimization_parameters
+from mlia.devices.ethosu.operators import generate_supported_operators_report
+from mlia.utils.console import create_section_header
+from mlia.utils.types import only_one_selected
+
+logger = logging.getLogger(__name__)
+
+CONFIG = create_section_header("ML Inference Advisor configuration")
+
+
+def all_tests(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: str,
+ optimization_type: str = "pruning,clustering",
+ optimization_target: str = "0.5,32",
+ output: Optional[PathOrFileLike] = None,
+ evaluate_on: Optional[List[str]] = None,
+) -> None:
+ """Generate a full report on the input model.
+
+ This command runs a series of tests in order to generate a
+ comprehensive report/advice:
+
+ - converts the input Keras model into TFLite format
+ - checks the model for operator compatibility on the specified device
+ - applies optimizations to the model and estimates the resulting performance
+ on both the original and the optimized models
+ - generates a final report on the steps above
+ - provides advice on how to (possibly) improve the inference performance
+
+ :param ctx: execution context
+ :param target_profile: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the Keras model
+ :param optimization_type: list of the optimization techniques separated
+ by comma, e.g. 'pruning,clustering'
+ :param optimization_target: list of the corresponding targets for
+ the provided optimization techniques, e.g. '0.5,32'
+ :param output: path to the file where the report will be saved
+ :param evaluate_on: list of the backends to use for evaluation
+
+ Example:
+ Run command for the target profile ethos-u55-256 with two model optimizations
+ and save report in json format locally in the file report.json
+
+ >>> from mlia.api import ExecutionContext
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import all_tests
+ >>> all_tests(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "model.h5", "pruning,clustering", "0.5,32",
+ output="report.json")
+ """
+ opt_params = parse_optimization_parameters(
+ optimization_type,
+ optimization_target,
+ )
+
+ get_advice(
+ target_profile,
+ model,
+ "all",
+ optimization_targets=opt_params,
+ output=output,
+ context=ctx,
+ backends=evaluate_on,
+ )
+
+
+def operators(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: Optional[str] = None,
+ output: Optional[PathOrFileLike] = None,
+ supported_ops_report: bool = False,
+) -> None:
+ """Print the model's operator list.
+
+ This command checks the operator compatibility of the input model with
+ the specific target profile. Generates a report of the operator placement
+ (NPU or CPU fallback) and advice on how to improve it (if necessary).
+
+ :param ctx: execution context
+ :param target_profile: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the model, which can be TFLite or Keras
+ :param output: path to the file where the report will be saved
+ :param supported_ops_report: if True then generates supported operators
+ report in current directory and exits
+
+ Example:
+ Run command for the target profile ethos-u55-256 and the provided
+ TFLite model and print report on the standard output
+
+ >>> from mlia.api import ExecutionContext
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import operators
+ >>> operators(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "model.tflite")
+ """
+ if supported_ops_report:
+ generate_supported_operators_report()
+ logger.info("Report saved into SUPPORTED_OPS.md")
+ return
+
+ if not model:
+ raise Exception("Model is not provided")
+
+ get_advice(
+ target_profile,
+ model,
+ "operators",
+ output=output,
+ context=ctx,
+ )
+
+
+def performance(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: str,
+ output: Optional[PathOrFileLike] = None,
+ evaluate_on: Optional[List[str]] = None,
+) -> None:
+ """Print the model's performance stats.
+
+ This command estimates the inference performance of the input model
+ on the specified target profile, and generates a report with advice on how
+ to improve it.
+
+ :param ctx: execution context
+ :param target_profile: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the model, which can be TFLite or Keras
+ :param output: path to the file where the report will be saved
+ :param evaluate_on: list of the backends to use for evaluation
+
+ Example:
+ Run command for the target profile ethos-u55-256 and
+ the provided TFLite model and print report on the standard output
+
+ >>> from mlia.api import ExecutionContext
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import performance
+ >>> performance(ExecutionContext(working_dir="mlia_output"), "ethos-u55-256",
+ "model.tflite")
+ """
+ get_advice(
+ target_profile,
+ model,
+ "performance",
+ output=output,
+ context=ctx,
+ backends=evaluate_on,
+ )
+
+
+def optimization(
+ ctx: ExecutionContext,
+ target_profile: str,
+ model: str,
+ optimization_type: str,
+ optimization_target: str,
+ layers_to_optimize: Optional[List[str]] = None,
+ output: Optional[PathOrFileLike] = None,
+ evaluate_on: Optional[List[str]] = None,
+) -> None:
+ """Show the performance improvements (if any) after applying the optimizations.
+
+ This command applies the selected optimization techniques (up to the
+ indicated targets) and generates a report with advice on how to improve
+ the inference performance (if possible).
+
+ :param ctx: execution context
+ :param target: target profile identifier. Will load appropriate parameters
+ from the profile.json file based on this argument.
+ :param model: path to the TFLite model
+ :param optimization_type: list of the optimization techniques separated
+ by comma, e.g. 'pruning,clustering'
+ :param optimization_target: list of the corresponding targets for
+ the provided optimization techniques, e.g. '0.5,32'
+ :param layers_to_optimize: list of the layers of the model which should be
+ optimized, if None then all layers are used
+ :param output: path to the file where the report will be saved
+ :param evaluate_on: list of the backends to use for evaluation
+
+ Example:
+ Run command for the target profile ethos-u55-256 and
+ the provided TFLite model and print report on the standard output
+
+ >>> from mlia.cli.logging import setup_logging
+ >>> setup_logging()
+ >>> from mlia.cli.commands import optimization
+ >>> optimization(ExecutionContext(working_dir="mlia_output"),
+ target="ethos-u55-256",
+ "model.tflite", "pruning", "0.5")
+ """
+ opt_params = parse_optimization_parameters(
+ optimization_type,
+ optimization_target,
+ layers_to_optimize=layers_to_optimize,
+ )
+
+ get_advice(
+ target_profile,
+ model,
+ "optimization",
+ optimization_targets=opt_params,
+ output=output,
+ context=ctx,
+ backends=evaluate_on,
+ )
+
+
+def backend(
+ backend_action: str,
+ path: Optional[Path] = None,
+ download: bool = False,
+ name: Optional[str] = None,
+ i_agree_to_the_contained_eula: bool = False,
+ noninteractive: bool = False,
+) -> None:
+ """Backends configuration."""
+ logger.info(CONFIG)
+
+ manager = get_installation_manager(noninteractive)
+
+ if backend_action == "status":
+ manager.show_env_details()
+
+ if backend_action == "install":
+ install_from_path = path is not None
+
+ if not only_one_selected(install_from_path, download):
+ raise Exception(
+ "Please select only one action: download or "
+ "provide path to the backend installation"
+ )
+
+ if install_from_path:
+ manager.install_from(cast(Path, path), name)
+
+ if download:
+ eula_agreement = not i_agree_to_the_contained_eula
+ manager.download_and_install(name, eula_agreement)
diff --git a/src/mlia/cli/common.py b/src/mlia/cli/common.py
new file mode 100644
index 0000000..54bd457
--- /dev/null
+++ b/src/mlia/cli/common.py
@@ -0,0 +1,38 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI common module."""
+import argparse
+from dataclasses import dataclass
+from typing import Callable
+from typing import List
+
+
+@dataclass
+class CommandInfo:
+ """Command description."""
+
+ func: Callable
+ aliases: List[str]
+ opt_groups: List[Callable[[argparse.ArgumentParser], None]]
+ is_default: bool = False
+
+ @property
+ def command_name(self) -> str:
+ """Return command name."""
+ return self.func.__name__
+
+ @property
+ def command_name_and_aliases(self) -> List[str]:
+ """Return list of command name and aliases."""
+ return [self.command_name, *self.aliases]
+
+ @property
+ def command_help(self) -> str:
+ """Return help message for the command."""
+ assert self.func.__doc__, "Command function does not have a docstring"
+ func_help = self.func.__doc__.splitlines()[0].rstrip(".")
+
+ if self.is_default:
+ func_help = f"{func_help} [default]"
+
+ return func_help
diff --git a/src/mlia/cli/config.py b/src/mlia/cli/config.py
new file mode 100644
index 0000000..838b051
--- /dev/null
+++ b/src/mlia/cli/config.py
@@ -0,0 +1,64 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Environment configuration functions."""
+import logging
+from functools import lru_cache
+from typing import List
+
+import mlia.tools.aiet_wrapper as aiet
+from mlia.tools.metadata.common import DefaultInstallationManager
+from mlia.tools.metadata.common import InstallationManager
+from mlia.tools.metadata.corstone import get_corstone_installations
+
+logger = logging.getLogger(__name__)
+
+
+def get_installation_manager(noninteractive: bool = False) -> InstallationManager:
+ """Return installation manager."""
+ backends = get_corstone_installations()
+
+ return DefaultInstallationManager(backends, noninteractive=noninteractive)
+
+
+@lru_cache
+def get_available_backends() -> List[str]:
+ """Return list of the available backends."""
+ available_backends = ["Vela"]
+
+ # Add backends using AIET
+ manager = get_installation_manager()
+ available_backends.extend(
+ (
+ backend
+ for backend in aiet.supported_backends()
+ if manager.backend_installed(backend)
+ )
+ )
+
+ return available_backends
+
+
+# List of mutually exclusive Corstone backends ordered by priority
+_CORSTONE_EXCLUSIVE_PRIORITY = ("Corstone-310", "Corstone-300")
+
+
+def get_default_backends() -> List[str]:
+ """Get default backends for evaluation."""
+ backends = get_available_backends()
+
+ # Filter backends to only include one Corstone backend
+ for corstone in _CORSTONE_EXCLUSIVE_PRIORITY:
+ if corstone in backends:
+ backends = [
+ backend
+ for backend in backends
+ if backend == corstone or backend not in _CORSTONE_EXCLUSIVE_PRIORITY
+ ]
+ break
+
+ return backends
+
+
+def is_corstone_backend(backend: str) -> bool:
+ """Check if the given backend is a Corstone backend."""
+ return backend in _CORSTONE_EXCLUSIVE_PRIORITY
diff --git a/src/mlia/cli/helpers.py b/src/mlia/cli/helpers.py
new file mode 100644
index 0000000..81d5a15
--- /dev/null
+++ b/src/mlia/cli/helpers.py
@@ -0,0 +1,116 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for various helper classes."""
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from mlia.cli.options import get_target_profile_opts
+from mlia.core.helpers import ActionResolver
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.utils.types import is_list_of
+
+
+class CLIActionResolver(ActionResolver):
+ """Helper class for generating cli commands."""
+
+ def __init__(self, args: Dict[str, Any]) -> None:
+ """Init action resolver."""
+ self.args = args
+
+ @staticmethod
+ def _general_optimization_command(model_path: Optional[str]) -> List[str]:
+ """Return general optimization command description."""
+ keras_note = []
+ if model_path is None or not is_keras_model(model_path):
+ model_path = "/path/to/keras_model"
+ keras_note = ["Note: you will need a Keras model for that."]
+
+ return [
+ *keras_note,
+ "For example: mlia optimization --optimization-type "
+ f"pruning,clustering --optimization-target 0.5,32 {model_path}",
+ "For more info: mlia optimization --help",
+ ]
+
+ @staticmethod
+ def _specific_optimization_command(
+ model_path: str,
+ device_opts: str,
+ opt_settings: List[OptimizationSettings],
+ ) -> List[str]:
+ """Return specific optimization command description."""
+ opt_types = ",".join(opt.optimization_type for opt in opt_settings)
+ opt_targs = ",".join(str(opt.optimization_target) for opt in opt_settings)
+
+ return [
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ f"mlia optimization --optimization-type {opt_types} "
+ f"--optimization-target {opt_targs}{device_opts} {model_path}",
+ ]
+
+ def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ """Return command details for applying optimizations."""
+ model_path, device_opts = self._get_model_and_device_opts()
+
+ if (opt_settings := kwargs.pop("opt_settings", None)) is None:
+ return self._general_optimization_command(model_path)
+
+ if is_list_of(opt_settings, OptimizationSettings) and model_path:
+ return self._specific_optimization_command(
+ model_path, device_opts, opt_settings
+ )
+
+ return []
+
+ def supported_operators_info(self) -> List[str]:
+ """Return command details for generating supported ops report."""
+ return [
+ "For guidance on supported operators, run: mlia operators "
+ "--supported-ops-report",
+ ]
+
+ def check_performance(self) -> List[str]:
+ """Return command details for checking performance."""
+ model_path, device_opts = self._get_model_and_device_opts()
+ if not model_path:
+ return []
+
+ return [
+ "Check the estimated performance by running the following command: ",
+ f"mlia performance{device_opts} {model_path}",
+ ]
+
+ def check_operator_compatibility(self) -> List[str]:
+ """Return command details for op compatibility."""
+ model_path, device_opts = self._get_model_and_device_opts()
+ if not model_path:
+ return []
+
+ return [
+ "Try running the following command to verify that:",
+ f"mlia operators{device_opts} {model_path}",
+ ]
+
+ def operator_compatibility_details(self) -> List[str]:
+ """Return command details for op compatibility."""
+ return ["For more details, run: mlia operators --help"]
+
+ def optimization_details(self) -> List[str]:
+ """Return command details for optimization."""
+ return ["For more info, see: mlia optimization --help"]
+
+ def _get_model_and_device_opts(
+ self, separate_device_opts: bool = True
+ ) -> Tuple[Optional[str], str]:
+ """Get model and device options."""
+ device_opts = " ".join(get_target_profile_opts(self.args))
+ if separate_device_opts and device_opts:
+ device_opts = f" {device_opts}"
+
+ model_path = self.args.get("model")
+ return model_path, device_opts
diff --git a/src/mlia/cli/logging.py b/src/mlia/cli/logging.py
new file mode 100644
index 0000000..c5fc7bd
--- /dev/null
+++ b/src/mlia/cli/logging.py
@@ -0,0 +1,117 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI logging configuration."""
+import logging
+import sys
+from pathlib import Path
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mlia.utils.logging import attach_handlers
+from mlia.utils.logging import create_log_handler
+from mlia.utils.logging import LogFilter
+
+
+_CONSOLE_DEBUG_FORMAT = "%(name)s - %(message)s"
+_FILE_DEBUG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+
+
+def setup_logging(
+ logs_dir: Optional[Union[str, Path]] = None,
+ verbose: bool = False,
+ log_filename: str = "mlia.log",
+) -> None:
+ """Set up logging.
+
+ MLIA uses module 'logging' when it needs to produce output.
+
+ :param logs_dir: path to the directory where application will save logs with
+ debug information. If the path is not provided then no log files will
+ be created during execution
+ :param verbose: enable extended logging for the tools loggers
+ :param log_filename: name of the log file in the logs directory
+ """
+ mlia_logger, *tools_loggers = [
+ logging.getLogger(logger_name)
+ for logger_name in ["mlia", "tensorflow", "py.warnings"]
+ ]
+
+ # enable debug output, actual message filtering depends on
+ # the provided parameters and being done on the handlers level
+ mlia_logger.setLevel(logging.DEBUG)
+
+ mlia_handlers = _get_mlia_handlers(logs_dir, log_filename, verbose)
+ attach_handlers(mlia_handlers, [mlia_logger])
+
+ tools_handlers = _get_tools_handlers(logs_dir, log_filename, verbose)
+ attach_handlers(tools_handlers, tools_loggers)
+
+
+def _get_mlia_handlers(
+ logs_dir: Optional[Union[str, Path]],
+ log_filename: str,
+ verbose: bool,
+) -> List[logging.Handler]:
+ """Get handlers for the MLIA loggers."""
+ result = []
+ stdout_handler = create_log_handler(
+ stream=sys.stdout,
+ log_level=logging.INFO,
+ )
+ result.append(stdout_handler)
+
+ if verbose:
+ mlia_verbose_handler = create_log_handler(
+ stream=sys.stdout,
+ log_level=logging.DEBUG,
+ log_format=_CONSOLE_DEBUG_FORMAT,
+ log_filter=LogFilter.equals(logging.DEBUG),
+ )
+ result.append(mlia_verbose_handler)
+
+ if logs_dir:
+ mlia_file_handler = create_log_handler(
+ file_path=_get_log_file(logs_dir, log_filename),
+ log_level=logging.DEBUG,
+ log_format=_FILE_DEBUG_FORMAT,
+ log_filter=LogFilter.skip(logging.INFO),
+ delay=True,
+ )
+ result.append(mlia_file_handler)
+
+ return result
+
+
+def _get_tools_handlers(
+ logs_dir: Optional[Union[str, Path]],
+ log_filename: str,
+ verbose: bool,
+) -> List[logging.Handler]:
+ """Get handler for the tools loggers."""
+ result = []
+ if verbose:
+ verbose_stdout_handler = create_log_handler(
+ stream=sys.stdout,
+ log_level=logging.DEBUG,
+ log_format=_CONSOLE_DEBUG_FORMAT,
+ )
+ result.append(verbose_stdout_handler)
+
+ if logs_dir:
+ file_handler = create_log_handler(
+ file_path=_get_log_file(logs_dir, log_filename),
+ log_level=logging.DEBUG,
+ log_format=_FILE_DEBUG_FORMAT,
+ delay=True,
+ )
+ result.append(file_handler)
+
+ return result
+
+
+def _get_log_file(logs_dir: Union[str, Path], log_filename: str) -> Path:
+ """Get the log file path."""
+ logs_dir_path = Path(logs_dir)
+ logs_dir_path.mkdir(exist_ok=True)
+ return logs_dir_path / log_filename
diff --git a/src/mlia/cli/main.py b/src/mlia/cli/main.py
new file mode 100644
index 0000000..33fcdeb
--- /dev/null
+++ b/src/mlia/cli/main.py
@@ -0,0 +1,280 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""CLI main entry point."""
+import argparse
+import logging
+import sys
+from inspect import signature
+from pathlib import Path
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from mlia import __version__
+from mlia.cli.commands import all_tests
+from mlia.cli.commands import backend
+from mlia.cli.commands import operators
+from mlia.cli.commands import optimization
+from mlia.cli.commands import performance
+from mlia.cli.common import CommandInfo
+from mlia.cli.helpers import CLIActionResolver
+from mlia.cli.logging import setup_logging
+from mlia.cli.options import add_backend_options
+from mlia.cli.options import add_custom_supported_operators_options
+from mlia.cli.options import add_debug_options
+from mlia.cli.options import add_evaluation_options
+from mlia.cli.options import add_keras_model_options
+from mlia.cli.options import add_multi_optimization_options
+from mlia.cli.options import add_optional_tflite_model_options
+from mlia.cli.options import add_output_options
+from mlia.cli.options import add_target_options
+from mlia.cli.options import add_tflite_model_options
+from mlia.core.context import ExecutionContext
+
+
+logger = logging.getLogger(__name__)
+
+INFO_MESSAGE = f"""
+ML Inference Advisor {__version__}
+
+Help the design and optimization of neural network models for efficient inference on a target CPU, GPU and NPU
+
+Supported targets:
+
+ - Ethos-U55 <op compatibility, perf estimation, model opt>
+ - Ethos-U65 <op compatibility, perf estimation, model opt>
+
+""".strip()
+
+
+def get_commands() -> List[CommandInfo]:
+ """Return commands configuration."""
+ return [
+ CommandInfo(
+ all_tests,
+ ["all"],
+ [
+ add_target_options,
+ add_keras_model_options,
+ add_multi_optimization_options,
+ add_output_options,
+ add_debug_options,
+ add_evaluation_options,
+ ],
+ True,
+ ),
+ CommandInfo(
+ operators,
+ ["ops"],
+ [
+ add_target_options,
+ add_optional_tflite_model_options,
+ add_output_options,
+ add_custom_supported_operators_options,
+ add_debug_options,
+ ],
+ ),
+ CommandInfo(
+ performance,
+ ["perf"],
+ [
+ add_target_options,
+ add_tflite_model_options,
+ add_output_options,
+ add_debug_options,
+ add_evaluation_options,
+ ],
+ ),
+ CommandInfo(
+ optimization,
+ ["opt"],
+ [
+ add_target_options,
+ add_keras_model_options,
+ add_multi_optimization_options,
+ add_output_options,
+ add_debug_options,
+ add_evaluation_options,
+ ],
+ ),
+ CommandInfo(
+ backend,
+ [],
+ [
+ add_backend_options,
+ add_debug_options,
+ ],
+ ),
+ ]
+
+
+def get_default_command() -> Optional[str]:
+ """Get name of the default command."""
+ commands = get_commands()
+
+ marked_as_default = [cmd.command_name for cmd in commands if cmd.is_default]
+ assert len(marked_as_default) <= 1, "Only one command could be marked as default"
+
+ return next(iter(marked_as_default), None)
+
+
+def get_possible_command_names() -> List[str]:
+ """Get all possible command names including aliases."""
+ return [
+ name_or_alias
+ for cmd in get_commands()
+ for name_or_alias in cmd.command_name_and_aliases
+ ]
+
+
+def init_commands(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ """Init cli subcommands."""
+ subparsers = parser.add_subparsers(title="Commands", dest="command")
+ subparsers.required = True
+
+ for command in get_commands():
+ command_parser = subparsers.add_parser(
+ command.command_name,
+ aliases=command.aliases,
+ help=command.command_help,
+ allow_abbrev=False,
+ )
+ command_parser.set_defaults(func=command.func)
+ for opt_group in command.opt_groups:
+ opt_group(command_parser)
+
+ return parser
+
+
+def setup_context(
+ args: argparse.Namespace, context_var_name: str = "ctx"
+) -> Tuple[ExecutionContext, Dict]:
+ """Set up context and resolve function parameters."""
+ ctx = ExecutionContext(
+ working_dir=args.working_dir,
+ verbose="verbose" in args and args.verbose,
+ action_resolver=CLIActionResolver(vars(args)),
+ )
+
+ # these parameters should not be passed into command function
+ skipped_params = ["func", "command", "working_dir", "verbose"]
+
+ # pass these parameters only if command expects them
+ expected_params = [context_var_name]
+ func_params = signature(args.func).parameters
+
+ params = {context_var_name: ctx, **vars(args)}
+
+ func_args = {
+ param_name: param_value
+ for param_name, param_value in params.items()
+ if param_name not in skipped_params
+ and (param_name not in expected_params or param_name in func_params)
+ }
+
+ return (ctx, func_args)
+
+
+def run_command(args: argparse.Namespace) -> int:
+ """Run command."""
+ ctx, func_args = setup_context(args)
+ setup_logging(ctx.logs_path, ctx.verbose)
+
+ logger.debug(
+ "*** This is the beginning of the command '%s' execution ***", args.command
+ )
+
+ try:
+ logger.info(INFO_MESSAGE)
+
+ args.func(**func_args)
+ return 0
+ except KeyboardInterrupt:
+ logger.error("Execution has been interrupted")
+ except Exception as err: # pylint: disable=broad-except
+ logger.error(
+ "\nExecution finished with error: %s",
+ err,
+ exc_info=err if ctx.verbose else None,
+ )
+
+ err_advice_message = (
+ f"Please check the log files in the {ctx.logs_path} for more details"
+ )
+ if not ctx.verbose:
+ err_advice_message += ", or enable verbose mode"
+
+ logger.error(err_advice_message)
+
+ return 1
+
+
+def init_common_parser() -> argparse.ArgumentParser:
+ """Init common parser."""
+ parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
+ parser.add_argument(
+ "--working-dir",
+ default=f"{Path.cwd() / 'mlia_output'}",
+ help="Path to the directory where MLIA will store logs, "
+ "models, etc. (default: %(default)s)",
+ )
+
+ return parser
+
+
+def init_subcommand_parser(parent: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ """Init subcommand parser."""
+ parser = argparse.ArgumentParser(
+ description=INFO_MESSAGE,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ parents=[parent],
+ add_help=False,
+ allow_abbrev=False,
+ )
+ parser.add_argument(
+ "-h",
+ "--help",
+ action="help",
+ default=argparse.SUPPRESS,
+ help="Show this help message and exit",
+ )
+ parser.add_argument(
+ "-v",
+ "--version",
+ action="version",
+ version=f"%(prog)s {__version__}",
+ help="Show program's version number and exit",
+ )
+
+ return parser
+
+
+def add_default_command_if_needed(args: List[str]) -> None:
+ """Add default command to the list of the arguments if needed."""
+ default_command = get_default_command()
+
+ if default_command and len(args) > 0:
+ commands = get_possible_command_names()
+ help_or_version = ["-h", "--help", "-v", "--version"]
+
+ command_is_missing = args[0] not in [*commands, *help_or_version]
+ if command_is_missing:
+ args.insert(0, default_command)
+
+
+def main(argv: Optional[List[str]] = None) -> int:
+ """Entry point of the application."""
+ common_parser = init_common_parser()
+ subcommand_parser = init_subcommand_parser(common_parser)
+ init_commands(subcommand_parser)
+
+ common_args, subcommand_args = common_parser.parse_known_args(argv)
+ add_default_command_if_needed(subcommand_args)
+
+ args = subcommand_parser.parse_args(subcommand_args, common_args)
+ return run_command(args)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/src/mlia/cli/options.py b/src/mlia/cli/options.py
new file mode 100644
index 0000000..dc5cb73
--- /dev/null
+++ b/src/mlia/cli/options.py
@@ -0,0 +1,280 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for the CLI options."""
+import argparse
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from mlia.cli.config import get_available_backends
+from mlia.cli.config import get_default_backends
+from mlia.cli.config import is_corstone_backend
+from mlia.utils.filesystem import get_supported_profile_names
+from mlia.utils.types import is_number
+
+
+def add_target_options(parser: argparse.ArgumentParser) -> None:
+ """Add target specific options."""
+ target_profiles = get_supported_profile_names()
+
+ default_target_profile = None
+ default_help = ""
+ if target_profiles:
+ default_target_profile = target_profiles[0]
+ default_help = " (default: %(default)s)"
+
+ target_group = parser.add_argument_group("target options")
+ target_group.add_argument(
+ "--target-profile",
+ choices=target_profiles,
+ default=default_target_profile,
+ help="Target profile that will set the target options "
+ "such as target, mac value, memory mode, etc. "
+ f"For the values associated with each target profile "
+ f" please refer to the documenation {default_help}.",
+ )
+
+
+def add_multi_optimization_options(parser: argparse.ArgumentParser) -> None:
+ """Add optimization specific options."""
+ multi_optimization_group = parser.add_argument_group("optimization options")
+
+ multi_optimization_group.add_argument(
+ "--optimization-type",
+ default="pruning,clustering",
+ help="List of the optimization types separated by comma (default: %(default)s)",
+ )
+ multi_optimization_group.add_argument(
+ "--optimization-target",
+ default="0.5,32",
+ help="""List of the optimization targets separated by comma,
+ (for pruning this is sparsity between (0,1),
+ for clustering this is the number of clusters (positive integer))
+ (default: %(default)s)""",
+ )
+
+
+def add_optional_tflite_model_options(parser: argparse.ArgumentParser) -> None:
+ """Add optional model specific options."""
+ model_group = parser.add_argument_group("TFLite model options")
+ # make model parameter optional
+ model_group.add_argument("model", nargs="?", help="TFLite model (optional)")
+
+
+def add_tflite_model_options(parser: argparse.ArgumentParser) -> None:
+ """Add model specific options."""
+ model_group = parser.add_argument_group("TFLite model options")
+ model_group.add_argument("model", help="TFLite model")
+
+
+def add_output_options(parser: argparse.ArgumentParser) -> None:
+ """Add output specific options."""
+ valid_extensions = ["csv", "json"]
+
+ def check_extension(filename: str) -> str:
+ """Check extension of the provided file."""
+ suffix = Path(filename).suffix
+ if suffix.startswith("."):
+ suffix = suffix[1:]
+
+ if suffix.lower() not in valid_extensions:
+ parser.error(f"Unsupported format '{suffix}'")
+
+ return filename
+
+ output_group = parser.add_argument_group("output options")
+ output_group.add_argument(
+ "--output",
+ type=check_extension,
+ help=(
+ "Name of the file where report will be saved. "
+ "Report format is automatically detected based on the file extension. "
+ f"Supported formats are: {', '.join(valid_extensions)}"
+ ),
+ )
+
+
+def add_debug_options(parser: argparse.ArgumentParser) -> None:
+ """Add debug options."""
+ debug_group = parser.add_argument_group("debug options")
+ debug_group.add_argument(
+ "--verbose", default=False, action="store_true", help="Produce verbose output"
+ )
+
+
+def add_keras_model_options(parser: argparse.ArgumentParser) -> None:
+ """Add model specific options."""
+ model_group = parser.add_argument_group("Keras model options")
+ model_group.add_argument("model", help="Keras model")
+
+
+def add_custom_supported_operators_options(parser: argparse.ArgumentParser) -> None:
+ """Add custom options for the command 'operators'."""
+ parser.add_argument(
+ "--supported-ops-report",
+ action="store_true",
+ default=False,
+ help=(
+ "Generate the SUPPORTED_OPS.md file in the "
+ "current working directory and exit"
+ ),
+ )
+
+
+def add_backend_options(parser: argparse.ArgumentParser) -> None:
+ """Add options for the backends configuration."""
+
+ def valid_directory(param: str) -> Path:
+ """Check if passed string is a valid directory path."""
+ if not (dir_path := Path(param)).is_dir():
+ parser.error(f"Invalid directory path {param}")
+
+ return dir_path
+
+ subparsers = parser.add_subparsers(title="Backend actions", dest="backend_action")
+ subparsers.required = True
+
+ install_subparser = subparsers.add_parser(
+ "install", help="Install backend", allow_abbrev=False
+ )
+ install_type_group = install_subparser.add_mutually_exclusive_group()
+ install_type_group.required = True
+ install_type_group.add_argument(
+ "--path", type=valid_directory, help="Path to the installed backend"
+ )
+ install_type_group.add_argument(
+ "--download",
+ default=False,
+ action="store_true",
+ help="Download and install backend",
+ )
+ install_subparser.add_argument(
+ "--i-agree-to-the-contained-eula",
+ default=False,
+ action="store_true",
+ help=argparse.SUPPRESS,
+ )
+ install_subparser.add_argument(
+ "--noninteractive",
+ default=False,
+ action="store_true",
+ help="Non interactive mode with automatic confirmation of every action",
+ )
+ install_subparser.add_argument(
+ "name",
+ nargs="?",
+ help="Name of the backend to install",
+ )
+
+ subparsers.add_parser("status", help="Show backends status")
+
+
+def add_evaluation_options(parser: argparse.ArgumentParser) -> None:
+ """Add evaluation options."""
+ available_backends = get_available_backends()
+ default_backends = get_default_backends()
+
+ def only_one_corstone_checker() -> Callable:
+ """
+ Return a callable to check that only one Corstone backend is passed.
+
+ Raises an exception when more than one Corstone backend is passed.
+ """
+ num_corstones = 0
+
+ def check(backend: str) -> str:
+ """Count Corstone backends and raise an exception if more than one."""
+ nonlocal num_corstones
+ if is_corstone_backend(backend):
+ num_corstones = num_corstones + 1
+ if num_corstones > 1:
+ raise argparse.ArgumentTypeError(
+ "There must be only one Corstone backend in the argument list."
+ )
+ return backend
+
+ return check
+
+ evaluation_group = parser.add_argument_group("evaluation options")
+ evaluation_group.add_argument(
+ "--evaluate-on",
+ help="Backends to use for evaluation (default: %(default)s)",
+ nargs="*",
+ choices=available_backends,
+ default=default_backends,
+ type=only_one_corstone_checker(),
+ )
+
+
+def parse_optimization_parameters(
+ optimization_type: str,
+ optimization_target: str,
+ sep: str = ",",
+ layers_to_optimize: Optional[List[str]] = None,
+) -> List[Dict[str, Any]]:
+ """Parse provided optimization parameters."""
+ if not optimization_type:
+ raise Exception("Optimization type is not provided")
+
+ if not optimization_target:
+ raise Exception("Optimization target is not provided")
+
+ opt_types = optimization_type.split(sep)
+ opt_targets = optimization_target.split(sep)
+
+ if len(opt_types) != len(opt_targets):
+ raise Exception("Wrong number of optimization targets and types")
+
+ non_numeric_targets = [
+ opt_target for opt_target in opt_targets if not is_number(opt_target)
+ ]
+ if len(non_numeric_targets) > 0:
+ raise Exception("Non numeric value for the optimization target")
+
+ optimizer_params = [
+ {
+ "optimization_type": opt_type.strip(),
+ "optimization_target": float(opt_target),
+ "layers_to_optimize": layers_to_optimize,
+ }
+ for opt_type, opt_target in zip(opt_types, opt_targets)
+ ]
+
+ return optimizer_params
+
+
+def get_target_profile_opts(device_args: Optional[Dict]) -> List[str]:
+ """Get non default values passed as parameters for the target profile."""
+ if not device_args:
+ return []
+
+ dummy_parser = argparse.ArgumentParser()
+ add_target_options(dummy_parser)
+ args = dummy_parser.parse_args([])
+
+ params_name = {
+ action.dest: param_name
+ for param_name, action in dummy_parser._option_string_actions.items() # pylint: disable=protected-access
+ }
+
+ non_default = [
+ arg_name
+ for arg_name, arg_value in device_args.items()
+ if arg_name in args and vars(args)[arg_name] != arg_value
+ ]
+
+ def construct_param(name: str, value: Any) -> List[str]:
+ """Construct parameter."""
+ if isinstance(value, list):
+ return [str(item) for v in value for item in [name, v]]
+
+ return [name, str(value)]
+
+ return [
+ item
+ for name in non_default
+ for item in construct_param(params_name[name], device_args[name])
+ ]
diff --git a/src/mlia/core/__init__.py b/src/mlia/core/__init__.py
new file mode 100644
index 0000000..49b1830
--- /dev/null
+++ b/src/mlia/core/__init__.py
@@ -0,0 +1,21 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Core module.
+
+Core module contains the main components that are used in the workflow of
+ML Inference Advisor:
+ - data collectors
+ - data analyzers
+ - advice producers
+ - event publishers
+ - event handlers
+
+The workflow of ML Inference Advisor consists of 3 stages:
+ - data collection
+ - data analysis
+ - advice generation
+
+Data is being passed from one stage to another via workflow executor.
+Results (collected data, analyzed data, advice, etc) are being published via
+publish/subscribe mechanishm.
+"""
diff --git a/src/mlia/core/_typing.py b/src/mlia/core/_typing.py
new file mode 100644
index 0000000..bda995c
--- /dev/null
+++ b/src/mlia/core/_typing.py
@@ -0,0 +1,12 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for custom type hints."""
+from pathlib import Path
+from typing import Literal
+from typing import TextIO
+from typing import Union
+
+
+FileLike = TextIO
+PathOrFileLike = Union[str, Path, FileLike]
+OutputFormat = Literal["plain_text", "csv", "json"]
diff --git a/src/mlia/core/advice_generation.py b/src/mlia/core/advice_generation.py
new file mode 100644
index 0000000..76cc1f2
--- /dev/null
+++ b/src/mlia/core/advice_generation.py
@@ -0,0 +1,106 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for advice generation."""
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from functools import wraps
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Union
+
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.events import SystemEvent
+from mlia.core.mixins import ContextMixin
+
+
+@dataclass
+class Advice:
+ """Base class for the advice."""
+
+ messages: List[str]
+
+
+@dataclass
+class AdviceEvent(SystemEvent):
+ """Advice event.
+
+ This event is published for every produced advice.
+
+ :param advice: Advice instance
+ """
+
+ advice: Advice
+
+
+class AdviceProducer(ABC):
+ """Base class for the advice producer.
+
+ Producer has two methods for advice generation:
+ - produce_advice - used to generate advice based on provided
+ data (analyzed data item from analyze stage)
+ - get_advice - used for getting generated advice
+
+ Advice producers that have predefined advice could skip
+ implementation of produce_advice method.
+ """
+
+ @abstractmethod
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Process data item and produce advice.
+
+ :param data_item: piece of data that could be used
+ for advice generation
+ """
+
+ @abstractmethod
+ def get_advice(self) -> Union[Advice, List[Advice]]:
+ """Get produced advice."""
+
+
+class ContextAwareAdviceProducer(AdviceProducer, ContextMixin):
+ """Context aware advice producer.
+
+ This class makes easier access to the Context object. Context object could
+ be automatically injected during workflow configuration.
+ """
+
+
+class FactBasedAdviceProducer(ContextAwareAdviceProducer):
+ """Advice producer based on provided facts.
+
+ This is an utility class that maintain list of generated Advice instances.
+ """
+
+ def __init__(self) -> None:
+ """Init advice producer."""
+ self.advice: List[Advice] = []
+
+ def get_advice(self) -> Union[Advice, List[Advice]]:
+ """Get produced advice."""
+ return self.advice
+
+ def add_advice(self, messages: List[str]) -> None:
+ """Add advice."""
+ self.advice.append(Advice(messages))
+
+
+def advice_category(*categories: AdviceCategory) -> Callable:
+ """Filter advice generation handler by advice category."""
+
+ def wrapper(handler: Callable) -> Callable:
+ """Wrap data handler."""
+
+ @wraps(handler)
+ def check_category(self: Any, *args: Any, **kwargs: Any) -> Any:
+ """Check if handler can produce advice for the requested category."""
+ if not self.context.any_category_enabled(*categories):
+ return
+
+ handler(self, *args, **kwargs)
+
+ return check_category
+
+ return wrapper
diff --git a/src/mlia/core/advisor.py b/src/mlia/core/advisor.py
new file mode 100644
index 0000000..868d0c7
--- /dev/null
+++ b/src/mlia/core/advisor.py
@@ -0,0 +1,21 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Inference advisor module."""
+from abc import abstractmethod
+
+from mlia.core.common import NamedEntity
+from mlia.core.context import Context
+from mlia.core.workflow import WorkflowExecutor
+
+
+class InferenceAdvisor(NamedEntity):
+ """Base class for inference advisors."""
+
+ @abstractmethod
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor execution."""
+
+ def run(self, context: Context) -> None:
+ """Run inference advisor."""
+ executor = self.configure(context)
+ executor.run()
diff --git a/src/mlia/core/common.py b/src/mlia/core/common.py
new file mode 100644
index 0000000..5fbad42
--- /dev/null
+++ b/src/mlia/core/common.py
@@ -0,0 +1,47 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common module.
+
+This module contains common interfaces/classess shared across
+core module.
+"""
+from abc import ABC
+from abc import abstractmethod
+from enum import Enum
+from typing import Any
+
+# This type is used as type alias for the items which are being passed around
+# in advisor workflow. There are no restrictions on the type of the
+# object. This alias used only to emphasize the nature of the input/output
+# arguments.
+DataItem = Any
+
+
+class AdviceCategory(Enum):
+ """Advice category.
+
+ Enumeration of advice categories supported by ML Inference Advisor.
+ """
+
+ OPERATORS = 1
+ PERFORMANCE = 2
+ OPTIMIZATION = 3
+ ALL = 4
+
+ @classmethod
+ def from_string(cls, value: str) -> "AdviceCategory":
+ """Resolve enum value from string value."""
+ category_names = [item.name for item in AdviceCategory]
+ if not value or value.upper() not in category_names:
+ raise Exception(f"Invalid advice category {value}")
+
+ return AdviceCategory[value.upper()]
+
+
+class NamedEntity(ABC):
+ """Entity with a name and description."""
+
+ @classmethod
+ @abstractmethod
+ def name(cls) -> str:
+ """Return name of the entity."""
diff --git a/src/mlia/core/context.py b/src/mlia/core/context.py
new file mode 100644
index 0000000..8b3dd2c
--- /dev/null
+++ b/src/mlia/core/context.py
@@ -0,0 +1,218 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Context module.
+
+This module contains functionality related to the Context.
+Context is an object that describes advisor working environment
+and requested behavior (advice categories, input configuration
+parameters).
+"""
+import logging
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Union
+
+from mlia.core.common import AdviceCategory
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import EventHandler
+from mlia.core.events import EventPublisher
+from mlia.core.helpers import ActionResolver
+from mlia.core.helpers import APIActionResolver
+
+logger = logging.getLogger(__name__)
+
+
+class Context(ABC):
+ """Abstract class for the execution context."""
+
+ @abstractmethod
+ def get_model_path(self, model_filename: str) -> Path:
+ """Return path for the intermediate/optimized models.
+
+ During workflow execution different parts of the advisor
+ require creating intermediate files for models.
+
+ This method allows to provide paths where those models
+ could be saved.
+
+ :param model_filename: filename of the model
+ """
+
+ @property
+ @abstractmethod
+ def event_publisher(self) -> EventPublisher:
+ """Return event publisher."""
+
+ @property
+ @abstractmethod
+ def event_handlers(self) -> Optional[List[EventHandler]]:
+ """Return list of the event_handlers."""
+
+ @property
+ @abstractmethod
+ def advice_category(self) -> Optional[AdviceCategory]:
+ """Return advice category."""
+
+ @property
+ @abstractmethod
+ def config_parameters(self) -> Optional[Mapping[str, Any]]:
+ """Return configuration parameters."""
+
+ @property
+ @abstractmethod
+ def action_resolver(self) -> ActionResolver:
+ """Return action resolver."""
+
+ @abstractmethod
+ def update(
+ self,
+ *,
+ advice_category: AdviceCategory,
+ event_handlers: List[EventHandler],
+ config_parameters: Mapping[str, Any],
+ ) -> None:
+ """Update context parameters."""
+
+ def category_enabled(self, category: AdviceCategory) -> bool:
+ """Check if category enabled."""
+ return category == self.advice_category
+
+ def any_category_enabled(self, *categories: AdviceCategory) -> bool:
+ """Return true if any category is enabled."""
+ return self.advice_category in categories
+
+ def register_event_handlers(self) -> None:
+ """Register event handlers."""
+ self.event_publisher.register_event_handlers(self.event_handlers)
+
+
+class ExecutionContext(Context):
+ """Execution context."""
+
+ def __init__(
+ self,
+ *,
+ advice_category: Optional[AdviceCategory] = None,
+ config_parameters: Optional[Mapping[str, Any]] = None,
+ working_dir: Optional[Union[str, Path]] = None,
+ event_handlers: Optional[List[EventHandler]] = None,
+ event_publisher: Optional[EventPublisher] = None,
+ verbose: bool = False,
+ logs_dir: str = "logs",
+ models_dir: str = "models",
+ action_resolver: Optional[ActionResolver] = None,
+ ) -> None:
+ """Init execution context.
+
+ :param advice_category: requested advice category
+ :param config_parameters: dictionary like object with input parameters
+ :param working_dir: path to the directory that will be used as a place
+ to store temporary files, logs, models. If not provided then
+ current working directory will be used instead
+ :param event_handlers: optional list of event handlers
+ :param event_publisher: optional event publisher instance. If not provided
+ then default implementation of event publisher will be used
+ :param verbose: enable verbose output
+ :param logs_dir: name of the directory inside working directory where
+ log files will be stored
+ :param models_dir: name of the directory inside working directory where
+ temporary models will be stored
+ :param action_resolver: instance of the action resolver that could make
+ advice actionable
+ """
+ self._advice_category = advice_category
+ self._config_parameters = config_parameters
+
+ self._working_dir_path = Path.cwd()
+ if working_dir:
+ self._working_dir_path = Path(working_dir)
+ self._working_dir_path.mkdir(exist_ok=True)
+
+ self._event_handlers = event_handlers
+ self._event_publisher = event_publisher or DefaultEventPublisher()
+ self.verbose = verbose
+ self.logs_dir = logs_dir
+ self.models_dir = models_dir
+ self._action_resolver = action_resolver or APIActionResolver()
+
+ @property
+ def advice_category(self) -> Optional[AdviceCategory]:
+ """Return advice category."""
+ return self._advice_category
+
+ @advice_category.setter
+ def advice_category(self, advice_category: AdviceCategory) -> None:
+ """Setter for the advice category."""
+ self._advice_category = advice_category
+
+ @property
+ def config_parameters(self) -> Optional[Mapping[str, Any]]:
+ """Return configuration parameters."""
+ return self._config_parameters
+
+ @config_parameters.setter
+ def config_parameters(self, config_parameters: Optional[Mapping[str, Any]]) -> None:
+ """Setter for the configuration parameters."""
+ self._config_parameters = config_parameters
+
+ @property
+ def event_handlers(self) -> Optional[List[EventHandler]]:
+ """Return list of the event handlers."""
+ return self._event_handlers
+
+ @event_handlers.setter
+ def event_handlers(self, event_handlers: List[EventHandler]) -> None:
+ """Setter for the event handlers."""
+ self._event_handlers = event_handlers
+
+ @property
+ def event_publisher(self) -> EventPublisher:
+ """Return event publisher."""
+ return self._event_publisher
+
+ @property
+ def action_resolver(self) -> ActionResolver:
+ """Return action resolver."""
+ return self._action_resolver
+
+ def get_model_path(self, model_filename: str) -> Path:
+ """Return path for the model."""
+ models_dir_path = self._working_dir_path / self.models_dir
+ models_dir_path.mkdir(exist_ok=True)
+
+ return models_dir_path / model_filename
+
+ @property
+ def logs_path(self) -> Path:
+ """Return path to the logs directory."""
+ return self._working_dir_path / self.logs_dir
+
+ def update(
+ self,
+ *,
+ advice_category: AdviceCategory,
+ event_handlers: List[EventHandler],
+ config_parameters: Mapping[str, Any],
+ ) -> None:
+ """Update context parameters."""
+ self._advice_category = advice_category
+ self._event_handlers = event_handlers
+ self._config_parameters = config_parameters
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ category = (
+ "<not set>" if self.advice_category is None else self.advice_category.name
+ )
+
+ return (
+ f"ExecutionContext: working_dir={self._working_dir_path}, "
+ f"advice_category={category}, "
+ f"config_parameters={self.config_parameters}, "
+ f"verbose={self.verbose}"
+ )
diff --git a/src/mlia/core/data_analysis.py b/src/mlia/core/data_analysis.py
new file mode 100644
index 0000000..6adb41e
--- /dev/null
+++ b/src/mlia/core/data_analysis.py
@@ -0,0 +1,70 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for data analysis."""
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from typing import List
+
+from mlia.core.common import DataItem
+from mlia.core.mixins import ContextMixin
+
+
+class DataAnalyzer(ABC):
+ """Base class for the data analysis.
+
+ Purpose of this class is to extract valuable data out of
+ collected data which could be used for advice generation.
+
+ This process consists of two steps:
+ - analyze every item of the collected data
+ - get analyzed data
+ """
+
+ @abstractmethod
+ def analyze_data(self, data_item: DataItem) -> None:
+ """Analyze data.
+
+ :param data_item: item of the collected data
+ """
+
+ @abstractmethod
+ def get_analyzed_data(self) -> List[DataItem]:
+ """Get analyzed data."""
+
+
+class ContextAwareDataAnalyzer(DataAnalyzer, ContextMixin):
+ """Context aware data analyzer.
+
+ This class makes easier access to the Context object. Context object could
+ be automatically injected during workflow configuration.
+ """
+
+
+@dataclass
+class Fact:
+ """Base class for the facts.
+
+ Fact represents some piece of knowledge about collected
+ data.
+ """
+
+
+class FactExtractor(ContextAwareDataAnalyzer):
+ """Data analyzer based on extracting facts.
+
+ Utility class that makes fact extraction easier.
+ Class maintains list of the extracted facts.
+ """
+
+ def __init__(self) -> None:
+ """Init fact extractor."""
+ self.facts: List[Fact] = []
+
+ def get_analyzed_data(self) -> List[DataItem]:
+ """Return list of the collected facts."""
+ return self.facts
+
+ def add_fact(self, fact: Fact) -> None:
+ """Add fact."""
+ self.facts.append(fact)
diff --git a/src/mlia/core/data_collection.py b/src/mlia/core/data_collection.py
new file mode 100644
index 0000000..43b6d1c
--- /dev/null
+++ b/src/mlia/core/data_collection.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for data collection.
+
+This module contains base classes for the first stage
+of the ML Inference Advisor workflow - data collection.
+"""
+from abc import abstractmethod
+
+from mlia.core.common import DataItem
+from mlia.core.common import NamedEntity
+from mlia.core.mixins import ContextMixin
+from mlia.core.mixins import ParameterResolverMixin
+
+
+class DataCollector(NamedEntity):
+ """Base class for the data collection.
+
+ Data collection is the first step in the process of the advice
+ generation.
+
+ Different implementations of this class can provide various
+ information about model or device. This information is being used
+ at later stages.
+ """
+
+ @abstractmethod
+ def collect_data(self) -> DataItem:
+ """Collect data."""
+
+
+class ContextAwareDataCollector(DataCollector, ContextMixin, ParameterResolverMixin):
+ """Context aware data collector.
+
+ This class makes easier access to the Context object. Context object could
+ be automatically injected during workflow configuration.
+ """
diff --git a/src/mlia/core/errors.py b/src/mlia/core/errors.py
new file mode 100644
index 0000000..7d6beb1
--- /dev/null
+++ b/src/mlia/core/errors.py
@@ -0,0 +1,18 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""MLIA exceptions module."""
+
+
+class ConfigurationError(Exception):
+ """Configuration error."""
+
+
+class FunctionalityNotSupportedError(Exception):
+ """Functionality is not supported error."""
+
+ def __init__(self, reason: str, description: str) -> None:
+ """Init exception."""
+ super().__init__(f"{reason}: {description}")
+
+ self.reason = reason
+ self.description = description
diff --git a/src/mlia/core/events.py b/src/mlia/core/events.py
new file mode 100644
index 0000000..10aec86
--- /dev/null
+++ b/src/mlia/core/events.py
@@ -0,0 +1,455 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for the events and related functionality.
+
+This module represents one of the main component of the workflow -
+events publishing and provides a way for delivering results to the
+calling application.
+
+Each component of the workflow can generate events of specific type.
+Application can subscribe and react to those events.
+"""
+import traceback
+import uuid
+from abc import ABC
+from abc import abstractmethod
+from contextlib import contextmanager
+from dataclasses import asdict
+from dataclasses import dataclass
+from dataclasses import field
+from functools import singledispatchmethod
+from typing import Any
+from typing import Dict
+from typing import Generator
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from mlia.core.common import DataItem
+
+
+@dataclass
+class Event:
+ """Base class for the events.
+
+ This class is used as a root node of the events class hierarchy.
+ """
+
+ event_id: str = field(init=False)
+
+ def __post_init__(self) -> None:
+ """Generate unique ID for the event."""
+ self.event_id = str(uuid.uuid4())
+
+ def compare_without_id(self, other: "Event") -> bool:
+ """Compare two events without event_id field."""
+ if not isinstance(other, Event) or self.__class__ != other.__class__:
+ return False
+
+ self_as_dict = asdict(self)
+ self_as_dict.pop("event_id")
+
+ other_as_dict = asdict(other)
+ other_as_dict.pop("event_id")
+
+ return self_as_dict == other_as_dict
+
+
+@dataclass
+class ChildEvent(Event):
+ """Child event.
+
+ This class could be used to link event with the parent event.
+ """
+
+ parent_event_id: str
+
+
+@dataclass
+class ActionStartedEvent(Event):
+ """Action started event.
+
+ This event is published when some action has been started.
+ """
+
+ action_type: str
+ params: Optional[Dict] = None
+
+
+@dataclass
+class SubActionEvent(ChildEvent):
+ """SubAction event.
+
+ This event could be used to represent some action during parent action.
+ """
+
+ action_type: str
+ params: Optional[Dict] = None
+
+
+@dataclass
+class ActionFinishedEvent(ChildEvent):
+ """Action finished event.
+
+ This event is published when some action has been finished.
+ """
+
+
+@dataclass
+class SystemEvent(Event):
+ """System event.
+
+ System event class represents events that published by components
+ of the core module. Most common example is an workflow executor
+ that publishes number of system events for starting/completion
+ of different stages/workflow.
+
+ Events that published by components outside of core module should not
+ use this class as base class.
+ """
+
+
+@dataclass
+class ExecutionStartedEvent(SystemEvent):
+ """Execution started event.
+
+ This event is published when workflow execution started.
+ """
+
+
+@dataclass
+class ExecutionFinishedEvent(SystemEvent):
+ """Execution finished event.
+
+ This event is published when workflow execution finished.
+ """
+
+
+@dataclass
+class ExecutionFailedEvent(SystemEvent):
+ """Execution failed event."""
+
+ err: Exception
+
+
+@dataclass
+class DataCollectionStageStartedEvent(SystemEvent):
+ """Data collection stage started.
+
+ This event is published when data collection stage started.
+ """
+
+
+@dataclass
+class DataCollectorSkippedEvent(SystemEvent):
+ """Data collector skipped event.
+
+ This event is published when particular data collector can
+ not provide data for the provided parameters.
+ """
+
+ data_collector: str
+ reason: str
+
+
+@dataclass
+class DataCollectionStageFinishedEvent(SystemEvent):
+ """Data collection stage finished.
+
+ This event is published when data collection stage finished.
+ """
+
+
+@dataclass
+class DataAnalysisStageStartedEvent(SystemEvent):
+ """Data analysis stage started.
+
+ This event is published when data analysis stage started.
+ """
+
+
+@dataclass
+class DataAnalysisStageFinishedEvent(SystemEvent):
+ """Data analysis stage finished.
+
+ This event is published when data analysis stage finished.
+ """
+
+
+@dataclass
+class AdviceStageStartedEvent(SystemEvent):
+ """Advace producing stage started.
+
+ This event is published when advice generation stage started.
+ """
+
+
+@dataclass
+class AdviceStageFinishedEvent(SystemEvent):
+ """Advace producing stage finished.
+
+ This event is published when advice generation stage finished.
+ """
+
+
+@dataclass
+class CollectedDataEvent(SystemEvent):
+ """Collected data event.
+
+ This event is published for every collected data item.
+
+ :param data_item: collected data item
+ """
+
+ data_item: DataItem
+
+
+@dataclass
+class AnalyzedDataEvent(SystemEvent):
+ """Analyzed data event.
+
+ This event is published for every analyzed data item.
+
+ :param data_item: analyzed data item
+ """
+
+ data_item: DataItem
+
+
+class EventHandler:
+ """Base class for the event handlers.
+
+ Each event handler should derive from this base class.
+ """
+
+ def handle_event(self, event: Event) -> None:
+ """Handle the event.
+
+ By default all published events are being passed to each
+ registered event handler. It is handler's responsibility
+ to filter events that it interested in.
+ """
+
+
+class DebugEventHandler(EventHandler):
+ """Event handler for debugging purposes.
+
+ This handler could print every published event to the
+ standard output.
+ """
+
+ def __init__(self, with_stacktrace: bool = False) -> None:
+ """Init event handler.
+
+ :param with_stacktrace: enable printing stacktrace of the
+ place where event publishing occurred.
+ """
+ self.with_stacktrace = with_stacktrace
+
+ def handle_event(self, event: Event) -> None:
+ """Handle event."""
+ print(f"Got event {event}")
+
+ if self.with_stacktrace:
+ traceback.print_stack()
+
+
+class EventDispatcherMetaclass(type):
+ """Metaclass for event dispatching.
+
+ It could be tedious to check type of the published event
+ inside event handler. Instead the following convention could be
+ established: if method name of the class starts with some
+ prefix then it is considered to be event handler of particular
+ type.
+
+ This metaclass goes through the list of class methods and
+ links all methods with the prefix "on_" to the common dispatcher
+ method.
+ """
+
+ def __new__(
+ cls,
+ clsname: str,
+ bases: Tuple,
+ namespace: Dict[str, Any],
+ event_handler_method_prefix: str = "on_",
+ ) -> Any:
+ """Create event dispatcher and link event handlers."""
+ new_class = super().__new__(cls, clsname, bases, namespace)
+
+ @singledispatchmethod
+ def dispatcher(_self: Any, _event: Event) -> Any:
+ """Event dispatcher."""
+
+ # get all class methods which starts with particular prefix
+ event_handler_methods = (
+ (item_name, item)
+ for item_name in dir(new_class)
+ if callable((item := getattr(new_class, item_name)))
+ and item_name.startswith(event_handler_method_prefix)
+ )
+
+ # link all collected event handlers to one dispatcher method
+ for method_name, method in event_handler_methods:
+ event_handler = dispatcher.register(method)
+ setattr(new_class, method_name, event_handler)
+
+ # override default handle_event method, replace it with the
+ # dispatcher
+ setattr(new_class, "handle_event", dispatcher)
+
+ return new_class
+
+
+class EventDispatcher(EventHandler, metaclass=EventDispatcherMetaclass):
+ """Event dispatcher."""
+
+
+class EventPublisher(ABC):
+ """Base class for the event publisher.
+
+ Event publisher is a intermidiate component between event emitter
+ and event consumer.
+ """
+
+ @abstractmethod
+ def register_event_handler(self, event_handler: EventHandler) -> None:
+ """Register event handler.
+
+ :param event_handler: instance of the event handler
+ """
+
+ def register_event_handlers(
+ self, event_handlers: Optional[List[EventHandler]]
+ ) -> None:
+ """Register event handlers.
+
+ Can be used for batch registration of the event handlers:
+
+ :param event_handlers: list of the event handler instances
+ """
+ if not event_handlers:
+ return
+
+ for handler in event_handlers:
+ self.register_event_handler(handler)
+
+ @abstractmethod
+ def publish_event(self, event: Event) -> None:
+ """Publish the event.
+
+ Deliver the event to the all registered event handlers.
+
+ :param event: event instance
+ """
+
+
+class DefaultEventPublisher(EventPublisher):
+ """Default event publishing implementation.
+
+ Simple implementation that maintains list of the registered event
+ handlers.
+ """
+
+ def __init__(self) -> None:
+ """Init the event publisher."""
+ self.handlers: List[EventHandler] = []
+
+ def register_event_handler(self, event_handler: EventHandler) -> None:
+ """Register the event handler.
+
+ :param event_handler: instance of the event handler
+ """
+ self.handlers.append(event_handler)
+
+ def publish_event(self, event: Event) -> None:
+ """Publish the event.
+
+ Publisher does not catch exceptions that could be raised by event handlers.
+ """
+ for handler in self.handlers:
+ handler.handle_event(event)
+
+
+@contextmanager
+def stage(
+ publisher: EventPublisher, events: Tuple[Event, Event]
+) -> Generator[None, None, None]:
+ """Generate events before and after stage.
+
+ This context manager could be used to mark start/finish
+ execution of a particular logical part of the workflow.
+ """
+ started, finished = events
+
+ publisher.publish_event(started)
+ yield
+ publisher.publish_event(finished)
+
+
+@contextmanager
+def action(
+ publisher: EventPublisher, action_type: str, params: Optional[Dict] = None
+) -> Generator[None, None, None]:
+ """Generate events before and after action."""
+ action_started = ActionStartedEvent(action_type, params)
+ action_finished = ActionFinishedEvent(action_started.event_id)
+
+ publisher.publish_event(action_started)
+ yield
+ publisher.publish_event(action_finished)
+
+
+class SystemEventsHandler(EventDispatcher):
+ """System events handler."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+
+ def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
+ """Handle ExecutionFinished event."""
+
+ def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
+ """Handle ExecutionFailed event."""
+
+ def on_data_collection_stage_started(
+ self, event: DataCollectionStageStartedEvent
+ ) -> None:
+ """Handle DataCollectionStageStarted event."""
+
+ def on_data_collection_stage_finished(
+ self, event: DataCollectionStageFinishedEvent
+ ) -> None:
+ """Handle DataCollectionStageFinished event."""
+
+ def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
+ """Handle DataCollectorSkipped event."""
+
+ def on_data_analysis_stage_started(
+ self, event: DataAnalysisStageStartedEvent
+ ) -> None:
+ """Handle DataAnalysisStageStartedEvent event."""
+
+ def on_data_analysis_stage_finished(
+ self, event: DataAnalysisStageFinishedEvent
+ ) -> None:
+ """Handle DataAnalysisStageFinishedEvent event."""
+
+ def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
+ """Handle AdviceStageStarted event."""
+
+ def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
+ """Handle AdviceStageFinished event."""
+
+ def on_collected_data(self, event: CollectedDataEvent) -> None:
+ """Handle CollectedData event."""
+
+ def on_analyzed_data(self, event: AnalyzedDataEvent) -> None:
+ """Handle AnalyzedData event."""
+
+ def on_action_started(self, event: ActionStartedEvent) -> None:
+ """Handle ActionStarted event."""
+
+ def on_action_finished(self, event: ActionFinishedEvent) -> None:
+ """Handle ActionFinished event."""
diff --git a/src/mlia/core/helpers.py b/src/mlia/core/helpers.py
new file mode 100644
index 0000000..d10ea5d
--- /dev/null
+++ b/src/mlia/core/helpers.py
@@ -0,0 +1,38 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for various helper classes."""
+# pylint: disable=no-self-use, unused-argument
+from typing import Any
+from typing import List
+
+
+class ActionResolver:
+ """Helper class for generating actions (e.g. commands with parameters)."""
+
+ def apply_optimizations(self, **kwargs: Any) -> List[str]:
+ """Return action details for applying optimizations."""
+ return []
+
+ def supported_operators_info(self) -> List[str]:
+ """Return action details for generating supported ops report."""
+ return []
+
+ def check_performance(self) -> List[str]:
+ """Return action details for checking performance."""
+ return []
+
+ def check_operator_compatibility(self) -> List[str]:
+ """Return action details for checking op compatibility."""
+ return []
+
+ def operator_compatibility_details(self) -> List[str]:
+ """Return action details for getting more information about op compatibility."""
+ return []
+
+ def optimization_details(self) -> List[str]:
+ """Return action detail for getting information about optimizations."""
+ return []
+
+
+class APIActionResolver(ActionResolver):
+ """Helper class for the actions performed through API."""
diff --git a/src/mlia/core/mixins.py b/src/mlia/core/mixins.py
new file mode 100644
index 0000000..ee03100
--- /dev/null
+++ b/src/mlia/core/mixins.py
@@ -0,0 +1,54 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Mixins module."""
+from typing import Any
+from typing import Optional
+
+from mlia.core.context import Context
+
+
+class ContextMixin:
+ """Mixin for injecting context object."""
+
+ context: Context
+
+ def set_context(self, context: Context) -> None:
+ """Context setter."""
+ self.context = context
+
+
+class ParameterResolverMixin:
+ """Mixin for parameter resolving."""
+
+ context: Context
+
+ def get_parameter(
+ self,
+ section: str,
+ name: str,
+ expected: bool = True,
+ expected_type: Optional[type] = None,
+ context: Optional[Context] = None,
+ ) -> Any:
+ """Get parameter value."""
+ ctx = context or self.context
+
+ if ctx.config_parameters is None:
+ raise Exception("Configuration parameters are not set")
+
+ section_params = ctx.config_parameters.get(section)
+ if section_params is None or not isinstance(section_params, dict):
+ raise Exception(
+ f"Parameter section {section} has wrong format, "
+ "expected to be a dictionary"
+ )
+
+ value = section_params.get(name)
+
+ if not value and expected:
+ raise Exception(f"Parameter {name} is not set")
+
+ if value and expected_type is not None and not isinstance(value, expected_type):
+ raise Exception(f"Parameter {name} expected to have type {expected_type}")
+
+ return value
diff --git a/src/mlia/core/performance.py b/src/mlia/core/performance.py
new file mode 100644
index 0000000..5433d5c
--- /dev/null
+++ b/src/mlia/core/performance.py
@@ -0,0 +1,47 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for performance estimation."""
+from abc import abstractmethod
+from typing import Callable
+from typing import Generic
+from typing import List
+from typing import TypeVar
+
+
+ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
+PerfMetricsType = TypeVar("PerfMetricsType") # pylint: disable=invalid-name
+
+
+class PerformanceEstimator(Generic[ModelType, PerfMetricsType]):
+ """Base class for the performance estimation."""
+
+ @abstractmethod
+ def estimate(self, model: ModelType) -> PerfMetricsType:
+ """Estimate performance."""
+
+
+def estimate_performance(
+ original_model: ModelType,
+ estimator: PerformanceEstimator[ModelType, PerfMetricsType],
+ model_transformations: List[Callable[[ModelType], ModelType]],
+) -> List[PerfMetricsType]:
+ """Estimate performance impact.
+
+ This function estimates performance impact on model performance after
+ applying provided transformations/optimizations.
+
+ :param original_model: object that represents a model, could be
+ instance of the model or path to the model. This depends on
+ provided performance estimator.
+ :param estimator: performance estimator
+ :param model_transformations: list of the callables each of those
+ returns object that represents optimized model
+ """
+ original_metrics = estimator.estimate(original_model)
+
+ optimized_metrics = [
+ estimator.estimate(transform(original_model))
+ for transform in model_transformations
+ ]
+
+ return [original_metrics, *optimized_metrics]
diff --git a/src/mlia/core/reporting.py b/src/mlia/core/reporting.py
new file mode 100644
index 0000000..1b75bb4
--- /dev/null
+++ b/src/mlia/core/reporting.py
@@ -0,0 +1,762 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Reporting module."""
+import csv
+import json
+import logging
+from abc import ABC
+from abc import abstractmethod
+from collections import defaultdict
+from contextlib import contextmanager
+from contextlib import ExitStack
+from dataclasses import dataclass
+from functools import partial
+from io import TextIOWrapper
+from pathlib import Path
+from textwrap import fill
+from textwrap import indent
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+
+from mlia.core._typing import FileLike
+from mlia.core._typing import OutputFormat
+from mlia.core._typing import PathOrFileLike
+from mlia.utils.console import apply_style
+from mlia.utils.console import produce_table
+from mlia.utils.logging import LoggerWriter
+from mlia.utils.types import is_list_of
+
+logger = logging.getLogger(__name__)
+
+
+class Report(ABC):
+ """Abstract class for the report."""
+
+ @abstractmethod
+ def to_json(self, **kwargs: Any) -> Any:
+ """Convert to json serializible format."""
+
+ @abstractmethod
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert to csv serializible format."""
+
+ @abstractmethod
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Convert to human readable format."""
+
+
+class ReportItem:
+ """Item of the report."""
+
+ def __init__(
+ self,
+ name: str,
+ alias: Optional[str] = None,
+ value: Optional[Union[str, int, "Cell"]] = None,
+ nested_items: Optional[List["ReportItem"]] = None,
+ ) -> None:
+ """Init the report item."""
+ self.name = name
+ self.alias = alias
+ self.value = value
+ self.nested_items = nested_items or []
+
+ @property
+ def compound(self) -> bool:
+ """Return true if item has nested items."""
+ return self.nested_items is not None and len(self.nested_items) > 0
+
+ @property
+ def raw_value(self) -> Any:
+ """Get actual item value."""
+ val = self.value
+ if isinstance(val, Cell):
+ return val.value
+
+ return val
+
+
+@dataclass
+class Format:
+ """Column or cell format.
+
+ Format could be applied either to a column or an individual cell.
+
+ :param wrap_width: width of the wrapped text value
+ :param str_fmt: string format to be applied to the value
+ :param style: text style
+ """
+
+ wrap_width: Optional[int] = None
+ str_fmt: Optional[Union[str, Callable[[Any], str]]] = None
+ style: Optional[str] = None
+
+
+@dataclass
+class Cell:
+ """Cell definition.
+
+ This a wrapper class for a particular value in the table. Could be used
+ for applying specific format to this value.
+ """
+
+ value: Any
+ fmt: Optional[Format] = None
+
+ def _apply_style(self, value: str) -> str:
+ """Apply style to the value."""
+ if self.fmt and self.fmt.style:
+ value = apply_style(value, self.fmt.style)
+
+ return value
+
+ def _get_value(self) -> str:
+ """Return cell value."""
+ if self.fmt:
+ if isinstance(self.fmt.str_fmt, str):
+ return "{:{fmt}}".format(self.value, fmt=self.fmt.str_fmt)
+
+ if callable(self.fmt.str_fmt):
+ return self.fmt.str_fmt(self.value)
+
+ return str(self.value)
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ val = self._get_value()
+ return self._apply_style(val)
+
+ def to_csv(self) -> Any:
+ """Cell definition for csv."""
+ return self.value
+
+ def to_json(self) -> Any:
+ """Cell definition for json."""
+ return self.value
+
+
+class CountAwareCell(Cell):
+ """Count aware cell."""
+
+ def __init__(
+ self,
+ value: Optional[Union[int, float]],
+ singular: str,
+ plural: str,
+ format_string: str = ",d",
+ ):
+ """Init cell instance."""
+ self.unit = singular if value == 1 else plural
+
+ def format_value(val: Optional[Union[int, float]]) -> str:
+ """Provide string representation for the value."""
+ if val is None:
+ return ""
+
+ if val == 1:
+ return f"1 {singular}"
+
+ return f"{val:{format_string}} {plural}"
+
+ super().__init__(value, Format(str_fmt=format_value))
+
+ def to_csv(self) -> Any:
+ """Cell definition for csv."""
+ return {"value": self.value, "unit": self.unit}
+
+ def to_json(self) -> Any:
+ """Cell definition for json."""
+ return {"value": self.value, "unit": self.unit}
+
+
+class BytesCell(CountAwareCell):
+ """Cell that represents memory size."""
+
+ def __init__(self, value: Optional[int]) -> None:
+ """Init cell instance."""
+ super().__init__(value, "byte", "bytes")
+
+
+class CyclesCell(CountAwareCell):
+ """Cell that represents cycles."""
+
+ def __init__(self, value: Optional[Union[int, float]]) -> None:
+ """Init cell instance."""
+ super().__init__(value, "cycle", "cycles", ",.0f")
+
+
+class ClockCell(CountAwareCell):
+ """Cell that represents clock value."""
+
+ def __init__(self, value: Optional[Union[int, float]]) -> None:
+ """Init cell instance."""
+ super().__init__(value, "Hz", "Hz", ",.0f")
+
+
+class Column:
+ """Column definition."""
+
+ def __init__(
+ self,
+ header: str,
+ alias: Optional[str] = None,
+ fmt: Optional[Format] = None,
+ only_for: Optional[List[str]] = None,
+ ) -> None:
+ """Init column definition.
+
+ :param header: column's header
+ :param alias: columns's alias, could be used as column's name
+ :param fmt: format that will be applied for all column's values
+ :param only_for: list of the formats where this column should be
+ represented. May be used to differentiate data representation in
+ different formats
+ """
+ self.header = header
+ self.alias = alias
+ self.fmt = fmt
+ self.only_for = only_for
+
+ def supports_format(self, fmt: str) -> bool:
+ """Return true if column should be shown."""
+ return not self.only_for or fmt in self.only_for
+
+
+class NestedReport(Report):
+ """Report with nested items."""
+
+ def __init__(self, name: str, alias: str, items: List[ReportItem]) -> None:
+ """Init nested report."""
+ self.name = name
+ self.alias = alias
+ self.items = items
+
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert to csv serializible format."""
+ result = {}
+
+ def collect_item_values(
+ item: ReportItem,
+ _parent: Optional[ReportItem],
+ _prev: Optional[ReportItem],
+ _level: int,
+ ) -> None:
+ """Collect item values into a dictionary.."""
+ if item.value is None:
+ return
+
+ if not isinstance(item.value, Cell):
+ result[item.alias] = item.raw_value
+ return
+
+ csv_value = item.value.to_csv()
+ if isinstance(csv_value, dict):
+ csv_value = {
+ f"{item.alias}_{key}": value for key, value in csv_value.items()
+ }
+ else:
+ csv_value = {item.alias: csv_value}
+
+ result.update(csv_value)
+
+ self._traverse(self.items, collect_item_values)
+
+ # make list out of the result dictionary
+ # first element - keys of the dictionary as headers
+ # second element - list of the dictionary values
+ return list(zip(*result.items()))
+
+ def to_json(self, **kwargs: Any) -> Any:
+ """Convert to json serializible format."""
+ per_parent: Dict[Optional[ReportItem], Dict] = defaultdict(dict)
+ result = per_parent[None]
+
+ def collect_as_dicts(
+ item: ReportItem,
+ parent: Optional[ReportItem],
+ _prev: Optional[ReportItem],
+ _level: int,
+ ) -> None:
+ """Collect item values as nested dictionaries."""
+ parent_dict = per_parent[parent]
+
+ if item.compound:
+ item_dict = per_parent[item]
+ parent_dict[item.alias] = item_dict
+ else:
+ out_dis = (
+ item.value.to_json()
+ if isinstance(item.value, Cell)
+ else item.raw_value
+ )
+ parent_dict[item.alias] = out_dis
+
+ self._traverse(self.items, collect_as_dicts)
+
+ return {self.alias: result}
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Convert to human readable format."""
+ header = f"{self.name}:\n"
+ processed_items = []
+
+ def convert_to_text(
+ item: ReportItem,
+ _parent: Optional[ReportItem],
+ prev: Optional[ReportItem],
+ level: int,
+ ) -> None:
+ """Convert item to text representation."""
+ if level >= 1 and prev is not None and (item.compound or prev.compound):
+ processed_items.append("")
+
+ val = self._item_value(item, level)
+ processed_items.append(val)
+
+ self._traverse(self.items, convert_to_text)
+ body = "\n".join(processed_items)
+
+ return header + body
+
+ @staticmethod
+ def _item_value(
+ item: ReportItem, level: int, tab_size: int = 2, column_width: int = 35
+ ) -> str:
+ """Get report item value."""
+ shift = " " * tab_size * level
+ if item.value is None:
+ return f"{shift}{item.name}:"
+
+ col1 = f"{shift}{item.name}".ljust(column_width)
+ col2 = f"{item.value}".rjust(column_width)
+
+ return col1 + col2
+
+ def _traverse(
+ self,
+ items: List[ReportItem],
+ visit_item: Callable[
+ [ReportItem, Optional[ReportItem], Optional[ReportItem], int], None
+ ],
+ level: int = 1,
+ parent: Optional[ReportItem] = None,
+ ) -> None:
+ """Traverse through items."""
+ prev = None
+ for item in items:
+ visit_item(item, parent, prev, level)
+
+ self._traverse(item.nested_items, visit_item, level + 1, item)
+ prev = item
+
+
+class Table(Report):
+ """Table definition.
+
+ This class could be used for representing tabular data.
+ """
+
+ def __init__(
+ self,
+ columns: List[Column],
+ rows: Collection,
+ name: str,
+ alias: Optional[str] = None,
+ notes: Optional[str] = None,
+ ) -> None:
+ """Init table definition.
+
+ :param columns: list of the table's columns
+ :param rows: list of the table's rows
+ :param name: name of the table
+ :param alias: alias for the table
+ """
+ self.columns = columns
+ self.rows = rows
+ self.name = name
+ self.alias = alias
+ self.notes = notes
+
+ def to_json(self, **kwargs: Any) -> Iterable:
+ """Convert table to dict object."""
+
+ def item_to_json(item: Any) -> Any:
+ value = item
+ if isinstance(item, Cell):
+ value = item.value
+
+ if isinstance(value, Table):
+ return value.to_json()
+
+ return value
+
+ json_data = [
+ {
+ col.alias or col.header: item_to_json(item)
+ for (item, col) in zip(row, self.columns)
+ if col.supports_format("json")
+ }
+ for row in self.rows
+ ]
+
+ if not self.alias:
+ return json_data
+
+ return {self.alias: json_data}
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Produce report in human readable format."""
+ nested = kwargs.get("nested", False)
+ show_headers = kwargs.get("show_headers", True)
+ show_title = kwargs.get("show_title", True)
+ table_style = kwargs.get("table_style", "default")
+ space = kwargs.get("space", False)
+
+ headers = (
+ [] if (nested or not show_headers) else [c.header for c in self.columns]
+ )
+
+ def item_to_plain_text(item: Any, col: Column) -> str:
+ """Convert item to text."""
+ if isinstance(item, Table):
+ return item.to_plain_text(nested=True, **kwargs)
+
+ if is_list_of(item, str):
+ as_text = "\n".join(item)
+ else:
+ as_text = str(item)
+
+ if col.fmt and col.fmt.wrap_width:
+ as_text = fill(as_text, col.fmt.wrap_width)
+
+ return as_text
+
+ title = ""
+ if show_title and not nested:
+ title = f"{self.name}:\n"
+
+ if space in (True, "top"):
+ title = "\n" + title
+
+ footer = ""
+ if space in (True, "bottom"):
+ footer = "\n"
+ if self.notes:
+ footer = "\n" + self.notes
+
+ formatted_rows = (
+ (
+ item_to_plain_text(item, col)
+ for item, col in zip(row, self.columns)
+ if col.supports_format("plain_text")
+ )
+ for row in self.rows
+ )
+
+ if space == "between":
+ formatted_table = "\n\n".join(
+ produce_table([row], table_style=table_style) for row in formatted_rows
+ )
+ else:
+ formatted_table = produce_table(
+ formatted_rows,
+ headers=headers,
+ table_style="nested" if nested else table_style,
+ )
+
+ return title + formatted_table + footer
+
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert table to csv format."""
+ headers = [[c.header for c in self.columns if c.supports_format("csv")]]
+
+ def item_data(item: Any) -> Any:
+ if isinstance(item, Cell):
+ return item.value
+
+ if isinstance(item, Table):
+ return ";".join(
+ str(item_data(cell)) for row in item.rows for cell in row
+ )
+
+ return item
+
+ rows = [
+ [
+ item_data(item)
+ for (item, col) in zip(row, self.columns)
+ if col.supports_format("csv")
+ ]
+ for row in self.rows
+ ]
+
+ return headers + rows
+
+
+class SingleRow(Table):
+ """Table with a single row."""
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Produce report in human readable format."""
+ if len(self.rows) != 1:
+ raise Exception("Table should have only one row")
+
+ items = "\n".join(
+ column.header.ljust(35) + str(item).rjust(25)
+ for row in self.rows
+ for item, column in zip(row, self.columns)
+ if column.supports_format("plain_text")
+ )
+
+ return "\n".join([f"{self.name}:", indent(items, " ")])
+
+
+class CompoundReport(Report):
+ """Compound report.
+
+ This class could be used for producing multiple reports at once.
+ """
+
+ def __init__(self, reports: List[Report]) -> None:
+ """Init compound report instance."""
+ self.reports = reports
+
+ def to_json(self, **kwargs: Any) -> Any:
+ """Convert to json serializible format.
+
+ Method attempts to create compound dictionary based on provided
+ parts.
+ """
+ result: Dict[str, Any] = {}
+ for item in self.reports:
+ result.update(item.to_json(**kwargs))
+
+ return result
+
+ def to_csv(self, **kwargs: Any) -> List[Any]:
+ """Convert to csv serializible format.
+
+ CSV format does support only one table. In order to be able to export
+ multiply tables they should be merged before that. This method tries to
+ do next:
+
+ - if all tables have the same length then just concatenate them
+ - if one table has many rows and other just one (two with headers), then
+ for each row in table with many rows duplicate values from other tables
+ """
+ csv_data = [item.to_csv() for item in self.reports]
+ lengths = [len(csv_item_data) for csv_item_data in csv_data]
+
+ same_length = len(set(lengths)) == 1
+ if same_length:
+ # all lists are of the same length, merge them into one
+ return [[cell for item in row for cell in item] for row in zip(*csv_data)]
+
+ main_obj_indexes = [i for i, item in enumerate(csv_data) if len(item) > 2]
+ one_main_obj = len(main_obj_indexes) == 1
+
+ reference_obj_indexes = [i for i, item in enumerate(csv_data) if len(item) == 2]
+ other_only_ref_objs = len(reference_obj_indexes) == len(csv_data) - 1
+
+ if one_main_obj and other_only_ref_objs:
+ main_obj = csv_data[main_obj_indexes[0]]
+ return [
+ item
+ + [
+ ref_item
+ for ref_table_index in reference_obj_indexes
+ for ref_item in csv_data[ref_table_index][0 if i == 0 else 1]
+ ]
+ for i, item in enumerate(main_obj)
+ ]
+
+ # write tables one after another if there is no other options
+ return [row for item in csv_data for row in item]
+
+ def to_plain_text(self, **kwargs: Any) -> str:
+ """Convert to human readable format."""
+ return "\n".join(item.to_plain_text(**kwargs) for item in self.reports)
+
+
+class CompoundFormatter:
+ """Compound data formatter."""
+
+ def __init__(self, formatters: List[Callable]) -> None:
+ """Init compound formatter."""
+ self.formatters = formatters
+
+ def __call__(self, data: Any) -> Report:
+ """Produce report."""
+ reports = [formatter(item) for item, formatter in zip(data, self.formatters)]
+ return CompoundReport(reports)
+
+
+class CustomJSONEncoder(json.JSONEncoder):
+ """Custom JSON encoder."""
+
+ def default(self, o: Any) -> Any:
+ """Support numpy types."""
+ if isinstance(o, np.integer):
+ return int(o)
+
+ if isinstance(o, np.floating):
+ return float(o)
+
+ return json.JSONEncoder.default(self, o)
+
+
+def json_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
+ """Produce report in json format."""
+ json_str = json.dumps(report.to_json(**kwargs), indent=4, cls=CustomJSONEncoder)
+ print(json_str, file=output)
+
+
+def text_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
+ """Produce report in text format."""
+ print(report.to_plain_text(**kwargs), file=output)
+
+
+def csv_reporter(report: Report, output: FileLike, **kwargs: Any) -> None:
+ """Produce report in csv format."""
+ csv_writer = csv.writer(output)
+ csv_writer.writerows(report.to_csv(**kwargs))
+
+
+def produce_report(
+ data: Any,
+ formatter: Callable[[Any], Report],
+ fmt: OutputFormat = "plain_text",
+ output: Optional[PathOrFileLike] = None,
+ **kwargs: Any,
+) -> None:
+ """Produce report based on provided data."""
+ # check if provided format value is supported
+ formats = {"json": json_reporter, "plain_text": text_reporter, "csv": csv_reporter}
+ if fmt not in formats:
+ raise Exception(f"Unknown format {fmt}")
+
+ if output is None:
+ output = cast(TextIOWrapper, LoggerWriter(logger, logging.INFO))
+
+ with ExitStack() as exit_stack:
+ if isinstance(output, (str, Path)):
+ # open file and add it to the ExitStack context manager
+ # in that case it will be automatically closed
+ stream = exit_stack.enter_context(open(output, "w", encoding="utf-8"))
+ else:
+ stream = cast(TextIOWrapper, output)
+
+ # convert data into serializable form
+ formatted_data = formatter(data)
+ # find handler for the format
+ format_handler = formats[fmt]
+ # produce report in requested format
+ format_handler(formatted_data, stream, **kwargs)
+
+
+class Reporter:
+ """Reporter class."""
+
+ def __init__(
+ self,
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+ output_format: OutputFormat = "plain_text",
+ print_as_submitted: bool = True,
+ ) -> None:
+ """Init reporter instance."""
+ self.formatter_resolver = formatter_resolver
+ self.output_format = output_format
+ self.print_as_submitted = print_as_submitted
+
+ self.data: List[Tuple[Any, Callable[[Any], Report]]] = []
+ self.delayed: List[Tuple[Any, Callable[[Any], Report]]] = []
+
+ def submit(self, data_item: Any, delay_print: bool = False, **kwargs: Any) -> None:
+ """Submit data for the report."""
+ if self.print_as_submitted and not delay_print:
+ produce_report(
+ data_item,
+ self.formatter_resolver(data_item),
+ fmt="plain_text",
+ **kwargs,
+ )
+
+ formatter = _apply_format_parameters(
+ self.formatter_resolver(data_item), self.output_format, **kwargs
+ )
+ self.data.append((data_item, formatter))
+
+ if delay_print:
+ self.delayed.append((data_item, formatter))
+
+ def print_delayed(self) -> None:
+ """Print delayed reports."""
+ if not self.delayed:
+ return
+
+ data, formatters = zip(*self.delayed)
+ produce_report(
+ data,
+ formatter=CompoundFormatter(formatters),
+ fmt="plain_text",
+ )
+ self.delayed = []
+
+ def generate_report(self, output: Optional[PathOrFileLike]) -> None:
+ """Generate report."""
+ already_printed = (
+ self.print_as_submitted
+ and self.output_format == "plain_text"
+ and output is None
+ )
+ if not self.data or already_printed:
+ return
+
+ data, formatters = zip(*self.data)
+ produce_report(
+ data,
+ formatter=CompoundFormatter(formatters),
+ fmt=self.output_format,
+ output=output,
+ )
+
+
+@contextmanager
+def get_reporter(
+ output_format: OutputFormat,
+ output: Optional[PathOrFileLike],
+ formatter_resolver: Callable[[Any], Callable[[Any], Report]],
+) -> Generator[Reporter, None, None]:
+ """Get reporter and generate report."""
+ reporter = Reporter(formatter_resolver, output_format)
+
+ yield reporter
+
+ reporter.generate_report(output)
+
+
+def _apply_format_parameters(
+ formatter: Callable[[Any], Report], output_format: OutputFormat, **kwargs: Any
+) -> Callable[[Any], Report]:
+ """Wrap report method."""
+
+ def wrapper(data: Any) -> Report:
+ report = formatter(data)
+ method_name = f"to_{output_format}"
+ method = getattr(report, method_name)
+ setattr(report, method_name, partial(method, **kwargs))
+
+ return report
+
+ return wrapper
diff --git a/src/mlia/core/workflow.py b/src/mlia/core/workflow.py
new file mode 100644
index 0000000..0245087
--- /dev/null
+++ b/src/mlia/core/workflow.py
@@ -0,0 +1,216 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for executors.
+
+This module contains implementation of the workflow
+executors.
+"""
+import itertools
+from abc import ABC
+from abc import abstractmethod
+from functools import wraps
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.advice_generation import AdviceProducer
+from mlia.core.common import DataItem
+from mlia.core.context import Context
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import AnalyzedDataEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataAnalysisStageStartedEvent
+from mlia.core.events import DataCollectionStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import Event
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.events import stage
+from mlia.core.mixins import ContextMixin
+
+
+class WorkflowExecutor(ABC):
+ """Base workflow executor."""
+
+ @abstractmethod
+ def run(self) -> None:
+ """Run the module."""
+
+
+STAGE_COLLECTION = (
+ DataCollectionStageStartedEvent(),
+ DataCollectionStageFinishedEvent(),
+)
+STAGE_ANALYSIS = (DataAnalysisStageStartedEvent(), DataAnalysisStageFinishedEvent())
+STAGE_ADVICE = (AdviceStageStartedEvent(), AdviceStageFinishedEvent())
+
+
+def on_stage(stage_events: Tuple[Event, Event]) -> Callable:
+ """Mark start/finish of the stage with appropriate events."""
+
+ def wrapper(method: Callable) -> Callable:
+ """Wrap method."""
+
+ @wraps(method)
+ def publish_events(self: Any, *args: Any, **kwargs: Any) -> Any:
+ """Publish events before and after execution."""
+ with stage(self.context.event_publisher, stage_events):
+ return method(self, *args, **kwargs)
+
+ return publish_events
+
+ return wrapper
+
+
+class DefaultWorkflowExecutor(WorkflowExecutor):
+ """Default module executor.
+
+ This is a default implementation of the workflow executor.
+ All components are launched sequentually in the same process.
+ """
+
+ def __init__(
+ self,
+ context: Context,
+ collectors: Sequence[DataCollector],
+ analyzers: Sequence[DataAnalyzer],
+ producers: Sequence[AdviceProducer],
+ before_start_events: Optional[Sequence[Event]] = None,
+ ):
+ """Init default workflow executor.
+
+ :param context: Context instance
+ :param collectors: List of the data collectors
+ :param analyzers: List of the data analyzers
+ :param producers: List of the advice producers
+ :param before_start_events: Optional list of the custom events that
+ should be published before start of the worfkow execution.
+ """
+ self.context = context
+ self.collectors = collectors
+ self.analyzers = analyzers
+ self.producers = producers
+ self.before_start_events = before_start_events
+
+ def run(self) -> None:
+ """Run the workflow."""
+ self.inject_context()
+ self.context.register_event_handlers()
+
+ try:
+ self.publish(ExecutionStartedEvent())
+
+ self.before_start()
+
+ collected_data = self.collect_data()
+ analyzed_data = self.analyze_data(collected_data)
+
+ self.produce_advice(analyzed_data)
+ except Exception as err: # pylint: disable=broad-except
+ self.publish(ExecutionFailedEvent(err))
+ else:
+ self.publish(ExecutionFinishedEvent())
+
+ def before_start(self) -> None:
+ """Run actions before start of the workflow execution."""
+ events = self.before_start_events or []
+ for event in events:
+ self.publish(event)
+
+ @on_stage(STAGE_COLLECTION)
+ def collect_data(self) -> List[DataItem]:
+ """Collect data.
+
+ Run each of data collector components and return list of
+ the collected data items.
+ """
+ collected_data = []
+ for collector in self.collectors:
+ try:
+ if (data_item := collector.collect_data()) is not None:
+ collected_data.append(data_item)
+ self.publish(CollectedDataEvent(data_item))
+ except FunctionalityNotSupportedError as err:
+ self.publish(DataCollectorSkippedEvent(collector.name(), str(err)))
+
+ return collected_data
+
+ @on_stage(STAGE_ANALYSIS)
+ def analyze_data(self, collected_data: List[DataItem]) -> List[DataItem]:
+ """Analyze data.
+
+ Pass each collected data item into each data analyzer and
+ return analyzed data.
+
+ :param collected_data: list of collected data items
+ """
+ analyzed_data = []
+ for analyzer in self.analyzers:
+ for item in collected_data:
+ analyzer.analyze_data(item)
+
+ for data_item in analyzer.get_analyzed_data():
+ analyzed_data.append(data_item)
+
+ self.publish(AnalyzedDataEvent(data_item))
+ return analyzed_data
+
+ @on_stage(STAGE_ADVICE)
+ def produce_advice(self, analyzed_data: List[DataItem]) -> None:
+ """Produce advice.
+
+ Pass each analyzed data item into each advice producer and
+ publish generated advice.
+
+ :param analyzed_data: list of analyzed data items
+ """
+ for producer in self.producers:
+ for data_item in analyzed_data:
+ producer.produce_advice(data_item)
+
+ advice = producer.get_advice()
+ if isinstance(advice, Advice):
+ advice = [advice]
+
+ for item in advice:
+ self.publish(AdviceEvent(item))
+
+ def inject_context(self) -> None:
+ """Inject context object into components.
+
+ Inject context object into components that supports context
+ injection.
+ """
+ context_aware_components = (
+ comp
+ for comp in itertools.chain(
+ self.collectors,
+ self.analyzers,
+ self.producers,
+ )
+ if isinstance(comp, ContextMixin)
+ )
+
+ for component in context_aware_components:
+ component.set_context(self.context)
+
+ def publish(self, event: Event) -> None:
+ """Publish event.
+
+ Helper method for event publising.
+
+ :param event: event instance
+ """
+ self.context.event_publisher.publish_event(event)
diff --git a/src/mlia/devices/__init__.py b/src/mlia/devices/__init__.py
new file mode 100644
index 0000000..d533f4a
--- /dev/null
+++ b/src/mlia/devices/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Devices module."""
diff --git a/src/mlia/devices/config.py b/src/mlia/devices/config.py
new file mode 100644
index 0000000..7ab6b43
--- /dev/null
+++ b/src/mlia/devices/config.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""IP configuration module."""
+
+
+class IPConfiguration: # pylint: disable=too-few-public-methods
+ """Base class for IP configuration."""
+
+ def __init__(self, target: str) -> None:
+ """Init IP configuration instance."""
+ self.target = target
diff --git a/src/mlia/devices/ethosu/__init__.py b/src/mlia/devices/ethosu/__init__.py
new file mode 100644
index 0000000..73925e1
--- /dev/null
+++ b/src/mlia/devices/ethosu/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U devices module."""
diff --git a/src/mlia/devices/ethosu/advice_generation.py b/src/mlia/devices/ethosu/advice_generation.py
new file mode 100644
index 0000000..7a818c9
--- /dev/null
+++ b/src/mlia/devices/ethosu/advice_generation.py
@@ -0,0 +1,209 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U advice generation."""
+from functools import singledispatchmethod
+from typing import List
+from typing import Union
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import advice_category
+from mlia.core.advice_generation import ContextAwareAdviceProducer
+from mlia.core.advice_generation import FactBasedAdviceProducer
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU
+from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators
+from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators
+from mlia.devices.ethosu.data_analysis import OptimizationResults
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+class EthosUAdviceProducer(FactBasedAdviceProducer):
+ """Ethos-U advice producer."""
+
+ @singledispatchmethod
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Produce advice."""
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ def handle_cpu_only_ops(self, data_item: HasCPUOnlyOperators) -> None:
+ """Advice for CPU only operators."""
+ cpu_only_ops = ",".join(sorted(set(data_item.cpu_only_ops)))
+ cpu_only_ops_num = len(data_item.cpu_only_ops)
+
+ self.add_advice(
+ [
+ f"You have at least {cpu_only_ops_num} "
+ f"operator{'s' if cpu_only_ops_num > 1 else ''} that is CPU "
+ f"only: {cpu_only_ops}.",
+ "Using operators that are supported by the NPU will "
+ "improve performance.",
+ ]
+ + self.context.action_resolver.supported_operators_info()
+ )
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ def handle_unsupported_operators(
+ self, data_item: HasUnsupportedOnNPUOperators
+ ) -> None:
+ """Advice for the unsupported operators."""
+ self.add_advice(
+ [
+ f"You have {data_item.npu_unsupported_ratio*100:.0f}% of operators "
+ "that cannot be placed on the NPU.",
+ "For better performance, please review the reasons reported "
+ "in the table, and adjust the model accordingly "
+ "where possible.",
+ ]
+ )
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL)
+ def handle_all_operators_supported(
+ self, _data_item: AllOperatorsSupportedOnNPU
+ ) -> None:
+ """Advice if all operators supported."""
+ self.add_advice(
+ [
+ "You don't have any unsupported operators, your model will "
+ "run completely on NPU."
+ ]
+ + self.context.action_resolver.check_performance()
+ )
+
+ @produce_advice.register
+ @advice_category(AdviceCategory.OPTIMIZATION, AdviceCategory.ALL)
+ def handle_optimization_results(self, data_item: OptimizationResults) -> None:
+ """Advice based on optimization results."""
+ if not data_item.diffs or len(data_item.diffs) != 1:
+ return
+
+ optim_details = data_item.diffs[0]
+ metrics = [
+ (metric_name, optim_details.opt_diffs[metric_key])
+ for (metric_name, metric_key) in (
+ ("DRAM used (KB)", "dram"),
+ ("SRAM used (KB)", "sram"),
+ ("On chip flash used (KB)", "on_chip_flash"),
+ ("Off chip flash used (KB)", "off_chip_flash"),
+ ("NPU total cycles", "npu_total_cycles"),
+ )
+ if metric_key in optim_details.opt_diffs
+ and not optim_details.opt_diffs[metric_key].same
+ ]
+
+ improved = [
+ f"- You have achieved {abs(metric_value.diff):.2f}% performance "
+ f"improvement in {metric_name}"
+ for metric_name, metric_value in metrics
+ if metric_value.improved
+ ]
+
+ degraded = [
+ f"- {metric_name} have degraded by {abs(metric_value.diff):.2f}%"
+ for metric_name, metric_value in metrics
+ if metric_value.degraded
+ ]
+
+ opts = ", ".join(str(s) for s in optim_details.opt_type)
+ messages = [f"With the selected optimization ({opts})", *improved, *degraded]
+
+ if improved:
+ if next_optimization_target := self.get_next_optimization_targets(
+ optim_details.opt_type
+ ):
+ next_optimization_target_as_str = " and/or ".join(
+ str(item) for item in next_optimization_target
+ )
+
+ messages.append(
+ "You can try to push the optimization target higher "
+ f"(e.g. {next_optimization_target_as_str}) "
+ "to check if those results can be further improved."
+ )
+ messages += self.context.action_resolver.apply_optimizations(
+ opt_settings=next_optimization_target
+ )
+
+ elif degraded:
+ messages.append(
+ "The performance seems to have degraded after "
+ "applying the selected optimizations, "
+ "try exploring different optimization types/targets."
+ )
+
+ self.add_advice(messages)
+
+ self.add_advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ )
+
+ @staticmethod
+ def get_next_optimization_targets(
+ opt_type: List[OptimizationSettings],
+ ) -> List[OptimizationSettings]:
+ """Get next optimization targets."""
+ next_targets = (item.next_target() for item in opt_type)
+
+ # filter out targets that have not been changed
+ valid_targets = [
+ next_
+ for next_, old in zip(next_targets, opt_type)
+ if (
+ old.optimization_type == "pruning"
+ and old.optimization_target < next_.optimization_target
+ )
+ or (
+ old.optimization_type == "clustering"
+ and old.optimization_target > next_.optimization_target
+ )
+ ]
+ return valid_targets
+
+
+class EthosUStaticAdviceProducer(ContextAwareAdviceProducer):
+ """Advice producer that not depends on input data."""
+
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Do not process passed data items."""
+
+ def get_advice(self) -> Union[Advice, List[Advice]]:
+ """Return predefined advice based on category."""
+ if self.context.advice_category is None:
+ return []
+
+ advice_per_category = {
+ AdviceCategory.PERFORMANCE: [
+ Advice(
+ [
+ "You can improve the inference time by using only operators "
+ "that are supported by the NPU.",
+ ]
+ + self.context.action_resolver.check_operator_compatibility()
+ ),
+ Advice(
+ [
+ "Check if you can improve the performance by applying "
+ "tooling techniques to your model."
+ ]
+ + self.context.action_resolver.apply_optimizations()
+ ),
+ ],
+ AdviceCategory.OPTIMIZATION: [
+ Advice(
+ [
+ "For better performance, make sure that all the operators "
+ "of your final TFLite model are supported by the NPU.",
+ ]
+ + self.context.action_resolver.operator_compatibility_details()
+ )
+ ],
+ }
+
+ return advice_per_category.get(self.context.advice_category, [])
diff --git a/src/mlia/devices/ethosu/advisor.py b/src/mlia/devices/ethosu/advisor.py
new file mode 100644
index 0000000..802826b
--- /dev/null
+++ b/src/mlia/devices/ethosu/advisor.py
@@ -0,0 +1,151 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U MLIA module."""
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+from mlia.core.advice_generation import AdviceProducer
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.common import AdviceCategory
+from mlia.core.context import Context
+from mlia.core.data_analysis import DataAnalyzer
+from mlia.core.data_collection import DataCollector
+from mlia.core.mixins import ParameterResolverMixin
+from mlia.core.workflow import DefaultWorkflowExecutor
+from mlia.core.workflow import WorkflowExecutor
+from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer
+from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.config import get_target
+from mlia.devices.ethosu.data_analysis import EthosUDataAnalyzer
+from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility
+from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance
+from mlia.devices.ethosu.data_collection import EthosUPerformance
+from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
+
+
+class EthosUInferenceAdvisor(InferenceAdvisor, ParameterResolverMixin):
+ """Ethos-U Inference Advisor."""
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the advisor."""
+ return "ethos_u_inference_advisor"
+
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor execution."""
+ model = self._get_model(context)
+ device = self._get_device(context)
+ backends = self._get_backends(context)
+
+ collectors = self._get_collectors(context, model, device, backends)
+ analyzers = self._get_analyzers()
+ producers = self._get_advice_producers()
+
+ return DefaultWorkflowExecutor(
+ context,
+ collectors,
+ analyzers,
+ producers,
+ before_start_events=[
+ EthosUAdvisorStartedEvent(device=device, model=model),
+ ],
+ )
+
+ def _get_collectors(
+ self,
+ context: Context,
+ model: Path,
+ device: EthosUConfiguration,
+ backends: Optional[List[str]],
+ ) -> List[DataCollector]:
+ """Get collectors."""
+ collectors: List[DataCollector] = []
+
+ if context.any_category_enabled(
+ AdviceCategory.OPERATORS,
+ AdviceCategory.ALL,
+ ):
+ collectors.append(EthosUOperatorCompatibility(model, device))
+
+ if context.category_enabled(AdviceCategory.PERFORMANCE):
+ collectors.append(EthosUPerformance(model, device, backends))
+
+ if context.any_category_enabled(
+ AdviceCategory.OPTIMIZATION,
+ AdviceCategory.ALL,
+ ):
+ optimization_settings = self._get_optimization_settings(context)
+ collectors.append(
+ EthosUOptimizationPerformance(
+ model, device, optimization_settings, backends
+ )
+ )
+
+ return collectors
+
+ @staticmethod
+ def _get_analyzers() -> List[DataAnalyzer]:
+ """Return data analyzers."""
+ return [
+ EthosUDataAnalyzer(),
+ ]
+
+ @staticmethod
+ def _get_advice_producers() -> List[AdviceProducer]:
+ """Return advice producers."""
+ return [
+ EthosUAdviceProducer(),
+ EthosUStaticAdviceProducer(),
+ ]
+
+ def _get_device(self, context: Context) -> EthosUConfiguration:
+ """Get device."""
+ device_params = self.get_parameter(
+ self.name(),
+ "device",
+ expected_type=dict,
+ context=context,
+ )
+
+ try:
+ target_profile = device_params["target_profile"]
+ except KeyError as err:
+ raise Exception("Unable to get device details") from err
+
+ return get_target(target_profile)
+
+ def _get_model(self, context: Context) -> Path:
+ """Get path to the model."""
+ model_param = self.get_parameter(
+ self.name(),
+ "model",
+ expected_type=str,
+ context=context,
+ )
+
+ if not (model := Path(model_param)).exists():
+ raise Exception(f"Path {model} does not exist")
+
+ return model
+
+ def _get_optimization_settings(self, context: Context) -> List[List[dict]]:
+ """Get optimization settings."""
+ return self.get_parameter( # type: ignore
+ EthosUOptimizationPerformance.name(),
+ "optimizations",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
+
+ def _get_backends(self, context: Context) -> Optional[List[str]]:
+ """Get list of backends."""
+ return self.get_parameter( # type: ignore
+ self.name(),
+ "backends",
+ expected_type=list,
+ expected=False,
+ context=context,
+ )
diff --git a/src/mlia/devices/ethosu/config.py b/src/mlia/devices/ethosu/config.py
new file mode 100644
index 0000000..cecbb27
--- /dev/null
+++ b/src/mlia/devices/ethosu/config.py
@@ -0,0 +1,89 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U configuration."""
+import logging
+from typing import Any
+from typing import Dict
+
+from mlia.devices.config import IPConfiguration
+from mlia.tools.vela_wrapper import resolve_compiler_config
+from mlia.tools.vela_wrapper import VelaCompilerOptions
+from mlia.utils.filesystem import get_profile
+from mlia.utils.filesystem import get_vela_config
+
+
+logger = logging.getLogger(__name__)
+
+
+class EthosUConfiguration(IPConfiguration):
+ """Ethos-U configuration."""
+
+ def __init__(self, target_profile: str) -> None:
+ """Init Ethos-U target configuration."""
+ target_data = get_profile(target_profile)
+ _check_target_data_complete(target_data)
+
+ target = target_data["target"]
+ super().__init__(target)
+
+ mac = target_data["mac"]
+ _check_device_options_valid(target, mac)
+
+ self.mac = mac
+ self.compiler_options = VelaCompilerOptions(
+ system_config=target_data["system_config"],
+ memory_mode=target_data["memory_mode"],
+ config_files=str(get_vela_config()),
+ accelerator_config=f"{self.target}-{mac}", # type: ignore
+ )
+
+ @property
+ def resolved_compiler_config(self) -> Dict[str, Any]:
+ """Resolve compiler configuration."""
+ return resolve_compiler_config(self.compiler_options)
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ return (
+ f"Ethos-U target={self.target} "
+ f"mac={self.mac} "
+ f"compiler_options={self.compiler_options}"
+ )
+
+ def __repr__(self) -> str:
+ """Return string representation."""
+ return f"<Ethos-U configuration target={self.target}>"
+
+
+def get_target(target_profile: str) -> EthosUConfiguration:
+ """Get target instance based on provided params."""
+ if not target_profile:
+ raise Exception("No target profile given")
+
+ return EthosUConfiguration(target_profile)
+
+
+def _check_target_data_complete(target_data: Dict[str, Any]) -> None:
+ """Check if profile contains all needed data."""
+ mandatory_keys = {"target", "mac", "system_config", "memory_mode"}
+ missing_keys = sorted(mandatory_keys - target_data.keys())
+
+ if missing_keys:
+ raise Exception(f"Mandatory fields missing from target profile: {missing_keys}")
+
+
+def _check_device_options_valid(target: str, mac: int) -> None:
+ """Check if mac is valid for selected device."""
+ target_mac_ranges = {
+ "ethos-u55": [32, 64, 128, 256],
+ "ethos-u65": [256, 512],
+ }
+
+ if target not in target_mac_ranges:
+ raise Exception(f"Unsupported target: {target}")
+
+ target_mac_range = target_mac_ranges[target]
+ if mac not in target_mac_range:
+ raise Exception(
+ f"Mac value for selected device should be in {target_mac_range}"
+ )
diff --git a/src/mlia/devices/ethosu/data_analysis.py b/src/mlia/devices/ethosu/data_analysis.py
new file mode 100644
index 0000000..9ed32ff
--- /dev/null
+++ b/src/mlia/devices/ethosu/data_analysis.py
@@ -0,0 +1,154 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U data analysis module."""
+from dataclasses import dataclass
+from functools import singledispatchmethod
+from typing import Dict
+from typing import List
+from typing import Union
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.core.data_analysis import FactExtractor
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.tools.vela_wrapper import Operators
+
+
+@dataclass
+class HasCPUOnlyOperators(Fact):
+ """Model has CPU only operators."""
+
+ cpu_only_ops: List[str]
+
+
+@dataclass
+class HasUnsupportedOnNPUOperators(Fact):
+ """Model has unsupported on NPU operators."""
+
+ npu_unsupported_ratio: float
+
+
+@dataclass
+class AllOperatorsSupportedOnNPU(Fact):
+ """All model's operators supported on NPU."""
+
+
+@dataclass
+class PerfMetricDiff:
+ """Performance metric difference."""
+
+ original_value: Union[int, float]
+ optimized_value: Union[int, float]
+
+ @property
+ def diff(self) -> float:
+ """Difference between metrics."""
+ if self.original_value == 0:
+ return 0
+
+ return 100 - ((self.optimized_value / self.original_value) * 100)
+
+ @property
+ def improved(self) -> bool:
+ """Return true if metric improved."""
+ return self.diff > 0
+
+ @property
+ def degraded(self) -> bool:
+ """Return true if metric degraded."""
+ return self.diff < 0
+
+ @property
+ def same(self) -> bool:
+ """Return true if metric stays the same."""
+ return self.diff == 0
+
+
+@dataclass
+class OptimizationDiff:
+ """Optimization performance impact."""
+
+ opt_type: List[OptimizationSettings]
+ opt_diffs: Dict[str, PerfMetricDiff]
+
+
+@dataclass
+class OptimizationResults(Fact):
+ """Optimization results."""
+
+ diffs: List[OptimizationDiff]
+
+
+class EthosUDataAnalyzer(FactExtractor):
+ """Ethos-U data analyzer."""
+
+ @singledispatchmethod
+ def analyze_data(self, data_item: DataItem) -> None:
+ """Analyse the data."""
+
+ @analyze_data.register
+ def analyze_operator_compatibility(self, operators: Operators) -> None:
+ """Analyse operator compatibility information."""
+ cpu_only = [op.op_type for op in operators.ops if op.cpu_only]
+ if cpu_only:
+ self.add_fact(HasCPUOnlyOperators(cpu_only))
+
+ if operators.npu_unsupported_ratio != 0:
+ self.add_fact(HasUnsupportedOnNPUOperators(operators.npu_unsupported_ratio))
+
+ if operators.npu_unsupported_ratio == 0:
+ self.add_fact(AllOperatorsSupportedOnNPU())
+
+ @analyze_data.register
+ def analyze_optimization_results(
+ self, optimization_results: OptimizationPerformanceMetrics
+ ) -> None:
+ """Analyse optimization performance metrics."""
+ optimizations = optimization_results.optimizations_perf_metrics
+ if not optimizations:
+ return
+
+ orig = optimization_results.original_perf_metrics.in_kilobytes()
+ orig_memory = orig.memory_usage
+ orig_cycles = orig.npu_cycles
+
+ diffs: List[OptimizationDiff] = []
+ for opt_type, opt_perf_metrics in optimizations:
+ opt = opt_perf_metrics.in_kilobytes()
+ opt_memory = opt.memory_usage
+ opt_cycles = opt.npu_cycles
+
+ opt_diffs: Dict[str, PerfMetricDiff] = {}
+
+ if orig_memory and opt_memory:
+ opt_diffs.update(
+ {
+ "sram": PerfMetricDiff(
+ orig_memory.sram_memory_area_size,
+ opt_memory.sram_memory_area_size,
+ ),
+ "dram": PerfMetricDiff(
+ orig_memory.dram_memory_area_size,
+ opt_memory.dram_memory_area_size,
+ ),
+ "on_chip_flash": PerfMetricDiff(
+ orig_memory.on_chip_flash_memory_area_size,
+ opt_memory.on_chip_flash_memory_area_size,
+ ),
+ "off_chip_flash": PerfMetricDiff(
+ orig_memory.off_chip_flash_memory_area_size,
+ opt_memory.off_chip_flash_memory_area_size,
+ ),
+ }
+ )
+ if orig_cycles and opt_cycles:
+ opt_diffs["npu_total_cycles"] = PerfMetricDiff(
+ orig_cycles.npu_total_cycles,
+ opt_cycles.npu_total_cycles,
+ )
+
+ diff = OptimizationDiff(opt_type=opt_type, opt_diffs=opt_diffs)
+ diffs.append(diff)
+
+ self.add_fact(OptimizationResults(diffs))
diff --git a/src/mlia/devices/ethosu/data_collection.py b/src/mlia/devices/ethosu/data_collection.py
new file mode 100644
index 0000000..291f1b8
--- /dev/null
+++ b/src/mlia/devices/ethosu/data_collection.py
@@ -0,0 +1,188 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Data collection module for Ethos-U."""
+import logging
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+from mlia.core.context import Context
+from mlia.core.data_collection import ContextAwareDataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.core.performance import estimate_performance
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.performance import EthosUPerformanceEstimator
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.nn.tensorflow.config import get_keras_model
+from mlia.nn.tensorflow.config import get_tflite_model
+from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.optimizations.select import get_optimizer
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.nn.tensorflow.utils import save_keras_model
+from mlia.tools.vela_wrapper import Operators
+from mlia.tools.vela_wrapper import supported_operators
+from mlia.utils.types import is_list_of
+
+logger = logging.getLogger(__name__)
+
+
+class EthosUOperatorCompatibility(ContextAwareDataCollector):
+ """Collect operator compatibility information."""
+
+ def __init__(self, model: Path, device: EthosUConfiguration) -> None:
+ """Init operator compatibility data collector."""
+ self.model = model
+ self.device = device
+
+ def collect_data(self) -> Operators:
+ """Collect operator compatibility information."""
+ tflite_model = get_tflite_model(self.model, self.context)
+
+ logger.info("Checking operator compatibility ...")
+ ops = supported_operators(
+ Path(tflite_model.model_path), self.device.compiler_options
+ )
+ logger.info("Done\n")
+ return ops
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the collector."""
+ return "ethos_u_operator_compatibility"
+
+
+class EthosUPerformance(ContextAwareDataCollector):
+ """Collect performance metrics."""
+
+ def __init__(
+ self,
+ model: Path,
+ device: EthosUConfiguration,
+ backends: Optional[List[str]] = None,
+ ) -> None:
+ """Init performance data collector."""
+ self.model = model
+ self.device = device
+ self.backends = backends
+
+ def collect_data(self) -> PerformanceMetrics:
+ """Collect model performance metrics."""
+ tflite_model = get_tflite_model(self.model, self.context)
+ estimator = EthosUPerformanceEstimator(
+ self.context,
+ self.device,
+ self.backends,
+ )
+
+ return estimator.estimate(tflite_model)
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the collector."""
+ return "ethos_u_performance"
+
+
+class OptimizeModel:
+ """Helper class for model optimization."""
+
+ def __init__(
+ self, context: Context, opt_settings: List[OptimizationSettings]
+ ) -> None:
+ """Init helper."""
+ self.context = context
+ self.opt_settings = opt_settings
+
+ def __call__(self, keras_model: KerasModel) -> KerasModel:
+ """Run optimization."""
+ optimizer = get_optimizer(keras_model, self.opt_settings)
+
+ opts_as_str = ", ".join(str(opt) for opt in self.opt_settings)
+ logger.info("Applying model optimizations - [%s]", opts_as_str)
+ optimizer.apply_optimization()
+
+ model = optimizer.get_model()
+ model_path = self.context.get_model_path("optimized_model.h5")
+ save_keras_model(model, model_path)
+
+ return KerasModel(model_path)
+
+
+class EthosUOptimizationPerformance(ContextAwareDataCollector):
+ """Collect performance metrics for the optimizations."""
+
+ def __init__(
+ self,
+ model: Path,
+ device: EthosUConfiguration,
+ optimizations: List[List[dict]],
+ backends: Optional[List[str]] = None,
+ ) -> None:
+ """Init performance optimizations data collector."""
+ self.model = model
+ self.device = device
+ self.optimizations = optimizations
+ self.backends = backends
+
+ def collect_data(self) -> Optional[OptimizationPerformanceMetrics]:
+ """Collect performance metrics for the optimizations."""
+ logger.info("Estimate performance ...")
+
+ if not self.optimizations:
+ raise FunctionalityNotSupportedError(
+ reason="Unable to estimate model optimizations impact",
+ description="No optimization targets provided",
+ )
+
+ opt_settings = self._parse_optimization_params(self.optimizations)
+
+ try:
+ keras_model = get_keras_model(self.model, self.context)
+ except NotImplementedError as err:
+ raise FunctionalityNotSupportedError(
+ reason="Unable to run model optimizations",
+ description=f"{self.model} is not a Keras model and "
+ "could not be converted to a Keras model",
+ ) from err
+
+ optimizers = [OptimizeModel(self.context, opts) for opts in opt_settings]
+
+ estimator = EthosUPerformanceEstimator(
+ self.context,
+ self.device,
+ self.backends,
+ )
+ original_metrics, *optimized_metrics = estimate_performance(
+ keras_model, estimator, optimizers # type: ignore
+ )
+
+ result = OptimizationPerformanceMetrics(
+ original_perf_metrics=original_metrics,
+ optimizations_perf_metrics=list(zip(opt_settings, optimized_metrics)),
+ )
+ return result
+
+ @staticmethod
+ def _parse_optimization_params(
+ optimizations: List[List[dict]],
+ ) -> List[List[OptimizationSettings]]:
+ """Parse optimization parameters."""
+ if not is_list_of(optimizations, list):
+ raise Exception("Optimization parameters expected to be a list")
+
+ return [
+ [
+ OptimizationSettings(
+ item.get("optimization_type"), # type: ignore
+ item.get("optimization_target"), # type: ignore
+ item.get("layers_to_optimized"),
+ )
+ for item in opt_configuration
+ ]
+ for opt_configuration in optimizations
+ ]
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the collector."""
+ return "ethos_u_model_optimizations"
diff --git a/src/mlia/devices/ethosu/events.py b/src/mlia/devices/ethosu/events.py
new file mode 100644
index 0000000..d5408b0
--- /dev/null
+++ b/src/mlia/devices/ethosu/events.py
@@ -0,0 +1,24 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Ethos-U MLIA module events."""
+from dataclasses import dataclass
+from pathlib import Path
+
+from mlia.core.events import Event
+from mlia.core.events import EventDispatcher
+from mlia.devices.ethosu.config import EthosUConfiguration
+
+
+@dataclass
+class EthosUAdvisorStartedEvent(Event):
+ """Event with Ethos-U advisor parameters."""
+
+ model: Path
+ device: EthosUConfiguration
+
+
+class EthosUAdvisorEventHandler(EventDispatcher):
+ """Event handler for the Ethos-U inference advisor."""
+
+ def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None:
+ """Handle EthosUAdvisorStarted event."""
diff --git a/src/mlia/devices/ethosu/handlers.py b/src/mlia/devices/ethosu/handlers.py
new file mode 100644
index 0000000..7a0c31c
--- /dev/null
+++ b/src/mlia/devices/ethosu/handlers.py
@@ -0,0 +1,146 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Event handler."""
+import logging
+from pathlib import Path
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from mlia.core._typing import OutputFormat
+from mlia.core._typing import PathOrFileLike
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.events import SystemEventsHandler
+from mlia.core.reporting import Reporter
+from mlia.devices.ethosu.events import EthosUAdvisorEventHandler
+from mlia.devices.ethosu.events import EthosUAdvisorStartedEvent
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.devices.ethosu.reporters import find_appropriate_formatter
+from mlia.tools.vela_wrapper import Operators
+from mlia.utils.console import create_section_header
+
+logger = logging.getLogger(__name__)
+
+ADV_EXECUTION_STARTED = create_section_header("ML Inference Advisor started")
+MODEL_ANALYSIS_MSG = create_section_header("Model Analysis")
+MODEL_ANALYSIS_RESULTS_MSG = create_section_header("Model Analysis Results")
+ADV_GENERATION_MSG = create_section_header("Advice Generation")
+REPORT_GENERATION_MSG = create_section_header("Report Generation")
+
+
+class WorkflowEventsHandler(SystemEventsHandler):
+ """Event handler for the system events."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+ logger.info(ADV_EXECUTION_STARTED)
+
+ def on_execution_failed(self, event: ExecutionFailedEvent) -> None:
+ """Handle ExecutionFailed event."""
+ raise event.err
+
+ def on_data_collection_stage_started(
+ self, event: DataCollectionStageStartedEvent
+ ) -> None:
+ """Handle DataCollectionStageStarted event."""
+ logger.info(MODEL_ANALYSIS_MSG)
+
+ def on_advice_stage_started(self, event: AdviceStageStartedEvent) -> None:
+ """Handle AdviceStageStarted event."""
+ logger.info(ADV_GENERATION_MSG)
+
+ def on_data_collector_skipped(self, event: DataCollectorSkippedEvent) -> None:
+ """Handle DataCollectorSkipped event."""
+ logger.info("Skipped: %s", event.reason)
+
+
+class EthosUEventHandler(WorkflowEventsHandler, EthosUAdvisorEventHandler):
+ """CLI event handler."""
+
+ def __init__(self, output: Optional[PathOrFileLike] = None) -> None:
+ """Init event handler."""
+ output_format = self.resolve_output_format(output)
+
+ self.reporter = Reporter(find_appropriate_formatter, output_format)
+ self.output = output
+ self.advice: List[Advice] = []
+
+ def on_advice_stage_finished(self, event: AdviceStageFinishedEvent) -> None:
+ """Handle AdviceStageFinishedEvent event."""
+ self.reporter.submit(
+ self.advice,
+ show_title=False,
+ show_headers=False,
+ space="between",
+ table_style="no_borders",
+ )
+
+ self.reporter.generate_report(self.output)
+
+ if self.output is not None:
+ logger.info(REPORT_GENERATION_MSG)
+ logger.info("Report(s) and advice list saved to: %s", self.output)
+
+ def on_data_analysis_stage_finished(
+ self, event: DataAnalysisStageFinishedEvent
+ ) -> None:
+ """Handle DataAnalysisStageFinished event."""
+ logger.info(MODEL_ANALYSIS_RESULTS_MSG)
+ self.reporter.print_delayed()
+
+ def on_collected_data(self, event: CollectedDataEvent) -> None:
+ """Handle CollectedDataEvent event."""
+ data_item = event.data_item
+
+ if isinstance(data_item, Operators):
+ self.reporter.submit([data_item.ops, data_item], delay_print=True)
+
+ if isinstance(data_item, PerformanceMetrics):
+ self.reporter.submit(data_item, delay_print=True)
+
+ if isinstance(data_item, OptimizationPerformanceMetrics):
+ original_metrics = data_item.original_perf_metrics
+ if not data_item.optimizations_perf_metrics:
+ return
+
+ _opt_settings, optimized_metrics = data_item.optimizations_perf_metrics[0]
+
+ self.reporter.submit(
+ [original_metrics, optimized_metrics],
+ delay_print=True,
+ columns_name="Metrics",
+ title="Performance metrics",
+ space=True,
+ )
+
+ def on_advice_event(self, event: AdviceEvent) -> None:
+ """Handle Advice event."""
+ self.advice.append(event.advice)
+
+ def on_ethos_u_advisor_started(self, event: EthosUAdvisorStartedEvent) -> None:
+ """Handle EthosUAdvisorStarted event."""
+ self.reporter.submit(event.device)
+
+ @staticmethod
+ def resolve_output_format(output: Optional[PathOrFileLike]) -> OutputFormat:
+ """Resolve output format based on the output name."""
+ output_format: OutputFormat = "plain_text"
+
+ if isinstance(output, str):
+ output_path = Path(output)
+ output_formats: Dict[str, OutputFormat] = {".csv": "csv", ".json": "json"}
+
+ if (suffix := output_path.suffix) in output_formats:
+ return output_formats[suffix]
+
+ return output_format
diff --git a/src/mlia/devices/ethosu/operators.py b/src/mlia/devices/ethosu/operators.py
new file mode 100644
index 0000000..ff0d99f
--- /dev/null
+++ b/src/mlia/devices/ethosu/operators.py
@@ -0,0 +1,14 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Operators module."""
+import logging
+
+from mlia.tools import vela_wrapper
+
+
+logger = logging.getLogger(__name__)
+
+
+def generate_supported_operators_report() -> None:
+ """Generate supported operators report."""
+ vela_wrapper.generate_supported_operators_report()
diff --git a/src/mlia/devices/ethosu/performance.py b/src/mlia/devices/ethosu/performance.py
new file mode 100644
index 0000000..b0718a5
--- /dev/null
+++ b/src/mlia/devices/ethosu/performance.py
@@ -0,0 +1,257 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Performance estimation."""
+import logging
+from dataclasses import dataclass
+from enum import Enum
+from pathlib import Path
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import mlia.tools.aiet_wrapper as aiet
+import mlia.tools.vela_wrapper as vela
+from mlia.core.context import Context
+from mlia.core.performance import PerformanceEstimator
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.nn.tensorflow.config import get_tflite_model
+from mlia.nn.tensorflow.config import ModelConfiguration
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class NPUCycles:
+ """NPU cycles metrics."""
+
+ npu_active_cycles: int
+ npu_idle_cycles: int
+ npu_total_cycles: int
+ npu_axi0_rd_data_beat_received: int
+ npu_axi0_wr_data_beat_written: int
+ npu_axi1_rd_data_beat_received: int
+
+
+BYTES_PER_KILOBYTE = 1024
+
+
+class MemorySizeType(Enum):
+ """Memory size type enumeration."""
+
+ BYTES = 0
+ KILOBYTES = 1
+
+
+@dataclass
+class MemoryUsage:
+ """Memory usage metrics."""
+
+ sram_memory_area_size: Union[int, float]
+ dram_memory_area_size: Union[int, float]
+ unknown_memory_area_size: Union[int, float]
+ on_chip_flash_memory_area_size: Union[int, float]
+ off_chip_flash_memory_area_size: Union[int, float]
+ memory_size_type: MemorySizeType = MemorySizeType.BYTES
+
+ _default_columns = [
+ "SRAM used",
+ "DRAM used",
+ "Unknown memory used",
+ "On chip flash used",
+ "Off chip flash used",
+ ]
+
+ def in_kilobytes(self) -> "MemoryUsage":
+ """Return memory usage with values in kilobytes."""
+ if self.memory_size_type == MemorySizeType.KILOBYTES:
+ return self
+
+ kilobytes = [
+ value / BYTES_PER_KILOBYTE
+ for value in [
+ self.sram_memory_area_size,
+ self.dram_memory_area_size,
+ self.unknown_memory_area_size,
+ self.on_chip_flash_memory_area_size,
+ self.off_chip_flash_memory_area_size,
+ ]
+ ]
+
+ return MemoryUsage(
+ *kilobytes, # type: ignore
+ memory_size_type=MemorySizeType.KILOBYTES,
+ )
+
+
+@dataclass
+class PerformanceMetrics:
+ """Performance metrics."""
+
+ device: EthosUConfiguration
+ npu_cycles: Optional[NPUCycles]
+ memory_usage: Optional[MemoryUsage]
+
+ def in_kilobytes(self) -> "PerformanceMetrics":
+ """Return metrics with memory usage in KiB."""
+ if self.memory_usage is None:
+ return PerformanceMetrics(self.device, self.npu_cycles, self.memory_usage)
+
+ return PerformanceMetrics(
+ self.device, self.npu_cycles, self.memory_usage.in_kilobytes()
+ )
+
+
+@dataclass
+class OptimizationPerformanceMetrics:
+ """Optimization performance metrics."""
+
+ original_perf_metrics: PerformanceMetrics
+ optimizations_perf_metrics: List[
+ Tuple[List[OptimizationSettings], PerformanceMetrics]
+ ]
+
+
+class VelaPerformanceEstimator(
+ PerformanceEstimator[Union[Path, ModelConfiguration], MemoryUsage]
+):
+ """Vela based performance estimator."""
+
+ def __init__(self, context: Context, device: EthosUConfiguration) -> None:
+ """Init Vela based performance estimator."""
+ self.context = context
+ self.device = device
+
+ def estimate(self, model: Union[Path, ModelConfiguration]) -> MemoryUsage:
+ """Estimate performance."""
+ logger.info("Getting the memory usage metrics ...")
+
+ model_path = (
+ Path(model.model_path) if isinstance(model, ModelConfiguration) else model
+ )
+
+ vela_perf_metrics = vela.estimate_performance(
+ model_path, self.device.compiler_options
+ )
+
+ memory_usage = MemoryUsage(
+ vela_perf_metrics.sram_memory_area_size,
+ vela_perf_metrics.dram_memory_area_size,
+ vela_perf_metrics.unknown_memory_area_size,
+ vela_perf_metrics.on_chip_flash_memory_area_size,
+ vela_perf_metrics.off_chip_flash_memory_area_size,
+ )
+ logger.info("Done\n")
+ return memory_usage
+
+
+class AIETPerformanceEstimator(
+ PerformanceEstimator[Union[Path, ModelConfiguration], NPUCycles]
+):
+ """AIET based performance estimator."""
+
+ def __init__(
+ self, context: Context, device: EthosUConfiguration, backend: str
+ ) -> None:
+ """Init AIET based performance estimator."""
+ self.context = context
+ self.device = device
+ self.backend = backend
+
+ def estimate(self, model: Union[Path, ModelConfiguration]) -> NPUCycles:
+ """Estimate performance."""
+ logger.info("Getting the performance metrics for '%s' ...", self.backend)
+ logger.info(
+ "WARNING: This task may require several minutes (press ctrl-c to interrupt)"
+ )
+
+ model_path = (
+ Path(model.model_path) if isinstance(model, ModelConfiguration) else model
+ )
+
+ optimized_model_path = self.context.get_model_path(
+ f"{model_path.stem}_vela.tflite"
+ )
+
+ vela.optimize_model(
+ model_path, self.device.compiler_options, optimized_model_path
+ )
+
+ model_info = aiet.ModelInfo(model_path=optimized_model_path)
+ device_info = aiet.DeviceInfo(
+ device_type=self.device.target, # type: ignore
+ mac=self.device.mac,
+ memory_mode=self.device.compiler_options.memory_mode, # type: ignore
+ )
+
+ aiet_perf_metrics = aiet.estimate_performance(
+ model_info, device_info, self.backend
+ )
+
+ npu_cycles = NPUCycles(
+ aiet_perf_metrics.npu_active_cycles,
+ aiet_perf_metrics.npu_idle_cycles,
+ aiet_perf_metrics.npu_total_cycles,
+ aiet_perf_metrics.npu_axi0_rd_data_beat_received,
+ aiet_perf_metrics.npu_axi0_wr_data_beat_written,
+ aiet_perf_metrics.npu_axi1_rd_data_beat_received,
+ )
+
+ logger.info("Done\n")
+ return npu_cycles
+
+
+class EthosUPerformanceEstimator(
+ PerformanceEstimator[Union[Path, ModelConfiguration], PerformanceMetrics]
+):
+ """Ethos-U performance estimator."""
+
+ def __init__(
+ self,
+ context: Context,
+ device: EthosUConfiguration,
+ backends: Optional[List[str]] = None,
+ ) -> None:
+ """Init performance estimator."""
+ self.context = context
+ self.device = device
+ if backends is None:
+ backends = ["Vela"] # Only Vela is always available as default
+ for backend in backends:
+ if backend != "Vela" and not aiet.is_supported(backend):
+ raise ValueError(
+ f"Unsupported backend '{backend}'. "
+ f"Only 'Vela' and {aiet.supported_backends()} are supported."
+ )
+ self.backends = set(backends)
+
+ def estimate(self, model: Union[Path, ModelConfiguration]) -> PerformanceMetrics:
+ """Estimate performance."""
+ model_path = (
+ Path(model.model_path) if isinstance(model, ModelConfiguration) else model
+ )
+
+ tflite_model = get_tflite_model(model_path, self.context)
+
+ memory_usage = None
+ npu_cycles = None
+
+ for backend in self.backends:
+ if backend == "Vela":
+ vela_estimator = VelaPerformanceEstimator(self.context, self.device)
+ memory_usage = vela_estimator.estimate(tflite_model)
+ elif backend in aiet.supported_backends():
+ aiet_estimator = AIETPerformanceEstimator(
+ self.context, self.device, backend
+ )
+ npu_cycles = aiet_estimator.estimate(tflite_model)
+ else:
+ logger.warning(
+ "Backend '%s' is not supported for Ethos-U performance "
+ "estimation.",
+ backend,
+ )
+
+ return PerformanceMetrics(self.device, npu_cycles, memory_usage)
diff --git a/src/mlia/devices/ethosu/reporters.py b/src/mlia/devices/ethosu/reporters.py
new file mode 100644
index 0000000..d28c68f
--- /dev/null
+++ b/src/mlia/devices/ethosu/reporters.py
@@ -0,0 +1,398 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Reports module."""
+from collections import defaultdict
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Tuple
+from typing import Union
+
+from mlia.core.advice_generation import Advice
+from mlia.core.reporting import BytesCell
+from mlia.core.reporting import Cell
+from mlia.core.reporting import ClockCell
+from mlia.core.reporting import Column
+from mlia.core.reporting import CompoundFormatter
+from mlia.core.reporting import CyclesCell
+from mlia.core.reporting import Format
+from mlia.core.reporting import NestedReport
+from mlia.core.reporting import Report
+from mlia.core.reporting import ReportItem
+from mlia.core.reporting import SingleRow
+from mlia.core.reporting import Table
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+from mlia.utils.console import style_improvement
+from mlia.utils.types import is_list_of
+
+
+def report_operators_stat(operators: Operators) -> Report:
+ """Return table representation for the ops stats."""
+ columns = [
+ Column("Number of operators", alias="num_of_operators"),
+ Column("Number of NPU supported operators", "num_of_npu_supported_operators"),
+ Column("Unsupported ops ratio", "npu_unsupported_ratio"),
+ ]
+ rows = [
+ (
+ operators.total_number,
+ operators.npu_supported_number,
+ Cell(
+ operators.npu_unsupported_ratio * 100,
+ fmt=Format(str_fmt="{0:.0f}%".format),
+ ),
+ )
+ ]
+
+ return SingleRow(
+ columns, rows, name="Operators statistics", alias="operators_stats"
+ )
+
+
+def report_operators(ops: List[Operator]) -> Report:
+ """Return table representation for the list of operators."""
+ columns = [
+ Column("#", only_for=["plain_text"]),
+ Column(
+ "Operator name",
+ alias="operator_name",
+ fmt=Format(wrap_width=30),
+ ),
+ Column(
+ "Operator type",
+ alias="operator_type",
+ fmt=Format(wrap_width=25),
+ ),
+ Column(
+ "Placement",
+ alias="placement",
+ fmt=Format(wrap_width=20),
+ ),
+ Column(
+ "Notes",
+ alias="notes",
+ fmt=Format(wrap_width=35),
+ ),
+ ]
+
+ rows = [
+ (
+ i + 1,
+ op.name,
+ op.op_type,
+ Cell(
+ "NPU" if (npu := op.run_on_npu.supported) else "CPU",
+ Format(style=style_improvement(npu)),
+ ),
+ Table(
+ columns=[
+ Column(
+ "Note",
+ alias="note",
+ fmt=Format(wrap_width=35),
+ )
+ ],
+ rows=[
+ (Cell(item, Format(str_fmt=lambda x: f"* {x}")),)
+ for reason in op.run_on_npu.reasons
+ for item in reason
+ if item
+ ],
+ name="Notes",
+ ),
+ )
+ for i, op in enumerate(ops)
+ ]
+
+ return Table(columns, rows, name="Operators", alias="operators")
+
+
+def report_device_details(device: EthosUConfiguration) -> Report:
+ """Return table representation for the device."""
+ compiler_config = device.resolved_compiler_config
+
+ memory_settings = [
+ ReportItem(
+ "Const mem area",
+ "const_mem_area",
+ compiler_config["const_mem_area"],
+ ),
+ ReportItem(
+ "Arena mem area",
+ "arena_mem_area",
+ compiler_config["arena_mem_area"],
+ ),
+ ReportItem(
+ "Cache mem area",
+ "cache_mem_area",
+ compiler_config["cache_mem_area"],
+ ),
+ ReportItem(
+ "Arena cache size",
+ "arena_cache_size",
+ BytesCell(compiler_config["arena_cache_size"]),
+ ),
+ ]
+
+ mem_areas_settings = [
+ ReportItem(
+ f"{mem_area_name}",
+ mem_area_name,
+ None,
+ nested_items=[
+ ReportItem(
+ "Clock scales",
+ "clock_scales",
+ mem_area_settings["clock_scales"],
+ ),
+ ReportItem(
+ "Burst length",
+ "burst_length",
+ BytesCell(mem_area_settings["burst_length"]),
+ ),
+ ReportItem(
+ "Read latency",
+ "read_latency",
+ CyclesCell(mem_area_settings["read_latency"]),
+ ),
+ ReportItem(
+ "Write latency",
+ "write_latency",
+ CyclesCell(mem_area_settings["write_latency"]),
+ ),
+ ],
+ )
+ for mem_area_name, mem_area_settings in compiler_config["memory_area"].items()
+ ]
+
+ system_settings = [
+ ReportItem(
+ "Accelerator clock",
+ "accelerator_clock",
+ ClockCell(compiler_config["core_clock"]),
+ ),
+ ReportItem(
+ "AXI0 port",
+ "axi0_port",
+ compiler_config["axi0_port"],
+ ),
+ ReportItem(
+ "AXI1 port",
+ "axi1_port",
+ compiler_config["axi1_port"],
+ ),
+ ReportItem(
+ "Memory area settings", "memory_area", None, nested_items=mem_areas_settings
+ ),
+ ]
+
+ arch_settings = [
+ ReportItem(
+ "Permanent storage mem area",
+ "permanent_storage_mem_area",
+ compiler_config["permanent_storage_mem_area"],
+ ),
+ ReportItem(
+ "Feature map storage mem area",
+ "feature_map_storage_mem_area",
+ compiler_config["feature_map_storage_mem_area"],
+ ),
+ ReportItem(
+ "Fast storage mem area",
+ "fast_storage_mem_area",
+ compiler_config["fast_storage_mem_area"],
+ ),
+ ]
+
+ return NestedReport(
+ "Device information",
+ "device",
+ [
+ ReportItem("Target", alias="target", value=device.target),
+ ReportItem("MAC", alias="mac", value=device.mac),
+ ReportItem(
+ "Memory mode",
+ alias="memory_mode",
+ value=compiler_config["memory_mode"],
+ nested_items=memory_settings,
+ ),
+ ReportItem(
+ "System config",
+ alias="system_config",
+ value=compiler_config["system_config"],
+ nested_items=system_settings,
+ ),
+ ReportItem(
+ "Architecture settings",
+ "arch_settings",
+ None,
+ nested_items=arch_settings,
+ ),
+ ],
+ )
+
+
+def metrics_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ """Convert perf metrics object into list of records."""
+ perf_metrics = [item.in_kilobytes() for item in perf_metrics]
+
+ def _cycles_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ metric_map = defaultdict(list)
+ for metrics in perf_metrics:
+ if not metrics.npu_cycles:
+ return []
+ metric_map["NPU active cycles"].append(metrics.npu_cycles.npu_active_cycles)
+ metric_map["NPU idle cycles"].append(metrics.npu_cycles.npu_idle_cycles)
+ metric_map["NPU total cycles"].append(metrics.npu_cycles.npu_total_cycles)
+
+ return [
+ (name, *(Cell(value, Format(str_fmt="12,d")) for value in values), "cycles")
+ for name, values in metric_map.items()
+ ]
+
+ def _memory_usage_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ metric_map = defaultdict(list)
+ for metrics in perf_metrics:
+ if not metrics.memory_usage:
+ return []
+ metric_map["SRAM used"].append(metrics.memory_usage.sram_memory_area_size)
+ metric_map["DRAM used"].append(metrics.memory_usage.dram_memory_area_size)
+ metric_map["Unknown memory area used"].append(
+ metrics.memory_usage.unknown_memory_area_size
+ )
+ metric_map["On-chip flash used"].append(
+ metrics.memory_usage.on_chip_flash_memory_area_size
+ )
+ metric_map["Off-chip flash used"].append(
+ metrics.memory_usage.off_chip_flash_memory_area_size
+ )
+
+ return [
+ (name, *(Cell(value, Format(str_fmt="12.2f")) for value in values), "KiB")
+ for name, values in metric_map.items()
+ if all(val > 0 for val in values)
+ ]
+
+ def _data_beats_as_records(perf_metrics: List[PerformanceMetrics]) -> List[Tuple]:
+ metric_map = defaultdict(list)
+ for metrics in perf_metrics:
+ if not metrics.npu_cycles:
+ return []
+ metric_map["NPU AXI0 RD data beat received"].append(
+ metrics.npu_cycles.npu_axi0_rd_data_beat_received
+ )
+ metric_map["NPU AXI0 WR data beat written"].append(
+ metrics.npu_cycles.npu_axi0_wr_data_beat_written
+ )
+ metric_map["NPU AXI1 RD data beat received"].append(
+ metrics.npu_cycles.npu_axi1_rd_data_beat_received
+ )
+
+ return [
+ (name, *(Cell(value, Format(str_fmt="12,d")) for value in values), "beats")
+ for name, values in metric_map.items()
+ ]
+
+ return [
+ metrics
+ for metrics_func in (
+ _memory_usage_as_records,
+ _cycles_as_records,
+ _data_beats_as_records,
+ )
+ for metrics in metrics_func(perf_metrics)
+ ]
+
+
+def report_perf_metrics(
+ perf_metrics: Union[PerformanceMetrics, List[PerformanceMetrics]]
+) -> Report:
+ """Return comparison table for the performance metrics."""
+ if isinstance(perf_metrics, PerformanceMetrics):
+ perf_metrics = [perf_metrics]
+
+ rows = metrics_as_records(perf_metrics)
+
+ if len(perf_metrics) == 2:
+ return Table(
+ columns=[
+ Column("Metric", alias="metric", fmt=Format(wrap_width=30)),
+ Column("Original", alias="original", fmt=Format(wrap_width=15)),
+ Column("Optimized", alias="optimized", fmt=Format(wrap_width=15)),
+ Column("Unit", alias="unit", fmt=Format(wrap_width=15)),
+ Column("Improvement (%)", alias="improvement"),
+ ],
+ rows=[
+ (
+ metric,
+ original_value,
+ optimized_value,
+ unit,
+ Cell(
+ (
+ diff := 100
+ - (optimized_value.value / original_value.value * 100)
+ ),
+ Format(str_fmt="15.2f", style=style_improvement(diff > 0)),
+ )
+ if original_value.value != 0
+ else None,
+ )
+ for metric, original_value, optimized_value, unit in rows
+ ],
+ name="Performance metrics",
+ alias="performance_metrics",
+ notes="IMPORTANT: The performance figures above refer to NPU only",
+ )
+
+ return Table(
+ columns=[
+ Column("Metric", alias="metric", fmt=Format(wrap_width=30)),
+ Column("Value", alias="value", fmt=Format(wrap_width=15)),
+ Column("Unit", alias="unit", fmt=Format(wrap_width=15)),
+ ],
+ rows=rows,
+ name="Performance metrics",
+ alias="performance_metrics",
+ notes="IMPORTANT: The performance figures above refer to NPU only",
+ )
+
+
+def report_advice(advice: List[Advice]) -> Report:
+ """Generate report for the advice."""
+ return Table(
+ columns=[
+ Column("#", only_for=["plain_text"]),
+ Column("Advice", alias="advice_message"),
+ ],
+ rows=[(i + 1, a.messages) for i, a in enumerate(advice)],
+ name="Advice",
+ alias="advice",
+ )
+
+
+def find_appropriate_formatter(data: Any) -> Callable[[Any], Report]:
+ """Find appropriate formatter for the provided data."""
+ if isinstance(data, PerformanceMetrics) or is_list_of(data, PerformanceMetrics, 2):
+ return report_perf_metrics
+
+ if is_list_of(data, Advice):
+ return report_advice
+
+ if is_list_of(data, Operator):
+ return report_operators
+
+ if isinstance(data, Operators):
+ return report_operators_stat
+
+ if isinstance(data, EthosUConfiguration):
+ return report_device_details
+
+ if isinstance(data, (list, tuple)):
+ formatters = [find_appropriate_formatter(item) for item in data]
+ return CompoundFormatter(formatters)
+
+ raise Exception(f"Unable to find appropriate formatter for {data}")
diff --git a/src/mlia/nn/__init__.py b/src/mlia/nn/__init__.py
new file mode 100644
index 0000000..aac2830
--- /dev/null
+++ b/src/mlia/nn/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""NN related module."""
diff --git a/src/mlia/nn/tensorflow/__init__.py b/src/mlia/nn/tensorflow/__init__.py
new file mode 100644
index 0000000..ff061c1
--- /dev/null
+++ b/src/mlia/nn/tensorflow/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""TensorFlow related module."""
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
new file mode 100644
index 0000000..d3235d7
--- /dev/null
+++ b/src/mlia/nn/tensorflow/config.py
@@ -0,0 +1,134 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Model configuration."""
+import logging
+from pathlib import Path
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Union
+
+import tensorflow as tf
+
+from mlia.core.context import Context
+from mlia.nn.tensorflow.utils import convert_tf_to_tflite
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.nn.tensorflow.utils import is_tf_model
+from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+
+logger = logging.getLogger(__name__)
+
+
+class ModelConfiguration:
+ """Base class for model configuration."""
+
+ def __init__(self, model_path: Union[str, Path]) -> None:
+ """Init model configuration instance."""
+ self.model_path = str(model_path)
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ raise NotImplementedError()
+
+ def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ """Convert model to Keras format."""
+ raise NotImplementedError()
+
+
+class KerasModel(ModelConfiguration):
+ """Keras model configuration.
+
+ Supports all models supported by Keras API: saved model, H5, HDF5
+ """
+
+ def get_keras_model(self) -> tf.keras.Model:
+ """Return associated Keras model."""
+ return tf.keras.models.load_model(self.model_path)
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ logger.info("Converting Keras to TFLite ...")
+
+ converted_model = convert_to_tflite(self.get_keras_model(), quantized)
+ logger.info("Done\n")
+
+ save_tflite_model(converted_model, tflite_model_path)
+ logger.debug(
+ "Model %s converted and saved to %s", self.model_path, tflite_model_path
+ )
+
+ return TFLiteModel(tflite_model_path)
+
+ def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ """Convert model to Keras format."""
+ return self
+
+
+class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
+ """TFLite model configuration."""
+
+ def input_details(self) -> List[Dict]:
+ """Get model's input details."""
+ interpreter = tf.lite.Interpreter(model_path=self.model_path)
+ return cast(List[Dict], interpreter.get_input_details())
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ return self
+
+
+class TfModel(ModelConfiguration): # pylint: disable=abstract-method
+ """TensorFlow model configuration.
+
+ Supports models supported by TensorFlow API (not Keras)
+ """
+
+ def convert_to_tflite(
+ self, tflite_model_path: Union[str, Path], quantized: bool = False
+ ) -> "TFLiteModel":
+ """Convert model to TFLite format."""
+ converted_model = convert_tf_to_tflite(self.model_path, quantized)
+ save_tflite_model(converted_model, tflite_model_path)
+
+ return TFLiteModel(tflite_model_path)
+
+
+def get_model(model: Union[Path, str]) -> "ModelConfiguration":
+ """Return the model object."""
+ if is_tflite_model(model):
+ return TFLiteModel(model)
+
+ if is_keras_model(model):
+ return KerasModel(model)
+
+ if is_tf_model(model):
+ return TfModel(model)
+
+ raise Exception(
+ "The input model format is not supported"
+ "(supported formats: TFLite, Keras, TensorFlow saved model)!"
+ )
+
+
+def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel":
+ """Convert input model to TFLite and returns TFLiteModel object."""
+ tflite_model_path = ctx.get_model_path("converted_model.tflite")
+ converted_model = get_model(model)
+
+ return converted_model.convert_to_tflite(tflite_model_path, True)
+
+
+def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel":
+ """Convert input model to Keras and returns KerasModel object."""
+ keras_model_path = ctx.get_model_path("converted_model.h5")
+ converted_model = get_model(model)
+
+ return converted_model.convert_to_keras(keras_model_path)
diff --git a/src/mlia/nn/tensorflow/optimizations/__init__.py b/src/mlia/nn/tensorflow/optimizations/__init__.py
new file mode 100644
index 0000000..201c130
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Optimizations module."""
diff --git a/src/mlia/nn/tensorflow/optimizations/clustering.py b/src/mlia/nn/tensorflow/optimizations/clustering.py
new file mode 100644
index 0000000..16d9e4b
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/clustering.py
@@ -0,0 +1,109 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class Clusterer that clusters unique weights per layer to a specified number.
+
+In order to do this, we need to have a base model and corresponding training data.
+We also have to specify a subset of layers we want to cluster. For more details,
+please refer to the documentation for TensorFlow Model Optimization Toolkit.
+"""
+from dataclasses import dataclass
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+from tensorflow_model_optimization.python.core.clustering.keras.experimental import ( # pylint: disable=no-name-in-module
+ cluster as experimental_cluster,
+)
+
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+
+
+@dataclass
+class ClusteringConfiguration(OptimizerConfiguration):
+ """Clustering configuration."""
+
+ optimization_target: int
+ layers_to_optimize: Optional[List[str]] = None
+
+ def __str__(self) -> str:
+ """Return string representation of the configuration."""
+ return f"clustering: {self.optimization_target}"
+
+
+class Clusterer(Optimizer):
+ """
+ Clusterer class.
+
+ Used to cluster a model to a specified number of unique weights per layer.
+
+ Sample usage:
+ clusterer = Clusterer(
+ base_model,
+ optimizer_configuration)
+
+ clusterer.apply_clustering()
+ clustered_model = clusterer.get_model()
+ """
+
+ def __init__(
+ self, model: tf.keras.Model, optimizer_configuration: ClusteringConfiguration
+ ):
+ """Init Clusterer instance."""
+ self.model = model
+ self.optimizer_configuration = optimizer_configuration
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return str(self.optimizer_configuration)
+
+ def _setup_clustering_params(self) -> Dict[str, Any]:
+ CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
+ return {
+ "number_of_clusters": self.optimizer_configuration.optimization_target,
+ "cluster_centroids_init": CentroidInitialization.LINEAR,
+ "preserve_sparsity": True,
+ }
+
+ def _apply_clustering_to_layer(
+ self, layer: tf.keras.layers.Layer
+ ) -> tf.keras.layers.Layer:
+ layers_to_optimize = self.optimizer_configuration.layers_to_optimize
+ assert layers_to_optimize, "List of the layers to optimize is empty"
+
+ if layer.name not in layers_to_optimize:
+ return layer
+
+ clustering_params = self._setup_clustering_params()
+ return experimental_cluster.cluster_weights(layer, **clustering_params)
+
+ def _init_for_clustering(self) -> None:
+ # Use `tf.keras.models.clone_model` to apply `apply_clustering_to_layer`
+ # to the layers of the model
+ if not self.optimizer_configuration.layers_to_optimize:
+ clustering_params = self._setup_clustering_params()
+ clustered_model = experimental_cluster.cluster_weights(
+ self.model, **clustering_params
+ )
+ else:
+ clustered_model = tf.keras.models.clone_model(
+ self.model, clone_function=self._apply_clustering_to_layer
+ )
+
+ self.model = clustered_model
+
+ def _strip_clustering(self) -> None:
+ self.model = tfmot.clustering.keras.strip_clustering(self.model)
+
+ def apply_optimization(self) -> None:
+ """Apply all steps of clustering at once."""
+ self._init_for_clustering()
+ self._strip_clustering()
+
+ def get_model(self) -> tf.keras.Model:
+ """Get model."""
+ return self.model
diff --git a/src/mlia/nn/tensorflow/optimizations/common.py b/src/mlia/nn/tensorflow/optimizations/common.py
new file mode 100644
index 0000000..1dce0b2
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/common.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common items for the optimizations module."""
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+
+import tensorflow as tf
+
+
+@dataclass
+class OptimizerConfiguration:
+ """Abstract optimizer configuration."""
+
+
+class Optimizer(ABC):
+ """Abstract class for the optimizer."""
+
+ @abstractmethod
+ def get_model(self) -> tf.keras.Model:
+ """Abstract method to return the model instance from the optimizer."""
+
+ @abstractmethod
+ def apply_optimization(self) -> None:
+ """Abstract method to apply optimization to the model."""
+
+ @abstractmethod
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
diff --git a/src/mlia/nn/tensorflow/optimizations/pruning.py b/src/mlia/nn/tensorflow/optimizations/pruning.py
new file mode 100644
index 0000000..f629ba1
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/pruning.py
@@ -0,0 +1,168 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class Pruner to prune a model to a specified sparsity.
+
+In order to do this, we need to have a base model and corresponding training data.
+We also have to specify a subset of layers we want to prune. For more details,
+please refer to the documentation for TensorFlow Model Optimization Toolkit.
+"""
+from dataclasses import dataclass
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import tensorflow as tf
+import tensorflow_model_optimization as tfmot
+from tensorflow_model_optimization.python.core.sparsity.keras import ( # pylint: disable=no-name-in-module
+ pruning_wrapper,
+)
+
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+
+
+@dataclass
+class PruningConfiguration(OptimizerConfiguration):
+ """Pruning configuration."""
+
+ optimization_target: float
+ layers_to_optimize: Optional[List[str]] = None
+ x_train: Optional[np.array] = None
+ y_train: Optional[np.array] = None
+ batch_size: int = 1
+ num_epochs: int = 1
+
+ def __str__(self) -> str:
+ """Return string representation of the configuration."""
+ return f"pruning: {self.optimization_target}"
+
+ def has_training_data(self) -> bool:
+ """Return True if training data provided."""
+ return self.x_train is not None and self.y_train is not None
+
+
+class Pruner(Optimizer):
+ """
+ Pruner class. Used to prune a model to a specified sparsity.
+
+ Sample usage:
+ pruner = Pruner(
+ base_model,
+ optimizer_configuration)
+
+ pruner.apply_pruning()
+ pruned_model = pruner.get_model()
+ """
+
+ def __init__(
+ self, model: tf.keras.Model, optimizer_configuration: PruningConfiguration
+ ):
+ """Init Pruner instance."""
+ self.model = model
+ self.optimizer_configuration = optimizer_configuration
+
+ if not optimizer_configuration.has_training_data():
+ mock_x_train, mock_y_train = self._mock_train_data()
+
+ self.optimizer_configuration.x_train = mock_x_train
+ self.optimizer_configuration.y_train = mock_y_train
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return str(self.optimizer_configuration)
+
+ def _mock_train_data(self) -> Tuple[np.array, np.array]:
+ # get rid of the batch_size dimension in input and output shape
+ input_shape = tuple(x for x in self.model.input_shape if x is not None)
+ output_shape = tuple(x for x in self.model.output_shape if x is not None)
+
+ return (
+ np.random.rand(*input_shape),
+ np.random.randint(0, output_shape[-1], (output_shape[:-1])),
+ )
+
+ def _setup_pruning_params(self) -> dict:
+ return {
+ "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay(
+ initial_sparsity=0,
+ final_sparsity=self.optimizer_configuration.optimization_target,
+ begin_step=0,
+ end_step=self.optimizer_configuration.num_epochs,
+ frequency=1,
+ ),
+ }
+
+ def _apply_pruning_to_layer(
+ self, layer: tf.keras.layers.Layer
+ ) -> tf.keras.layers.Layer:
+ layers_to_optimize = self.optimizer_configuration.layers_to_optimize
+ assert layers_to_optimize, "List of the layers to optimize is empty"
+
+ if layer.name not in layers_to_optimize:
+ return layer
+
+ pruning_params = self._setup_pruning_params()
+ return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
+
+ def _init_for_pruning(self) -> None:
+ # Use `tf.keras.models.clone_model` to apply `apply_pruning_to_layer`
+ # to the layers of the model
+ if not self.optimizer_configuration.layers_to_optimize:
+ pruning_params = self._setup_pruning_params()
+ prunable_model = tfmot.sparsity.keras.prune_low_magnitude(
+ self.model, **pruning_params
+ )
+ else:
+ prunable_model = tf.keras.models.clone_model(
+ self.model, clone_function=self._apply_pruning_to_layer
+ )
+
+ self.model = prunable_model
+
+ def _train_pruning(self) -> None:
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
+ self.model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
+
+ # Model callbacks
+ callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
+
+ # Fitting data
+ self.model.fit(
+ self.optimizer_configuration.x_train,
+ self.optimizer_configuration.y_train,
+ batch_size=self.optimizer_configuration.batch_size,
+ epochs=self.optimizer_configuration.num_epochs,
+ callbacks=callbacks,
+ verbose=0,
+ )
+
+ def _assert_sparsity_reached(self) -> None:
+ for layer in self.model.layers:
+ if not isinstance(layer, pruning_wrapper.PruneLowMagnitude):
+ continue
+
+ for weight in layer.layer.get_prunable_weights():
+ nonzero_weights = np.count_nonzero(tf.keras.backend.get_value(weight))
+ all_weights = tf.keras.backend.get_value(weight).size
+
+ np.testing.assert_approx_equal(
+ self.optimizer_configuration.optimization_target,
+ 1 - nonzero_weights / all_weights,
+ significant=2,
+ )
+
+ def _strip_pruning(self) -> None:
+ self.model = tfmot.sparsity.keras.strip_pruning(self.model)
+
+ def apply_optimization(self) -> None:
+ """Apply all steps of pruning sequentially."""
+ self._init_for_pruning()
+ self._train_pruning()
+ self._assert_sparsity_reached()
+ self._strip_pruning()
+
+ def get_model(self) -> tf.keras.Model:
+ """Get model."""
+ return self.model
diff --git a/src/mlia/nn/tensorflow/optimizations/select.py b/src/mlia/nn/tensorflow/optimizations/select.py
new file mode 100644
index 0000000..1b0c755
--- /dev/null
+++ b/src/mlia/nn/tensorflow/optimizations/select.py
@@ -0,0 +1,179 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for optimization selection."""
+import math
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import tensorflow as tf
+
+from mlia.core.errors import ConfigurationError
+from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.optimizations.clustering import Clusterer
+from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
+from mlia.nn.tensorflow.optimizations.common import Optimizer
+from mlia.nn.tensorflow.optimizations.common import OptimizerConfiguration
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.utils.types import is_list_of
+
+
+class OptimizationSettings(NamedTuple):
+ """Optimization settings."""
+
+ optimization_type: str
+ optimization_target: Union[int, float]
+ layers_to_optimize: Optional[List[str]]
+
+ @staticmethod
+ def create_from(
+ optimizer_params: List[Tuple[str, float]],
+ layers_to_optimize: Optional[List[str]] = None,
+ ) -> List["OptimizationSettings"]:
+ """Create optimization settings from the provided parameters."""
+ return [
+ OptimizationSettings(
+ optimization_type=opt_type,
+ optimization_target=opt_target,
+ layers_to_optimize=layers_to_optimize,
+ )
+ for opt_type, opt_target in optimizer_params
+ ]
+
+ def __str__(self) -> str:
+ """Return string representation."""
+ return f"{self.optimization_type}: {self.optimization_target}"
+
+ def next_target(self) -> "OptimizationSettings":
+ """Return next optimization target."""
+ if self.optimization_type == "pruning":
+ next_target = round(min(self.optimization_target + 0.1, 0.9), 2)
+ return OptimizationSettings(
+ self.optimization_type, next_target, self.layers_to_optimize
+ )
+
+ if self.optimization_type == "clustering":
+ # return next lowest power of two for clustering
+ next_target = math.log(self.optimization_target, 2)
+ if next_target.is_integer():
+ next_target -= 1
+
+ next_target = max(int(2 ** int(next_target)), 4)
+ return OptimizationSettings(
+ self.optimization_type, next_target, self.layers_to_optimize
+ )
+
+ raise Exception(f"Unknown optimization type {self.optimization_type}")
+
+
+class MultiStageOptimizer(Optimizer):
+ """Optimizer with multiply stages."""
+
+ def __init__(
+ self,
+ model: tf.keras.Model,
+ optimizations: List[OptimizerConfiguration],
+ ) -> None:
+ """Init MultiStageOptimizer instance."""
+ self.model = model
+ self.optimizations = optimizations
+
+ def optimization_config(self) -> str:
+ """Return string representation of the optimization config."""
+ return " - ".join(str(opt) for opt in self.optimizations)
+
+ def get_model(self) -> tf.keras.Model:
+ """Return optimized model."""
+ return self.model
+
+ def apply_optimization(self) -> None:
+ """Apply optimization to the model."""
+ for config in self.optimizations:
+ optimizer = get_optimizer(self.model, config)
+ optimizer.apply_optimization()
+ self.model = optimizer.get_model()
+
+
+def get_optimizer(
+ model: Union[tf.keras.Model, KerasModel],
+ config: Union[
+ OptimizerConfiguration, OptimizationSettings, List[OptimizationSettings]
+ ],
+) -> Optimizer:
+ """Get optimizer for provided configuration."""
+ if isinstance(model, KerasModel):
+ model = model.get_keras_model()
+
+ if isinstance(config, PruningConfiguration):
+ return Pruner(model, config)
+
+ if isinstance(config, ClusteringConfiguration):
+ return Clusterer(model, config)
+
+ if isinstance(config, OptimizationSettings) or is_list_of(
+ config, OptimizationSettings
+ ):
+ return _get_optimizer(model, config) # type: ignore
+
+ raise ConfigurationError(f"Unknown optimization configuration {config}")
+
+
+def _get_optimizer(
+ model: tf.keras.Model,
+ optimization_settings: Union[OptimizationSettings, List[OptimizationSettings]],
+) -> Optimizer:
+ if isinstance(optimization_settings, OptimizationSettings):
+ optimization_settings = [optimization_settings]
+
+ optimizer_configs = []
+ for opt_type, opt_target, layers_to_optimize in optimization_settings:
+ _check_optimizer_params(opt_type, opt_target)
+
+ opt_config = _get_optimizer_configuration(
+ opt_type, opt_target, layers_to_optimize
+ )
+ optimizer_configs.append(opt_config)
+
+ if len(optimizer_configs) == 1:
+ return get_optimizer(model, optimizer_configs[0])
+
+ return MultiStageOptimizer(model, optimizer_configs)
+
+
+def _get_optimizer_configuration(
+ optimization_type: str,
+ optimization_target: Union[int, float],
+ layers_to_optimize: Optional[List[str]] = None,
+) -> OptimizerConfiguration:
+ """Get optimizer configuration for provided parameters."""
+ _check_optimizer_params(optimization_type, optimization_target)
+
+ opt_type = optimization_type.lower()
+ if opt_type == "pruning":
+ return PruningConfiguration(optimization_target, layers_to_optimize)
+
+ if opt_type == "clustering":
+ # make sure an integer is given as clustering target
+ if optimization_target == int(optimization_target):
+ return ClusteringConfiguration(int(optimization_target), layers_to_optimize)
+
+ raise ConfigurationError(
+ "Optimization target should be a positive integer. "
+ f"Optimization target provided: {optimization_target}"
+ )
+
+ raise ConfigurationError(f"Unsupported optimization type: {optimization_type}")
+
+
+def _check_optimizer_params(
+ optimization_type: str, optimization_target: Union[int, float]
+) -> None:
+ """Check optimizer params."""
+ if not optimization_target:
+ raise ConfigurationError("Optimization target is not provided")
+
+ if not optimization_type:
+ raise ConfigurationError("Optimization type is not provided")
diff --git a/src/mlia/nn/tensorflow/tflite_metrics.py b/src/mlia/nn/tensorflow/tflite_metrics.py
new file mode 100644
index 0000000..b29fab3
--- /dev/null
+++ b/src/mlia/nn/tensorflow/tflite_metrics.py
@@ -0,0 +1,296 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""
+Contains class TFLiteMetrics to calculate metrics from a TFLite file.
+
+These metrics include:
+* Sparsity (per layer and overall)
+* Unique weights (clusters) (per layer)
+* gzip compression ratio
+"""
+import os
+from enum import Enum
+from pprint import pprint
+from typing import Any
+from typing import List
+from typing import Optional
+
+import numpy as np
+import tensorflow as tf
+
+DEFAULT_IGNORE_LIST = [
+ "relu",
+ "pooling",
+ "reshape",
+ "identity",
+ "input",
+ "add",
+ "flatten",
+ "StatefulPartitionedCall",
+ "bias",
+]
+
+
+def calculate_num_unique_weights(weights: np.array) -> int:
+ """Calculate the number of unique weights in the given weights."""
+ num_unique_weights = len(np.unique(weights))
+ return num_unique_weights
+
+
+def calculate_num_unique_weights_per_axis(weights: np.array, axis: int) -> List[int]:
+ """Calculate unique weights per quantization axis."""
+ # Make quantized dimension the first dimension
+ weights_trans = np.swapaxes(weights, 0, axis)
+ num_uniques_weights = [
+ calculate_num_unique_weights(weights_trans[i])
+ for i in range(weights_trans.shape[0])
+ ]
+ assert num_uniques_weights
+ return num_uniques_weights
+
+
+class SparsityAccumulator:
+ """Helper class to accumulate sparsity over several layers."""
+
+ def __init__(self) -> None:
+ """Create an empty accumulator."""
+ self.total_non_zero_weights: int = 0
+ self.total_weights: int = 0
+
+ def __call__(self, weights: np.array) -> None:
+ """Update the accumulator with the given weights."""
+ non_zero_weights = np.count_nonzero(weights)
+ self.total_non_zero_weights += non_zero_weights
+ self.total_weights += weights.size
+
+ def sparsity(self) -> float:
+ """Calculate the sparsity for all added weights."""
+ return 1.0 - self.total_non_zero_weights / float(self.total_weights)
+
+
+def calculate_sparsity(
+ weights: np.array, accumulator: Optional[SparsityAccumulator] = None
+) -> float:
+ """
+ Calculate the sparsity for the given weights.
+
+ If the accumulator is passed, it is updated as well.
+ """
+ non_zero_weights = np.count_nonzero(weights)
+ sparsity = 1.0 - float(non_zero_weights) / float(weights.size)
+ if accumulator is not None:
+ accumulator(weights)
+ return sparsity
+
+
+class ReportClusterMode(Enum):
+ """Specifies the way cluster values are aggregated and reported."""
+
+ NUM_CLUSTERS_HISTOGRAM = (
+ "A histogram of the number of clusters per axis. "
+ "I.e. the number of clusters is the index of the list (the bin) and "
+ "the value is the number of axes that have this number of clusters. "
+ "The first bin is 1."
+ )
+ NUM_CLUSTERS_PER_AXIS = "Number of clusters (unique weights) per axis."
+ NUM_CLUSTERS_MIN_MAX = "Min/max number of clusters over all axes."
+
+
+class TFLiteMetrics:
+ """Helper class to calculate metrics from a TFLite file.
+
+ Metrics include:
+ * sparsity (per-layer and overall)
+ * number of unique weights (clusters) per layer
+ * File compression via gzip
+ """
+
+ def __init__(
+ self, tflite_file: str, ignore_list: Optional[List[str]] = None
+ ) -> None:
+ """Load the TFLite file and filter layers."""
+ self.tflite_file = tflite_file
+ if ignore_list is None:
+ ignore_list = DEFAULT_IGNORE_LIST
+ self.ignore_list = [ignore.casefold() for ignore in ignore_list]
+ # Initialize the TFLite interpreter with the model file
+ self.interpreter = tf.lite.Interpreter(model_path=tflite_file)
+ self.interpreter.allocate_tensors()
+ self.details: dict = {}
+
+ def ignore(details: dict) -> bool:
+ name = details["name"].casefold()
+ if not name:
+ return True
+ for to_ignore in self.ignore_list:
+ if to_ignore in name:
+ return True
+ return False
+
+ self.filtered_details = {
+ details["name"]: details
+ for details in self.interpreter.get_tensor_details()
+ if not ignore(details)
+ }
+
+ def get_tensor(self, details: dict) -> Any:
+ """Return the weights/tensor specified in the given details map."""
+ return self.interpreter.tensor(details["index"])()
+
+ def sparsity_per_layer(self) -> dict:
+ """Return a dict of layer name and sparsity value."""
+ sparsity = {
+ name: calculate_sparsity(self.get_tensor(details))
+ for name, details in self.filtered_details.items()
+ }
+ return sparsity
+
+ def sparsity_overall(self) -> float:
+ """Return an instance of SparsityAccumulator for the filtered layers."""
+ acc = SparsityAccumulator()
+ for details in self.filtered_details.values():
+ acc(self.get_tensor(details))
+ return acc.sparsity()
+
+ def calc_num_clusters_per_axis(self, details: dict) -> List[int]:
+ """Calculate number of clusters per axis."""
+ quant_params = details["quantization_parameters"]
+ per_axis = len(quant_params["zero_points"]) > 1
+ if per_axis:
+ # Calculate unique weights along quantization axis
+ axis = quant_params["quantized_dimension"]
+ return calculate_num_unique_weights_per_axis(self.get_tensor(details), axis)
+
+ # Calculate unique weights over all axes/dimensions
+ return [calculate_num_unique_weights(self.get_tensor(details))]
+
+ def num_unique_weights(self, mode: ReportClusterMode) -> dict:
+ """Return a dict of layer name and number of unique weights."""
+ aggregation_func = None
+ if mode == ReportClusterMode.NUM_CLUSTERS_PER_AXIS:
+ aggregation_func = self.calc_num_clusters_per_axis
+ elif mode == ReportClusterMode.NUM_CLUSTERS_MIN_MAX:
+
+ def cluster_min_max(details: dict) -> List[int]:
+ num_clusters = self.calc_num_clusters_per_axis(details)
+ return [min(num_clusters), max(num_clusters)]
+
+ aggregation_func = cluster_min_max
+ elif mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM:
+
+ def cluster_hist(details: dict) -> List[int]:
+ num_clusters = self.calc_num_clusters_per_axis(details)
+ max_num = max(num_clusters)
+ hist = [0] * (max_num)
+ for num in num_clusters:
+ idx = num - 1
+ hist[idx] += 1
+ return hist
+
+ aggregation_func = cluster_hist
+ else:
+ raise NotImplementedError(
+ "ReportClusterMode '{}' not implemented.".format(mode)
+ )
+ uniques = {
+ name: aggregation_func(details)
+ for name, details in self.filtered_details.items()
+ }
+ return uniques
+
+ @staticmethod
+ def _prettify_name(name: str) -> str:
+ if name.startswith("model"):
+ return name.split("/", 1)[1]
+ return name
+
+ def summary(
+ self,
+ report_sparsity: bool,
+ report_cluster_mode: ReportClusterMode = None,
+ max_num_clusters: int = 32,
+ verbose: bool = False,
+ ) -> None:
+ """Print a summary of all the model information."""
+ print("Model file: {}".format(self.tflite_file))
+ print("#" * 80)
+ print(" " * 28 + "### TFLITE SUMMARY ###")
+ print("File: {}".format(os.path.abspath(self.tflite_file)))
+ print("Input(s):")
+ self._print_in_outs(self.interpreter.get_input_details(), verbose)
+ print("Output(s):")
+ self._print_in_outs(self.interpreter.get_output_details(), verbose)
+ print()
+ header = ["Layer", "Index", "Type", "Num weights"]
+ if report_sparsity:
+ header.append("Sparsity")
+ rows = []
+ sparsity_accumulator = SparsityAccumulator()
+ for details in self.filtered_details.values():
+ name = details["name"]
+ weights = self.get_tensor(details)
+ row = [
+ self._prettify_name(name),
+ details["index"],
+ weights.dtype,
+ weights.size,
+ ]
+ if report_sparsity:
+ sparsity = calculate_sparsity(weights, sparsity_accumulator)
+ row.append("{:.2f}".format(sparsity))
+ rows.append(row)
+ if verbose:
+ # Print cluster centroids
+ print("{} cluster centroids:".format(name))
+ pprint(np.unique(weights))
+ # Add summary/overall values
+ empty_row = ["" for _ in range(len(header))]
+ summary_row = empty_row
+ summary_row[header.index("Layer")] = "=> OVERALL"
+ summary_row[header.index("Num weights")] = str(
+ sparsity_accumulator.total_weights
+ )
+ if report_sparsity:
+ summary_row[header.index("Sparsity")] = "{:.2f}".format(
+ sparsity_accumulator.sparsity()
+ )
+ rows.append(summary_row)
+ # Report detailed cluster info
+ if report_cluster_mode is not None:
+ print()
+ self._print_cluster_details(report_cluster_mode, max_num_clusters)
+ print("#" * 80)
+
+ def _print_cluster_details(
+ self, report_cluster_mode: ReportClusterMode, max_num_clusters: int
+ ) -> None:
+ print("{}:\n{}".format(report_cluster_mode.name, report_cluster_mode.value))
+ num_clusters = self.num_unique_weights(report_cluster_mode)
+ if (
+ report_cluster_mode == ReportClusterMode.NUM_CLUSTERS_HISTOGRAM
+ and max_num_clusters > 0
+ ):
+ # Only show cluster histogram if there are not more than
+ # max_num_clusters. This is a workaround for not showing a huge
+ # histogram for unclustered layers.
+ for name, value in num_clusters.items():
+ if len(value) > max_num_clusters:
+ num_clusters[name] = "More than {} unique values.".format(
+ max_num_clusters
+ )
+ for name, nums in num_clusters.items():
+ print("- {}: {}".format(self._prettify_name(name), nums))
+
+ @staticmethod
+ def _print_in_outs(ios: List[dict], verbose: bool = False) -> None:
+ for item in ios:
+ if verbose:
+ pprint(item)
+ else:
+ print(
+ "- {} ({}): {}".format(
+ item["name"],
+ np.dtype(item["dtype"]).name,
+ item["shape"],
+ )
+ )
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
new file mode 100644
index 0000000..4abf6cd
--- /dev/null
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -0,0 +1,149 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+"""Collection of useful functions for optimizations."""
+import logging
+from pathlib import Path
+from typing import Callable
+from typing import Iterable
+from typing import Union
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.lite.python.interpreter import Interpreter
+
+from mlia.utils.logging import redirect_output
+
+
+def representative_dataset(model: tf.keras.Model) -> Callable:
+ """Sample dataset used for quantization."""
+ input_shape = model.input_shape
+
+ def dataset() -> Iterable:
+ for _ in range(100):
+ if input_shape[0] != 1:
+ raise Exception("Only the input batch_size=1 is supported!")
+ data = np.random.rand(*input_shape)
+ yield [data.astype(np.float32)]
+
+ return dataset
+
+
+def get_tf_tensor_shape(model: str) -> list:
+ """Get input shape for the TensorFlow tensor model."""
+ # Loading the model
+ loaded = tf.saved_model.load(model)
+ # The model signature must have 'serving_default' as a key
+ if "serving_default" not in loaded.signatures.keys():
+ raise Exception(
+ "Unsupported TensorFlow model signature, must have 'serving_default'"
+ )
+ # Get the signature inputs
+ inputs_tensor_info = loaded.signatures["serving_default"].inputs
+ dims = []
+ # Build a list of all inputs shape sizes
+ for input_key in inputs_tensor_info:
+ if input_key.get_shape():
+ dims.extend(list(input_key.get_shape()))
+ return dims
+
+
+def representative_tf_dataset(model: str) -> Callable:
+ """Sample dataset used for quantization."""
+ if not (input_shape := get_tf_tensor_shape(model)):
+ raise Exception("Unable to get input shape")
+
+ def dataset() -> Iterable:
+ for _ in range(100):
+ data = np.random.rand(*input_shape)
+ yield [data.astype(np.float32)]
+
+ return dataset
+
+
+def convert_to_tflite(model: tf.keras.Model, quantized: bool = False) -> Interpreter:
+ """Convert Keras model to TFLite."""
+ if not isinstance(model, tf.keras.Model):
+ raise Exception("Invalid model type")
+
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset(model)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ with redirect_output(logging.getLogger("tensorflow")):
+ tflite_model = converter.convert()
+
+ return tflite_model
+
+
+def convert_tf_to_tflite(model: str, quantized: bool = False) -> Interpreter:
+ """Convert TensorFlow model to TFLite."""
+ if not isinstance(model, str):
+ raise Exception("Invalid model type")
+
+ converter = tf.lite.TFLiteConverter.from_saved_model(model)
+
+ if quantized:
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_tf_dataset(model)
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+
+ with redirect_output(logging.getLogger("tensorflow")):
+ tflite_model = converter.convert()
+
+ return tflite_model
+
+
+def save_keras_model(model: tf.keras.Model, save_path: Union[str, Path]) -> None:
+ """Save Keras model at provided path."""
+ # Checkpoint: saving the optimizer is necessary.
+ model.save(save_path, include_optimizer=True)
+
+
+def save_tflite_model(
+ model: tf.lite.TFLiteConverter, save_path: Union[str, Path]
+) -> None:
+ """Save TFLite model at provided path."""
+ with open(save_path, "wb") as file:
+ file.write(model)
+
+
+def is_tflite_model(model: Union[Path, str]) -> bool:
+ """Check if model type is supported by TFLite API.
+
+ TFLite model is indicated by the model file extension .tflite
+ """
+ model_path = Path(model)
+ return model_path.suffix == ".tflite"
+
+
+def is_keras_model(model: Union[Path, str]) -> bool:
+ """Check if model type is supported by Keras API.
+
+ Keras model is indicated by:
+ 1. if it's a directory (meaning saved model),
+ it should contain keras_metadata.pb file
+ 2. or if the model file extension is .h5/.hdf5
+ """
+ model_path = Path(model)
+
+ if model_path.is_dir():
+ return (model_path / "keras_metadata.pb").exists()
+ return model_path.suffix in (".h5", ".hdf5")
+
+
+def is_tf_model(model: Union[Path, str]) -> bool:
+ """Check if model type is supported by TensorFlow API.
+
+ TensorFlow model is indicated if its directory (meaning saved model)
+ doesn't contain keras_metadata.pb file
+ """
+ model_path = Path(model)
+ return model_path.is_dir() and not is_keras_model(model)
diff --git a/src/mlia/resources/aiet/applications/APPLICATIONS.txt b/src/mlia/resources/aiet/applications/APPLICATIONS.txt
new file mode 100644
index 0000000..09127f8
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/APPLICATIONS.txt
@@ -0,0 +1,6 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+This directory contains the Generic Inference Runner application packages for AIET
+
+Each package should contain its own aiet-config.json file
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json
new file mode 100644
index 0000000..757ccd1
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json
@@ -0,0 +1,18 @@
+[
+ {
+ "name": "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.",
+ "supported_systems": [
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U55"
+ },
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U65"
+ }
+ ],
+ "lock": true,
+ "variables": {
+ "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf"
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf
new file mode 100644
index 0000000..4c50e1f
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf
Binary files differ
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license
new file mode 100644
index 0000000..8896f92
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license
@@ -0,0 +1,31 @@
+SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano
+SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation
+SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc
+SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman
+SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc.
+SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about)
+SPDX-FileCopyrightText: Copyright 2011, John Resig
+SPDX-FileCopyrightText: Copyright 2016-2021 NXP
+SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org
+SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved.
+SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch.
+SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems
+SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure
+SPDX-FileCopyrightText: Copyright 2015 gRPC authors.
+SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch
+SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors
+SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors
+SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc.
+SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello.
+SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar
+SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved.
+SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors
+SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved.
+
+SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json
new file mode 100644
index 0000000..cb7e113
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json
@@ -0,0 +1,15 @@
+[
+ {
+ "name": "Generic Inference Runner: Ethos-U55 SRAM",
+ "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.",
+ "supported_systems": [
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U55"
+ }
+ ],
+ "lock": true,
+ "variables": {
+ "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf"
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf
new file mode 100644
index 0000000..850e2eb
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf
Binary files differ
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license
new file mode 100644
index 0000000..8896f92
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license
@@ -0,0 +1,31 @@
+SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano
+SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation
+SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc
+SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman
+SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc.
+SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about)
+SPDX-FileCopyrightText: Copyright 2011, John Resig
+SPDX-FileCopyrightText: Copyright 2016-2021 NXP
+SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org
+SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved.
+SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch.
+SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems
+SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure
+SPDX-FileCopyrightText: Copyright 2015 gRPC authors.
+SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch
+SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors
+SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors
+SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc.
+SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello.
+SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar
+SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved.
+SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors
+SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved.
+
+SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json
new file mode 100644
index 0000000..d524f64
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json
@@ -0,0 +1,15 @@
+[
+ {
+ "name": "Generic Inference Runner: Ethos-U65 Dedicated SRAM",
+ "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.",
+ "supported_systems": [
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U65"
+ }
+ ],
+ "lock": true,
+ "variables": {
+ "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf"
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf
new file mode 100644
index 0000000..f881bb8
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf
Binary files differ
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license
new file mode 100644
index 0000000..8896f92
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA/ethos-u-inference_runner.axf.license
@@ -0,0 +1,31 @@
+SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano
+SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation
+SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc
+SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman
+SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc.
+SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about)
+SPDX-FileCopyrightText: Copyright 2011, John Resig
+SPDX-FileCopyrightText: Copyright 2016-2021 NXP
+SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org
+SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved.
+SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch.
+SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems
+SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure
+SPDX-FileCopyrightText: Copyright 2015 gRPC authors.
+SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch
+SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors
+SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors
+SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc.
+SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello.
+SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar
+SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved.
+SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors
+SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved.
+
+SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json
new file mode 100644
index 0000000..2cbab70
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json
@@ -0,0 +1,15 @@
+[
+ {
+ "name": "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.",
+ "supported_systems": [
+ {
+ "name": "Corstone-310: Cortex-M85+Ethos-U55"
+ }
+ ],
+ "lock": true,
+ "variables": {
+ "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf"
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf
new file mode 100644
index 0000000..846ee33
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf
Binary files differ
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license
new file mode 100644
index 0000000..8896f92
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA/ethos-u-inference_runner.axf.license
@@ -0,0 +1,31 @@
+SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano
+SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation
+SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc
+SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman
+SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc.
+SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about)
+SPDX-FileCopyrightText: Copyright 2011, John Resig
+SPDX-FileCopyrightText: Copyright 2016-2021 NXP
+SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org
+SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved.
+SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch.
+SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems
+SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure
+SPDX-FileCopyrightText: Copyright 2015 gRPC authors.
+SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch
+SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors
+SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors
+SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc.
+SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello.
+SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar
+SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved.
+SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors
+SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved.
+
+SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json
new file mode 100644
index 0000000..01bec74
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json
@@ -0,0 +1,15 @@
+[
+ {
+ "name": "Generic Inference Runner: Ethos-U55 SRAM",
+ "description": "This application allows running inferences using custom NN TFLite models on Ethos-U. No data pre-/post-processing is executed.",
+ "supported_systems": [
+ {
+ "name": "Corstone-310: Cortex-M85+Ethos-U55"
+ }
+ ],
+ "lock": true,
+ "variables": {
+ "eval_app": "{software.config_dir}/ethos-u-inference_runner.axf"
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf
new file mode 100644
index 0000000..e3eab97
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf
Binary files differ
diff --git a/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license
new file mode 100644
index 0000000..8896f92
--- /dev/null
+++ b/src/mlia/resources/aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA/ethos-u-inference_runner.axf.license
@@ -0,0 +1,31 @@
+SPDX-FileCopyrightText: Copyright 2009-2022, Arm Limited and/or its affiliates.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2008 Junio C Hamano
+SPDX-FileCopyrightText: Copyright 2011, The Dojo Foundation
+SPDX-FileCopyrightText: Copyright (c) 1999-2009 KEIL, 2009-2016 ARM Germany GmbH. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2006, 2007 CodeSourcery Inc
+SPDX-FileCopyrightText: Copyright (c) 2010 "Cowboy" Ben Alman
+SPDX-FileCopyrightText: Copyright 1997-2016 Freescale Semiconductor, Inc.
+SPDX-FileCopyrightText: Copyright 2011, AUTHORS.txt (http://jqueryui.com/about)
+SPDX-FileCopyrightText: Copyright 2011, John Resig
+SPDX-FileCopyrightText: Copyright 2016-2021 NXP
+SPDX-FileCopyrightText: Copyright (c) 2012 mbed.org
+SPDX-FileCopyrightText: Copyright (c) 2012-2017 Keil Software. All rights reserved.
+SPDX-FileCopyrightText: Copyright (C) 2009 by Dimitri van Heesch.
+SPDX-FileCopyrightText: Copyright (c) 2017-2021 IAR Systems
+SPDX-FileCopyrightText: Copyright (C) 2003-2008 Greg Valure
+SPDX-FileCopyrightText: Copyright 2015 gRPC authors.
+SPDX-FileCopyrightText: Copyright 2018 Dan Field. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2014 Stefan.Eilemann@epfl.ch
+SPDX-FileCopyrightText: Copyright (c) 2016, the Dart project authors
+SPDX-FileCopyrightText: Copyright 2015 The Chromium Authors
+SPDX-FileCopyrightText: Copyright (C) 2019 Free Software Foundation, Inc.
+SPDX-FileCopyrightText: Copyright (c) 2021, Vasyl Gello.
+SPDX-FileCopyrightText: Copyright 2020 Jan Tojnar
+SPDX-FileCopyrightText: Copyright 2017-2022 The TensorFlow Authors. All Rights Reserved.
+SPDX-FileCopyrightText: Copyright 2014-2022 Google Inc. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2015-2018 The Gemmlowp Authors. All rights reserved.
+SPDX-FileCopyrightText: Copyright (c) 2003-2019, Mark Borgerding. All rights reserved.
+SPDX-FileCopyrightText: Copyright 2019-2022 The Pigweed Authors
+SPDX-FileCopyrightText: Copyright 2019-2021 Google LLC. All Rights Reserved.
+
+SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause and CC-PDDC
diff --git a/src/mlia/resources/aiet/systems/SYSTEMS.txt b/src/mlia/resources/aiet/systems/SYSTEMS.txt
new file mode 100644
index 0000000..bc27e73
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/SYSTEMS.txt
@@ -0,0 +1,10 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+This directory contains the configuration files of the systems for the AIET
+middleware.
+
+Supported systems:
+
+ * FVP Corstone-300 Ecosystem
+ * FVP Corstone-310 Ecosystem
diff --git a/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json
new file mode 100644
index 0000000..3ffa548
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json
@@ -0,0 +1,80 @@
+[
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U55",
+ "description": "Cortex-M55 and Ethos-U55 functional model implementations based on Corstone-300 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U55",
+ "sim_type": "FM",
+ "variant": "Cortex-M55+Ethos-U55"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "/opt/VHT/VHT_Corstone_SSE-300_Ethos-U55 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.",
+ "values": [
+ "32",
+ "64",
+ "128",
+ "256"
+ ],
+ "default_value": "256",
+ "alias": "mac"
+ }
+ ]
+ }
+ },
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U65",
+ "description": "Cortex-M55 and Ethos-U65 functional model implementations based on Corstone-300 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U65",
+ "sim_type": "FM",
+ "variant": "Cortex-M55+Ethos-U65"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "/opt/VHT/VHT_Corstone_SSE-300_Ethos-U65 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U65 configuration - the number represents MACs per cycle.",
+ "values": [
+ "256",
+ "512"
+ ],
+ "default_value": "512",
+ "alias": "mac"
+ }
+ ]
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-300-vht/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json
new file mode 100644
index 0000000..6d6785d
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json
@@ -0,0 +1,80 @@
+[
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U55",
+ "description": "Cortex-M55 and Ethos-U55 functional model implementations based on Corstone-300 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U55",
+ "sim_type": "FM",
+ "variant": "Cortex-M55+Ethos-U55"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "FVP_Corstone_SSE-300_Ethos-U55 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.",
+ "values": [
+ "32",
+ "64",
+ "128",
+ "256"
+ ],
+ "default_value": "256",
+ "alias": "mac"
+ }
+ ]
+ }
+ },
+ {
+ "name": "Corstone-300: Cortex-M55+Ethos-U65",
+ "description": "Cortex-M55 and Ethos-U65 functional model implementations based on Corstone-300 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U65",
+ "sim_type": "FM",
+ "variant": "Cortex-M55+Ethos-U65"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "FVP_Corstone_SSE-300_Ethos-U65 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U65 configuration - the number represents MACs per cycle.",
+ "values": [
+ "256",
+ "512"
+ ],
+ "default_value": "512",
+ "alias": "mac"
+ }
+ ]
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-300/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json
new file mode 100644
index 0000000..dbc2622
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json
@@ -0,0 +1,42 @@
+[
+ {
+ "name": "Corstone-310: Cortex-M85+Ethos-U55",
+ "description": "Cortex-M85 and Ethos-U55 functional model implementations based on Corstone-310 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U55",
+ "sim_type": "FM",
+ "variant": "Cortex-M85+Ethos-U55"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "/opt/VHT/VHT_Corstone_SSE-310 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.",
+ "values": [
+ "32",
+ "64",
+ "128",
+ "256"
+ ],
+ "default_value": "256",
+ "alias": "mac"
+ }
+ ]
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-310-vht/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json
new file mode 100644
index 0000000..7aa3b0a
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json
@@ -0,0 +1,42 @@
+[
+ {
+ "name": "Corstone-310: Cortex-M85+Ethos-U55",
+ "description": "Cortex-M85 and Ethos-U55 functional model implementations based on Corstone-310 design for MPS3 board.",
+ "annotations": {
+ "ip_class": "Ethos-U55",
+ "sim_type": "FM",
+ "variant": "Cortex-M85+Ethos-U55"
+ },
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "FVP_Corstone_SSE-310 -a {software.variables:eval_app} {user_params:input_file}@0x90000000 -C {user_params:mac} -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='-' -C mps3_board.uart0.shutdown_on_eot=1 -C mps3_board.visualisation.disable-visualisation=1 --stat"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--data",
+ "description": "Full file name for a custom model. Model must be in TFLite format compiled with Vela.",
+ "values": [],
+ "alias": "input_file"
+ },
+ {
+ "name": "ethosu.num_macs=",
+ "description": "Arm Ethos-U55 configuration - the number represents MACs per cycle.",
+ "values": [
+ "32",
+ "64",
+ "128",
+ "256"
+ ],
+ "default_value": "256",
+ "alias": "mac"
+ }
+ ]
+ }
+ }
+]
diff --git a/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/aiet/systems/corstone-310/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/profiles.json b/src/mlia/resources/profiles.json
new file mode 100644
index 0000000..4493d7b
--- /dev/null
+++ b/src/mlia/resources/profiles.json
@@ -0,0 +1,20 @@
+{
+ "ethos-u55-256": {
+ "target": "ethos-u55",
+ "mac": 256,
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "memory_mode": "Shared_Sram"
+ },
+ "ethos-u55-128": {
+ "target": "ethos-u55",
+ "mac": 128,
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "memory_mode": "Shared_Sram"
+ },
+ "ethos-u65-512": {
+ "target": "ethos-u65",
+ "mac": 512,
+ "system_config": "Ethos_U65_High_End",
+ "memory_mode": "Dedicated_Sram"
+ }
+}
diff --git a/src/mlia/resources/profiles.json.license b/src/mlia/resources/profiles.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/src/mlia/resources/profiles.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/src/mlia/resources/vela/vela.ini b/src/mlia/resources/vela/vela.ini
new file mode 100644
index 0000000..382820d
--- /dev/null
+++ b/src/mlia/resources/vela/vela.ini
@@ -0,0 +1,75 @@
+; SPDX-FileCopyrightText: Copyright 2020, 2022, Arm Limited and/or its affiliates.
+; SPDX-License-Identifier: Apache-2.0
+
+; -----------------------------------------------------------------------------
+; Vela configuration file
+; -----------------------------------------------------------------------------
+
+; System Configuration
+
+; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s)
+[System_Config.Ethos_U55_High_End_Embedded]
+core_clock=500e6
+axi0_port=Sram
+axi1_port=OffChipFlash
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+OffChipFlash_clock_scale=0.125
+OffChipFlash_burst_length=128
+OffChipFlash_read_latency=64
+OffChipFlash_write_latency=64
+
+; Ethos-U65 Embedded: SRAM (8 GB/s) and Flash (0.5 GB/s)
+[System_Config.Ethos_U65_Embedded]
+core_clock=500e6
+axi0_port=Sram
+axi1_port=OffChipFlash
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+OffChipFlash_clock_scale=0.0625
+OffChipFlash_burst_length=128
+OffChipFlash_read_latency=64
+OffChipFlash_write_latency=64
+
+; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s)
+[System_Config.Ethos_U65_High_End]
+core_clock=1e9
+axi0_port=Sram
+axi1_port=Dram
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+Dram_clock_scale=0.234375
+Dram_burst_length=128
+Dram_read_latency=500
+Dram_write_latency=250
+
+; -----------------------------------------------------------------------------
+
+; Memory Mode
+
+; SRAM Only: only one AXI port is used and the SRAM is used for all storage
+[Memory_Mode.Sram_Only]
+const_mem_area=Axi0
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+
+; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software
+; The non-SRAM memory is assumed to be read-only
+[Memory_Mode.Shared_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+
+; Dedicated SRAM: the SRAM (384KB) is only for use by the Ethos-U
+; The non-SRAM memory is assumed to be read-writeable
+[Memory_Mode.Dedicated_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi1
+cache_mem_area=Axi0
+arena_cache_size=393216
diff --git a/src/mlia/tools/__init__.py b/src/mlia/tools/__init__.py
new file mode 100644
index 0000000..184e966
--- /dev/null
+++ b/src/mlia/tools/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tools module."""
diff --git a/src/mlia/tools/aiet_wrapper.py b/src/mlia/tools/aiet_wrapper.py
new file mode 100644
index 0000000..73e82ee
--- /dev/null
+++ b/src/mlia/tools/aiet_wrapper.py
@@ -0,0 +1,435 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for AIET integration."""
+import logging
+import re
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+
+from aiet.backend.application import get_available_applications
+from aiet.backend.application import install_application
+from aiet.backend.system import get_available_systems
+from aiet.backend.system import install_system
+from mlia.utils.proc import CommandExecutor
+from mlia.utils.proc import OutputConsumer
+from mlia.utils.proc import RunningCommand
+
+
+logger = logging.getLogger(__name__)
+
+# Mapping backend -> device_type -> system_name
+_SUPPORTED_SYSTEMS = {
+ "Corstone-300": {
+ "ethos-u55": "Corstone-300: Cortex-M55+Ethos-U55",
+ "ethos-u65": "Corstone-300: Cortex-M55+Ethos-U65",
+ },
+ "Corstone-310": {
+ "ethos-u55": "Corstone-310: Cortex-M85+Ethos-U55",
+ },
+}
+
+# Mapping system_name -> memory_mode -> application
+_SYSTEM_TO_APP_MAP = {
+ "Corstone-300: Cortex-M55+Ethos-U55": {
+ "Sram": "Generic Inference Runner: Ethos-U55 SRAM",
+ "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ },
+ "Corstone-300: Cortex-M55+Ethos-U65": {
+ "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ "Dedicated_Sram": "Generic Inference Runner: Ethos-U65 Dedicated SRAM",
+ },
+ "Corstone-310: Cortex-M85+Ethos-U55": {
+ "Sram": "Generic Inference Runner: Ethos-U55 SRAM",
+ "Shared_Sram": "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ },
+}
+
+
+def get_system_name(backend: str, device_type: str) -> str:
+ """Get the AIET system name for the given backend and device type."""
+ return _SUPPORTED_SYSTEMS[backend][device_type]
+
+
+def is_supported(backend: str, device_type: Optional[str] = None) -> bool:
+ """Check if the backend (and optionally device type) is supported."""
+ if device_type is None:
+ return backend in _SUPPORTED_SYSTEMS
+
+ try:
+ get_system_name(backend, device_type)
+ return True
+ except KeyError:
+ return False
+
+
+def supported_backends() -> List[str]:
+ """Get a list of all backends supported by the AIET wrapper."""
+ return list(_SUPPORTED_SYSTEMS.keys())
+
+
+def get_all_system_names(backend: str) -> List[str]:
+ """Get all systems supported by the backend."""
+ return list(_SUPPORTED_SYSTEMS.get(backend, {}).values())
+
+
+def get_all_application_names(backend: str) -> List[str]:
+ """Get all applications supported by the backend."""
+ app_set = {
+ app
+ for sys in get_all_system_names(backend)
+ for app in _SYSTEM_TO_APP_MAP[sys].values()
+ }
+ return list(app_set)
+
+
+@dataclass
+class DeviceInfo:
+ """Device information."""
+
+ device_type: Literal["ethos-u55", "ethos-u65"]
+ mac: int
+ memory_mode: Literal["Sram", "Shared_Sram", "Dedicated_Sram"]
+
+
+@dataclass
+class ModelInfo:
+ """Model info."""
+
+ model_path: Path
+
+
+@dataclass
+class PerformanceMetrics:
+ """Performance metrics parsed from generic inference output."""
+
+ npu_active_cycles: int
+ npu_idle_cycles: int
+ npu_total_cycles: int
+ npu_axi0_rd_data_beat_received: int
+ npu_axi0_wr_data_beat_written: int
+ npu_axi1_rd_data_beat_received: int
+
+
+@dataclass
+class ExecutionParams:
+ """Application execution params."""
+
+ application: str
+ system: str
+ application_params: List[str]
+ system_params: List[str]
+ deploy_params: List[str]
+
+
+class AIETLogWriter(OutputConsumer):
+ """Redirect AIET command output to the logger."""
+
+ def feed(self, line: str) -> None:
+ """Process line from the output."""
+ logger.debug(line.strip())
+
+
+class GenericInferenceOutputParser(OutputConsumer):
+ """Generic inference app output parser."""
+
+ PATTERNS = {
+ name: tuple(re.compile(pattern, re.IGNORECASE) for pattern in patterns)
+ for name, patterns in (
+ (
+ "npu_active_cycles",
+ (
+ r"NPU ACTIVE cycles: (?P<value>\d+)",
+ r"NPU ACTIVE: (?P<value>\d+) cycles",
+ ),
+ ),
+ (
+ "npu_idle_cycles",
+ (
+ r"NPU IDLE cycles: (?P<value>\d+)",
+ r"NPU IDLE: (?P<value>\d+) cycles",
+ ),
+ ),
+ (
+ "npu_total_cycles",
+ (
+ r"NPU TOTAL cycles: (?P<value>\d+)",
+ r"NPU TOTAL: (?P<value>\d+) cycles",
+ ),
+ ),
+ (
+ "npu_axi0_rd_data_beat_received",
+ (
+ r"NPU AXI0_RD_DATA_BEAT_RECEIVED beats: (?P<value>\d+)",
+ r"NPU AXI0_RD_DATA_BEAT_RECEIVED: (?P<value>\d+) beats",
+ ),
+ ),
+ (
+ "npu_axi0_wr_data_beat_written",
+ (
+ r"NPU AXI0_WR_DATA_BEAT_WRITTEN beats: (?P<value>\d+)",
+ r"NPU AXI0_WR_DATA_BEAT_WRITTEN: (?P<value>\d+) beats",
+ ),
+ ),
+ (
+ "npu_axi1_rd_data_beat_received",
+ (
+ r"NPU AXI1_RD_DATA_BEAT_RECEIVED beats: (?P<value>\d+)",
+ r"NPU AXI1_RD_DATA_BEAT_RECEIVED: (?P<value>\d+) beats",
+ ),
+ ),
+ )
+ }
+
+ def __init__(self) -> None:
+ """Init generic inference output parser instance."""
+ self.result: Dict = {}
+
+ def feed(self, line: str) -> None:
+ """Feed new line to the parser."""
+ for name, patterns in self.PATTERNS.items():
+ for pattern in patterns:
+ match = pattern.search(line)
+
+ if match:
+ self.result[name] = int(match["value"])
+ return
+
+ def is_ready(self) -> bool:
+ """Return true if all expected data has been parsed."""
+ return self.result.keys() == self.PATTERNS.keys()
+
+ def missed_keys(self) -> List[str]:
+ """Return list of the keys that have not been found in the output."""
+ return sorted(self.PATTERNS.keys() - self.result.keys())
+
+
+class AIETRunner:
+ """AIET runner."""
+
+ def __init__(self, executor: CommandExecutor) -> None:
+ """Init AIET runner instance."""
+ self.executor = executor
+
+ @staticmethod
+ def get_installed_systems() -> List[str]:
+ """Get list of the installed systems."""
+ return [system.name for system in get_available_systems()]
+
+ @staticmethod
+ def get_installed_applications(system: Optional[str] = None) -> List[str]:
+ """Get list of the installed application."""
+ return [
+ app.name
+ for app in get_available_applications()
+ if system is None or app.can_run_on(system)
+ ]
+
+ def is_application_installed(self, application: str, system: str) -> bool:
+ """Return true if requested application installed."""
+ return application in self.get_installed_applications(system)
+
+ def is_system_installed(self, system: str) -> bool:
+ """Return true if requested system installed."""
+ return system in self.get_installed_systems()
+
+ def systems_installed(self, systems: List[str]) -> bool:
+ """Check if all provided systems are installed."""
+ if not systems:
+ return False
+
+ installed_systems = self.get_installed_systems()
+ return all(system in installed_systems for system in systems)
+
+ def applications_installed(self, applications: List[str]) -> bool:
+ """Check if all provided applications are installed."""
+ if not applications:
+ return False
+
+ installed_apps = self.get_installed_applications()
+ return all(app in installed_apps for app in applications)
+
+ def all_installed(self, systems: List[str], apps: List[str]) -> bool:
+ """Check if all provided artifacts are installed."""
+ return self.systems_installed(systems) and self.applications_installed(apps)
+
+ @staticmethod
+ def install_system(system_path: Path) -> None:
+ """Install system."""
+ install_system(system_path)
+
+ @staticmethod
+ def install_application(app_path: Path) -> None:
+ """Install application."""
+ install_application(app_path)
+
+ def run_application(self, execution_params: ExecutionParams) -> RunningCommand:
+ """Run requested application."""
+ command = [
+ "aiet",
+ "application",
+ "run",
+ "-n",
+ execution_params.application,
+ "-s",
+ execution_params.system,
+ *self._params("-p", execution_params.application_params),
+ *self._params("--system-param", execution_params.system_params),
+ *self._params("--deploy", execution_params.deploy_params),
+ ]
+
+ return self._submit(command)
+
+ @staticmethod
+ def _params(name: str, params: List[str]) -> List[str]:
+ return [p for item in [(name, param) for param in params] for p in item]
+
+ def _submit(self, command: List[str]) -> RunningCommand:
+ """Submit command for the execution."""
+ logger.debug("Submit command %s", " ".join(command))
+ return self.executor.submit(command)
+
+
+class GenericInferenceRunner(ABC):
+ """Abstract class for generic inference runner."""
+
+ def __init__(self, aiet_runner: AIETRunner):
+ """Init generic inference runner instance."""
+ self.aiet_runner = aiet_runner
+ self.running_inference: Optional[RunningCommand] = None
+
+ def run(
+ self, model_info: ModelInfo, output_consumers: List[OutputConsumer]
+ ) -> None:
+ """Run generic inference for the provided device/model."""
+ execution_params = self.get_execution_params(model_info)
+
+ self.running_inference = self.aiet_runner.run_application(execution_params)
+ self.running_inference.output_consumers = output_consumers
+ self.running_inference.consume_output()
+
+ def stop(self) -> None:
+ """Stop running inference."""
+ if self.running_inference is None:
+ return
+
+ self.running_inference.stop()
+
+ @abstractmethod
+ def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams:
+ """Get execution params for the provided model."""
+
+ def __enter__(self) -> "GenericInferenceRunner":
+ """Enter context."""
+ return self
+
+ def __exit__(self, *_args: Any) -> None:
+ """Exit context."""
+ self.stop()
+
+ def check_system_and_application(self, system_name: str, app_name: str) -> None:
+ """Check if requested system and application installed."""
+ if not self.aiet_runner.is_system_installed(system_name):
+ raise Exception(f"System {system_name} is not installed")
+
+ if not self.aiet_runner.is_application_installed(app_name, system_name):
+ raise Exception(
+ f"Application {app_name} for the system {system_name} "
+ "is not installed"
+ )
+
+
+class GenericInferenceRunnerEthosU(GenericInferenceRunner):
+ """Generic inference runner on U55/65."""
+
+ def __init__(
+ self, aiet_runner: AIETRunner, device_info: DeviceInfo, backend: str
+ ) -> None:
+ """Init generic inference runner instance."""
+ super().__init__(aiet_runner)
+
+ system_name, app_name = self.resolve_system_and_app(device_info, backend)
+ self.system_name = system_name
+ self.app_name = app_name
+ self.device_info = device_info
+
+ @staticmethod
+ def resolve_system_and_app(
+ device_info: DeviceInfo, backend: str
+ ) -> Tuple[str, str]:
+ """Find appropriate system and application for the provided device/backend."""
+ try:
+ system_name = get_system_name(backend, device_info.device_type)
+ except KeyError as ex:
+ raise RuntimeError(
+ f"Unsupported device {device_info.device_type} "
+ f"for backend {backend}"
+ ) from ex
+
+ if system_name not in _SYSTEM_TO_APP_MAP:
+ raise RuntimeError(f"System {system_name} is not installed")
+
+ try:
+ app_name = _SYSTEM_TO_APP_MAP[system_name][device_info.memory_mode]
+ except KeyError as err:
+ raise RuntimeError(
+ f"Unsupported memory mode {device_info.memory_mode}"
+ ) from err
+
+ return system_name, app_name
+
+ def get_execution_params(self, model_info: ModelInfo) -> ExecutionParams:
+ """Get execution params for Ethos-U55/65."""
+ self.check_system_and_application(self.system_name, self.app_name)
+
+ system_params = [
+ f"mac={self.device_info.mac}",
+ f"input_file={model_info.model_path.absolute()}",
+ ]
+
+ return ExecutionParams(
+ self.app_name,
+ self.system_name,
+ [],
+ system_params,
+ [],
+ )
+
+
+def get_generic_runner(device_info: DeviceInfo, backend: str) -> GenericInferenceRunner:
+ """Get generic runner for provided device and backend."""
+ aiet_runner = get_aiet_runner()
+ return GenericInferenceRunnerEthosU(aiet_runner, device_info, backend)
+
+
+def estimate_performance(
+ model_info: ModelInfo, device_info: DeviceInfo, backend: str
+) -> PerformanceMetrics:
+ """Get performance estimations."""
+ with get_generic_runner(device_info, backend) as generic_runner:
+ output_parser = GenericInferenceOutputParser()
+ output_consumers = [output_parser, AIETLogWriter()]
+
+ generic_runner.run(model_info, output_consumers)
+
+ if not output_parser.is_ready():
+ missed_data = ",".join(output_parser.missed_keys())
+ logger.debug(
+ "Unable to get performance metrics, missed data %s", missed_data
+ )
+ raise Exception("Unable to get performance metrics, insufficient data")
+
+ return PerformanceMetrics(**output_parser.result)
+
+
+def get_aiet_runner() -> AIETRunner:
+ """Return AIET runner."""
+ executor = CommandExecutor()
+ return AIETRunner(executor)
diff --git a/src/mlia/tools/metadata/__init__.py b/src/mlia/tools/metadata/__init__.py
new file mode 100644
index 0000000..f877e4f
--- /dev/null
+++ b/src/mlia/tools/metadata/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for the tools metadata."""
diff --git a/src/mlia/tools/metadata/common.py b/src/mlia/tools/metadata/common.py
new file mode 100644
index 0000000..c17a738
--- /dev/null
+++ b/src/mlia/tools/metadata/common.py
@@ -0,0 +1,290 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for installation process."""
+import logging
+from abc import ABC
+from abc import abstractmethod
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mlia.utils.misc import yes
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class InstallFromPath:
+ """Installation from the local path."""
+
+ backend_path: Path
+
+
+@dataclass
+class DownloadAndInstall:
+ """Download and install."""
+
+ eula_agreement: bool = True
+
+
+InstallationType = Union[InstallFromPath, DownloadAndInstall]
+
+
+class Installation(ABC):
+ """Base class for the installation process of the backends."""
+
+ @property
+ @abstractmethod
+ def name(self) -> str:
+ """Return name of the backend."""
+
+ @property
+ @abstractmethod
+ def description(self) -> str:
+ """Return description of the backend."""
+
+ @property
+ @abstractmethod
+ def could_be_installed(self) -> bool:
+ """Return true if backend could be installed in current environment."""
+
+ @property
+ @abstractmethod
+ def already_installed(self) -> bool:
+ """Return true if backend is already installed."""
+
+ @abstractmethod
+ def supports(self, install_type: InstallationType) -> bool:
+ """Return true if installation supports requested installation type."""
+
+ @abstractmethod
+ def install(self, install_type: InstallationType) -> None:
+ """Install the backend."""
+
+
+InstallationFilter = Callable[[Installation], bool]
+
+
+class AlreadyInstalledFilter:
+ """Filter for already installed backends."""
+
+ def __call__(self, installation: Installation) -> bool:
+ """Installation filter."""
+ return installation.already_installed
+
+
+class ReadyForInstallationFilter:
+ """Filter for ready to be installed backends."""
+
+ def __call__(self, installation: Installation) -> bool:
+ """Installation filter."""
+ return installation.could_be_installed and not installation.already_installed
+
+
+class SupportsInstallTypeFilter:
+ """Filter backends that support certain type of the installation."""
+
+ def __init__(self, installation_type: InstallationType) -> None:
+ """Init filter."""
+ self.installation_type = installation_type
+
+ def __call__(self, installation: Installation) -> bool:
+ """Installation filter."""
+ return installation.supports(self.installation_type)
+
+
+class SearchByNameFilter:
+ """Filter installation by name."""
+
+ def __init__(self, backend_name: Optional[str]) -> None:
+ """Init filter."""
+ self.backend_name = backend_name
+
+ def __call__(self, installation: Installation) -> bool:
+ """Installation filter."""
+ return not self.backend_name or installation.name == self.backend_name
+
+
+class InstallationManager(ABC):
+ """Helper class for managing installations."""
+
+ @abstractmethod
+ def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None:
+ """Install backend from the local directory."""
+
+ @abstractmethod
+ def download_and_install(
+ self, backend_name: Optional[str], eula_agreement: bool
+ ) -> None:
+ """Download and install backends."""
+
+ @abstractmethod
+ def show_env_details(self) -> None:
+ """Show environment details."""
+
+ @abstractmethod
+ def backend_installed(self, backend_name: str) -> bool:
+ """Return true if requested backend installed."""
+
+
+class InstallationFiltersMixin:
+ """Mixin for filtering installation based on different conditions."""
+
+ installations: List[Installation]
+
+ def filter_by(self, *filters: InstallationFilter) -> List[Installation]:
+ """Filter installations."""
+ return [
+ installation
+ for installation in self.installations
+ if all(filter_(installation) for filter_ in filters)
+ ]
+
+ def could_be_installed_from(
+ self, backend_path: Path, backend_name: Optional[str]
+ ) -> List[Installation]:
+ """Return installations that could be installed from provided directory."""
+ return self.filter_by(
+ SupportsInstallTypeFilter(InstallFromPath(backend_path)),
+ SearchByNameFilter(backend_name),
+ )
+
+ def could_be_downloaded_and_installed(
+ self, backend_name: Optional[str] = None
+ ) -> List[Installation]:
+ """Return installations that could be downloaded and installed."""
+ return self.filter_by(
+ SupportsInstallTypeFilter(DownloadAndInstall()),
+ SearchByNameFilter(backend_name),
+ ReadyForInstallationFilter(),
+ )
+
+ def already_installed(
+ self, backend_name: Optional[str] = None
+ ) -> List[Installation]:
+ """Return list of backends that are already installed."""
+ return self.filter_by(
+ AlreadyInstalledFilter(), SearchByNameFilter(backend_name)
+ )
+
+ def ready_for_installation(self) -> List[Installation]:
+ """Return list of the backends that could be installed."""
+ return self.filter_by(ReadyForInstallationFilter())
+
+
+class DefaultInstallationManager(InstallationManager, InstallationFiltersMixin):
+ """Interactive installation manager."""
+
+ def __init__(
+ self, installations: List[Installation], noninteractive: bool = False
+ ) -> None:
+ """Init the manager."""
+ self.installations = installations
+ self.noninteractive = noninteractive
+
+ def choose_installation_for_path(
+ self, backend_path: Path, backend_name: Optional[str]
+ ) -> Optional[Installation]:
+ """Check available installation and select one if possible."""
+ installs = self.could_be_installed_from(backend_path, backend_name)
+
+ if not installs:
+ logger.info(
+ "Unfortunatelly, it was not possible to automatically "
+ "detect type of the installed FVP. "
+ "Please, check provided path to the installed FVP."
+ )
+ return None
+
+ if len(installs) != 1:
+ names = ",".join((install.name for install in installs))
+ logger.info(
+ "Unable to correctly detect type of the installed FVP."
+ "The following FVPs are detected %s. Installation skipped.",
+ names,
+ )
+ return None
+
+ installation = installs[0]
+ if installation.already_installed:
+ logger.info(
+ "%s was found in %s, but it has been already installed.",
+ installation.name,
+ backend_path,
+ )
+ return None
+
+ return installation
+
+ def install_from(self, backend_path: Path, backend_name: Optional[str]) -> None:
+ """Install from the provided directory."""
+ installation = self.choose_installation_for_path(backend_path, backend_name)
+
+ if not installation:
+ return
+
+ prompt = (
+ f"{installation.name} was found in {backend_path}. "
+ "Would you like to install it?"
+ )
+ self._install(installation, InstallFromPath(backend_path), prompt)
+
+ def download_and_install(
+ self, backend_name: Optional[str] = None, eula_agreement: bool = True
+ ) -> None:
+ """Download and install available backends."""
+ installations = self.could_be_downloaded_and_installed(backend_name)
+
+ if not installations:
+ logger.info("No backends available for the installation.")
+ return
+
+ names = ",".join((installation.name for installation in installations))
+ logger.info("Following backends are available for downloading: %s", names)
+
+ for installation in installations:
+ prompt = f"Would you like to download and install {installation.name}?"
+ self._install(
+ installation, DownloadAndInstall(eula_agreement=eula_agreement), prompt
+ )
+
+ def show_env_details(self) -> None:
+ """Print current state of the execution environment."""
+ if installed := self.already_installed():
+ logger.info("Installed backends:\n")
+
+ for installation in installed:
+ logger.info(" - %s", installation.name)
+
+ if could_be_installed := self.ready_for_installation():
+ logger.info("Following backends could be installed:")
+
+ for installation in could_be_installed:
+ logger.info(" - %s", installation.name)
+
+ if not installed and not could_be_installed:
+ logger.info("No backends installed")
+
+ def _install(
+ self,
+ installation: Installation,
+ installation_type: InstallationType,
+ prompt: str,
+ ) -> None:
+ proceed = self.noninteractive or yes(prompt)
+
+ if proceed:
+ installation.install(installation_type)
+ logger.info("%s successfully installed.", installation.name)
+ else:
+ logger.info("%s installation canceled.", installation.name)
+
+ def backend_installed(self, backend_name: str) -> bool:
+ """Return true if requested backend installed."""
+ installations = self.already_installed(backend_name)
+
+ return len(installations) == 1
diff --git a/src/mlia/tools/metadata/corstone.py b/src/mlia/tools/metadata/corstone.py
new file mode 100644
index 0000000..7a9d113
--- /dev/null
+++ b/src/mlia/tools/metadata/corstone.py
@@ -0,0 +1,402 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for Corstone based FVPs."""
+import logging
+import platform
+import subprocess
+import tarfile
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+import mlia.tools.aiet_wrapper as aiet
+from mlia.tools.metadata.common import DownloadAndInstall
+from mlia.tools.metadata.common import Installation
+from mlia.tools.metadata.common import InstallationType
+from mlia.tools.metadata.common import InstallFromPath
+from mlia.utils.download import DownloadArtifact
+from mlia.utils.filesystem import all_files_exist
+from mlia.utils.filesystem import all_paths_valid
+from mlia.utils.filesystem import copy_all
+from mlia.utils.filesystem import get_mlia_resources
+from mlia.utils.filesystem import temp_directory
+from mlia.utils.proc import working_directory
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class BackendInfo:
+ """Backend information."""
+
+ backend_path: Path
+ copy_source: bool = True
+ system_config: Optional[str] = None
+
+
+PathChecker = Callable[[Path], Optional[BackendInfo]]
+BackendInstaller = Callable[[bool, Path], Path]
+
+
+class AIETMetadata:
+ """AIET installation metadata."""
+
+ def __init__(
+ self,
+ name: str,
+ description: str,
+ system_config: str,
+ apps_resources: List[str],
+ fvp_dir_name: str,
+ download_artifact: Optional[DownloadArtifact],
+ supported_platforms: Optional[List[str]] = None,
+ ) -> None:
+ """
+ Initialize AIETMetaData.
+
+ Members expected_systems and expected_apps are filled automatically.
+ """
+ self.name = name
+ self.description = description
+ self.system_config = system_config
+ self.apps_resources = apps_resources
+ self.fvp_dir_name = fvp_dir_name
+ self.download_artifact = download_artifact
+ self.supported_platforms = supported_platforms
+
+ self.expected_systems = aiet.get_all_system_names(name)
+ self.expected_apps = aiet.get_all_application_names(name)
+
+ @property
+ def expected_resources(self) -> Iterable[Path]:
+ """Return list of expected resources."""
+ resources = [self.system_config, *self.apps_resources]
+
+ return (get_mlia_resources() / resource for resource in resources)
+
+ @property
+ def supported_platform(self) -> bool:
+ """Return true if current platform supported."""
+ if not self.supported_platforms:
+ return True
+
+ return platform.system() in self.supported_platforms
+
+
+class AIETBasedInstallation(Installation):
+ """Backend installation based on AIET functionality."""
+
+ def __init__(
+ self,
+ aiet_runner: aiet.AIETRunner,
+ metadata: AIETMetadata,
+ path_checker: PathChecker,
+ backend_installer: Optional[BackendInstaller],
+ ) -> None:
+ """Init the tool installation."""
+ self.aiet_runner = aiet_runner
+ self.metadata = metadata
+ self.path_checker = path_checker
+ self.backend_installer = backend_installer
+
+ @property
+ def name(self) -> str:
+ """Return name of the tool."""
+ return self.metadata.name
+
+ @property
+ def description(self) -> str:
+ """Return description of the tool."""
+ return self.metadata.description
+
+ @property
+ def already_installed(self) -> bool:
+ """Return true if tool already installed."""
+ return self.aiet_runner.all_installed(
+ self.metadata.expected_systems, self.metadata.expected_apps
+ )
+
+ @property
+ def could_be_installed(self) -> bool:
+ """Return true if tool could be installed."""
+ if not self.metadata.supported_platform:
+ return False
+
+ return all_paths_valid(self.metadata.expected_resources)
+
+ def supports(self, install_type: InstallationType) -> bool:
+ """Return true if tools supported type of the installation."""
+ if isinstance(install_type, DownloadAndInstall):
+ return self.metadata.download_artifact is not None
+
+ if isinstance(install_type, InstallFromPath):
+ return self.path_checker(install_type.backend_path) is not None
+
+ return False # type: ignore
+
+ def install(self, install_type: InstallationType) -> None:
+ """Install the tool."""
+ if isinstance(install_type, DownloadAndInstall):
+ download_artifact = self.metadata.download_artifact
+ assert download_artifact is not None, "No artifact provided"
+
+ self.download_and_install(download_artifact, install_type.eula_agreement)
+ elif isinstance(install_type, InstallFromPath):
+ backend_path = self.path_checker(install_type.backend_path)
+ assert backend_path is not None, "Unable to resolve backend path"
+
+ self.install_from(backend_path)
+ else:
+ raise Exception(f"Unable to install {install_type}")
+
+ def install_from(self, backend_info: BackendInfo) -> None:
+ """Install tool from the directory."""
+ mlia_resources = get_mlia_resources()
+
+ with temp_directory() as tmpdir:
+ fvp_dist_dir = tmpdir / self.metadata.fvp_dir_name
+
+ system_config = self.metadata.system_config
+ if backend_info.system_config:
+ system_config = backend_info.system_config
+
+ resources_to_copy = [mlia_resources / system_config]
+ if backend_info.copy_source:
+ resources_to_copy.append(backend_info.backend_path)
+
+ copy_all(*resources_to_copy, dest=fvp_dist_dir)
+
+ self.aiet_runner.install_system(fvp_dist_dir)
+
+ for app in self.metadata.apps_resources:
+ self.aiet_runner.install_application(mlia_resources / app)
+
+ def download_and_install(
+ self, download_artifact: DownloadArtifact, eula_agrement: bool
+ ) -> None:
+ """Download and install the tool."""
+ with temp_directory() as tmpdir:
+ try:
+ downloaded_to = download_artifact.download_to(tmpdir)
+ except Exception as err:
+ raise Exception("Unable to download backend artifact") from err
+
+ with working_directory(tmpdir / "dist", create_dir=True) as dist_dir:
+ with tarfile.open(downloaded_to) as archive:
+ archive.extractall(dist_dir)
+
+ assert self.backend_installer, (
+ f"Backend '{self.metadata.name}' does not support "
+ "download and installation."
+ )
+ backend_path = self.backend_installer(eula_agrement, dist_dir)
+ if self.path_checker(backend_path) is None:
+ raise Exception("Downloaded artifact has invalid structure")
+
+ self.install(InstallFromPath(backend_path))
+
+
+class PackagePathChecker:
+ """Package path checker."""
+
+ def __init__(
+ self, expected_files: List[str], backend_subfolder: Optional[str] = None
+ ) -> None:
+ """Init the path checker."""
+ self.expected_files = expected_files
+ self.backend_subfolder = backend_subfolder
+
+ def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ """Check if directory contains all expected files."""
+ resolved_paths = (backend_path / file for file in self.expected_files)
+ if not all_files_exist(resolved_paths):
+ return None
+
+ if self.backend_subfolder:
+ subfolder = backend_path / self.backend_subfolder
+
+ if not subfolder.is_dir():
+ return None
+
+ return BackendInfo(subfolder)
+
+ return BackendInfo(backend_path)
+
+
+class StaticPathChecker:
+ """Static path checker."""
+
+ def __init__(
+ self,
+ static_backend_path: Path,
+ expected_files: List[str],
+ copy_source: bool = False,
+ system_config: Optional[str] = None,
+ ) -> None:
+ """Init static path checker."""
+ self.static_backend_path = static_backend_path
+ self.expected_files = expected_files
+ self.copy_source = copy_source
+ self.system_config = system_config
+
+ def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ """Check if directory equals static backend path with all expected files."""
+ if backend_path != self.static_backend_path:
+ return None
+
+ resolved_paths = (backend_path / file for file in self.expected_files)
+ if not all_files_exist(resolved_paths):
+ return None
+
+ return BackendInfo(
+ backend_path,
+ copy_source=self.copy_source,
+ system_config=self.system_config,
+ )
+
+
+class CompoundPathChecker:
+ """Compound path checker."""
+
+ def __init__(self, *path_checkers: PathChecker) -> None:
+ """Init compound path checker."""
+ self.path_checkers = path_checkers
+
+ def __call__(self, backend_path: Path) -> Optional[BackendInfo]:
+ """Iterate over checkers and return first non empty backend info."""
+ first_resolved_backend_info = (
+ backend_info
+ for path_checker in self.path_checkers
+ if (backend_info := path_checker(backend_path)) is not None
+ )
+
+ return next(first_resolved_backend_info, None)
+
+
+class Corstone300Installer:
+ """Helper class that wraps Corstone 300 installation logic."""
+
+ def __call__(self, eula_agreement: bool, dist_dir: Path) -> Path:
+ """Install Corstone-300 and return path to the models."""
+ with working_directory(dist_dir):
+ install_dir = "corstone-300"
+ try:
+ fvp_install_cmd = [
+ "./FVP_Corstone_SSE-300.sh",
+ "-q",
+ "-d",
+ install_dir,
+ ]
+ if not eula_agreement:
+ fvp_install_cmd += [
+ "--nointeractive",
+ "--i-agree-to-the-contained-eula",
+ ]
+
+ subprocess.check_call(fvp_install_cmd)
+ except subprocess.CalledProcessError as err:
+ raise Exception(
+ "Error occurred during Corstone-300 installation"
+ ) from err
+
+ return dist_dir / install_dir
+
+
+def get_corstone_300_installation() -> Installation:
+ """Get Corstone-300 installation."""
+ corstone_300 = AIETBasedInstallation(
+ aiet_runner=aiet.get_aiet_runner(),
+ # pylint: disable=line-too-long
+ metadata=AIETMetadata(
+ name="Corstone-300",
+ description="Corstone-300 FVP",
+ system_config="aiet/systems/corstone-300/aiet-config.json",
+ apps_resources=[
+ "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Shared_Sram-TA",
+ "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U55-Sram_Only-TA",
+ "aiet/applications/inference_runner-sse-300-22.05.01-ethos-U65-Dedicated_Sram-TA",
+ ],
+ fvp_dir_name="corstone_300",
+ download_artifact=DownloadArtifact(
+ name="Corstone-300 FVP",
+ url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.16_26.tgz",
+ filename="FVP_Corstone_SSE-300_11.16_26.tgz",
+ version="11.16_26",
+ sha256_hash="e26139be756b5003a30d978c629de638aed1934d597dc24a17043d4708e934d7",
+ ),
+ supported_platforms=["Linux"],
+ ),
+ # pylint: enable=line-too-long
+ path_checker=CompoundPathChecker(
+ PackagePathChecker(
+ expected_files=[
+ "models/Linux64_GCC-6.4/FVP_Corstone_SSE-300_Ethos-U55",
+ "models/Linux64_GCC-6.4/FVP_Corstone_SSE-300_Ethos-U65",
+ ],
+ backend_subfolder="models/Linux64_GCC-6.4",
+ ),
+ StaticPathChecker(
+ static_backend_path=Path("/opt/VHT"),
+ expected_files=[
+ "VHT_Corstone_SSE-300_Ethos-U55",
+ "VHT_Corstone_SSE-300_Ethos-U65",
+ ],
+ copy_source=False,
+ system_config="aiet/systems/corstone-300-vht/aiet-config.json",
+ ),
+ ),
+ backend_installer=Corstone300Installer(),
+ )
+
+ return corstone_300
+
+
+def get_corstone_310_installation() -> Installation:
+ """Get Corstone-310 installation."""
+ corstone_310 = AIETBasedInstallation(
+ aiet_runner=aiet.get_aiet_runner(),
+ # pylint: disable=line-too-long
+ metadata=AIETMetadata(
+ name="Corstone-310",
+ description="Corstone-310 FVP",
+ system_config="aiet/systems/corstone-310/aiet-config.json",
+ apps_resources=[
+ "aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Shared_Sram-TA",
+ "aiet/applications/inference_runner-sse-310-22.05.01-ethos-U55-Sram_Only-TA",
+ ],
+ fvp_dir_name="corstone_310",
+ download_artifact=None,
+ supported_platforms=["Linux"],
+ ),
+ # pylint: enable=line-too-long
+ path_checker=CompoundPathChecker(
+ PackagePathChecker(
+ expected_files=[
+ "models/Linux64_GCC-9.3/FVP_Corstone_SSE-310",
+ ],
+ backend_subfolder="models/Linux64_GCC-9.3",
+ ),
+ StaticPathChecker(
+ static_backend_path=Path("/opt/VHT"),
+ expected_files=[
+ "VHT_Corstone_SSE-310",
+ ],
+ copy_source=False,
+ system_config="aiet/systems/corstone-310-vht/aiet-config.json",
+ ),
+ ),
+ backend_installer=None,
+ )
+
+ return corstone_310
+
+
+def get_corstone_installations() -> List[Installation]:
+ """Get Corstone installations."""
+ return [
+ get_corstone_300_installation(),
+ get_corstone_310_installation(),
+ ]
diff --git a/src/mlia/tools/vela_wrapper.py b/src/mlia/tools/vela_wrapper.py
new file mode 100644
index 0000000..7225797
--- /dev/null
+++ b/src/mlia/tools/vela_wrapper.py
@@ -0,0 +1,500 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Vela wrapper module."""
+import itertools
+import logging
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+from ethosu.vela.architecture_features import ArchitectureFeatures
+from ethosu.vela.compiler_driver import compiler_driver
+from ethosu.vela.compiler_driver import CompilerOptions
+from ethosu.vela.compiler_driver import TensorAllocator
+from ethosu.vela.model_reader import ModelReaderOptions
+from ethosu.vela.model_reader import read_model
+from ethosu.vela.nn_graph import Graph
+from ethosu.vela.nn_graph import NetworkType
+from ethosu.vela.npu_performance import PassCycles
+from ethosu.vela.operation import CustomType
+from ethosu.vela.operation import Op
+from ethosu.vela.scheduler import OptimizationStrategy
+from ethosu.vela.scheduler import SchedulerOptions
+from ethosu.vela.tensor import BandwidthDirection
+from ethosu.vela.tensor import MemArea
+from ethosu.vela.tensor import Tensor
+from ethosu.vela.tflite_mapping import optype_to_builtintype
+from ethosu.vela.tflite_model_semantic import TFLiteSemantic
+from ethosu.vela.tflite_supported_operators import TFLiteSupportedOperators
+from ethosu.vela.tflite_writer import write_tflite
+from ethosu.vela.vela import generate_supported_ops
+
+from mlia.utils.logging import redirect_output
+
+
+logger = logging.getLogger(__name__)
+
+VELA_INTERNAL_OPS = (Op.Placeholder, Op.SubgraphInput, Op.Const)
+
+
+@dataclass
+class PerformanceMetrics: # pylint: disable=too-many-instance-attributes
+ """Contains all the performance metrics Vela generates in a run."""
+
+ npu_cycles: int
+ sram_access_cycles: int
+ dram_access_cycles: int
+ on_chip_flash_access_cycles: int
+ off_chip_flash_access_cycles: int
+ total_cycles: int
+ batch_inference_time: float
+ inferences_per_second: float
+ batch_size: int
+ unknown_memory_area_size: int
+ sram_memory_area_size: int
+ dram_memory_area_size: int
+ on_chip_flash_memory_area_size: int
+ off_chip_flash_memory_area_size: int
+
+
+@dataclass
+class NpuSupported:
+ """Operator's npu supported attribute."""
+
+ supported: bool
+ reasons: List[Tuple[str, str]]
+
+
+@dataclass
+class Operator:
+ """Model operator."""
+
+ name: str
+ op_type: str
+ run_on_npu: NpuSupported
+
+ @property
+ def cpu_only(self) -> bool:
+ """Return true if operator is CPU only."""
+ cpu_only_reasons = [("CPU only operator", "")]
+ return (
+ not self.run_on_npu.supported
+ and self.run_on_npu.reasons == cpu_only_reasons
+ )
+
+
+@dataclass
+class Operators:
+ """Model's operators."""
+
+ ops: List[Operator]
+
+ @property
+ def npu_supported_ratio(self) -> float:
+ """Return NPU supported ratio."""
+ total = self.total_number
+ npu_supported = self.npu_supported_number
+
+ if total == 0 or npu_supported == 0:
+ return 0
+
+ return npu_supported / total
+
+ @property
+ def npu_unsupported_ratio(self) -> float:
+ """Return NPU unsupported ratio."""
+ return 1 - self.npu_supported_ratio
+
+ @property
+ def total_number(self) -> int:
+ """Return total number of operators."""
+ return len(self.ops)
+
+ @property
+ def npu_supported_number(self) -> int:
+ """Return number of npu supported operators."""
+ return sum(op.run_on_npu.supported for op in self.ops)
+
+
+@dataclass
+class Model:
+ """Model metadata."""
+
+ nng: Graph
+ network_type: NetworkType
+
+ @property
+ def optimized(self) -> bool:
+ """Return true if model is already optimized."""
+ return any(
+ op.attrs.get("custom_type") == CustomType.ExistingNpuOp
+ for sg in self.nng.subgraphs
+ for op in sg.get_all_ops()
+ )
+
+
+@dataclass
+class OptimizedModel:
+ """Instance of the Vela optimized model."""
+
+ nng: Graph
+ arch: ArchitectureFeatures
+ compiler_options: CompilerOptions
+ scheduler_options: SchedulerOptions
+
+ def save(self, output_filename: Union[str, Path]) -> None:
+ """Save instance of the optimized model to the file."""
+ write_tflite(self.nng, output_filename)
+
+
+AcceleratorConfigType = Literal[
+ "ethos-u55-32",
+ "ethos-u55-64",
+ "ethos-u55-128",
+ "ethos-u55-256",
+ "ethos-u65-256",
+ "ethos-u65-512",
+]
+
+TensorAllocatorType = Literal["LinearAlloc", "Greedy", "HillClimb"]
+
+OptimizationStrategyType = Literal["Performance", "Size"]
+
+
+@dataclass
+class VelaCompilerOptions: # pylint: disable=too-many-instance-attributes
+ """Vela compiler options."""
+
+ config_files: Optional[Union[str, List[str]]] = None
+ system_config: str = ArchitectureFeatures.DEFAULT_CONFIG
+ memory_mode: str = ArchitectureFeatures.DEFAULT_CONFIG
+ accelerator_config: Optional[AcceleratorConfigType] = None
+ max_block_dependency: int = ArchitectureFeatures.MAX_BLOCKDEP
+ arena_cache_size: Optional[int] = None
+ tensor_allocator: TensorAllocatorType = "HillClimb"
+ cpu_tensor_alignment: int = Tensor.AllocationQuantum
+ optimization_strategy: OptimizationStrategyType = "Performance"
+ output_dir: Optional[str] = None
+ recursion_limit: int = 1000
+
+
+class VelaCompiler: # pylint: disable=too-many-instance-attributes
+ """Vela compiler wrapper."""
+
+ def __init__(self, compiler_options: VelaCompilerOptions):
+ """Init Vela wrapper instance."""
+ self.config_files = compiler_options.config_files
+ self.system_config = compiler_options.system_config
+ self.memory_mode = compiler_options.memory_mode
+ self.accelerator_config = compiler_options.accelerator_config
+ self.max_block_dependency = compiler_options.max_block_dependency
+ self.arena_cache_size = compiler_options.arena_cache_size
+ self.tensor_allocator = TensorAllocator[compiler_options.tensor_allocator]
+ self.cpu_tensor_alignment = compiler_options.cpu_tensor_alignment
+ self.optimization_strategy = OptimizationStrategy[
+ compiler_options.optimization_strategy
+ ]
+ self.output_dir = compiler_options.output_dir
+ self.recursion_limit = compiler_options.recursion_limit
+
+ sys.setrecursionlimit(self.recursion_limit)
+
+ def read_model(self, model: Union[str, Path]) -> Model:
+ """Read model."""
+ logger.debug("Read model %s", model)
+
+ nng, network_type = self._read_model(model)
+ return Model(nng, network_type)
+
+ def compile_model(self, model: Union[str, Path, Model]) -> OptimizedModel:
+ """Compile the model."""
+ if isinstance(model, (str, Path)):
+ nng, network_type = self._read_model(model)
+ else:
+ nng, network_type = model.nng, NetworkType.TFLite
+
+ if not nng:
+ raise Exception("Unable to read model")
+
+ try:
+ arch = self._architecture_features()
+ compiler_options = self._compiler_options()
+ scheduler_options = self._scheduler_options()
+
+ with redirect_output(
+ logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG
+ ):
+ compiler_driver(
+ nng, arch, compiler_options, scheduler_options, network_type
+ )
+
+ return OptimizedModel(nng, arch, compiler_options, scheduler_options)
+ except (SystemExit, Exception) as err:
+ raise Exception("Model could not be optimized with Vela compiler") from err
+
+ def get_config(self) -> Dict[str, Any]:
+ """Get compiler configuration."""
+ arch = self._architecture_features()
+
+ memory_area = {
+ mem.name: {
+ "clock_scales": arch.memory_clock_scales[mem],
+ "burst_length": arch.memory_burst_length[mem],
+ "read_latency": arch.memory_latency[mem][BandwidthDirection.Read],
+ "write_latency": arch.memory_latency[mem][BandwidthDirection.Write],
+ }
+ for mem in (
+ MemArea.Sram,
+ MemArea.Dram,
+ MemArea.OnChipFlash,
+ MemArea.OffChipFlash,
+ )
+ }
+
+ return {
+ "accelerator_config": arch.accelerator_config.value,
+ "system_config": arch.system_config,
+ "core_clock": arch.core_clock,
+ "axi0_port": arch.axi0_port.name,
+ "axi1_port": arch.axi1_port.name,
+ "memory_mode": arch.memory_mode,
+ "const_mem_area": arch.const_mem_area.name,
+ "arena_mem_area": arch.arena_mem_area.name,
+ "cache_mem_area": arch.cache_mem_area.name,
+ "arena_cache_size": arch.arena_cache_size,
+ "permanent_storage_mem_area": arch.permanent_storage_mem_area.name,
+ "feature_map_storage_mem_area": arch.feature_map_storage_mem_area.name,
+ "fast_storage_mem_area": arch.fast_storage_mem_area.name,
+ "memory_area": memory_area,
+ }
+
+ @staticmethod
+ def _read_model(model: Union[str, Path]) -> Tuple[Graph, NetworkType]:
+ """Read TFLite model."""
+ try:
+ model_path = str(model) if isinstance(model, Path) else model
+
+ with redirect_output(
+ logger, stdout_level=logging.DEBUG, stderr_level=logging.DEBUG
+ ):
+ return read_model(model_path, ModelReaderOptions()) # type: ignore
+ except (SystemExit, Exception) as err:
+ raise Exception(f"Unable to read model {model_path}") from err
+
+ def _architecture_features(self) -> ArchitectureFeatures:
+ """Return ArchitectureFeatures instance."""
+ return ArchitectureFeatures(
+ vela_config_files=self.config_files,
+ accelerator_config=self.accelerator_config,
+ system_config=self.system_config,
+ memory_mode=self.memory_mode,
+ max_blockdep=self.max_block_dependency,
+ verbose_config=False,
+ arena_cache_size=self.arena_cache_size,
+ )
+
+ def _scheduler_options(self) -> SchedulerOptions:
+ """Return SchedulerOptions instance."""
+ arch = self._architecture_features()
+
+ return SchedulerOptions(
+ optimization_strategy=self.optimization_strategy,
+ sram_target=arch.arena_cache_size,
+ verbose_schedule=False,
+ )
+
+ def _compiler_options(self) -> CompilerOptions:
+ """Return CompilerOptions instance."""
+ return CompilerOptions(
+ verbose_graph=False,
+ verbose_quantization=False,
+ verbose_packing=False,
+ verbose_tensor_purpose=False,
+ verbose_tensor_format=False,
+ verbose_allocation=False,
+ verbose_high_level_command_stream=False,
+ verbose_register_command_stream=False,
+ verbose_operators=False,
+ verbose_weights=False,
+ show_cpu_operations=False,
+ tensor_allocator=self.tensor_allocator,
+ timing=False,
+ output_dir=self.output_dir,
+ cpu_tensor_alignment=self.cpu_tensor_alignment,
+ )
+
+
+def resolve_compiler_config(
+ vela_compiler_options: VelaCompilerOptions,
+) -> Dict[str, Any]:
+ """Resolve passed compiler options.
+
+ Vela has number of configuration parameters that being
+ resolved during passing compiler options. E.g. Vela
+ reads configuration parameters from vela.ini and fills
+ it's internal structures with resolved values (memory mode,
+ system mode, etc.).
+
+ In order to get this information we need to create
+ instance of the Vela compiler first.
+ """
+ vela_compiler = VelaCompiler(vela_compiler_options)
+ return vela_compiler.get_config()
+
+
+def estimate_performance(
+ model_path: Path, compiler_options: VelaCompilerOptions
+) -> PerformanceMetrics:
+ """Return performance estimations for the model/device.
+
+ Logic for this function comes from Vela module stats_writer.py
+ """
+ logger.debug(
+ "Estimate performance for the model %s on %s",
+ model_path,
+ compiler_options.accelerator_config,
+ )
+
+ vela_compiler = VelaCompiler(compiler_options)
+
+ initial_model = vela_compiler.read_model(model_path)
+ if initial_model.optimized:
+ raise Exception("Unable to estimate performance for the given optimized model")
+
+ optimized_model = vela_compiler.compile_model(initial_model)
+
+ return _performance_metrics(optimized_model)
+
+
+def optimize_model(
+ model_path: Path, compiler_options: VelaCompilerOptions, output_model_path: Path
+) -> None:
+ """Optimize model and return it's path after optimization."""
+ logger.debug(
+ "Optimize model %s for device %s",
+ model_path,
+ compiler_options.accelerator_config,
+ )
+
+ vela_compiler = VelaCompiler(compiler_options)
+ optimized_model = vela_compiler.compile_model(model_path)
+
+ logger.debug("Save optimized model into %s", output_model_path)
+ optimized_model.save(output_model_path)
+
+
+def _performance_metrics(optimized_model: OptimizedModel) -> PerformanceMetrics:
+ """Return performance metrics for optimized model."""
+ cycles = optimized_model.nng.cycles
+
+ def memory_usage(mem_area: MemArea) -> int:
+ """Get memory usage for the proviced memory area type."""
+ memory_used: Dict[MemArea, int] = optimized_model.nng.memory_used
+ bandwidths = optimized_model.nng.bandwidths
+
+ return memory_used.get(mem_area, 0) if np.sum(bandwidths[mem_area]) > 0 else 0
+
+ midpoint_fps = np.nan
+ midpoint_inference_time = cycles[PassCycles.Total] / optimized_model.arch.core_clock
+ if midpoint_inference_time > 0:
+ midpoint_fps = 1 / midpoint_inference_time
+
+ return PerformanceMetrics(
+ npu_cycles=int(cycles[PassCycles.Npu]),
+ sram_access_cycles=int(cycles[PassCycles.SramAccess]),
+ dram_access_cycles=int(cycles[PassCycles.DramAccess]),
+ on_chip_flash_access_cycles=int(cycles[PassCycles.OnChipFlashAccess]),
+ off_chip_flash_access_cycles=int(cycles[PassCycles.OffChipFlashAccess]),
+ total_cycles=int(cycles[PassCycles.Total]),
+ batch_inference_time=midpoint_inference_time * 1000,
+ inferences_per_second=midpoint_fps,
+ batch_size=optimized_model.nng.batch_size,
+ unknown_memory_area_size=memory_usage(MemArea.Unknown),
+ sram_memory_area_size=memory_usage(MemArea.Sram),
+ dram_memory_area_size=memory_usage(MemArea.Dram),
+ on_chip_flash_memory_area_size=memory_usage(MemArea.OnChipFlash),
+ off_chip_flash_memory_area_size=memory_usage(MemArea.OffChipFlash),
+ )
+
+
+def supported_operators(
+ model_path: Path, compiler_options: VelaCompilerOptions
+) -> Operators:
+ """Return list of model's operators."""
+ logger.debug("Check supported operators for the model %s", model_path)
+
+ vela_compiler = VelaCompiler(compiler_options)
+ initial_model = vela_compiler.read_model(model_path)
+
+ return Operators(
+ [
+ Operator(op.name, optype_to_builtintype(op.type), run_on_npu(op))
+ for sg in initial_model.nng.subgraphs
+ for op in sg.get_all_ops()
+ if op.type not in VELA_INTERNAL_OPS
+ ]
+ )
+
+
+def run_on_npu(operator: Op) -> NpuSupported:
+ """Return information if operator can run on NPU.
+
+ Vela does a number of checks that can help establish whether
+ a particular operator is supported to run on NPU.
+
+ There are two groups of checks:
+ - general TFLite constraints
+ - operator specific constraints
+
+ If an operator is not supported on NPU then this function
+ will return the reason of that.
+
+ The reason is split in two parts:
+ - general description of why the operator cannot be placed on NPU
+ - details on the particular operator
+ """
+ semantic_checker = TFLiteSemantic()
+ semantic_constraints = itertools.chain(
+ semantic_checker.generic_constraints,
+ semantic_checker.specific_constraints[operator.type],
+ )
+
+ for constraint in semantic_constraints:
+ op_valid, op_reason = constraint(operator)
+ if not op_valid:
+ return NpuSupported(False, [(constraint.__doc__, op_reason)])
+
+ if operator.type not in TFLiteSupportedOperators.supported_operators:
+ reasons = (
+ [("CPU only operator", "")]
+ if operator.type not in VELA_INTERNAL_OPS
+ else []
+ )
+
+ return NpuSupported(False, reasons)
+
+ tflite_supported_operators = TFLiteSupportedOperators()
+ operation_constraints = itertools.chain(
+ tflite_supported_operators.generic_constraints,
+ tflite_supported_operators.specific_constraints[operator.type],
+ )
+ for constraint in operation_constraints:
+ op_valid, op_reason = constraint(operator)
+ if not op_valid:
+ return NpuSupported(False, [(constraint.__doc__, op_reason)])
+
+ return NpuSupported(True, [])
+
+
+def generate_supported_operators_report() -> None:
+ """Generate supported operators report in current working directory."""
+ with redirect_output(logger):
+ generate_supported_ops()
diff --git a/src/mlia/utils/__init__.py b/src/mlia/utils/__init__.py
new file mode 100644
index 0000000..ecb5ca1
--- /dev/null
+++ b/src/mlia/utils/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils module."""
diff --git a/src/mlia/utils/console.py b/src/mlia/utils/console.py
new file mode 100644
index 0000000..7cb3d83
--- /dev/null
+++ b/src/mlia/utils/console.py
@@ -0,0 +1,97 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Console output utility functions."""
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+from rich.console import Console
+from rich.console import RenderableType
+from rich.table import box
+from rich.table import Table
+from rich.text import Text
+
+
+def create_section_header(
+ section_name: Optional[str] = None, length: int = 80, sep: str = "-"
+) -> str:
+ """Return section header."""
+ if not section_name:
+ content = sep * length
+ else:
+ before = 3
+ spaces = 2
+ after = length - (len(section_name) + before + spaces)
+ if after < 0:
+ raise ValueError("Section name too long")
+ content = f"{sep * before} {section_name} {sep * after}"
+
+ return f"\n{content}\n"
+
+
+def apply_style(value: str, style: str) -> str:
+ """Apply style to the value."""
+ return f"[{style}]{value}"
+
+
+def style_improvement(result: bool) -> str:
+ """Return different text style based on result."""
+ return "green" if result else "yellow"
+
+
+def produce_table(
+ rows: Iterable,
+ headers: Optional[List[str]] = None,
+ table_style: str = "default",
+) -> str:
+ """Represent data in tabular form."""
+ table = _get_table(table_style)
+
+ if headers:
+ table.show_header = True
+ for header in headers:
+ table.add_column(header)
+
+ for row in rows:
+ table.add_row(*row)
+
+ return _convert_to_text(table)
+
+
+def _get_table(table_style: str) -> Table:
+ """Get Table instance for the provided style."""
+ if table_style == "default":
+ return Table(
+ show_header=False,
+ show_lines=True,
+ box=box.SQUARE_DOUBLE_HEAD,
+ )
+
+ if table_style == "nested":
+ return Table(
+ show_header=False,
+ box=None,
+ padding=(0, 1, 1, 0),
+ )
+
+ if table_style == "no_borders":
+ return Table(show_header=False, box=None)
+
+ raise Exception(f"Unsupported table style {table_style}")
+
+
+def _convert_to_text(*renderables: RenderableType) -> str:
+ """Convert renderable object to text."""
+ console = Console()
+ with console.capture() as capture:
+ for item in renderables:
+ console.print(item)
+
+ text = capture.get()
+ return text.rstrip()
+
+
+def remove_ascii_codes(value: str) -> str:
+ """Decode and remove ASCII codes."""
+ text = Text.from_ansi(value)
+ return text.plain
diff --git a/src/mlia/utils/download.py b/src/mlia/utils/download.py
new file mode 100644
index 0000000..4658738
--- /dev/null
+++ b/src/mlia/utils/download.py
@@ -0,0 +1,89 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils for files downloading."""
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+import requests
+from rich.progress import BarColumn
+from rich.progress import DownloadColumn
+from rich.progress import FileSizeColumn
+from rich.progress import Progress
+from rich.progress import ProgressColumn
+from rich.progress import TextColumn
+
+from mlia.utils.filesystem import sha256
+from mlia.utils.types import parse_int
+
+
+def download_progress(
+ content_chunks: Iterable[bytes], content_length: Optional[int], label: Optional[str]
+) -> Iterable[bytes]:
+ """Show progress info while reading content."""
+ columns: List[ProgressColumn] = [TextColumn("{task.description}")]
+
+ if content_length is None:
+ total = float("inf")
+ columns.append(FileSizeColumn())
+ else:
+ total = content_length
+ columns.extend([BarColumn(), DownloadColumn(binary_units=True)])
+
+ with Progress(*columns) as progress:
+ task = progress.add_task(label or "Downloading", total=total)
+
+ for chunk in content_chunks:
+ progress.update(task, advance=len(chunk))
+ yield chunk
+
+
+def download(
+ url: str,
+ dest: Path,
+ show_progress: bool = False,
+ label: Optional[str] = None,
+ chunk_size: int = 8192,
+) -> None:
+ """Download the file."""
+ with requests.get(url, stream=True) as resp:
+ resp.raise_for_status()
+ content_chunks = resp.iter_content(chunk_size=chunk_size)
+
+ if show_progress:
+ content_length = parse_int(resp.headers.get("Content-Length"))
+ content_chunks = download_progress(content_chunks, content_length, label)
+
+ with open(dest, "wb") as file:
+ for chunk in content_chunks:
+ file.write(chunk)
+
+
+@dataclass
+class DownloadArtifact:
+ """Download artifact attributes."""
+
+ name: str
+ url: str
+ filename: str
+ version: str
+ sha256_hash: str
+
+ def download_to(self, dest_dir: Path, show_progress: bool = True) -> Path:
+ """Download artifact into destination directory."""
+ if (dest := dest_dir / self.filename).exists():
+ raise ValueError(f"{dest} already exists")
+
+ download(
+ self.url,
+ dest,
+ show_progress=show_progress,
+ label=f"Downloading {self.name} ver. {self.version}",
+ )
+
+ if sha256(dest) != self.sha256_hash:
+ raise ValueError("Digests do not match")
+
+ return dest
diff --git a/src/mlia/utils/filesystem.py b/src/mlia/utils/filesystem.py
new file mode 100644
index 0000000..73a88d9
--- /dev/null
+++ b/src/mlia/utils/filesystem.py
@@ -0,0 +1,124 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils related to file management."""
+import hashlib
+import importlib.resources as pkg_resources
+import json
+import os
+import shutil
+from contextlib import contextmanager
+from pathlib import Path
+from tempfile import mkstemp
+from tempfile import TemporaryDirectory
+from typing import Any
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Union
+
+
+def get_mlia_resources() -> Path:
+ """Get the path to the resources directory."""
+ with pkg_resources.path("mlia", "__init__.py") as init_path:
+ project_root = init_path.parent
+ return project_root / "resources"
+
+
+def get_vela_config() -> Path:
+ """Get the path to the default Vela config file."""
+ return get_mlia_resources() / "vela/vela.ini"
+
+
+def get_profiles_file() -> Path:
+ """Get the Ethos-U profiles file."""
+ return get_mlia_resources() / "profiles.json"
+
+
+def get_profiles_data() -> Dict[str, Dict[str, Any]]:
+ """Get the Ethos-U profile values as a dictionary."""
+ with open(get_profiles_file(), encoding="utf-8") as json_file:
+ profiles = json.load(json_file)
+
+ if not isinstance(profiles, dict):
+ raise Exception("Profiles data format is not valid")
+
+ return profiles
+
+
+def get_profile(target: str) -> Dict[str, Any]:
+ """Get settings for the provided target profile."""
+ profiles = get_profiles_data()
+
+ if target not in profiles:
+ raise Exception(f"Unable to find target profile {target}")
+
+ return profiles[target]
+
+
+def get_supported_profile_names() -> List[str]:
+ """Get the supported Ethos-U profile names."""
+ return list(get_profiles_data().keys())
+
+
+@contextmanager
+def temp_file(suffix: Optional[str] = None) -> Generator[Path, None, None]:
+ """Create temp file and remove it after."""
+ _, tmp_file = mkstemp(suffix=suffix)
+
+ try:
+ yield Path(tmp_file)
+ finally:
+ os.remove(tmp_file)
+
+
+@contextmanager
+def temp_directory(suffix: Optional[str] = None) -> Generator[Path, None, None]:
+ """Create temp directory and remove it after."""
+ with TemporaryDirectory(suffix=suffix) as tmpdir:
+ yield Path(tmpdir)
+
+
+def file_chunks(
+ filepath: Union[Path, str], chunk_size: int = 4096
+) -> Generator[bytes, None, None]:
+ """Return sequence of the file chunks."""
+ with open(filepath, "rb") as file:
+ while data := file.read(chunk_size):
+ yield data
+
+
+def hexdigest(filepath: Union[Path, str], hash_obj: "hashlib._Hash") -> str:
+ """Return hex digest of the file."""
+ for chunk in file_chunks(filepath):
+ hash_obj.update(chunk)
+
+ return hash_obj.hexdigest()
+
+
+def sha256(filepath: Path) -> str:
+ """Return SHA256 hash of the file."""
+ return hexdigest(filepath, hashlib.sha256())
+
+
+def all_files_exist(paths: Iterable[Path]) -> bool:
+ """Check if all files are exist."""
+ return all(item.is_file() for item in paths)
+
+
+def all_paths_valid(paths: Iterable[Path]) -> bool:
+ """Check if all paths are valid."""
+ return all(item.exists() for item in paths)
+
+
+def copy_all(*paths: Path, dest: Path) -> None:
+ """Copy files/directories into destination folder."""
+ dest.mkdir(exist_ok=True)
+
+ for path in paths:
+ if path.is_file():
+ shutil.copy2(path, dest)
+
+ if path.is_dir():
+ shutil.copytree(path, dest, dirs_exist_ok=True)
diff --git a/src/mlia/utils/logging.py b/src/mlia/utils/logging.py
new file mode 100644
index 0000000..86d7567
--- /dev/null
+++ b/src/mlia/utils/logging.py
@@ -0,0 +1,120 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Logging utility functions."""
+import logging
+from contextlib import contextmanager
+from contextlib import ExitStack
+from contextlib import redirect_stderr
+from contextlib import redirect_stdout
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Generator
+from typing import List
+from typing import Optional
+
+
+class LoggerWriter:
+ """Redirect printed messages to the logger."""
+
+ def __init__(self, logger: logging.Logger, level: int):
+ """Init logger writer."""
+ self.logger = logger
+ self.level = level
+
+ def write(self, message: str) -> None:
+ """Write message."""
+ if message.strip() != "":
+ self.logger.log(self.level, message)
+
+ def flush(self) -> None:
+ """Flush buffers."""
+
+
+@contextmanager
+def redirect_output(
+ logger: logging.Logger,
+ stdout_level: int = logging.INFO,
+ stderr_level: int = logging.INFO,
+) -> Generator[None, None, None]:
+ """Redirect standard output to the logger."""
+ stdout_to_log = LoggerWriter(logger, stdout_level)
+ stderr_to_log = LoggerWriter(logger, stderr_level)
+
+ with ExitStack() as exit_stack:
+ exit_stack.enter_context(redirect_stdout(stdout_to_log)) # type: ignore
+ exit_stack.enter_context(redirect_stderr(stderr_to_log)) # type: ignore
+
+ yield
+
+
+class LogFilter(logging.Filter):
+ """Configurable log filter."""
+
+ def __init__(self, log_record_filter: Callable[[logging.LogRecord], bool]) -> None:
+ """Init log filter instance."""
+ super().__init__()
+ self.log_record_filter = log_record_filter
+
+ def filter(self, record: logging.LogRecord) -> bool:
+ """Filter log messages."""
+ return self.log_record_filter(record)
+
+ @classmethod
+ def equals(cls, log_level: int) -> "LogFilter":
+ """Return log filter that filters messages by log level."""
+
+ def filter_by_level(log_record: logging.LogRecord) -> bool:
+ return log_record.levelno == log_level
+
+ return cls(filter_by_level)
+
+ @classmethod
+ def skip(cls, log_level: int) -> "LogFilter":
+ """Return log filter that skips messages with particular level."""
+
+ def skip_by_level(log_record: logging.LogRecord) -> bool:
+ return log_record.levelno != log_level
+
+ return cls(skip_by_level)
+
+
+def create_log_handler(
+ *,
+ file_path: Optional[Path] = None,
+ stream: Optional[Any] = None,
+ log_level: Optional[int] = None,
+ log_format: Optional[str] = None,
+ log_filter: Optional[logging.Filter] = None,
+ delay: bool = True,
+) -> logging.Handler:
+ """Create logger handler."""
+ handler: Optional[logging.Handler] = None
+
+ if file_path is not None:
+ handler = logging.FileHandler(file_path, delay=delay)
+ elif stream is not None:
+ handler = logging.StreamHandler(stream)
+
+ if handler is None:
+ raise Exception("Unable to create logging handler")
+
+ if log_level:
+ handler.setLevel(log_level)
+
+ if log_format:
+ handler.setFormatter(logging.Formatter(log_format))
+
+ if log_filter:
+ handler.addFilter(log_filter)
+
+ return handler
+
+
+def attach_handlers(
+ handlers: List[logging.Handler], loggers: List[logging.Logger]
+) -> None:
+ """Attach handlers to the loggers."""
+ for handler in handlers:
+ for logger in loggers:
+ logger.addHandler(handler)
diff --git a/src/mlia/utils/misc.py b/src/mlia/utils/misc.py
new file mode 100644
index 0000000..de95448
--- /dev/null
+++ b/src/mlia/utils/misc.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Various util functions."""
+
+
+def yes(prompt: str) -> bool:
+ """Return true if user confirms the action."""
+ response = input(f"{prompt} [y/n]: ")
+ return response in ["y", "Y"]
diff --git a/src/mlia/utils/proc.py b/src/mlia/utils/proc.py
new file mode 100644
index 0000000..39aca43
--- /dev/null
+++ b/src/mlia/utils/proc.py
@@ -0,0 +1,164 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils related to process management."""
+import os
+import signal
+import subprocess
+import time
+from abc import ABC
+from abc import abstractmethod
+from contextlib import contextmanager
+from contextlib import suppress
+from pathlib import Path
+from typing import Any
+from typing import Generator
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+
+class OutputConsumer(ABC):
+ """Base class for the output consumers."""
+
+ @abstractmethod
+ def feed(self, line: str) -> None:
+ """Feed new line to the consumerr."""
+
+
+class RunningCommand:
+ """Running command."""
+
+ def __init__(self, process: subprocess.Popen) -> None:
+ """Init running command instance."""
+ self.process = process
+ self._output_consumers: Optional[List[OutputConsumer]] = None
+
+ def is_alive(self) -> bool:
+ """Return true if process is still alive."""
+ return self.process.poll() is None
+
+ def exit_code(self) -> Optional[int]:
+ """Return process's return code."""
+ return self.process.poll()
+
+ def stdout(self) -> Iterable[str]:
+ """Return std output of the process."""
+ assert self.process.stdout is not None
+
+ for line in self.process.stdout:
+ yield line
+
+ def kill(self) -> None:
+ """Kill the process."""
+ self.process.kill()
+
+ def send_signal(self, signal_num: int) -> None:
+ """Send signal to the process."""
+ self.process.send_signal(signal_num)
+
+ @property
+ def output_consumers(self) -> Optional[List[OutputConsumer]]:
+ """Property output_consumers."""
+ return self._output_consumers
+
+ @output_consumers.setter
+ def output_consumers(self, output_consumers: List[OutputConsumer]) -> None:
+ """Set output consumers."""
+ self._output_consumers = output_consumers
+
+ def consume_output(self) -> None:
+ """Pass program's output to the consumers."""
+ if self.process is None or self.output_consumers is None:
+ return
+
+ for line in self.stdout():
+ for consumer in self.output_consumers:
+ with suppress():
+ consumer.feed(line)
+
+ def stop(
+ self, wait: bool = True, num_of_attempts: int = 5, interval: float = 0.5
+ ) -> None:
+ """Stop execution."""
+ try:
+ if not self.is_alive():
+ return
+
+ self.process.send_signal(signal.SIGINT)
+ self.consume_output()
+
+ if not wait:
+ return
+
+ for _ in range(num_of_attempts):
+ time.sleep(interval)
+ if not self.is_alive():
+ break
+ else:
+ raise Exception("Unable to stop running command")
+ finally:
+ self._close_fd()
+
+ def _close_fd(self) -> None:
+ """Close file descriptors."""
+
+ def close(file_descriptor: Any) -> None:
+ """Check and close file."""
+ if file_descriptor is not None and hasattr(file_descriptor, "close"):
+ file_descriptor.close()
+
+ close(self.process.stdout)
+ close(self.process.stderr)
+
+ def wait(self, redirect_output: bool = False) -> None:
+ """Redirect process output to stdout and wait for completion."""
+ if redirect_output:
+ for line in self.stdout():
+ print(line, end="")
+
+ self.process.wait()
+
+
+class CommandExecutor:
+ """Command executor."""
+
+ @staticmethod
+ def execute(command: List[str]) -> Tuple[int, bytes, bytes]:
+ """Execute the command."""
+ result = subprocess.run(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
+ )
+
+ return (result.returncode, result.stdout, result.stderr)
+
+ @staticmethod
+ def submit(command: List[str]) -> RunningCommand:
+ """Submit command for the execution."""
+ process = subprocess.Popen( # pylint: disable=consider-using-with
+ command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT, # redirect command stderr to stdout
+ universal_newlines=True,
+ bufsize=1,
+ )
+
+ return RunningCommand(process)
+
+
+@contextmanager
+def working_directory(
+ working_dir: Path, create_dir: bool = False
+) -> Generator[Path, None, None]:
+ """Temporary change working directory."""
+ current_working_dir = Path.cwd()
+
+ if create_dir:
+ working_dir.mkdir()
+
+ os.chdir(working_dir)
+
+ try:
+ yield working_dir
+ finally:
+ os.chdir(current_working_dir)
diff --git a/src/mlia/utils/types.py b/src/mlia/utils/types.py
new file mode 100644
index 0000000..9b63928
--- /dev/null
+++ b/src/mlia/utils/types.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Types related utility functions."""
+from typing import Any
+from typing import Optional
+
+
+def is_list_of(data: Any, cls: type, elem_num: Optional[int] = None) -> bool:
+ """Check if data is a list of object of the same class."""
+ return (
+ isinstance(data, (tuple, list))
+ and all(isinstance(item, cls) for item in data)
+ and (elem_num is None or len(data) == elem_num)
+ )
+
+
+def is_number(value: str) -> bool:
+ """Return true if string contains a number."""
+ try:
+ float(value)
+ except ValueError:
+ return False
+
+ return True
+
+
+def parse_int(value: Any, default: Optional[int] = None) -> Optional[int]:
+ """Parse integer value."""
+ try:
+ return int(value)
+ except (TypeError, ValueError):
+ return default
+
+
+def only_one_selected(*options: bool) -> bool:
+ """Return true if only one True value found."""
+ return sum(options) == 1
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..4a1e153
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests module."""
diff --git a/tests/aiet/__init__.py b/tests/aiet/__init__.py
new file mode 100644
index 0000000..873a7df
--- /dev/null
+++ b/tests/aiet/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""AIET tests module."""
diff --git a/tests/aiet/conftest.py b/tests/aiet/conftest.py
new file mode 100644
index 0000000..cab3dc2
--- /dev/null
+++ b/tests/aiet/conftest.py
@@ -0,0 +1,139 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=redefined-outer-name
+"""conftest for pytest."""
+import shutil
+import tarfile
+from pathlib import Path
+from typing import Any
+
+import pytest
+from click.testing import CliRunner
+
+from aiet.backend.common import get_backend_configs
+
+
+@pytest.fixture(scope="session")
+def test_systems_path(test_resources_path: Path) -> Path:
+ """Return test systems path in a pytest fixture."""
+ return test_resources_path / "systems"
+
+
+@pytest.fixture(scope="session")
+def test_applications_path(test_resources_path: Path) -> Path:
+ """Return test applications path in a pytest fixture."""
+ return test_resources_path / "applications"
+
+
+@pytest.fixture(scope="session")
+def test_tools_path(test_resources_path: Path) -> Path:
+ """Return test tools path in a pytest fixture."""
+ return test_resources_path / "tools"
+
+
+@pytest.fixture(scope="session")
+def test_resources_path() -> Path:
+ """Return test resources path in a pytest fixture."""
+ current_path = Path(__file__).parent.absolute()
+ return current_path / "test_resources"
+
+
+@pytest.fixture(scope="session")
+def non_optimised_input_model_file(test_tflite_model: Path) -> Path:
+ """Provide the path to a quantized dummy model file."""
+ return test_tflite_model
+
+
+@pytest.fixture(scope="session")
+def optimised_input_model_file(test_tflite_vela_model: Path) -> Path:
+ """Provide path to Vela-optimised dummy model file."""
+ return test_tflite_vela_model
+
+
+@pytest.fixture(scope="session")
+def invalid_input_model_file(test_tflite_invalid_model: Path) -> Path:
+ """Provide the path to an invalid dummy model file."""
+ return test_tflite_invalid_model
+
+
+@pytest.fixture(autouse=True)
+def test_resources(monkeypatch: pytest.MonkeyPatch, test_resources_path: Path) -> Any:
+ """Force using test resources as middleware's repository."""
+
+ def get_test_resources() -> Path:
+ """Return path to the test resources."""
+ return test_resources_path
+
+ monkeypatch.setattr("aiet.utils.fs.get_aiet_resources", get_test_resources)
+ yield
+
+
+@pytest.fixture(scope="session", autouse=True)
+def add_tools(test_resources_path: Path) -> Any:
+ """Symlink the tools from the original resources path to the test resources path."""
+ # tool_dirs = get_available_tool_directory_names()
+ tool_dirs = [cfg.parent for cfg in get_backend_configs("tools")]
+
+ links = {
+ src_dir: (test_resources_path / "tools" / src_dir.name) for src_dir in tool_dirs
+ }
+ for src_dir, dst_dir in links.items():
+ if not dst_dir.exists():
+ dst_dir.symlink_to(src_dir, target_is_directory=True)
+ yield
+ # Remove symlinks
+ for dst_dir in links.values():
+ if dst_dir.is_symlink():
+ dst_dir.unlink()
+
+
+def create_archive(
+ archive_name: str, source: Path, destination: Path, with_root_folder: bool = False
+) -> None:
+ """Create archive from directory source."""
+ with tarfile.open(destination / archive_name, mode="w:gz") as tar:
+ for item in source.iterdir():
+ item_name = item.name
+ if with_root_folder:
+ item_name = f"{source.name}/{item_name}"
+ tar.add(item, item_name)
+
+
+def process_directory(source: Path, destination: Path) -> None:
+ """Process resource directory."""
+ destination.mkdir()
+
+ for item in source.iterdir():
+ if item.is_dir():
+ create_archive(f"{item.name}.tar.gz", item, destination)
+ create_archive(f"{item.name}_dir.tar.gz", item, destination, True)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def add_archives(
+ test_resources_path: Path, tmp_path_factory: pytest.TempPathFactory
+) -> Any:
+ """Generate archives of the test resources."""
+ tmp_path = tmp_path_factory.mktemp("archives")
+
+ archives_path = tmp_path / "archives"
+ archives_path.mkdir()
+
+ if (archives_path_link := test_resources_path / "archives").is_symlink():
+ archives_path.unlink()
+
+ archives_path_link.symlink_to(archives_path, target_is_directory=True)
+
+ for item in ["applications", "systems"]:
+ process_directory(test_resources_path / item, archives_path / item)
+
+ yield
+
+ archives_path_link.unlink()
+ shutil.rmtree(tmp_path)
+
+
+@pytest.fixture(scope="module")
+def cli_runner() -> CliRunner:
+ """Return CliRunner instance in a pytest fixture."""
+ return CliRunner()
diff --git a/tests/aiet/test_backend_application.py b/tests/aiet/test_backend_application.py
new file mode 100644
index 0000000..abfab00
--- /dev/null
+++ b/tests/aiet/test_backend_application.py
@@ -0,0 +1,452 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use
+"""Tests for the application backend."""
+from collections import Counter
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import List
+from unittest.mock import MagicMock
+
+import pytest
+
+from aiet.backend.application import Application
+from aiet.backend.application import get_application
+from aiet.backend.application import get_available_application_directory_names
+from aiet.backend.application import get_available_applications
+from aiet.backend.application import get_unique_application_names
+from aiet.backend.application import install_application
+from aiet.backend.application import load_applications
+from aiet.backend.application import remove_application
+from aiet.backend.common import Command
+from aiet.backend.common import DataPaths
+from aiet.backend.common import Param
+from aiet.backend.common import UserParamConfig
+from aiet.backend.config import ApplicationConfig
+from aiet.backend.config import ExtendedApplicationConfig
+from aiet.backend.config import NamedExecutionConfig
+
+
+def test_get_available_application_directory_names() -> None:
+ """Test get_available_applicationss mocking get_resources."""
+ directory_names = get_available_application_directory_names()
+ assert Counter(directory_names) == Counter(
+ ["application1", "application2", "application4", "application5"]
+ )
+
+
+def test_get_available_applications() -> None:
+ """Test get_available_applicationss mocking get_resources."""
+ available_applications = get_available_applications()
+
+ assert all(isinstance(s, Application) for s in available_applications)
+ assert all(s != 42 for s in available_applications)
+ assert len(available_applications) == 9
+ # application_5 has multiply items with multiply supported systems
+ assert [str(s) for s in available_applications] == [
+ "application_1",
+ "application_2",
+ "application_4",
+ "application_5",
+ "application_5",
+ "application_5A",
+ "application_5A",
+ "application_5B",
+ "application_5B",
+ ]
+
+
+def test_get_unique_application_names() -> None:
+ """Test get_unique_application_names."""
+ unique_names = get_unique_application_names()
+
+ assert all(isinstance(s, str) for s in unique_names)
+ assert all(s for s in unique_names)
+ assert sorted(unique_names) == [
+ "application_1",
+ "application_2",
+ "application_4",
+ "application_5",
+ "application_5A",
+ "application_5B",
+ ]
+
+
+def test_get_application() -> None:
+ """Test get_application mocking get_resoures."""
+ application = get_application("application_1")
+ if len(application) != 1:
+ pytest.fail("Unable to get application")
+ assert application[0].name == "application_1"
+
+ application = get_application("unknown application")
+ assert len(application) == 0
+
+
+@pytest.mark.parametrize(
+ "source, call_count, expected_exception",
+ (
+ (
+ "archives/applications/application1.tar.gz",
+ 0,
+ pytest.raises(
+ Exception, match=r"Applications \[application_1\] are already installed"
+ ),
+ ),
+ (
+ "various/applications/application_with_empty_config",
+ 0,
+ pytest.raises(Exception, match="No application definition found"),
+ ),
+ (
+ "various/applications/application_with_wrong_config1",
+ 0,
+ pytest.raises(Exception, match="Unable to read application definition"),
+ ),
+ (
+ "various/applications/application_with_wrong_config2",
+ 0,
+ pytest.raises(Exception, match="Unable to read application definition"),
+ ),
+ (
+ "various/applications/application_with_wrong_config3",
+ 0,
+ pytest.raises(Exception, match="Unable to read application definition"),
+ ),
+ ("various/applications/application_with_valid_config", 1, does_not_raise()),
+ (
+ "archives/applications/application3.tar.gz",
+ 0,
+ pytest.raises(Exception, match="Unable to read application definition"),
+ ),
+ (
+ "applications/application1",
+ 0,
+ pytest.raises(
+ Exception, match=r"Applications \[application_1\] are already installed"
+ ),
+ ),
+ (
+ "applications/application3",
+ 0,
+ pytest.raises(Exception, match="Unable to read application definition"),
+ ),
+ ),
+)
+def test_install_application(
+ monkeypatch: Any,
+ test_resources_path: Path,
+ source: str,
+ call_count: int,
+ expected_exception: Any,
+) -> None:
+ """Test application install from archive."""
+ mock_create_destination_and_install = MagicMock()
+ monkeypatch.setattr(
+ "aiet.backend.application.create_destination_and_install",
+ mock_create_destination_and_install,
+ )
+
+ with expected_exception:
+ install_application(test_resources_path / source)
+ assert mock_create_destination_and_install.call_count == call_count
+
+
+def test_remove_application(monkeypatch: Any) -> None:
+ """Test application removal."""
+ mock_remove_backend = MagicMock()
+ monkeypatch.setattr("aiet.backend.application.remove_backend", mock_remove_backend)
+
+ remove_application("some_application_directory")
+ mock_remove_backend.assert_called_once()
+
+
+def test_application_config_without_commands() -> None:
+ """Test application config without commands."""
+ config = ApplicationConfig(name="application")
+ application = Application(config)
+ # pylint: disable=use-implicit-booleaness-not-comparison
+ assert application.commands == {}
+
+
+class TestApplication:
+ """Test for application class methods."""
+
+ def test___eq__(self) -> None:
+ """Test overloaded __eq__ method."""
+ config = ApplicationConfig(
+ # Application
+ supported_systems=["system1", "system2"],
+ build_dir="build_dir",
+ # inherited from Backend
+ name="name",
+ description="description",
+ commands={},
+ )
+ application1 = Application(config)
+ application2 = Application(config) # Identical
+ assert application1 == application2
+
+ application3 = Application(config) # changed
+ # Change one single attribute so not equal, but same Type
+ setattr(application3, "supported_systems", ["somewhere/else"])
+ assert application1 != application3
+
+ # different Type
+ application4 = "Not the Application you are looking for"
+ assert application1 != application4
+
+ application5 = Application(config)
+ # supported systems could be in any order
+ setattr(application5, "supported_systems", ["system2", "system1"])
+ assert application1 == application5
+
+ def test_can_run_on(self) -> None:
+ """Test Application can run on."""
+ config = ApplicationConfig(name="application", supported_systems=["System-A"])
+
+ application = Application(config)
+ assert application.can_run_on("System-A")
+ assert not application.can_run_on("System-B")
+
+ applications = get_application("application_1", "System 1")
+ assert len(applications) == 1
+ assert applications[0].can_run_on("System 1")
+
+ def test_get_deploy_data(self, tmp_path: Path) -> None:
+ """Test Application can run on."""
+ src, dest = "src", "dest"
+ config = ApplicationConfig(
+ name="application", deploy_data=[(src, dest)], config_location=tmp_path
+ )
+ src_path = tmp_path / src
+ src_path.mkdir()
+ application = Application(config)
+ assert application.get_deploy_data() == [DataPaths(src_path, dest)]
+
+ def test_get_deploy_data_no_config_location(self) -> None:
+ """Test that getting deploy data fails if no config location provided."""
+ with pytest.raises(
+ Exception, match="Unable to get application .* config location"
+ ):
+ Application(ApplicationConfig(name="application")).get_deploy_data()
+
+ def test_unable_to_create_application_without_name(self) -> None:
+ """Test that it is not possible to create application without name."""
+ with pytest.raises(Exception, match="Name is empty"):
+ Application(ApplicationConfig())
+
+ def test_application_config_without_commands(self) -> None:
+ """Test application config without commands."""
+ config = ApplicationConfig(name="application")
+ application = Application(config)
+ # pylint: disable=use-implicit-booleaness-not-comparison
+ assert application.commands == {}
+
+ @pytest.mark.parametrize(
+ "config, expected_params",
+ (
+ (
+ ApplicationConfig(
+ name="application",
+ commands={"command": ["cmd {user_params:0} {user_params:1}"]},
+ user_params={
+ "command": [
+ UserParamConfig(
+ name="--param1", description="param1", alias="param1"
+ ),
+ UserParamConfig(
+ name="--param2", description="param2", alias="param2"
+ ),
+ ]
+ },
+ ),
+ [Param("--param1", "param1"), Param("--param2", "param2")],
+ ),
+ (
+ ApplicationConfig(
+ name="application",
+ commands={"command": ["cmd {user_params:param1} {user_params:1}"]},
+ user_params={
+ "command": [
+ UserParamConfig(
+ name="--param1", description="param1", alias="param1"
+ ),
+ UserParamConfig(
+ name="--param2", description="param2", alias="param2"
+ ),
+ ]
+ },
+ ),
+ [Param("--param1", "param1"), Param("--param2", "param2")],
+ ),
+ (
+ ApplicationConfig(
+ name="application",
+ commands={"command": ["cmd {user_params:param1}"]},
+ user_params={
+ "command": [
+ UserParamConfig(
+ name="--param1", description="param1", alias="param1"
+ ),
+ UserParamConfig(
+ name="--param2", description="param2", alias="param2"
+ ),
+ ]
+ },
+ ),
+ [Param("--param1", "param1")],
+ ),
+ ),
+ )
+ def test_remove_unused_params(
+ self, config: ApplicationConfig, expected_params: List[Param]
+ ) -> None:
+ """Test mod remove_unused_parameter."""
+ application = Application(config)
+ application.remove_unused_params()
+ assert application.commands["command"].params == expected_params
+
+
+@pytest.mark.parametrize(
+ "config, expected_error",
+ (
+ (
+ ExtendedApplicationConfig(name="application"),
+ pytest.raises(Exception, match="No supported systems definition provided"),
+ ),
+ (
+ ExtendedApplicationConfig(
+ name="application", supported_systems=[NamedExecutionConfig(name="")]
+ ),
+ pytest.raises(
+ Exception,
+ match="Unable to read supported system definition, name is missed",
+ ),
+ ),
+ (
+ ExtendedApplicationConfig(
+ name="application",
+ supported_systems=[
+ NamedExecutionConfig(
+ name="system",
+ commands={"command": ["cmd"]},
+ user_params={"command": [UserParamConfig(name="param")]},
+ )
+ ],
+ commands={"command": ["cmd {user_params:0}"]},
+ user_params={"command": [UserParamConfig(name="param")]},
+ ),
+ pytest.raises(
+ Exception, match="Default parameters for command .* should have aliases"
+ ),
+ ),
+ (
+ ExtendedApplicationConfig(
+ name="application",
+ supported_systems=[
+ NamedExecutionConfig(
+ name="system",
+ commands={"command": ["cmd"]},
+ user_params={"command": [UserParamConfig(name="param")]},
+ )
+ ],
+ commands={"command": ["cmd {user_params:0}"]},
+ user_params={"command": [UserParamConfig(name="param", alias="param")]},
+ ),
+ pytest.raises(
+ Exception, match="system parameters for command .* should have aliases"
+ ),
+ ),
+ ),
+)
+def test_load_application_exceptional_cases(
+ config: ExtendedApplicationConfig, expected_error: Any
+) -> None:
+ """Test exceptional cases for application load function."""
+ with expected_error:
+ load_applications(config)
+
+
+def test_load_application() -> None:
+ """Test application load function.
+
+ The main purpose of this test is to test configuration for application
+ for different systems. All configuration should be correctly
+ overridden if needed.
+ """
+ application_5 = get_application("application_5")
+ assert len(application_5) == 2
+
+ default_commands = {
+ "build": Command(["default build command"]),
+ "run": Command(["default run command"]),
+ }
+ default_variables = {"var1": "value1", "var2": "value2"}
+
+ application_5_0 = application_5[0]
+ assert application_5_0.build_dir == "default_build_dir"
+ assert application_5_0.supported_systems == ["System 1"]
+ assert application_5_0.commands == default_commands
+ assert application_5_0.variables == default_variables
+ assert application_5_0.lock is False
+
+ application_5_1 = application_5[1]
+ assert application_5_1.build_dir == application_5_0.build_dir
+ assert application_5_1.supported_systems == ["System 2"]
+ assert application_5_1.commands == application_5_1.commands
+ assert application_5_1.variables == default_variables
+
+ application_5a = get_application("application_5A")
+ assert len(application_5a) == 2
+
+ application_5a_0 = application_5a[0]
+ assert application_5a_0.supported_systems == ["System 1"]
+ assert application_5a_0.build_dir == "build_5A"
+ assert application_5a_0.commands == default_commands
+ assert application_5a_0.variables == {"var1": "new value1", "var2": "value2"}
+ assert application_5a_0.lock is False
+
+ application_5a_1 = application_5a[1]
+ assert application_5a_1.supported_systems == ["System 2"]
+ assert application_5a_1.build_dir == "build"
+ assert application_5a_1.commands == {
+ "build": Command(["default build command"]),
+ "run": Command(["run command on system 2"]),
+ }
+ assert application_5a_1.variables == {"var1": "value1", "var2": "new value2"}
+ assert application_5a_1.lock is True
+
+ application_5b = get_application("application_5B")
+ assert len(application_5b) == 2
+
+ application_5b_0 = application_5b[0]
+ assert application_5b_0.build_dir == "build_5B"
+ assert application_5b_0.supported_systems == ["System 1"]
+ assert application_5b_0.commands == {
+ "build": Command(["default build command with value for var1 System1"], []),
+ "run": Command(["default run command with value for var2 System1"]),
+ }
+ assert "non_used_command" not in application_5b_0.commands
+
+ application_5b_1 = application_5b[1]
+ assert application_5b_1.build_dir == "build"
+ assert application_5b_1.supported_systems == ["System 2"]
+ assert application_5b_1.commands == {
+ "build": Command(
+ [
+ "build command on system 2 with value"
+ " for var1 System2 {user_params:param1}"
+ ],
+ [
+ Param(
+ "--param",
+ "Sample command param",
+ ["value1", "value2", "value3"],
+ "value1",
+ )
+ ],
+ ),
+ "run": Command(["run command on system 2"], []),
+ }
diff --git a/tests/aiet/test_backend_common.py b/tests/aiet/test_backend_common.py
new file mode 100644
index 0000000..12c30ec
--- /dev/null
+++ b/tests/aiet/test_backend_common.py
@@ -0,0 +1,486 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use,protected-access
+"""Tests for the common backend module."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import IO
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+from unittest.mock import MagicMock
+
+import pytest
+
+from aiet.backend.application import Application
+from aiet.backend.common import Backend
+from aiet.backend.common import BaseBackendConfig
+from aiet.backend.common import Command
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import load_config
+from aiet.backend.common import Param
+from aiet.backend.common import parse_raw_parameter
+from aiet.backend.common import remove_backend
+from aiet.backend.config import ApplicationConfig
+from aiet.backend.config import UserParamConfig
+from aiet.backend.execution import ExecutionContext
+from aiet.backend.execution import ParamResolver
+from aiet.backend.system import System
+
+
+@pytest.mark.parametrize(
+ "directory_name, expected_exception",
+ (
+ ("some_dir", does_not_raise()),
+ (None, pytest.raises(Exception, match="No directory name provided")),
+ ),
+)
+def test_remove_backend(
+ monkeypatch: Any, directory_name: str, expected_exception: Any
+) -> None:
+ """Test remove_backend function."""
+ mock_remove_resource = MagicMock()
+ monkeypatch.setattr("aiet.backend.common.remove_resource", mock_remove_resource)
+
+ with expected_exception:
+ remove_backend(directory_name, "applications")
+
+
+@pytest.mark.parametrize(
+ "filename, expected_exception",
+ (
+ ("application_config.json", does_not_raise()),
+ (None, pytest.raises(Exception, match="Unable to read config")),
+ ),
+)
+def test_load_config(
+ filename: str, expected_exception: Any, test_resources_path: Path, monkeypatch: Any
+) -> None:
+ """Test load_config."""
+ with expected_exception:
+ configs: List[Optional[Union[Path, IO[bytes]]]] = (
+ [None]
+ if not filename
+ else [
+ # Ignore pylint warning as 'with' can't be used inside of a
+ # generator expression.
+ # pylint: disable=consider-using-with
+ open(test_resources_path / filename, "rb"),
+ test_resources_path / filename,
+ ]
+ )
+ for config in configs:
+ json_mock = MagicMock()
+ monkeypatch.setattr("aiet.backend.common.json.load", json_mock)
+ load_config(config)
+ json_mock.assert_called_once()
+
+
+class TestBackend:
+ """Test Backend class."""
+
+ def test___repr__(self) -> None:
+ """Test the representation of Backend instance."""
+ backend = Backend(
+ BaseBackendConfig(name="Testing name", description="Testing description")
+ )
+ assert str(backend) == "Testing name"
+
+ def test__eq__(self) -> None:
+ """Test equality method with different cases."""
+ backend1 = Backend(BaseBackendConfig(name="name", description="description"))
+ backend1.commands = {"command": Command(["command"])}
+
+ backend2 = Backend(BaseBackendConfig(name="name", description="description"))
+ backend2.commands = {"command": Command(["command"])}
+
+ backend3 = Backend(
+ BaseBackendConfig(
+ name="Ben", description="This is not the Backend you are looking for"
+ )
+ )
+ backend3.commands = {"wave": Command(["wave hand"])}
+
+ backend4 = "Foo" # checking not isinstance(backend4, Backend)
+
+ assert backend1 == backend2
+ assert backend1 != backend3
+ assert backend1 != backend4
+
+ @pytest.mark.parametrize(
+ "parameter, valid",
+ [
+ ("--choice-param dummy_value_1", True),
+ ("--choice-param wrong_value", False),
+ ("--open-param something", True),
+ ("--wrong-param value", False),
+ ],
+ )
+ def test_validate_parameter(
+ self, parameter: str, valid: bool, test_resources_path: Path
+ ) -> None:
+ """Test validate_parameter."""
+ config = cast(
+ List[ApplicationConfig],
+ load_config(test_resources_path / "hello_world.json"),
+ )
+ # The application configuration is a list of configurations so we need
+ # only the first one
+ # Exercise the validate_parameter test using the Application classe which
+ # inherits from Backend.
+ application = Application(config[0])
+ assert application.validate_parameter("run", parameter) == valid
+
+ def test_validate_parameter_with_invalid_command(
+ self, test_resources_path: Path
+ ) -> None:
+ """Test validate_parameter with an invalid command_name."""
+ config = cast(
+ List[ApplicationConfig],
+ load_config(test_resources_path / "hello_world.json"),
+ )
+ application = Application(config[0])
+ with pytest.raises(AttributeError) as err:
+ # command foo does not exist, so raise an error
+ application.validate_parameter("foo", "bar")
+ assert "Unknown command: 'foo'" in str(err.value)
+
+ def test_build_command(self, monkeypatch: Any) -> None:
+ """Test command building."""
+ config = {
+ "name": "test",
+ "commands": {
+ "build": ["build {user_params:0} {user_params:1}"],
+ "run": ["run {user_params:0}"],
+ "post_run": ["post_run {application_params:0} on {system_params:0}"],
+ "some_command": ["Command with {variables:var_A}"],
+ "empty_command": [""],
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "choice_param_0=",
+ "values": [1, 2, 3],
+ "default_value": 1,
+ },
+ {"name": "choice_param_1", "values": [3, 4, 5], "default_value": 3},
+ {"name": "choice_param_3", "values": [6, 7, 8]},
+ ],
+ "run": [{"name": "flag_param_0"}],
+ },
+ "variables": {"var_A": "value for variable A"},
+ }
+
+ monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock())
+ application, system = Application(config), System(config) # type: ignore
+ context = ExecutionContext(
+ app=application,
+ app_params=[],
+ system=system,
+ system_params=[],
+ custom_deploy_data=[],
+ )
+
+ param_resolver = ParamResolver(context)
+
+ cmd = application.build_command(
+ "build", ["choice_param_0=2", "choice_param_1=4"], param_resolver
+ )
+ assert cmd == ["build choice_param_0=2 choice_param_1 4"]
+
+ cmd = application.build_command("build", ["choice_param_0=2"], param_resolver)
+ assert cmd == ["build choice_param_0=2 choice_param_1 3"]
+
+ cmd = application.build_command(
+ "build", ["choice_param_0=2", "choice_param_3=7"], param_resolver
+ )
+ assert cmd == ["build choice_param_0=2 choice_param_1 3"]
+
+ with pytest.raises(
+ ConfigurationException, match="Command 'foo' could not be found."
+ ):
+ application.build_command("foo", [""], param_resolver)
+
+ cmd = application.build_command("some_command", [], param_resolver)
+ assert cmd == ["Command with value for variable A"]
+
+ cmd = application.build_command("empty_command", [], param_resolver)
+ assert cmd == [""]
+
+ @pytest.mark.parametrize("class_", [Application, System])
+ def test_build_command_unknown_variable(self, class_: type) -> None:
+ """Test that unable to construct backend with unknown variable."""
+ with pytest.raises(Exception, match="Unknown variable var1"):
+ config = {"name": "test", "commands": {"run": ["run {variables:var1}"]}}
+ class_(config)
+
+ @pytest.mark.parametrize(
+ "class_, config, expected_output",
+ [
+ (
+ Application,
+ {
+ "name": "test",
+ "commands": {
+ "build": ["build {user_params:0} {user_params:1}"],
+ "run": ["run {user_params:0}"],
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "choice_param_0=",
+ "values": ["a", "b", "c"],
+ "default_value": "a",
+ "alias": "param_1",
+ },
+ {
+ "name": "choice_param_1",
+ "values": ["a", "b", "c"],
+ "default_value": "a",
+ "alias": "param_2",
+ },
+ {"name": "choice_param_3", "values": ["a", "b", "c"]},
+ ],
+ "run": [{"name": "flag_param_0"}],
+ },
+ },
+ [
+ (
+ "b",
+ Param(
+ name="choice_param_0=",
+ description="",
+ values=["a", "b", "c"],
+ default_value="a",
+ alias="param_1",
+ ),
+ ),
+ (
+ "a",
+ Param(
+ name="choice_param_1",
+ description="",
+ values=["a", "b", "c"],
+ default_value="a",
+ alias="param_2",
+ ),
+ ),
+ (
+ "c",
+ Param(
+ name="choice_param_3",
+ description="",
+ values=["a", "b", "c"],
+ ),
+ ),
+ ],
+ ),
+ (System, {"name": "test"}, []),
+ ],
+ )
+ def test_resolved_parameters(
+ self,
+ monkeypatch: Any,
+ class_: type,
+ config: Dict,
+ expected_output: List[Tuple[Optional[str], Param]],
+ ) -> None:
+ """Test command building."""
+ monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock())
+ backend = class_(config)
+
+ params = backend.resolved_parameters(
+ "build", ["choice_param_0=b", "choice_param_3=c"]
+ )
+ assert params == expected_output
+
+ @pytest.mark.parametrize(
+ ["param_name", "user_param", "expected_value"],
+ [
+ (
+ "test_name",
+ "test_name=1234",
+ "1234",
+ ), # optional parameter using '='
+ (
+ "test_name",
+ "test_name 1234",
+ "1234",
+ ), # optional parameter using ' '
+ ("test_name", "test_name", None), # flag
+ (None, "test_name=1234", "1234"), # positional parameter
+ ],
+ )
+ def test_resolved_user_parameters(
+ self, param_name: str, user_param: str, expected_value: str
+ ) -> None:
+ """Test different variants to provide user parameters."""
+ # A dummy config providing one backend config
+ config = {
+ "name": "test_backend",
+ "commands": {
+ "test": ["user_param:test_param"],
+ },
+ "user_params": {
+ "test": [UserParamConfig(name=param_name, alias="test_name")],
+ },
+ }
+ backend = Backend(cast(BaseBackendConfig, config))
+ params = backend.resolved_parameters(
+ command_name="test", user_params=[user_param]
+ )
+ assert len(params) == 1
+ value, param = params[0]
+ assert param_name == param.name
+ assert expected_value == value
+
+ @pytest.mark.parametrize(
+ "input_param,expected",
+ [
+ ("--param=1", ("--param", "1")),
+ ("--param 1", ("--param", "1")),
+ ("--flag", ("--flag", None)),
+ ],
+ )
+ def test__parse_raw_parameter(
+ self, input_param: str, expected: Tuple[str, Optional[str]]
+ ) -> None:
+ """Test internal method of parsing a single raw parameter."""
+ assert parse_raw_parameter(input_param) == expected
+
+
+class TestParam:
+ """Test Param class."""
+
+ def test__eq__(self) -> None:
+ """Test equality method with different cases."""
+ param1 = Param(name="test", description="desc", values=["values"])
+ param2 = Param(name="test", description="desc", values=["values"])
+ param3 = Param(name="test1", description="desc", values=["values"])
+ param4 = object()
+
+ assert param1 == param2
+ assert param1 != param3
+ assert param1 != param4
+
+ def test_get_details(self) -> None:
+ """Test get_details() method."""
+ param1 = Param(name="test", description="desc", values=["values"])
+ assert param1.get_details() == {
+ "name": "test",
+ "values": ["values"],
+ "description": "desc",
+ }
+
+ def test_invalid(self) -> None:
+ """Test invalid use cases for the Param class."""
+ with pytest.raises(
+ ConfigurationException,
+ match="Either name, alias or both must be set to identify a parameter.",
+ ):
+ Param(name=None, description="desc", values=["values"])
+
+
+class TestCommand:
+ """Test Command class."""
+
+ def test_get_details(self) -> None:
+ """Test get_details() method."""
+ param1 = Param(name="test", description="desc", values=["values"])
+ command1 = Command(command_strings=["echo test"], params=[param1])
+ assert command1.get_details() == {
+ "command_strings": ["echo test"],
+ "user_params": [
+ {"name": "test", "values": ["values"], "description": "desc"}
+ ],
+ }
+
+ def test__eq__(self) -> None:
+ """Test equality method with different cases."""
+ param1 = Param("test", "desc", ["values"])
+ param2 = Param("test1", "desc1", ["values1"])
+ command1 = Command(command_strings=["echo test"], params=[param1])
+ command2 = Command(command_strings=["echo test"], params=[param1])
+ command3 = Command(command_strings=["echo test"])
+ command4 = Command(command_strings=["echo test"], params=[param2])
+ command5 = object()
+
+ assert command1 == command2
+ assert command1 != command3
+ assert command1 != command4
+ assert command1 != command5
+
+ @pytest.mark.parametrize(
+ "params, expected_error",
+ [
+ [[], does_not_raise()],
+ [[Param("param", "param description", [])], does_not_raise()],
+ [
+ [
+ Param("param", "param description", [], None, "alias"),
+ Param("param", "param description", [], None),
+ ],
+ does_not_raise(),
+ ],
+ [
+ [
+ Param("param1", "param1 description", [], None, "alias1"),
+ Param("param2", "param2 description", [], None, "alias2"),
+ ],
+ does_not_raise(),
+ ],
+ [
+ [
+ Param("param", "param description", [], None, "alias"),
+ Param("param", "param description", [], None, "alias"),
+ ],
+ pytest.raises(ConfigurationException, match="Non unique aliases alias"),
+ ],
+ [
+ [
+ Param("alias", "param description", [], None, "alias1"),
+ Param("param", "param description", [], None, "alias"),
+ ],
+ pytest.raises(
+ ConfigurationException,
+ match="Aliases .* could not be used as parameter name",
+ ),
+ ],
+ [
+ [
+ Param("alias", "param description", [], None, "alias"),
+ Param("param1", "param1 description", [], None, "alias1"),
+ ],
+ does_not_raise(),
+ ],
+ [
+ [
+ Param("alias", "param description", [], None, "alias"),
+ Param("alias", "param1 description", [], None, "alias1"),
+ ],
+ pytest.raises(
+ ConfigurationException,
+ match="Aliases .* could not be used as parameter name",
+ ),
+ ],
+ [
+ [
+ Param("param1", "param1 description", [], None, "alias1"),
+ Param("param2", "param2 description", [], None, "alias1"),
+ Param("param3", "param3 description", [], None, "alias2"),
+ Param("param4", "param4 description", [], None, "alias2"),
+ ],
+ pytest.raises(
+ ConfigurationException, match="Non unique aliases alias1, alias2"
+ ),
+ ],
+ ],
+ )
+ def test_validate_params(self, params: List[Param], expected_error: Any) -> None:
+ """Test command validation function."""
+ with expected_error:
+ Command([], params)
diff --git a/tests/aiet/test_backend_controller.py b/tests/aiet/test_backend_controller.py
new file mode 100644
index 0000000..8836ec5
--- /dev/null
+++ b/tests/aiet/test_backend_controller.py
@@ -0,0 +1,160 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for system controller."""
+import csv
+import os
+import time
+from pathlib import Path
+from typing import Any
+
+import psutil
+import pytest
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.controller import SystemController
+from aiet.backend.controller import SystemControllerSingleInstance
+from aiet.utils.proc import ShellCommand
+
+
+def get_system_controller(**kwargs: Any) -> SystemController:
+ """Get service controller."""
+ single_instance = kwargs.get("single_instance", False)
+ if single_instance:
+ pid_file_path = kwargs.get("pid_file_path")
+ return SystemControllerSingleInstance(pid_file_path)
+
+ return SystemController()
+
+
+def test_service_controller() -> None:
+ """Test service controller functionality."""
+ service_controller = get_system_controller()
+
+ assert service_controller.get_output() == ("", "")
+ with pytest.raises(ConfigurationException, match="Wrong working directory"):
+ service_controller.start(["sleep 100"], Path("unknown"))
+
+ service_controller.start(["sleep 100"], Path.cwd())
+ assert service_controller.is_running()
+
+ service_controller.stop(True)
+ assert not service_controller.is_running()
+ assert service_controller.get_output() == ("", "")
+
+ service_controller.stop()
+
+ with pytest.raises(
+ ConfigurationException, match="System should have only one command to run"
+ ):
+ service_controller.start(["sleep 100", "sleep 101"], Path.cwd())
+
+ with pytest.raises(ConfigurationException, match="No startup command provided"):
+ service_controller.start([""], Path.cwd())
+
+
+def test_service_controller_bad_configuration() -> None:
+ """Test service controller functionality for bad configuration."""
+ with pytest.raises(Exception, match="No pid file path presented"):
+ service_controller = get_system_controller(
+ single_instance=True, pid_file_path=None
+ )
+ service_controller.start(["sleep 100"], Path.cwd())
+
+
+def test_service_controller_writes_process_info_correctly(tmpdir: Any) -> None:
+ """Test that controller writes process info correctly."""
+ pid_file = Path(tmpdir) / "test.pid"
+
+ service_controller = get_system_controller(
+ single_instance=True, pid_file_path=Path(tmpdir) / "test.pid"
+ )
+
+ service_controller.start(["sleep 100"], Path.cwd())
+ assert service_controller.is_running()
+ assert pid_file.is_file()
+
+ with open(pid_file, "r", encoding="utf-8") as file:
+ csv_reader = csv.reader(file)
+ rows = list(csv_reader)
+ assert len(rows) == 1
+
+ name, *_ = rows[0]
+ assert name == "sleep"
+
+ service_controller.stop()
+ assert pid_file.exists()
+
+
+def test_service_controller_does_not_write_process_info_if_process_finishes(
+ tmpdir: Any,
+) -> None:
+ """Test that controller does not write process info if process already finished."""
+ pid_file = Path(tmpdir) / "test.pid"
+ service_controller = get_system_controller(
+ single_instance=True, pid_file_path=pid_file
+ )
+ service_controller.is_running = lambda: False # type: ignore
+ service_controller.start(["echo hello"], Path.cwd())
+
+ assert not pid_file.exists()
+
+
+def test_service_controller_searches_for_previous_instances_correctly(
+ tmpdir: Any,
+) -> None:
+ """Test that controller searches for previous instances correctly."""
+ pid_file = Path(tmpdir) / "test.pid"
+ command = ShellCommand().run("sleep", "100")
+ assert command.is_alive()
+
+ pid = command.process.pid
+ process = psutil.Process(pid)
+ with open(pid_file, "w", encoding="utf-8") as file:
+ csv_writer = csv.writer(file)
+ csv_writer.writerow(("some_process", "some_program", "some_cwd", os.getpid()))
+ csv_writer.writerow((process.name(), process.exe(), process.cwd(), process.pid))
+ csv_writer.writerow(("some_old_process", "not_running", "from_nowhere", 77777))
+
+ service_controller = get_system_controller(
+ single_instance=True, pid_file_path=pid_file
+ )
+ service_controller.start(["sleep 100"], Path.cwd())
+ # controller should stop this process as it is currently running and
+ # mentioned in pid file
+ assert not command.is_alive()
+
+ service_controller.stop()
+
+
+@pytest.mark.parametrize(
+ "executable", ["test_backend_run_script.sh", "test_backend_run"]
+)
+def test_service_controller_run_shell_script(
+ executable: str, test_resources_path: Path
+) -> None:
+ """Test controller's ability to run shell scripts."""
+ script_path = test_resources_path / "scripts"
+
+ service_controller = get_system_controller()
+
+ service_controller.start([executable], script_path)
+
+ assert service_controller.is_running()
+ # give time for the command to produce output
+ time.sleep(2)
+ service_controller.stop(wait=True)
+ assert not service_controller.is_running()
+ stdout, stderr = service_controller.get_output()
+ assert stdout == "Hello from script\n"
+ assert stderr == "Oops!\n"
+
+
+def test_service_controller_does_nothing_if_not_started(tmpdir: Any) -> None:
+ """Test that nothing happened if controller is not started."""
+ service_controller = get_system_controller(
+ single_instance=True, pid_file_path=Path(tmpdir) / "test.pid"
+ )
+
+ assert not service_controller.is_running()
+ service_controller.stop()
+ assert not service_controller.is_running()
diff --git a/tests/aiet/test_backend_execution.py b/tests/aiet/test_backend_execution.py
new file mode 100644
index 0000000..8aa45f1
--- /dev/null
+++ b/tests/aiet/test_backend_execution.py
@@ -0,0 +1,526 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use
+"""Test backend context module."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import Optional
+from unittest import mock
+from unittest.mock import MagicMock
+
+import pytest
+from sh import CommandNotFound
+
+from aiet.backend.application import Application
+from aiet.backend.application import get_application
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import DataPaths
+from aiet.backend.common import UserParamConfig
+from aiet.backend.config import ApplicationConfig
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.config import SystemConfig
+from aiet.backend.execution import deploy_data
+from aiet.backend.execution import execute_commands_locally
+from aiet.backend.execution import ExecutionContext
+from aiet.backend.execution import get_application_and_system
+from aiet.backend.execution import get_application_by_name_and_system
+from aiet.backend.execution import get_file_lock_path
+from aiet.backend.execution import get_tool_by_system
+from aiet.backend.execution import ParamResolver
+from aiet.backend.execution import Reporter
+from aiet.backend.execution import wait
+from aiet.backend.output_parser import OutputParser
+from aiet.backend.system import get_system
+from aiet.backend.system import load_system
+from aiet.backend.tool import get_tool
+from aiet.utils.proc import CommandFailedException
+
+
+def test_context_param_resolver(tmpdir: Any) -> None:
+ """Test parameter resolving."""
+ system_config_location = Path(tmpdir) / "system"
+ system_config_location.mkdir()
+
+ application_config_location = Path(tmpdir) / "application"
+ application_config_location.mkdir()
+
+ ctx = ExecutionContext(
+ app=Application(
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ config_location=application_config_location,
+ build_dir="build-{application.name}-{system.name}",
+ commands={
+ "run": [
+ "run_command1 {user_params:0}",
+ "run_command2 {user_params:1}",
+ ]
+ },
+ variables={"var_1": "value for var_1"},
+ user_params={
+ "run": [
+ UserParamConfig(
+ name="--param1",
+ description="Param 1",
+ default_value="123",
+ alias="param_1",
+ ),
+ UserParamConfig(
+ name="--param2", description="Param 2", default_value="456"
+ ),
+ UserParamConfig(
+ name="--param3", description="Param 3", alias="param_3"
+ ),
+ UserParamConfig(
+ name="--param4=",
+ description="Param 4",
+ default_value="456",
+ alias="param_4",
+ ),
+ UserParamConfig(
+ description="Param 5",
+ default_value="789",
+ alias="param_5",
+ ),
+ ]
+ },
+ )
+ ),
+ app_params=["--param2=789"],
+ system=load_system(
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ config_location=system_config_location,
+ build_dir="build",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={
+ "build": ["build_command1 {user_params:0}"],
+ "run": ["run_command {application.commands.run:1}"],
+ },
+ variables={"var_1": "value for var_1"},
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="--param1", description="Param 1", default_value="aaa"
+ ),
+ UserParamConfig(name="--param2", description="Param 2"),
+ ]
+ },
+ )
+ ),
+ system_params=["--param1=bbb"],
+ custom_deploy_data=[],
+ )
+
+ param_resolver = ParamResolver(ctx)
+ expected_values = {
+ "application.name": "test_application",
+ "application.description": "Test application",
+ "application.config_dir": str(application_config_location),
+ "application.build_dir": "{}/build-test_application-test_system".format(
+ application_config_location
+ ),
+ "application.commands.run:0": "run_command1 --param1 123",
+ "application.commands.run.params:0": "123",
+ "application.commands.run.params:param_1": "123",
+ "application.commands.run:1": "run_command2 --param2 789",
+ "application.commands.run.params:1": "789",
+ "application.variables:var_1": "value for var_1",
+ "system.name": "test_system",
+ "system.description": "Test system",
+ "system.config_dir": str(system_config_location),
+ "system.commands.build:0": "build_command1 --param1 bbb",
+ "system.commands.run:0": "run_command run_command2 --param2 789",
+ "system.commands.build.params:0": "bbb",
+ "system.variables:var_1": "value for var_1",
+ }
+
+ for param, value in expected_values.items():
+ assert param_resolver(param) == value
+
+ assert ctx.build_dir() == Path(
+ "{}/build-test_application-test_system".format(application_config_location)
+ )
+
+ expected_errors = {
+ "application.variables:var_2": pytest.raises(
+ Exception, match="Unknown variable var_2"
+ ),
+ "application.commands.clean:0": pytest.raises(
+ Exception, match="Command clean not found"
+ ),
+ "application.commands.run:2": pytest.raises(
+ Exception, match="Invalid index 2 for command run"
+ ),
+ "application.commands.run.params:5": pytest.raises(
+ Exception, match="Invalid parameter index 5 for command run"
+ ),
+ "application.commands.run.params:param_2": pytest.raises(
+ Exception,
+ match="No value for parameter with index or alias param_2 of command run",
+ ),
+ "UNKNOWN": pytest.raises(
+ Exception, match="Unable to resolve parameter UNKNOWN"
+ ),
+ "system.commands.build.params:1": pytest.raises(
+ Exception,
+ match="No value for parameter with index or alias 1 of command build",
+ ),
+ "system.commands.build:A": pytest.raises(
+ Exception, match="Bad command index A"
+ ),
+ "system.variables:var_2": pytest.raises(
+ Exception, match="Unknown variable var_2"
+ ),
+ }
+ for param, error in expected_errors.items():
+ with error:
+ param_resolver(param)
+
+ resolved_params = ctx.app.resolved_parameters("run", [])
+ expected_user_params = {
+ "user_params:0": "--param1 123",
+ "user_params:param_1": "--param1 123",
+ "user_params:2": "--param3",
+ "user_params:param_3": "--param3",
+ "user_params:3": "--param4=456",
+ "user_params:param_4": "--param4=456",
+ "user_params:param_5": "789",
+ }
+ for param, expected_value in expected_user_params.items():
+ assert param_resolver(param, "run", resolved_params) == expected_value
+
+ with pytest.raises(
+ Exception, match="Invalid index 5 for user params of command run"
+ ):
+ param_resolver("user_params:5", "run", resolved_params)
+
+ with pytest.raises(
+ Exception, match="No user parameter for command 'run' with alias 'param_2'."
+ ):
+ param_resolver("user_params:param_2", "run", resolved_params)
+
+ with pytest.raises(Exception, match="Unable to resolve user params"):
+ param_resolver("user_params:0", "", resolved_params)
+
+ bad_ctx = ExecutionContext(
+ app=Application(
+ ApplicationConfig(
+ name="test_application",
+ config_location=application_config_location,
+ build_dir="build-{user_params:0}",
+ )
+ ),
+ app_params=["--param2=789"],
+ system=load_system(
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ config_location=system_config_location,
+ build_dir="build-{system.commands.run.params:123}",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ )
+ ),
+ system_params=["--param1=bbb"],
+ custom_deploy_data=[],
+ )
+ param_resolver = ParamResolver(bad_ctx)
+ with pytest.raises(Exception, match="Unable to resolve user params"):
+ bad_ctx.build_dir()
+
+
+# pylint: disable=too-many-arguments
+@pytest.mark.parametrize(
+ "application_name, soft_lock, sys_lock, lock_dir, expected_error, expected_path",
+ (
+ (
+ "test_application",
+ True,
+ True,
+ Path("/tmp"),
+ does_not_raise(),
+ Path("/tmp/middleware_test_application_test_system.lock"),
+ ),
+ (
+ "$$test_application$!:",
+ True,
+ True,
+ Path("/tmp"),
+ does_not_raise(),
+ Path("/tmp/middleware_test_application_test_system.lock"),
+ ),
+ (
+ "test_application",
+ True,
+ True,
+ Path("unknown"),
+ pytest.raises(
+ Exception, match="Invalid directory unknown for lock files provided"
+ ),
+ None,
+ ),
+ (
+ "test_application",
+ False,
+ True,
+ Path("/tmp"),
+ does_not_raise(),
+ Path("/tmp/middleware_test_system.lock"),
+ ),
+ (
+ "test_application",
+ True,
+ False,
+ Path("/tmp"),
+ does_not_raise(),
+ Path("/tmp/middleware_test_application.lock"),
+ ),
+ (
+ "test_application",
+ False,
+ False,
+ Path("/tmp"),
+ pytest.raises(Exception, match="No filename for lock provided"),
+ None,
+ ),
+ ),
+)
+def test_get_file_lock_path(
+ application_name: str,
+ soft_lock: bool,
+ sys_lock: bool,
+ lock_dir: Path,
+ expected_error: Any,
+ expected_path: Path,
+) -> None:
+ """Test get_file_lock_path function."""
+ with expected_error:
+ ctx = ExecutionContext(
+ app=Application(ApplicationConfig(name=application_name, lock=soft_lock)),
+ app_params=[],
+ system=load_system(
+ SystemConfig(
+ name="test_system",
+ lock=sys_lock,
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ )
+ ),
+ system_params=[],
+ custom_deploy_data=[],
+ )
+ path = get_file_lock_path(ctx, lock_dir)
+ assert path == expected_path
+
+
+def test_get_application_by_name_and_system(monkeypatch: Any) -> None:
+ """Test exceptional case for get_application_by_name_and_system."""
+ monkeypatch.setattr(
+ "aiet.backend.execution.get_application",
+ MagicMock(return_value=[MagicMock(), MagicMock()]),
+ )
+
+ with pytest.raises(
+ ValueError,
+ match="Error during getting application test_application for the "
+ "system test_system",
+ ):
+ get_application_by_name_and_system("test_application", "test_system")
+
+
+def test_get_application_and_system(monkeypatch: Any) -> None:
+ """Test exceptional case for get_application_and_system."""
+ monkeypatch.setattr(
+ "aiet.backend.execution.get_system", MagicMock(return_value=None)
+ )
+
+ with pytest.raises(ValueError, match="System test_system is not found"):
+ get_application_and_system("test_application", "test_system")
+
+
+def test_wait_function(monkeypatch: Any) -> None:
+ """Test wait function."""
+ sleep_mock = MagicMock()
+ monkeypatch.setattr("time.sleep", sleep_mock)
+ wait(0.1)
+ sleep_mock.assert_called_once()
+
+
+def test_deployment_execution_context() -> None:
+ """Test property 'is_deploy_needed' of the ExecutionContext."""
+ ctx = ExecutionContext(
+ app=get_application("application_1")[0],
+ app_params=[],
+ system=get_system("System 1"),
+ system_params=[],
+ )
+ assert not ctx.is_deploy_needed
+ deploy_data(ctx) # should be a NOP
+
+ ctx = ExecutionContext(
+ app=get_application("application_1")[0],
+ app_params=[],
+ system=get_system("System 1"),
+ system_params=[],
+ custom_deploy_data=[DataPaths(Path("README.md"), ".")],
+ )
+ assert ctx.is_deploy_needed
+
+ ctx = ExecutionContext(
+ app=get_application("application_1")[0],
+ app_params=[],
+ system=None,
+ system_params=[],
+ )
+ assert not ctx.is_deploy_needed
+ with pytest.raises(AssertionError):
+ deploy_data(ctx)
+
+ ctx = ExecutionContext(
+ app=get_tool("tool_1")[0],
+ app_params=[],
+ system=None,
+ system_params=[],
+ )
+ assert not ctx.is_deploy_needed
+ deploy_data(ctx) # should be a NOP
+
+
+@pytest.mark.parametrize(
+ ["tool_name", "system_name", "exception"],
+ [
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U65", None),
+ ("unknown tool", "Corstone-300: Cortex-M55+Ethos-U65", ConfigurationException),
+ ("vela", "unknown system", ConfigurationException),
+ ("vela", None, ConfigurationException),
+ ],
+)
+def test_get_tool_by_system(
+ tool_name: str, system_name: Optional[str], exception: Optional[Any]
+) -> None:
+ """Test exceptions thrown by function get_tool_by_system()."""
+
+ def test() -> None:
+ """Test call of get_tool_by_system()."""
+ tool = get_tool_by_system(tool_name, system_name)
+ assert tool is not None
+
+ if exception is None:
+ test()
+ else:
+ with pytest.raises(exception):
+ test()
+
+
+class TestExecuteCommandsLocally:
+ """Test execute_commands_locally() function."""
+
+ @pytest.mark.parametrize(
+ "first_command, exception, expected_output",
+ (
+ (
+ "echo 'hello'",
+ None,
+ "Running: echo 'hello'\nhello\nRunning: echo 'goodbye'\ngoodbye\n",
+ ),
+ (
+ "non-existent-command",
+ CommandNotFound,
+ "Running: non-existent-command\n",
+ ),
+ ("false", CommandFailedException, "Running: false\n"),
+ ),
+ ids=(
+ "runs_multiple_commands",
+ "stops_executing_on_non_existent_command",
+ "stops_executing_when_command_exits_with_error_code",
+ ),
+ )
+ def test_execution(
+ self,
+ first_command: str,
+ exception: Any,
+ expected_output: str,
+ test_resources_path: Path,
+ capsys: Any,
+ ) -> None:
+ """Test expected behaviour of the function."""
+ commands = [first_command, "echo 'goodbye'"]
+ cwd = test_resources_path
+ if exception is None:
+ execute_commands_locally(commands, cwd)
+ else:
+ with pytest.raises(exception):
+ execute_commands_locally(commands, cwd)
+
+ captured = capsys.readouterr()
+ assert captured.out == expected_output
+
+ def test_stops_executing_on_exception(
+ self, monkeypatch: Any, test_resources_path: Path
+ ) -> None:
+ """Ensure commands following an error-exit-code command don't run."""
+ # Mock execute_command() function
+ execute_command_mock = mock.MagicMock()
+ monkeypatch.setattr("aiet.utils.proc.execute_command", execute_command_mock)
+
+ # Mock Command object and assign as return value to execute_command()
+ cmd_mock = mock.MagicMock()
+ execute_command_mock.return_value = cmd_mock
+
+ # Mock the terminate_command (speed up test)
+ terminate_command_mock = mock.MagicMock()
+ monkeypatch.setattr("aiet.utils.proc.terminate_command", terminate_command_mock)
+
+ # Mock a thrown Exception and assign to Command().exit_code
+ exit_code_mock = mock.PropertyMock(side_effect=Exception("Exception."))
+ type(cmd_mock).exit_code = exit_code_mock
+
+ with pytest.raises(Exception, match="Exception."):
+ execute_commands_locally(
+ ["command_1", "command_2"], cwd=test_resources_path
+ )
+
+ # Assert only "command_1" was executed
+ assert execute_command_mock.call_count == 1
+
+
+def test_reporter(tmpdir: Any) -> None:
+ """Test class 'Reporter'."""
+ ctx = ExecutionContext(
+ app=get_application("application_4")[0],
+ app_params=["--app=TestApp"],
+ system=get_system("System 4"),
+ system_params=[],
+ )
+ assert ctx.system is not None
+
+ class MockParser(OutputParser):
+ """Mock implementation of an output parser."""
+
+ def __init__(self, metrics: Dict[str, Any]) -> None:
+ """Set up the MockParser."""
+ super().__init__(name="test")
+ self.metrics = metrics
+
+ def __call__(self, output: bytearray) -> Dict[str, Any]:
+ """Return mock metrics (ignoring the given output)."""
+ return self.metrics
+
+ metrics = {"Metric": 123, "AnotherMetric": 456}
+ reporter = Reporter(
+ parsers=[MockParser(metrics={key: val}) for key, val in metrics.items()],
+ )
+ reporter.parse(bytearray())
+ report = reporter.report(ctx)
+ assert report["system"]["name"] == ctx.system.name
+ assert report["system"]["params"] == {}
+ assert report["application"]["name"] == ctx.app.name
+ assert report["application"]["params"] == {"--app": "TestApp"}
+ assert report["test"]["metrics"] == metrics
+ report_file = Path(tmpdir) / "report.json"
+ reporter.save(report, report_file)
+ assert report_file.is_file()
diff --git a/tests/aiet/test_backend_output_parser.py b/tests/aiet/test_backend_output_parser.py
new file mode 100644
index 0000000..d659812
--- /dev/null
+++ b/tests/aiet/test_backend_output_parser.py
@@ -0,0 +1,152 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the output parsing."""
+import base64
+import json
+from typing import Any
+from typing import Dict
+
+import pytest
+
+from aiet.backend.output_parser import Base64OutputParser
+from aiet.backend.output_parser import OutputParser
+from aiet.backend.output_parser import RegexOutputParser
+
+
+OUTPUT_MATCH_ALL = bytearray(
+ """
+String1: My awesome string!
+String2: STRINGS_ARE_GREAT!!!
+Int: 12
+Float: 3.14
+""",
+ encoding="utf-8",
+)
+
+OUTPUT_NO_MATCH = bytearray(
+ """
+This contains no matches...
+Test1234567890!"£$%^&*()_+@~{}[]/.,<>?|
+""",
+ encoding="utf-8",
+)
+
+OUTPUT_PARTIAL_MATCH = bytearray(
+ "String1: My awesome string!",
+ encoding="utf-8",
+)
+
+REGEX_CONFIG = {
+ "FirstString": {"pattern": r"String1.*: (.*)", "type": "str"},
+ "SecondString": {"pattern": r"String2.*: (.*)!!!", "type": "str"},
+ "IntegerValue": {"pattern": r"Int.*: (.*)", "type": "int"},
+ "FloatValue": {"pattern": r"Float.*: (.*)", "type": "float"},
+}
+
+EMPTY_REGEX_CONFIG: Dict[str, Dict[str, Any]] = {}
+
+EXPECTED_METRICS_ALL = {
+ "FirstString": "My awesome string!",
+ "SecondString": "STRINGS_ARE_GREAT",
+ "IntegerValue": 12,
+ "FloatValue": 3.14,
+}
+
+EXPECTED_METRICS_PARTIAL = {
+ "FirstString": "My awesome string!",
+}
+
+
+class TestRegexOutputParser:
+ """Collect tests for the RegexOutputParser."""
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ ["output", "config", "expected_metrics"],
+ [
+ (OUTPUT_MATCH_ALL, REGEX_CONFIG, EXPECTED_METRICS_ALL),
+ (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL),
+ (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH, REGEX_CONFIG, EXPECTED_METRICS_ALL),
+ (
+ OUTPUT_MATCH_ALL + OUTPUT_PARTIAL_MATCH,
+ REGEX_CONFIG,
+ EXPECTED_METRICS_ALL,
+ ),
+ (OUTPUT_NO_MATCH, REGEX_CONFIG, {}),
+ (OUTPUT_MATCH_ALL, EMPTY_REGEX_CONFIG, {}),
+ (bytearray(), EMPTY_REGEX_CONFIG, {}),
+ (bytearray(), REGEX_CONFIG, {}),
+ ],
+ )
+ def test_parsing(output: bytearray, config: Dict, expected_metrics: Dict) -> None:
+ """
+ Make sure the RegexOutputParser yields valid results.
+
+ I.e. return an empty dict if either the input or the config is empty and
+ return the parsed metrics otherwise.
+ """
+ parser = RegexOutputParser(name="Test", regex_config=config)
+ assert parser.name == "Test"
+ assert isinstance(parser, OutputParser)
+ res = parser(output)
+ assert res == expected_metrics
+
+ @staticmethod
+ def test_unsupported_type() -> None:
+ """An unsupported type in the regex_config must raise an exception."""
+ config = {"BrokenMetric": {"pattern": "(.*)", "type": "UNSUPPORTED_TYPE"}}
+ with pytest.raises(TypeError):
+ RegexOutputParser(name="Test", regex_config=config)
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "config",
+ (
+ {"TooManyGroups": {"pattern": r"(\w)(\d)", "type": "str"}},
+ {"NoGroups": {"pattern": r"\W", "type": "str"}},
+ ),
+ )
+ def test_invalid_pattern(config: Dict) -> None:
+ """Exactly one capturing parenthesis is allowed in the regex pattern."""
+ with pytest.raises(ValueError):
+ RegexOutputParser(name="Test", regex_config=config)
+
+
+@pytest.mark.parametrize(
+ "expected_metrics",
+ [
+ EXPECTED_METRICS_ALL,
+ EXPECTED_METRICS_PARTIAL,
+ ],
+)
+def test_base64_output_parser(expected_metrics: Dict) -> None:
+ """
+ Make sure the Base64OutputParser yields valid results.
+
+ I.e. return an empty dict if either the input or the config is empty and
+ return the parsed metrics otherwise.
+ """
+ parser = Base64OutputParser(name="Test")
+ assert parser.name == "Test"
+ assert isinstance(parser, OutputParser)
+
+ def create_base64_output(expected_metrics: Dict) -> bytearray:
+ json_str = json.dumps(expected_metrics, indent=4)
+ json_b64 = base64.b64encode(json_str.encode("utf-8"))
+ return (
+ OUTPUT_MATCH_ALL # Should not be matched by the Base64OutputParser
+ + f"<{Base64OutputParser.TAG_NAME}>".encode("utf-8")
+ + bytearray(json_b64)
+ + f"</{Base64OutputParser.TAG_NAME}>".encode("utf-8")
+ + OUTPUT_NO_MATCH # Just to add some difficulty...
+ )
+
+ output = create_base64_output(expected_metrics)
+ res = parser(output)
+ assert len(res) == 1
+ assert isinstance(res, dict)
+ for val in res.values():
+ assert val == expected_metrics
+
+ output = parser.filter_out_parsed_content(output)
+ assert output == (OUTPUT_MATCH_ALL + OUTPUT_NO_MATCH)
diff --git a/tests/aiet/test_backend_protocol.py b/tests/aiet/test_backend_protocol.py
new file mode 100644
index 0000000..2103238
--- /dev/null
+++ b/tests/aiet/test_backend_protocol.py
@@ -0,0 +1,231 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use,attribute-defined-outside-init,protected-access
+"""Tests for the protocol backend module."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from unittest.mock import MagicMock
+
+import paramiko
+import pytest
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.protocol import CustomSFTPClient
+from aiet.backend.protocol import LocalProtocol
+from aiet.backend.protocol import ProtocolFactory
+from aiet.backend.protocol import SSHProtocol
+
+
+class TestProtocolFactory:
+ """Test ProtocolFactory class."""
+
+ @pytest.mark.parametrize(
+ "config, expected_class, exception",
+ [
+ (
+ {
+ "protocol": "ssh",
+ "username": "user",
+ "password": "pass",
+ "hostname": "hostname",
+ "port": "22",
+ },
+ SSHProtocol,
+ does_not_raise(),
+ ),
+ ({"protocol": "local"}, LocalProtocol, does_not_raise()),
+ (
+ {"protocol": "something"},
+ None,
+ pytest.raises(Exception, match="Protocol not supported"),
+ ),
+ (None, None, pytest.raises(Exception, match="No protocol config provided")),
+ ],
+ )
+ def test_get_protocol(
+ self, config: Any, expected_class: type, exception: Any
+ ) -> None:
+ """Test get_protocol method."""
+ factory = ProtocolFactory()
+ with exception:
+ protocol = factory.get_protocol(config)
+ assert isinstance(protocol, expected_class)
+
+
+class TestLocalProtocol:
+ """Test local protocol."""
+
+ def test_local_protocol_run_command(self) -> None:
+ """Test local protocol run command."""
+ config = LocalProtocolConfig(protocol="local")
+ protocol = LocalProtocol(config, cwd=Path("/tmp"))
+ ret, stdout, stderr = protocol.run("pwd")
+ assert ret == 0
+ assert stdout.decode("utf-8").strip() == "/tmp"
+ assert stderr.decode("utf-8") == ""
+
+ def test_local_protocol_run_wrong_cwd(self) -> None:
+ """Execution should fail if wrong working directory provided."""
+ config = LocalProtocolConfig(protocol="local")
+ protocol = LocalProtocol(config, cwd=Path("unknown_directory"))
+ with pytest.raises(
+ ConfigurationException, match="Wrong working directory unknown_directory"
+ ):
+ protocol.run("pwd")
+
+
+class TestSSHProtocol:
+ """Test SSH protocol."""
+
+ @pytest.fixture(autouse=True)
+ def setup_method(self, monkeypatch: Any) -> None:
+ """Set up protocol mocks."""
+ self.mock_ssh_client = MagicMock(spec=paramiko.client.SSHClient)
+
+ self.mock_ssh_channel = (
+ self.mock_ssh_client.get_transport.return_value.open_session.return_value
+ )
+ self.mock_ssh_channel.mock_add_spec(spec=paramiko.channel.Channel)
+ self.mock_ssh_channel.exit_status_ready.side_effect = [False, True]
+ self.mock_ssh_channel.recv_exit_status.return_value = True
+ self.mock_ssh_channel.recv_ready.side_effect = [False, True]
+ self.mock_ssh_channel.recv_stderr_ready.side_effect = [False, True]
+
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.client.SSHClient",
+ MagicMock(return_value=self.mock_ssh_client),
+ )
+
+ self.mock_sftp_client = MagicMock(spec=CustomSFTPClient)
+ monkeypatch.setattr(
+ "aiet.backend.protocol.CustomSFTPClient.from_transport",
+ MagicMock(return_value=self.mock_sftp_client),
+ )
+
+ ssh_config = {
+ "protocol": "ssh",
+ "username": "user",
+ "password": "pass",
+ "hostname": "hostname",
+ "port": "22",
+ }
+ self.protocol = SSHProtocol(ssh_config)
+
+ def test_unable_create_ssh_client(self, monkeypatch: Any) -> None:
+ """Test that command should fail if unable to create ssh client instance."""
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.client.SSHClient",
+ MagicMock(side_effect=OSError("Error!")),
+ )
+
+ with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"):
+ self.protocol.run("command_example", retry=False)
+
+ def test_ssh_protocol_run_command(self) -> None:
+ """Test that command run via ssh successfully."""
+ self.protocol.run("command_example")
+ self.mock_ssh_channel.exec_command.assert_called_once()
+
+ def test_ssh_protocol_run_command_connect_failed(self) -> None:
+ """Test that if connection is not possible then correct exception is raised."""
+ self.mock_ssh_client.connect.side_effect = OSError("Unable to connect")
+ self.mock_ssh_client.close.side_effect = Exception("Error!")
+
+ with pytest.raises(Exception, match="Couldn't connect to 'hostname:22'"):
+ self.protocol.run("command_example", retry=False)
+
+ def test_ssh_protocol_run_command_bad_transport(self) -> None:
+ """Test that command should fail if unable to get transport."""
+ self.mock_ssh_client.get_transport.return_value = None
+
+ with pytest.raises(Exception, match="Unable to get transport"):
+ self.protocol.run("command_example", retry=False)
+
+ def test_ssh_protocol_deploy_command_file(
+ self, test_applications_path: Path
+ ) -> None:
+ """Test that files could be deployed over ssh."""
+ file_for_deploy = test_applications_path / "readme.txt"
+ dest = "/tmp/dest"
+
+ self.protocol.deploy(file_for_deploy, dest)
+ self.mock_sftp_client.put.assert_called_once_with(str(file_for_deploy), dest)
+
+ def test_ssh_protocol_deploy_command_unknown_file(self) -> None:
+ """Test that deploy will fail if file does not exist."""
+ with pytest.raises(Exception, match="Deploy error: file type not supported"):
+ self.protocol.deploy(Path("unknown_file"), "/tmp/dest")
+
+ def test_ssh_protocol_deploy_command_bad_transport(self) -> None:
+ """Test that deploy should fail if unable to get transport."""
+ self.mock_ssh_client.get_transport.return_value = None
+
+ with pytest.raises(Exception, match="Unable to get transport"):
+ self.protocol.deploy(Path("some_file"), "/tmp/dest")
+
+ def test_ssh_protocol_deploy_command_directory(
+ self, test_resources_path: Path
+ ) -> None:
+ """Test that directory could be deployed over ssh."""
+ directory_for_deploy = test_resources_path / "scripts"
+ dest = "/tmp/dest"
+
+ self.protocol.deploy(directory_for_deploy, dest)
+ self.mock_sftp_client.put_dir.assert_called_once_with(
+ directory_for_deploy, dest
+ )
+
+ @pytest.mark.parametrize("establish_connection", (True, False))
+ def test_ssh_protocol_close(self, establish_connection: bool) -> None:
+ """Test protocol close operation."""
+ if establish_connection:
+ self.protocol.establish_connection()
+ self.protocol.close()
+
+ call_count = 1 if establish_connection else 0
+ assert self.mock_ssh_channel.exec_command.call_count == call_count
+
+ def test_connection_details(self) -> None:
+ """Test getting connection details."""
+ assert self.protocol.connection_details() == ("hostname", 22)
+
+
+class TestCustomSFTPClient:
+ """Test CustomSFTPClient class."""
+
+ @pytest.fixture(autouse=True)
+ def setup_method(self, monkeypatch: Any) -> None:
+ """Set up mocks for CustomSFTPClient instance."""
+ self.mock_mkdir = MagicMock()
+ self.mock_put = MagicMock()
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.SFTPClient.__init__",
+ MagicMock(return_value=None),
+ )
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.SFTPClient.mkdir", self.mock_mkdir
+ )
+ monkeypatch.setattr(
+ "aiet.backend.protocol.paramiko.SFTPClient.put", self.mock_put
+ )
+
+ self.sftp_client = CustomSFTPClient(MagicMock())
+
+ def test_put_dir(self, test_systems_path: Path) -> None:
+ """Test deploying directory to remote host."""
+ directory_for_deploy = test_systems_path / "system1"
+
+ self.sftp_client.put_dir(directory_for_deploy, "/tmp/dest")
+ assert self.mock_put.call_count == 3
+ assert self.mock_mkdir.call_count == 3
+
+ def test_mkdir(self) -> None:
+ """Test creating directory on remote host."""
+ self.mock_mkdir.side_effect = IOError("Cannot create directory")
+
+ self.sftp_client._mkdir("new_directory", ignore_existing=True)
+
+ with pytest.raises(IOError, match="Cannot create directory"):
+ self.sftp_client._mkdir("new_directory", ignore_existing=False)
diff --git a/tests/aiet/test_backend_source.py b/tests/aiet/test_backend_source.py
new file mode 100644
index 0000000..13b2c6d
--- /dev/null
+++ b/tests/aiet/test_backend_source.py
@@ -0,0 +1,199 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use
+"""Tests for the source backend module."""
+from collections import Counter
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+import pytest
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.source import create_destination_and_install
+from aiet.backend.source import DirectorySource
+from aiet.backend.source import get_source
+from aiet.backend.source import TarArchiveSource
+
+
+def test_create_destination_and_install(test_systems_path: Path, tmpdir: Any) -> None:
+ """Test create_destination_and_install function."""
+ system_directory = test_systems_path / "system1"
+
+ dir_source = DirectorySource(system_directory)
+ resources = Path(tmpdir)
+ create_destination_and_install(dir_source, resources)
+ assert (resources / "system1").is_dir()
+
+
+@patch("aiet.backend.source.DirectorySource.create_destination", return_value=False)
+def test_create_destination_and_install_if_dest_creation_not_required(
+ mock_ds_create_destination: Any, tmpdir: Any
+) -> None:
+ """Test create_destination_and_install function."""
+ dir_source = DirectorySource(Path("unknown"))
+ resources = Path(tmpdir)
+ with pytest.raises(Exception):
+ create_destination_and_install(dir_source, resources)
+
+ mock_ds_create_destination.assert_called_once()
+
+
+def test_create_destination_and_install_if_installation_fails(tmpdir: Any) -> None:
+ """Test create_destination_and_install function if installation fails."""
+ dir_source = DirectorySource(Path("unknown"))
+ resources = Path(tmpdir)
+ with pytest.raises(Exception, match="Directory .* does not exist"):
+ create_destination_and_install(dir_source, resources)
+ assert not (resources / "unknown").exists()
+ assert resources.exists()
+
+
+def test_create_destination_and_install_if_name_is_empty() -> None:
+ """Test create_destination_and_install function fails if source name is empty."""
+ source = MagicMock()
+ source.create_destination.return_value = True
+ source.name.return_value = None
+
+ with pytest.raises(Exception, match="Unable to get source name"):
+ create_destination_and_install(source, Path("some_path"))
+
+ source.install_into.assert_not_called()
+
+
+@pytest.mark.parametrize(
+ "source_path, expected_class, expected_error",
+ [
+ (Path("applications/application1/"), DirectorySource, does_not_raise()),
+ (
+ Path("archives/applications/application1.tar.gz"),
+ TarArchiveSource,
+ does_not_raise(),
+ ),
+ (
+ Path("doesnt/exist"),
+ None,
+ pytest.raises(
+ ConfigurationException, match="Unable to read .*doesnt/exist"
+ ),
+ ),
+ ],
+)
+def test_get_source(
+ source_path: Path,
+ expected_class: Any,
+ expected_error: Any,
+ test_resources_path: Path,
+) -> None:
+ """Test get_source function."""
+ with expected_error:
+ full_source_path = test_resources_path / source_path
+ source = get_source(full_source_path)
+ assert isinstance(source, expected_class)
+
+
+class TestDirectorySource:
+ """Test DirectorySource class."""
+
+ @pytest.mark.parametrize(
+ "directory, name",
+ [
+ (Path("/some/path/some_system"), "some_system"),
+ (Path("some_system"), "some_system"),
+ ],
+ )
+ def test_name(self, directory: Path, name: str) -> None:
+ """Test getting source name."""
+ assert DirectorySource(directory).name() == name
+
+ def test_install_into(self, test_systems_path: Path, tmpdir: Any) -> None:
+ """Test install directory into destination."""
+ system_directory = test_systems_path / "system1"
+
+ dir_source = DirectorySource(system_directory)
+ with pytest.raises(Exception, match="Wrong destination .*"):
+ dir_source.install_into(Path("unknown_destination"))
+
+ tmpdir_path = Path(tmpdir)
+ dir_source.install_into(tmpdir_path)
+ source_files = [f.name for f in system_directory.iterdir()]
+ dest_files = [f.name for f in tmpdir_path.iterdir()]
+ assert Counter(source_files) == Counter(dest_files)
+
+ def test_install_into_unknown_source_directory(self, tmpdir: Any) -> None:
+ """Test install system from unknown directory."""
+ with pytest.raises(Exception, match="Directory .* does not exist"):
+ DirectorySource(Path("unknown_directory")).install_into(Path(tmpdir))
+
+
+class TestTarArchiveSource:
+ """Test TarArchiveSource class."""
+
+ @pytest.mark.parametrize(
+ "archive, name",
+ [
+ (Path("some_archive.tgz"), "some_archive"),
+ (Path("some_archive.tar.gz"), "some_archive"),
+ (Path("some_archive"), "some_archive"),
+ ("archives/systems/system1.tar.gz", "system1"),
+ ("archives/systems/system1_dir.tar.gz", "system1"),
+ ],
+ )
+ def test_name(self, test_resources_path: Path, archive: Path, name: str) -> None:
+ """Test getting source name."""
+ assert TarArchiveSource(test_resources_path / archive).name() == name
+
+ def test_install_into(self, test_resources_path: Path, tmpdir: Any) -> None:
+ """Test install archive into destination."""
+ system_archive = test_resources_path / "archives/systems/system1.tar.gz"
+
+ tar_source = TarArchiveSource(system_archive)
+ with pytest.raises(Exception, match="Wrong destination .*"):
+ tar_source.install_into(Path("unknown_destination"))
+
+ tmpdir_path = Path(tmpdir)
+ tar_source.install_into(tmpdir_path)
+ source_files = [
+ "aiet-config.json.license",
+ "aiet-config.json",
+ "system_artifact",
+ ]
+ dest_files = [f.name for f in tmpdir_path.iterdir()]
+ assert Counter(source_files) == Counter(dest_files)
+
+ def test_install_into_unknown_source_archive(self, tmpdir: Any) -> None:
+ """Test install unknown source archive."""
+ with pytest.raises(Exception, match="File .* does not exist"):
+ TarArchiveSource(Path("unknown.tar.gz")).install_into(Path(tmpdir))
+
+ def test_install_into_unsupported_source_archive(self, tmpdir: Any) -> None:
+ """Test install unsupported file type."""
+ plain_text_file = Path(tmpdir) / "test_file"
+ plain_text_file.write_text("Not a system config")
+
+ with pytest.raises(Exception, match="Unsupported archive type .*"):
+ TarArchiveSource(plain_text_file).install_into(Path(tmpdir))
+
+ def test_lazy_property_init(self, test_resources_path: Path) -> None:
+ """Test that class properties initialized correctly."""
+ system_archive = test_resources_path / "archives/systems/system1.tar.gz"
+
+ tar_source = TarArchiveSource(system_archive)
+ assert tar_source.name() == "system1"
+ assert tar_source.config() is not None
+ assert tar_source.create_destination()
+
+ tar_source = TarArchiveSource(system_archive)
+ assert tar_source.config() is not None
+ assert tar_source.create_destination()
+ assert tar_source.name() == "system1"
+
+ def test_create_destination_property(self, test_resources_path: Path) -> None:
+ """Test create_destination property filled correctly for different archives."""
+ system_archive1 = test_resources_path / "archives/systems/system1.tar.gz"
+ system_archive2 = test_resources_path / "archives/systems/system1_dir.tar.gz"
+
+ assert TarArchiveSource(system_archive1).create_destination()
+ assert not TarArchiveSource(system_archive2).create_destination()
diff --git a/tests/aiet/test_backend_system.py b/tests/aiet/test_backend_system.py
new file mode 100644
index 0000000..a581547
--- /dev/null
+++ b/tests/aiet/test_backend_system.py
@@ -0,0 +1,536 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for system backend."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from unittest.mock import MagicMock
+
+import pytest
+
+from aiet.backend.common import Command
+from aiet.backend.common import ConfigurationException
+from aiet.backend.common import Param
+from aiet.backend.common import UserParamConfig
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.config import ProtocolConfig
+from aiet.backend.config import SSHConfig
+from aiet.backend.config import SystemConfig
+from aiet.backend.controller import SystemController
+from aiet.backend.controller import SystemControllerSingleInstance
+from aiet.backend.protocol import LocalProtocol
+from aiet.backend.protocol import SSHProtocol
+from aiet.backend.protocol import SupportsClose
+from aiet.backend.protocol import SupportsDeploy
+from aiet.backend.system import ControlledSystem
+from aiet.backend.system import get_available_systems
+from aiet.backend.system import get_controller
+from aiet.backend.system import get_system
+from aiet.backend.system import install_system
+from aiet.backend.system import load_system
+from aiet.backend.system import remove_system
+from aiet.backend.system import StandaloneSystem
+from aiet.backend.system import System
+
+
+def dummy_resolver(
+ values: Optional[Dict[str, str]] = None
+) -> Callable[[str, str, List[Tuple[Optional[str], Param]]], str]:
+ """Return dummy parameter resolver implementation."""
+ # pylint: disable=unused-argument
+ def resolver(
+ param: str, cmd: str, param_values: List[Tuple[Optional[str], Param]]
+ ) -> str:
+ """Implement dummy parameter resolver."""
+ return values.get(param, "") if values else ""
+
+ return resolver
+
+
+def test_get_available_systems() -> None:
+ """Test get_available_systems mocking get_resources."""
+ available_systems = get_available_systems()
+ assert all(isinstance(s, System) for s in available_systems)
+ assert len(available_systems) == 3
+ assert [str(s) for s in available_systems] == ["System 1", "System 2", "System 4"]
+
+
+def test_get_system() -> None:
+ """Test get_system."""
+ system1 = get_system("System 1")
+ assert isinstance(system1, ControlledSystem)
+ assert system1.connectable is True
+ assert system1.connection_details() == ("localhost", 8021)
+ assert system1.name == "System 1"
+
+ system2 = get_system("System 2")
+ # check that comparison with object of another type returns false
+ assert system1 != 42
+ assert system1 != system2
+
+ system = get_system("Unknown system")
+ assert system is None
+
+
+@pytest.mark.parametrize(
+ "source, call_count, exception_type",
+ (
+ (
+ "archives/systems/system1.tar.gz",
+ 0,
+ pytest.raises(Exception, match="Systems .* are already installed"),
+ ),
+ (
+ "archives/systems/system3.tar.gz",
+ 0,
+ pytest.raises(Exception, match="Unable to read system definition"),
+ ),
+ (
+ "systems/system1",
+ 0,
+ pytest.raises(Exception, match="Systems .* are already installed"),
+ ),
+ (
+ "systems/system3",
+ 0,
+ pytest.raises(Exception, match="Unable to read system definition"),
+ ),
+ ("unknown_path", 0, pytest.raises(Exception, match="Unable to read")),
+ (
+ "various/systems/system_with_empty_config",
+ 0,
+ pytest.raises(Exception, match="No system definition found"),
+ ),
+ ("various/systems/system_with_valid_config", 1, does_not_raise()),
+ ),
+)
+def test_install_system(
+ monkeypatch: Any,
+ test_resources_path: Path,
+ source: str,
+ call_count: int,
+ exception_type: Any,
+) -> None:
+ """Test system installation from archive."""
+ mock_create_destination_and_install = MagicMock()
+ monkeypatch.setattr(
+ "aiet.backend.system.create_destination_and_install",
+ mock_create_destination_and_install,
+ )
+
+ with exception_type:
+ install_system(test_resources_path / source)
+
+ assert mock_create_destination_and_install.call_count == call_count
+
+
+def test_remove_system(monkeypatch: Any) -> None:
+ """Test system removal."""
+ mock_remove_backend = MagicMock()
+ monkeypatch.setattr("aiet.backend.system.remove_backend", mock_remove_backend)
+ remove_system("some_system_dir")
+ mock_remove_backend.assert_called_once()
+
+
+def test_system(monkeypatch: Any) -> None:
+ """Test the System class."""
+ config = SystemConfig(name="System 1")
+ monkeypatch.setattr("aiet.backend.system.ProtocolFactory", MagicMock())
+ system = System(config)
+ assert str(system) == "System 1"
+ assert system.name == "System 1"
+
+
+def test_system_with_empty_parameter_name() -> None:
+ """Test that configuration fails if parameter name is empty."""
+ bad_config = SystemConfig(
+ name="System 1",
+ commands={"run": ["run"]},
+ user_params={"run": [{"name": "", "values": ["1", "2", "3"]}]},
+ )
+ with pytest.raises(Exception, match="Parameter has an empty 'name' attribute."):
+ System(bad_config)
+
+
+def test_system_standalone_run() -> None:
+ """Test run operation for standalone system."""
+ system = get_system("System 4")
+ assert isinstance(system, StandaloneSystem)
+
+ with pytest.raises(
+ ConfigurationException, match="System .* does not support connections"
+ ):
+ system.connection_details()
+
+ with pytest.raises(
+ ConfigurationException, match="System .* does not support connections"
+ ):
+ system.establish_connection()
+
+ assert system.connectable is False
+
+ system.run("echo 'application run'")
+
+
+@pytest.mark.parametrize(
+ "system_name, expected_value", [("System 1", True), ("System 4", False)]
+)
+def test_system_supports_deploy(system_name: str, expected_value: bool) -> None:
+ """Test system property supports_deploy."""
+ system = get_system(system_name)
+ if system is None:
+ pytest.fail("Unable to get system {}".format(system_name))
+ assert system.supports_deploy == expected_value
+
+
+@pytest.mark.parametrize(
+ "mock_protocol",
+ [
+ MagicMock(spec=SSHProtocol),
+ MagicMock(
+ spec=SSHProtocol,
+ **{"close.side_effect": ValueError("Unable to close protocol")}
+ ),
+ MagicMock(spec=LocalProtocol),
+ ],
+)
+def test_system_start_and_stop(monkeypatch: Any, mock_protocol: MagicMock) -> None:
+ """Test system start, run commands and stop."""
+ monkeypatch.setattr(
+ "aiet.backend.system.ProtocolFactory.get_protocol",
+ MagicMock(return_value=mock_protocol),
+ )
+
+ system = get_system("System 1")
+ if system is None:
+ pytest.fail("Unable to get system")
+ assert isinstance(system, ControlledSystem)
+
+ with pytest.raises(Exception, match="System has not been started"):
+ system.stop()
+
+ assert not system.is_running()
+ assert system.get_output() == ("", "")
+ system.start(["sleep 10"], False)
+ assert system.is_running()
+ system.stop(wait=True)
+ assert not system.is_running()
+ assert system.get_output() == ("", "")
+
+ if isinstance(mock_protocol, SupportsClose):
+ mock_protocol.close.assert_called_once()
+
+ if isinstance(mock_protocol, SSHProtocol):
+ system.establish_connection()
+
+
+def test_system_start_no_config_location() -> None:
+ """Test that system without config location could not start."""
+ system = load_system(
+ SystemConfig(
+ name="test",
+ data_transfer=SSHConfig(
+ protocol="ssh",
+ username="user",
+ password="user",
+ hostname="localhost",
+ port="123",
+ ),
+ )
+ )
+
+ assert isinstance(system, ControlledSystem)
+ with pytest.raises(
+ ConfigurationException, match="System test has wrong config location"
+ ):
+ system.start(["sleep 100"])
+
+
+@pytest.mark.parametrize(
+ "config, expected_class, expected_error",
+ [
+ (
+ SystemConfig(
+ name="test",
+ data_transfer=SSHConfig(
+ protocol="ssh",
+ username="user",
+ password="user",
+ hostname="localhost",
+ port="123",
+ ),
+ ),
+ ControlledSystem,
+ does_not_raise(),
+ ),
+ (
+ SystemConfig(
+ name="test", data_transfer=LocalProtocolConfig(protocol="local")
+ ),
+ StandaloneSystem,
+ does_not_raise(),
+ ),
+ (
+ SystemConfig(
+ name="test",
+ data_transfer=ProtocolConfig(protocol="cool_protocol"), # type: ignore
+ ),
+ None,
+ pytest.raises(
+ Exception, match="Unsupported execution type for protocol cool_protocol"
+ ),
+ ),
+ ],
+)
+def test_load_system(
+ config: SystemConfig, expected_class: type, expected_error: Any
+) -> None:
+ """Test load_system function."""
+ if not expected_class:
+ with expected_error:
+ load_system(config)
+ else:
+ system = load_system(config)
+ assert isinstance(system, expected_class)
+
+
+def test_load_system_populate_shared_params() -> None:
+ """Test shared parameters population."""
+ with pytest.raises(Exception, match="All shared parameters should have aliases"):
+ load_system(
+ SystemConfig(
+ name="test_system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ user_params={
+ "shared": [
+ UserParamConfig(
+ name="--shared_param1",
+ description="Shared parameter",
+ values=["1", "2", "3"],
+ default_value="1",
+ )
+ ]
+ },
+ )
+ )
+
+ with pytest.raises(
+ Exception, match="All parameters for command run should have aliases"
+ ):
+ load_system(
+ SystemConfig(
+ name="test_system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ user_params={
+ "shared": [
+ UserParamConfig(
+ name="--shared_param1",
+ description="Shared parameter",
+ values=["1", "2", "3"],
+ default_value="1",
+ alias="shared_param1",
+ )
+ ],
+ "run": [
+ UserParamConfig(
+ name="--run_param1",
+ description="Run specific parameter",
+ values=["1", "2", "3"],
+ default_value="2",
+ )
+ ],
+ },
+ )
+ )
+ system0 = load_system(
+ SystemConfig(
+ name="test_system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["run_command"]},
+ user_params={
+ "shared": [],
+ "run": [
+ UserParamConfig(
+ name="--run_param1",
+ description="Run specific parameter",
+ values=["1", "2", "3"],
+ default_value="2",
+ alias="run_param1",
+ )
+ ],
+ },
+ )
+ )
+ assert len(system0.commands) == 1
+ run_command1 = system0.commands["run"]
+ assert run_command1 == Command(
+ ["run_command"],
+ [
+ Param(
+ "--run_param1",
+ "Run specific parameter",
+ ["1", "2", "3"],
+ "2",
+ "run_param1",
+ )
+ ],
+ )
+
+ system1 = load_system(
+ SystemConfig(
+ name="test_system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ user_params={
+ "shared": [
+ UserParamConfig(
+ name="--shared_param1",
+ description="Shared parameter",
+ values=["1", "2", "3"],
+ default_value="1",
+ alias="shared_param1",
+ )
+ ],
+ "run": [
+ UserParamConfig(
+ name="--run_param1",
+ description="Run specific parameter",
+ values=["1", "2", "3"],
+ default_value="2",
+ alias="run_param1",
+ )
+ ],
+ },
+ )
+ )
+ assert len(system1.commands) == 2
+ build_command1 = system1.commands["build"]
+ assert build_command1 == Command(
+ [],
+ [
+ Param(
+ "--shared_param1",
+ "Shared parameter",
+ ["1", "2", "3"],
+ "1",
+ "shared_param1",
+ )
+ ],
+ )
+
+ run_command1 = system1.commands["run"]
+ assert run_command1 == Command(
+ [],
+ [
+ Param(
+ "--shared_param1",
+ "Shared parameter",
+ ["1", "2", "3"],
+ "1",
+ "shared_param1",
+ ),
+ Param(
+ "--run_param1",
+ "Run specific parameter",
+ ["1", "2", "3"],
+ "2",
+ "run_param1",
+ ),
+ ],
+ )
+
+ system2 = load_system(
+ SystemConfig(
+ name="test_system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"build": ["build_command"]},
+ user_params={
+ "shared": [
+ UserParamConfig(
+ name="--shared_param1",
+ description="Shared parameter",
+ values=["1", "2", "3"],
+ default_value="1",
+ alias="shared_param1",
+ )
+ ],
+ "run": [
+ UserParamConfig(
+ name="--run_param1",
+ description="Run specific parameter",
+ values=["1", "2", "3"],
+ default_value="2",
+ alias="run_param1",
+ )
+ ],
+ },
+ )
+ )
+ assert len(system2.commands) == 2
+ build_command2 = system2.commands["build"]
+ assert build_command2 == Command(
+ ["build_command"],
+ [
+ Param(
+ "--shared_param1",
+ "Shared parameter",
+ ["1", "2", "3"],
+ "1",
+ "shared_param1",
+ )
+ ],
+ )
+
+ run_command2 = system1.commands["run"]
+ assert run_command2 == Command(
+ [],
+ [
+ Param(
+ "--shared_param1",
+ "Shared parameter",
+ ["1", "2", "3"],
+ "1",
+ "shared_param1",
+ ),
+ Param(
+ "--run_param1",
+ "Run specific parameter",
+ ["1", "2", "3"],
+ "2",
+ "run_param1",
+ ),
+ ],
+ )
+
+
+@pytest.mark.parametrize(
+ "mock_protocol, expected_call_count",
+ [(MagicMock(spec=SupportsDeploy), 1), (MagicMock(), 0)],
+)
+def test_system_deploy_data(
+ monkeypatch: Any, mock_protocol: MagicMock, expected_call_count: int
+) -> None:
+ """Test deploy data functionality."""
+ monkeypatch.setattr(
+ "aiet.backend.system.ProtocolFactory.get_protocol",
+ MagicMock(return_value=mock_protocol),
+ )
+
+ system = ControlledSystem(SystemConfig(name="test"))
+ system.deploy(Path("some_file"), "some_dest")
+
+ assert mock_protocol.deploy.call_count == expected_call_count
+
+
+@pytest.mark.parametrize(
+ "single_instance, controller_class",
+ ((False, SystemController), (True, SystemControllerSingleInstance)),
+)
+def test_get_controller(single_instance: bool, controller_class: type) -> None:
+ """Test function get_controller."""
+ controller = get_controller(single_instance)
+ assert isinstance(controller, controller_class)
diff --git a/tests/aiet/test_backend_tool.py b/tests/aiet/test_backend_tool.py
new file mode 100644
index 0000000..fd5960d
--- /dev/null
+++ b/tests/aiet/test_backend_tool.py
@@ -0,0 +1,60 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use
+"""Tests for the tool backend."""
+from collections import Counter
+
+import pytest
+
+from aiet.backend.common import ConfigurationException
+from aiet.backend.config import ToolConfig
+from aiet.backend.tool import get_available_tool_directory_names
+from aiet.backend.tool import get_available_tools
+from aiet.backend.tool import get_tool
+from aiet.backend.tool import Tool
+
+
+def test_get_available_tool_directory_names() -> None:
+ """Test get_available_tools mocking get_resources."""
+ directory_names = get_available_tool_directory_names()
+ assert Counter(directory_names) == Counter(["tool1", "tool2", "vela"])
+
+
+def test_get_available_tools() -> None:
+ """Test get_available_tools mocking get_resources."""
+ available_tools = get_available_tools()
+ expected_tool_names = sorted(
+ [
+ "tool_1",
+ "tool_2",
+ "vela",
+ "vela",
+ "vela",
+ ]
+ )
+
+ assert all(isinstance(s, Tool) for s in available_tools)
+ assert all(s != 42 for s in available_tools)
+ assert any(s == available_tools[0] for s in available_tools)
+ assert len(available_tools) == len(expected_tool_names)
+ available_tool_names = sorted(str(s) for s in available_tools)
+ assert available_tool_names == expected_tool_names
+
+
+def test_get_tool() -> None:
+ """Test get_tool mocking get_resoures."""
+ tools = get_tool("tool_1")
+ assert len(tools) == 1
+ tool = tools[0]
+ assert tool is not None
+ assert isinstance(tool, Tool)
+ assert tool.name == "tool_1"
+
+ tools = get_tool("unknown tool")
+ assert not tools
+
+
+def test_tool_creation() -> None:
+ """Test edge cases when creating a Tool instance."""
+ with pytest.raises(ConfigurationException):
+ Tool(ToolConfig(name="test", commands={"test": []})) # no 'run' command
diff --git a/tests/aiet/test_check_model.py b/tests/aiet/test_check_model.py
new file mode 100644
index 0000000..4eafe59
--- /dev/null
+++ b/tests/aiet/test_check_model.py
@@ -0,0 +1,162 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=redefined-outer-name,no-self-use
+"""Module for testing check_model.py script."""
+from pathlib import Path
+from typing import Any
+
+import pytest
+from ethosu.vela.tflite.Model import Model
+from ethosu.vela.tflite.OperatorCode import OperatorCode
+
+from aiet.cli.common import InvalidTFLiteFileError
+from aiet.cli.common import ModelOptimisedException
+from aiet.resources.tools.vela.check_model import check_custom_codes_for_ethosu
+from aiet.resources.tools.vela.check_model import check_model
+from aiet.resources.tools.vela.check_model import get_custom_codes_from_operators
+from aiet.resources.tools.vela.check_model import get_model_from_file
+from aiet.resources.tools.vela.check_model import get_operators_from_model
+from aiet.resources.tools.vela.check_model import is_vela_optimised
+
+
+@pytest.fixture(scope="session")
+def optimised_tflite_model(
+ optimised_input_model_file: Path,
+) -> Model:
+ """Return Model instance read from a Vela-optimised TFLite file."""
+ return get_model_from_file(optimised_input_model_file)
+
+
+@pytest.fixture(scope="session")
+def non_optimised_tflite_model(
+ non_optimised_input_model_file: Path,
+) -> Model:
+ """Return Model instance read from a Vela-optimised TFLite file."""
+ return get_model_from_file(non_optimised_input_model_file)
+
+
+class TestIsVelaOptimised:
+ """Test class for is_vela_optimised() function."""
+
+ def test_return_true_when_input_is_optimised(
+ self,
+ optimised_tflite_model: Model,
+ ) -> None:
+ """Verify True returned when input is optimised model."""
+ output = is_vela_optimised(optimised_tflite_model)
+
+ assert output is True
+
+ def test_return_false_when_input_is_not_optimised(
+ self,
+ non_optimised_tflite_model: Model,
+ ) -> None:
+ """Verify False returned when input is non-optimised model."""
+ output = is_vela_optimised(non_optimised_tflite_model)
+
+ assert output is False
+
+
+def test_get_operator_list_returns_correct_instances(
+ optimised_tflite_model: Model,
+) -> None:
+ """Verify list of OperatorCode instances returned by get_operator_list()."""
+ operator_list = get_operators_from_model(optimised_tflite_model)
+
+ assert all(isinstance(operator, OperatorCode) for operator in operator_list)
+
+
+class TestGetCustomCodesFromOperators:
+ """Test the get_custom_codes_from_operators() function."""
+
+ def test_returns_empty_list_when_input_operators_have_no_custom_codes(
+ self, monkeypatch: Any
+ ) -> None:
+ """Verify function returns empty list when operators have no custom codes."""
+ # Mock OperatorCode.CustomCode() function to return None
+ monkeypatch.setattr(
+ "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode", lambda _: None
+ )
+
+ operators = [OperatorCode()] * 3
+
+ custom_codes = get_custom_codes_from_operators(operators)
+
+ assert custom_codes == []
+
+ def test_returns_custom_codes_when_input_operators_have_custom_codes(
+ self, monkeypatch: Any
+ ) -> None:
+ """Verify list of bytes objects returned representing the CustomCodes."""
+ # Mock OperatorCode.CustomCode() function to return a byte string
+ monkeypatch.setattr(
+ "ethosu.vela.tflite.OperatorCode.OperatorCode.CustomCode",
+ lambda _: b"custom-code",
+ )
+
+ operators = [OperatorCode()] * 3
+
+ custom_codes = get_custom_codes_from_operators(operators)
+
+ assert custom_codes == [b"custom-code", b"custom-code", b"custom-code"]
+
+
+@pytest.mark.parametrize(
+ "custom_codes, expected_output",
+ [
+ ([b"ethos-u", b"something else"], True),
+ ([b"custom-code-1", b"custom-code-2"], False),
+ ],
+)
+def test_check_list_for_ethosu(custom_codes: list, expected_output: bool) -> None:
+ """Verify function detects 'ethos-u' bytes in the input list."""
+ output = check_custom_codes_for_ethosu(custom_codes)
+ assert output is expected_output
+
+
+class TestGetModelFromFile:
+ """Test the get_model_from_file() function."""
+
+ def test_error_raised_when_input_is_invalid_model_file(
+ self,
+ invalid_input_model_file: Path,
+ ) -> None:
+ """Verify error thrown when an invalid model file is given."""
+ with pytest.raises(InvalidTFLiteFileError):
+ get_model_from_file(invalid_input_model_file)
+
+ def test_model_instance_returned_when_input_is_valid_model_file(
+ self,
+ optimised_input_model_file: Path,
+ ) -> None:
+ """Verify file is read successfully and returns model instance."""
+ tflite_model = get_model_from_file(optimised_input_model_file)
+
+ assert isinstance(tflite_model, Model)
+
+
+class TestCheckModel:
+ """Test the check_model() function."""
+
+ def test_check_model_with_non_optimised_input(
+ self,
+ non_optimised_input_model_file: Path,
+ ) -> None:
+ """Verify no error occurs for a valid input file."""
+ check_model(non_optimised_input_model_file)
+
+ def test_check_model_with_optimised_input(
+ self,
+ optimised_input_model_file: Path,
+ ) -> None:
+ """Verify that the right exception is raised with already optimised input."""
+ with pytest.raises(ModelOptimisedException):
+ check_model(optimised_input_model_file)
+
+ def test_check_model_with_invalid_input(
+ self,
+ invalid_input_model_file: Path,
+ ) -> None:
+ """Verify that an exception is raised with invalid input."""
+ with pytest.raises(Exception):
+ check_model(invalid_input_model_file)
diff --git a/tests/aiet/test_cli.py b/tests/aiet/test_cli.py
new file mode 100644
index 0000000..e8589fa
--- /dev/null
+++ b/tests/aiet/test_cli.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for testing CLI top command."""
+from typing import Any
+from unittest.mock import ANY
+from unittest.mock import MagicMock
+
+from click.testing import CliRunner
+
+from aiet.cli import cli
+
+
+def test_cli(cli_runner: CliRunner) -> None:
+ """Test CLI top level command."""
+ result = cli_runner.invoke(cli)
+ assert result.exit_code == 0
+ assert "system" in cli.commands
+ assert "application" in cli.commands
+
+
+def test_cli_version(cli_runner: CliRunner) -> None:
+ """Test version option."""
+ result = cli_runner.invoke(cli, ["--version"])
+ assert result.exit_code == 0
+ assert "version" in result.output
+
+
+def test_cli_verbose(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test verbose option."""
+ with monkeypatch.context() as mock_context:
+ mock = MagicMock()
+ # params[1] is the verbose option and we need to replace the
+ # callback with a mock object
+ mock_context.setattr(cli.params[1], "callback", mock)
+ cli_runner.invoke(cli, ["-vvvv"])
+ # 4 is the number -v called earlier
+ mock.assert_called_once_with(ANY, ANY, 4)
diff --git a/tests/aiet/test_cli_application.py b/tests/aiet/test_cli_application.py
new file mode 100644
index 0000000..f1ccc44
--- /dev/null
+++ b/tests/aiet/test_cli_application.py
@@ -0,0 +1,1153 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=attribute-defined-outside-init,no-member,line-too-long,too-many-arguments,too-many-locals,redefined-outer-name,too-many-lines
+"""Module for testing CLI application subcommand."""
+import base64
+import json
+import re
+import time
+from contextlib import contextmanager
+from contextlib import ExitStack
+from pathlib import Path
+from typing import Any
+from typing import Generator
+from typing import IO
+from typing import List
+from typing import Optional
+from typing import TypedDict
+from unittest.mock import MagicMock
+
+import click
+import pytest
+from click.testing import CliRunner
+from filelock import FileLock
+
+from aiet.backend.application import Application
+from aiet.backend.config import ApplicationConfig
+from aiet.backend.config import LocalProtocolConfig
+from aiet.backend.config import SSHConfig
+from aiet.backend.config import SystemConfig
+from aiet.backend.config import UserParamConfig
+from aiet.backend.output_parser import Base64OutputParser
+from aiet.backend.protocol import SSHProtocol
+from aiet.backend.system import load_system
+from aiet.cli.application import application_cmd
+from aiet.cli.application import details_cmd
+from aiet.cli.application import execute_cmd
+from aiet.cli.application import install_cmd
+from aiet.cli.application import list_cmd
+from aiet.cli.application import parse_payload_run_config
+from aiet.cli.application import remove_cmd
+from aiet.cli.application import run_cmd
+from aiet.cli.common import MiddlewareExitCode
+
+
+def test_application_cmd() -> None:
+ """Test application commands."""
+ commands = ["list", "details", "install", "remove", "execute", "run"]
+ assert all(command in application_cmd.commands for command in commands)
+
+
+@pytest.mark.parametrize("format_", ["json", "cli"])
+def test_application_cmd_context(cli_runner: CliRunner, format_: str) -> None:
+ """Test setting command context parameters."""
+ result = cli_runner.invoke(application_cmd, ["--format", format_])
+ # command should fail if no subcommand provided
+ assert result.exit_code == 2
+
+ result = cli_runner.invoke(application_cmd, ["--format", format_, "list"])
+ assert result.exit_code == 0
+
+
+@pytest.mark.parametrize(
+ "format_, system_name, expected_output",
+ [
+ (
+ "json",
+ None,
+ '{"type": "application", "available": ["application_1", "application_2"]}\n',
+ ),
+ (
+ "json",
+ "system_1",
+ '{"type": "application", "available": ["application_1"]}\n',
+ ),
+ ("cli", None, "Available applications:\n\napplication_1\napplication_2\n"),
+ ("cli", "system_1", "Available applications:\n\napplication_1\n"),
+ ],
+)
+def test_list_cmd(
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ format_: str,
+ system_name: str,
+ expected_output: str,
+) -> None:
+ """Test available applications commands."""
+ # Mock some applications
+ mock_application_1 = MagicMock(spec=Application)
+ mock_application_1.name = "application_1"
+ mock_application_1.can_run_on.return_value = system_name == "system_1"
+ mock_application_2 = MagicMock(spec=Application)
+ mock_application_2.name = "application_2"
+ mock_application_2.can_run_on.return_value = system_name == "system_2"
+
+ # Monkey patch the call get_available_applications
+ mock_available_applications = MagicMock()
+ mock_available_applications.return_value = [mock_application_1, mock_application_2]
+
+ monkeypatch.setattr(
+ "aiet.backend.application.get_available_applications",
+ mock_available_applications,
+ )
+
+ obj = {"format": format_}
+ args = []
+ if system_name:
+ list_cmd.params[0].type = click.Choice([system_name])
+ args = ["--system", system_name]
+ result = cli_runner.invoke(list_cmd, obj=obj, args=args)
+ assert result.output == expected_output
+
+
+def get_test_application() -> Application:
+ """Return test system details."""
+ config = ApplicationConfig(
+ name="application",
+ description="test",
+ build_dir="",
+ supported_systems=[],
+ deploy_data=[],
+ user_params={},
+ commands={
+ "clean": ["clean"],
+ "build": ["build"],
+ "run": ["run"],
+ "post_run": ["post_run"],
+ },
+ )
+
+ return Application(config)
+
+
+def get_details_cmd_json_output() -> str:
+ """Get JSON output for details command."""
+ json_output = """
+[
+ {
+ "type": "application",
+ "name": "application",
+ "description": "test",
+ "supported_systems": [],
+ "commands": {
+ "clean": {
+ "command_strings": [
+ "clean"
+ ],
+ "user_params": []
+ },
+ "build": {
+ "command_strings": [
+ "build"
+ ],
+ "user_params": []
+ },
+ "run": {
+ "command_strings": [
+ "run"
+ ],
+ "user_params": []
+ },
+ "post_run": {
+ "command_strings": [
+ "post_run"
+ ],
+ "user_params": []
+ }
+ }
+ }
+]"""
+ return json.dumps(json.loads(json_output)) + "\n"
+
+
+def get_details_cmd_console_output() -> str:
+ """Get console output for details command."""
+ return (
+ 'Application "application" details'
+ + "\nDescription: test"
+ + "\n\nSupported systems: "
+ + "\n\nclean commands:"
+ + "\nCommands: ['clean']"
+ + "\n\nbuild commands:"
+ + "\nCommands: ['build']"
+ + "\n\nrun commands:"
+ + "\nCommands: ['run']"
+ + "\n\npost_run commands:"
+ + "\nCommands: ['post_run']"
+ + "\n"
+ )
+
+
+@pytest.mark.parametrize(
+ "application_name,format_, expected_output",
+ [
+ ("application", "json", get_details_cmd_json_output()),
+ ("application", "cli", get_details_cmd_console_output()),
+ ],
+)
+def test_details_cmd(
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ application_name: str,
+ format_: str,
+ expected_output: str,
+) -> None:
+ """Test application details command."""
+ monkeypatch.setattr(
+ "aiet.cli.application.get_application",
+ MagicMock(return_value=[get_test_application()]),
+ )
+
+ details_cmd.params[0].type = click.Choice(["application"])
+ result = cli_runner.invoke(
+ details_cmd, obj={"format": format_}, args=["--name", application_name]
+ )
+ assert result.exception is None
+ assert result.output == expected_output
+
+
+def test_details_cmd_wrong_system(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test details command fails if application is not supported by the system."""
+ monkeypatch.setattr(
+ "aiet.backend.execution.get_application", MagicMock(return_value=[])
+ )
+
+ details_cmd.params[0].type = click.Choice(["application"])
+ details_cmd.params[1].type = click.Choice(["system"])
+ result = cli_runner.invoke(
+ details_cmd, args=["--name", "application", "--system", "system"]
+ )
+ assert result.exit_code == 2
+ assert (
+ "Application 'application' doesn't support the system 'system'" in result.stdout
+ )
+
+
+def test_install_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test install application command."""
+ mock_install_application = MagicMock()
+ monkeypatch.setattr(
+ "aiet.cli.application.install_application", mock_install_application
+ )
+
+ args = ["--source", "test"]
+ cli_runner.invoke(install_cmd, args=args)
+ mock_install_application.assert_called_once_with(Path("test"))
+
+
+def test_remove_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test remove application command."""
+ mock_remove_application = MagicMock()
+ monkeypatch.setattr(
+ "aiet.cli.application.remove_application", mock_remove_application
+ )
+ remove_cmd.params[0].type = click.Choice(["test"])
+
+ args = ["--directory_name", "test"]
+ cli_runner.invoke(remove_cmd, args=args)
+ mock_remove_application.assert_called_once_with("test")
+
+
+class ExecutionCase(TypedDict, total=False):
+ """Execution case."""
+
+ args: List[str]
+ lock_path: str
+ can_establish_connection: bool
+ establish_connection_delay: int
+ app_exit_code: int
+ exit_code: int
+ output: str
+
+
+@pytest.mark.parametrize(
+ "application_config, system_config, executions",
+ [
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ config_location=Path("wrong_location"),
+ commands={"build": ["echo build {application.name}"]},
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ config_location=Path("wrong_location"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "build"],
+ exit_code=MiddlewareExitCode.CONFIGURATION_ERROR,
+ output="Error: Application test_application has wrong config location\n",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ build_dir="build",
+ deploy_data=[("sample_file", "/tmp/sample_file")],
+ commands={"build": ["echo build {application.name}"]},
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "run"],
+ exit_code=MiddlewareExitCode.CONFIGURATION_ERROR,
+ output="Error: System test_system does not support data deploy\n",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ commands={"build": ["echo build {application.name}"]},
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "build"],
+ exit_code=MiddlewareExitCode.CONFIGURATION_ERROR,
+ output="Error: No build directory defined for the app test_application\n",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["new_system"],
+ build_dir="build",
+ commands={
+ "build": ["echo build {application.name} with {user_params:0}"]
+ },
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="param",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "build"],
+ exit_code=1,
+ output="Error: Application 'test_application' doesn't support the system 'test_system'\n",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ build_dir="build",
+ commands={"build": ["false"]},
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "build"],
+ exit_code=MiddlewareExitCode.BACKEND_ERROR,
+ output="""Running: false
+Error: Execution failed. Please check output for the details.\n""",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ lock=True,
+ build_dir="build",
+ commands={
+ "build": ["echo build {application.name} with {user_params:0}"]
+ },
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="param",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ lock=True,
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "build"],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Running: echo build test_application with param default
+build test_application with param default\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "build"],
+ lock_path="/tmp/middleware_test_application_test_system.lock",
+ exit_code=MiddlewareExitCode.CONCURRENT_ERROR,
+ output="Error: Another instance of the system is running\n",
+ ),
+ ExecutionCase(
+ args=["-c", "build", "--param=param=val3"],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Running: echo build test_application with param val3
+build test_application with param val3\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "build", "--param=param=newval"],
+ exit_code=1,
+ output="Error: Application parameter 'param=newval' not valid for command 'build'\n",
+ ),
+ ExecutionCase(
+ args=["-c", "some_command"],
+ exit_code=MiddlewareExitCode.CONFIGURATION_ERROR,
+ output="Error: Unsupported command some_command\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run"],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Generating commands to execute
+Running: echo run test_application on test_system
+run test_application on test_system\n""",
+ ),
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ deploy_data=[("sample_file", "/tmp/sample_file")],
+ commands={
+ "run": [
+ "echo run {application.name} with {user_params:param} on {system.name}"
+ ]
+ },
+ user_params={
+ "run": [
+ UserParamConfig(
+ name="param=",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ alias="param",
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ lock=True,
+ data_transfer=SSHConfig(
+ protocol="ssh",
+ username="username",
+ password="password",
+ hostname="localhost",
+ port="8022",
+ ),
+ commands={"run": ["sleep 100"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "run"],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .
+Deploying {application.config_location}/sample_file onto /tmp/sample_file
+Running: echo run test_application with param=default on test_system
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "run"],
+ lock_path="/tmp/middleware_test_system.lock",
+ exit_code=MiddlewareExitCode.CONCURRENT_ERROR,
+ output="Error: Another instance of the system is running\n",
+ ),
+ ExecutionCase(
+ args=[
+ "-c",
+ "run",
+ "--deploy={application.config_location}/sample_file:/tmp/sample_file",
+ ],
+ exit_code=0,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .
+Deploying {application.config_location}/sample_file onto /tmp/sample_file
+Deploying {application.config_location}/sample_file onto /tmp/sample_file
+Running: echo run test_application with param=default on test_system
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "run"],
+ app_exit_code=1,
+ exit_code=0,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .
+Deploying {application.config_location}/sample_file onto /tmp/sample_file
+Running: echo run test_application with param=default on test_system
+Application exited with exit code 1
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "run"],
+ exit_code=MiddlewareExitCode.CONNECTION_ERROR,
+ can_establish_connection=False,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds ..........................................................................................
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.
+Error: Couldn't connect to 'localhost:8022'.\n""",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy=bad_format"],
+ exit_code=1,
+ output="Error: Invalid deploy parameter 'bad_format' for command run\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy=:"],
+ exit_code=1,
+ output="Error: Invalid deploy parameter ':' for command run\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy= : "],
+ exit_code=1,
+ output="Error: Invalid deploy parameter ' : ' for command run\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy=some_src_file:"],
+ exit_code=1,
+ output="Error: Invalid deploy parameter 'some_src_file:' for command run\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy=:some_dst_file"],
+ exit_code=1,
+ output="Error: Invalid deploy parameter ':some_dst_file' for command run\n",
+ ),
+ ExecutionCase(
+ args=["-c", "run", "--deploy=unknown_file:/tmp/dest"],
+ exit_code=1,
+ output="Error: Path unknown_file does not exist\n",
+ ),
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ commands={
+ "run": [
+ "echo run {application.name} with {user_params:param} on {system.name}"
+ ]
+ },
+ user_params={
+ "run": [
+ UserParamConfig(
+ name="param=",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ alias="param",
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=SSHConfig(
+ protocol="ssh",
+ username="username",
+ password="password",
+ hostname="localhost",
+ port="8022",
+ ),
+ commands={"run": ["echo Unable to start system"]},
+ ),
+ [
+ ExecutionCase(
+ args=["-c", "run"],
+ exit_code=4,
+ can_establish_connection=False,
+ establish_connection_delay=1,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .
+
+---------- test_system execution failed ----------
+Unable to start system
+
+
+
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.
+Error: Execution failed. Please check output for the details.\n""",
+ )
+ ],
+ ],
+ ],
+)
+def test_application_command_execution(
+ application_config: ApplicationConfig,
+ system_config: SystemConfig,
+ executions: List[ExecutionCase],
+ tmpdir: Any,
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+) -> None:
+ """Test application command execution."""
+
+ @contextmanager
+ def lock_execution(lock_path: str) -> Generator[None, None, None]:
+ lock = FileLock(lock_path)
+ lock.acquire(timeout=1)
+
+ try:
+ yield
+ finally:
+ lock.release()
+
+ def replace_vars(str_val: str) -> str:
+ """Replace variables."""
+ application_config_location = str(
+ application_config["config_location"].absolute()
+ )
+
+ return str_val.replace(
+ "{application.config_location}", application_config_location
+ )
+
+ for execution in executions:
+ init_execution_test(
+ monkeypatch,
+ tmpdir,
+ application_config,
+ system_config,
+ can_establish_connection=execution.get("can_establish_connection", True),
+ establish_conection_delay=execution.get("establish_connection_delay", 0),
+ remote_app_exit_code=execution.get("app_exit_code", 0),
+ )
+
+ lock_path = execution.get("lock_path")
+
+ with ExitStack() as stack:
+ if lock_path:
+ stack.enter_context(lock_execution(lock_path))
+
+ args = [replace_vars(arg) for arg in execution["args"]]
+
+ result = cli_runner.invoke(
+ execute_cmd,
+ args=["-n", application_config["name"], "-s", system_config["name"]]
+ + args,
+ )
+ output = replace_vars(execution["output"])
+ assert result.exit_code == execution["exit_code"]
+ assert result.stdout == output
+
+
+@pytest.fixture(params=[False, True], ids=["run-cli", "run-json"])
+def payload_path_or_none(request: Any, tmp_path_factory: Any) -> Optional[Path]:
+ """Drives tests for run command so that it executes them both to use a json file, and to use CLI."""
+ if request.param:
+ ret: Path = tmp_path_factory.getbasetemp() / "system_config_payload_file.json"
+ return ret
+ return None
+
+
+def write_system_payload_config(
+ payload_file: IO[str],
+ application_config: ApplicationConfig,
+ system_config: SystemConfig,
+) -> None:
+ """Write a json payload file for the given test configuration."""
+ payload_dict = {
+ "id": system_config["name"],
+ "arguments": {
+ "application": application_config["name"],
+ },
+ }
+ json.dump(payload_dict, payload_file)
+
+
+@pytest.mark.parametrize(
+ "application_config, system_config, executions",
+ [
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ build_dir="build",
+ commands={
+ "build": ["echo build {application.name} with {user_params:0}"]
+ },
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="param",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ ),
+ [
+ ExecutionCase(
+ args=[],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Running: echo build test_application with param default
+build test_application with param default
+Generating commands to execute
+Running: echo run test_application on test_system
+run test_application on test_system\n""",
+ )
+ ],
+ ],
+ [
+ ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ commands={
+ "run": [
+ "echo run {application.name} with {user_params:param} on {system.name}"
+ ]
+ },
+ user_params={
+ "run": [
+ UserParamConfig(
+ name="param=",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ alias="param",
+ )
+ ]
+ },
+ ),
+ SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=SSHConfig(
+ protocol="ssh",
+ username="username",
+ password="password",
+ hostname="localhost",
+ port="8022",
+ ),
+ commands={"run": ["sleep 100"]},
+ ),
+ [
+ ExecutionCase(
+ args=[],
+ exit_code=MiddlewareExitCode.SUCCESS,
+ output="""Generating commands to execute
+Trying to establish connection with 'localhost:8022' - 90 retries every 15.0 seconds .
+Running: echo run test_application with param=default on test_system
+Shutting down sequence...
+Stopping test_system... (It could take few seconds)
+test_system stopped successfully.\n""",
+ )
+ ],
+ ],
+ ],
+)
+def test_application_run(
+ application_config: ApplicationConfig,
+ system_config: SystemConfig,
+ executions: List[ExecutionCase],
+ tmpdir: Any,
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ payload_path_or_none: Path,
+) -> None:
+ """Test application command execution."""
+ for execution in executions:
+ init_execution_test(monkeypatch, tmpdir, application_config, system_config)
+
+ if payload_path_or_none:
+ with open(payload_path_or_none, "w", encoding="utf-8") as payload_file:
+ write_system_payload_config(
+ payload_file, application_config, system_config
+ )
+
+ result = cli_runner.invoke(
+ run_cmd,
+ args=["--config", str(payload_path_or_none)],
+ )
+ else:
+ result = cli_runner.invoke(
+ run_cmd,
+ args=["-n", application_config["name"], "-s", system_config["name"]]
+ + execution["args"],
+ )
+
+ assert result.stdout == execution["output"]
+ assert result.exit_code == execution["exit_code"]
+
+
+@pytest.mark.parametrize(
+ "cmdline,error_pattern",
+ [
+ [
+ "--config {payload} -s test_system",
+ "when --config is set, the following parameters should not be provided",
+ ],
+ [
+ "--config {payload} -n test_application",
+ "when --config is set, the following parameters should not be provided",
+ ],
+ [
+ "--config {payload} -p mypar:3",
+ "when --config is set, the following parameters should not be provided",
+ ],
+ [
+ "-p mypar:3",
+ "when --config is not set, the following parameters are required",
+ ],
+ ["-s test_system", "when --config is not set, --name is required"],
+ ["-n test_application", "when --config is not set, --system is required"],
+ ],
+)
+def test_application_run_invalid_param_combinations(
+ cmdline: str,
+ error_pattern: str,
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ tmp_path: Any,
+ tmpdir: Any,
+) -> None:
+ """Test that invalid combinations arguments result in error as expected."""
+ application_config = ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ build_dir="build",
+ commands={"build": ["echo build {application.name} with {user_params:0}"]},
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="param",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ )
+ ]
+ },
+ )
+ system_config = SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={"run": ["echo run {application.name} on {system.name}"]},
+ )
+
+ init_execution_test(monkeypatch, tmpdir, application_config, system_config)
+
+ payload_file = tmp_path / "payload.json"
+ payload_file.write_text("dummy")
+ result = cli_runner.invoke(
+ run_cmd,
+ args=cmdline.format(payload=payload_file).split(),
+ )
+ found = re.search(error_pattern, result.stdout)
+ assert found, f"Cannot find pattern: [{error_pattern}] in \n[\n{result.stdout}\n]"
+
+
+@pytest.mark.parametrize(
+ "payload,expected",
+ [
+ pytest.param(
+ {"arguments": {}},
+ None,
+ marks=pytest.mark.xfail(reason="no system 'id''", strict=True),
+ ),
+ pytest.param(
+ {"id": "testsystem"},
+ None,
+ marks=pytest.mark.xfail(reason="no arguments object", strict=True),
+ ),
+ (
+ {"id": "testsystem", "arguments": {"application": "testapp"}},
+ ("testsystem", "testapp", [], [], [], None),
+ ),
+ (
+ {
+ "id": "testsystem",
+ "arguments": {"application": "testapp", "par1": "val1"},
+ },
+ ("testsystem", "testapp", ["par1=val1"], [], [], None),
+ ),
+ (
+ {
+ "id": "testsystem",
+ "arguments": {"application": "testapp", "application/par1": "val1"},
+ },
+ ("testsystem", "testapp", ["par1=val1"], [], [], None),
+ ),
+ (
+ {
+ "id": "testsystem",
+ "arguments": {"application": "testapp", "system/par1": "val1"},
+ },
+ ("testsystem", "testapp", [], ["par1=val1"], [], None),
+ ),
+ (
+ {
+ "id": "testsystem",
+ "arguments": {"application": "testapp", "deploy/par1": "val1"},
+ },
+ ("testsystem", "testapp", [], [], ["par1"], None),
+ ),
+ (
+ {
+ "id": "testsystem",
+ "arguments": {
+ "application": "testapp",
+ "appar1": "val1",
+ "application/appar2": "val2",
+ "system/syspar1": "val3",
+ "deploy/depploypar1": "val4",
+ "application/appar3": "val5",
+ "system/syspar2": "val6",
+ "deploy/depploypar2": "val7",
+ },
+ },
+ (
+ "testsystem",
+ "testapp",
+ ["appar1=val1", "appar2=val2", "appar3=val5"],
+ ["syspar1=val3", "syspar2=val6"],
+ ["depploypar1", "depploypar2"],
+ None,
+ ),
+ ),
+ ],
+)
+def test_parse_payload_run_config(payload: dict, expected: tuple) -> None:
+ """Test parsing of the JSON payload for the run_config command."""
+ assert parse_payload_run_config(payload) == expected
+
+
+def test_application_run_report(
+ tmpdir: Any,
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+) -> None:
+ """Test flag '--report' of command 'application run'."""
+ app_metrics = {"app_metric": 3.14}
+ app_metrics_b64 = base64.b64encode(json.dumps(app_metrics).encode("utf-8"))
+ application_config = ApplicationConfig(
+ name="test_application",
+ description="Test application",
+ supported_systems=["test_system"],
+ build_dir="build",
+ commands={"build": ["echo build {application.name} with {user_params:0}"]},
+ user_params={
+ "build": [
+ UserParamConfig(
+ name="param",
+ description="sample parameter",
+ default_value="default",
+ values=["val1", "val2", "val3"],
+ ),
+ UserParamConfig(
+ name="p2",
+ description="another parameter, not overridden",
+ default_value="the-right-choice",
+ values=["the-right-choice", "the-bad-choice"],
+ ),
+ ]
+ },
+ )
+ system_config = SystemConfig(
+ name="test_system",
+ description="Test system",
+ data_transfer=LocalProtocolConfig(protocol="local"),
+ commands={
+ "run": [
+ "echo run {application.name} on {system.name}",
+ f"echo build <{Base64OutputParser.TAG_NAME}>{app_metrics_b64.decode('utf-8')}</{Base64OutputParser.TAG_NAME}>",
+ ]
+ },
+ reporting={
+ "regex": {
+ "app_name": {
+ "pattern": r"run (.\S*) ",
+ "type": "str",
+ },
+ "sys_name": {
+ "pattern": r"on (.\S*)",
+ "type": "str",
+ },
+ }
+ },
+ )
+ report_file = Path(tmpdir) / "test_report.json"
+ param_val = "param=val1"
+ exit_code = MiddlewareExitCode.SUCCESS
+
+ init_execution_test(monkeypatch, tmpdir, application_config, system_config)
+
+ result = cli_runner.invoke(
+ run_cmd,
+ args=[
+ "-n",
+ application_config["name"],
+ "-s",
+ system_config["name"],
+ "--report",
+ str(report_file),
+ "--param",
+ param_val,
+ ],
+ )
+ assert result.exit_code == exit_code
+ assert report_file.is_file()
+ with open(report_file, "r", encoding="utf-8") as file:
+ report = json.load(file)
+
+ assert report == {
+ "application": {
+ "metrics": {"0": {"app_metric": 3.14}},
+ "name": "test_application",
+ "params": {"param": "val1", "p2": "the-right-choice"},
+ },
+ "system": {
+ "metrics": {"app_name": "test_application", "sys_name": "test_system"},
+ "name": "test_system",
+ "params": {},
+ },
+ }
+
+
+def init_execution_test(
+ monkeypatch: Any,
+ tmpdir: Any,
+ application_config: ApplicationConfig,
+ system_config: SystemConfig,
+ can_establish_connection: bool = True,
+ establish_conection_delay: float = 0,
+ remote_app_exit_code: int = 0,
+) -> None:
+ """Init execution test."""
+ application_name = application_config["name"]
+ system_name = system_config["name"]
+
+ execute_cmd.params[0].type = click.Choice([application_name])
+ execute_cmd.params[1].type = click.Choice([system_name])
+ execute_cmd.params[2].type = click.Choice(["build", "run", "some_command"])
+
+ run_cmd.params[0].type = click.Choice([application_name])
+ run_cmd.params[1].type = click.Choice([system_name])
+
+ if "config_location" not in application_config:
+ application_path = Path(tmpdir) / "application"
+ application_path.mkdir()
+ application_config["config_location"] = application_path
+
+ # this file could be used as deploy parameter value or
+ # as deploy parameter in application configuration
+ sample_file = application_path / "sample_file"
+ sample_file.touch()
+ monkeypatch.setattr(
+ "aiet.backend.application.get_available_applications",
+ MagicMock(return_value=[Application(application_config)]),
+ )
+
+ ssh_protocol_mock = MagicMock(spec=SSHProtocol)
+
+ def mock_establish_connection() -> bool:
+ """Mock establish connection function."""
+ # give some time for the system to start
+ time.sleep(establish_conection_delay)
+ return can_establish_connection
+
+ ssh_protocol_mock.establish_connection.side_effect = mock_establish_connection
+ ssh_protocol_mock.connection_details.return_value = ("localhost", 8022)
+ ssh_protocol_mock.run.return_value = (
+ remote_app_exit_code,
+ bytearray(),
+ bytearray(),
+ )
+ monkeypatch.setattr(
+ "aiet.backend.protocol.SSHProtocol", MagicMock(return_value=ssh_protocol_mock)
+ )
+
+ if "config_location" not in system_config:
+ system_path = Path(tmpdir) / "system"
+ system_path.mkdir()
+ system_config["config_location"] = system_path
+ monkeypatch.setattr(
+ "aiet.backend.system.get_available_systems",
+ MagicMock(return_value=[load_system(system_config)]),
+ )
+
+ monkeypatch.setattr("aiet.backend.execution.wait", MagicMock())
diff --git a/tests/aiet/test_cli_common.py b/tests/aiet/test_cli_common.py
new file mode 100644
index 0000000..d018e44
--- /dev/null
+++ b/tests/aiet/test_cli_common.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for cli common module."""
+from typing import Any
+
+import pytest
+
+from aiet.cli.common import print_command_details
+from aiet.cli.common import raise_exception_at_signal
+
+
+def test_print_command_details(capsys: Any) -> None:
+ """Test print_command_details function."""
+ command = {
+ "command_strings": ["echo test"],
+ "user_params": [
+ {"name": "param_name", "description": "param_description"},
+ {
+ "name": "param_name2",
+ "description": "param_description2",
+ "alias": "alias2",
+ },
+ ],
+ }
+ print_command_details(command)
+ captured = capsys.readouterr()
+ assert "echo test" in captured.out
+ assert "param_name" in captured.out
+ assert "alias2" in captured.out
+
+
+def test_raise_exception_at_signal() -> None:
+ """Test raise_exception_at_signal graceful shutdown."""
+ with pytest.raises(Exception) as err:
+ raise_exception_at_signal(1, "")
+
+ assert str(err.value) == "Middleware shutdown requested"
diff --git a/tests/aiet/test_cli_system.py b/tests/aiet/test_cli_system.py
new file mode 100644
index 0000000..fd39f31
--- /dev/null
+++ b/tests/aiet/test_cli_system.py
@@ -0,0 +1,240 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for testing CLI system subcommand."""
+import json
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+from unittest.mock import MagicMock
+
+import click
+import pytest
+from click.testing import CliRunner
+
+from aiet.backend.config import SystemConfig
+from aiet.backend.system import load_system
+from aiet.backend.system import System
+from aiet.cli.system import details_cmd
+from aiet.cli.system import install_cmd
+from aiet.cli.system import list_cmd
+from aiet.cli.system import remove_cmd
+from aiet.cli.system import system_cmd
+
+
+def test_system_cmd() -> None:
+ """Test system commands."""
+ commands = ["list", "details", "install", "remove"]
+ assert all(command in system_cmd.commands for command in commands)
+
+
+@pytest.mark.parametrize("format_", ["json", "cli"])
+def test_system_cmd_context(cli_runner: CliRunner, format_: str) -> None:
+ """Test setting command context parameters."""
+ result = cli_runner.invoke(system_cmd, ["--format", format_])
+ # command should fail if no subcommand provided
+ assert result.exit_code == 2
+
+ result = cli_runner.invoke(system_cmd, ["--format", format_, "list"])
+ assert result.exit_code == 0
+
+
+@pytest.mark.parametrize(
+ "format_,expected_output",
+ [
+ ("json", '{"type": "system", "available": ["system1", "system2"]}\n'),
+ ("cli", "Available systems:\n\nsystem1\nsystem2\n"),
+ ],
+)
+def test_list_cmd_with_format(
+ cli_runner: CliRunner, monkeypatch: Any, format_: str, expected_output: str
+) -> None:
+ """Test available systems command with different formats output."""
+ # Mock some systems
+ mock_system1 = MagicMock()
+ mock_system1.name = "system1"
+ mock_system2 = MagicMock()
+ mock_system2.name = "system2"
+
+ # Monkey patch the call get_available_systems
+ mock_available_systems = MagicMock()
+ mock_available_systems.return_value = [mock_system1, mock_system2]
+ monkeypatch.setattr("aiet.cli.system.get_available_systems", mock_available_systems)
+
+ obj = {"format": format_}
+ result = cli_runner.invoke(list_cmd, obj=obj)
+ assert result.output == expected_output
+
+
+def get_test_system(
+ annotations: Optional[Dict[str, Union[str, List[str]]]] = None
+) -> System:
+ """Return test system details."""
+ config = SystemConfig(
+ name="system",
+ description="test",
+ data_transfer={
+ "protocol": "ssh",
+ "username": "root",
+ "password": "root",
+ "hostname": "localhost",
+ "port": "8022",
+ },
+ commands={
+ "clean": ["clean"],
+ "build": ["build"],
+ "run": ["run"],
+ "post_run": ["post_run"],
+ },
+ annotations=annotations or {},
+ )
+
+ return load_system(config)
+
+
+def get_details_cmd_json_output(
+ annotations: Optional[Dict[str, Union[str, List[str]]]] = None
+) -> str:
+ """Test JSON output for details command."""
+ ann_str = ""
+ if annotations is not None:
+ ann_str = '"annotations":{},'.format(json.dumps(annotations))
+
+ json_output = (
+ """
+{
+ "type": "system",
+ "name": "system",
+ "description": "test",
+ "data_transfer_protocol": "ssh",
+ "commands": {
+ "clean":
+ {
+ "command_strings": ["clean"],
+ "user_params": []
+ },
+ "build":
+ {
+ "command_strings": ["build"],
+ "user_params": []
+ },
+ "run":
+ {
+ "command_strings": ["run"],
+ "user_params": []
+ },
+ "post_run":
+ {
+ "command_strings": ["post_run"],
+ "user_params": []
+ }
+ },
+"""
+ + ann_str
+ + """
+ "available_application" : []
+ }
+"""
+ )
+ return json.dumps(json.loads(json_output)) + "\n"
+
+
+def get_details_cmd_console_output(
+ annotations: Optional[Dict[str, Union[str, List[str]]]] = None
+) -> str:
+ """Test console output for details command."""
+ ann_str = ""
+ if annotations:
+ val_str = "".join(
+ "\n\t{}: {}".format(ann_name, ann_value)
+ for ann_name, ann_value in annotations.items()
+ )
+ ann_str = "\nAnnotations:{}".format(val_str)
+ return (
+ 'System "system" details'
+ + "\nDescription: test"
+ + "\nData Transfer Protocol: ssh"
+ + "\nAvailable Applications: "
+ + ann_str
+ + "\n\nclean commands:"
+ + "\nCommands: ['clean']"
+ + "\n\nbuild commands:"
+ + "\nCommands: ['build']"
+ + "\n\nrun commands:"
+ + "\nCommands: ['run']"
+ + "\n\npost_run commands:"
+ + "\nCommands: ['post_run']"
+ + "\n"
+ )
+
+
+@pytest.mark.parametrize(
+ "format_,system,expected_output",
+ [
+ (
+ "json",
+ get_test_system(annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}),
+ get_details_cmd_json_output(
+ annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}
+ ),
+ ),
+ (
+ "cli",
+ get_test_system(annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}),
+ get_details_cmd_console_output(
+ annotations={"ann1": "annotation1", "ann2": ["a1", "a2"]}
+ ),
+ ),
+ (
+ "json",
+ get_test_system(annotations={}),
+ get_details_cmd_json_output(annotations={}),
+ ),
+ (
+ "cli",
+ get_test_system(annotations={}),
+ get_details_cmd_console_output(annotations={}),
+ ),
+ ],
+)
+def test_details_cmd(
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ format_: str,
+ system: System,
+ expected_output: str,
+) -> None:
+ """Test details command with different formats output."""
+ mock_get_system = MagicMock()
+ mock_get_system.return_value = system
+ monkeypatch.setattr("aiet.cli.system.get_system", mock_get_system)
+
+ args = ["--name", "system"]
+ obj = {"format": format_}
+ details_cmd.params[0].type = click.Choice(["system"])
+
+ result = cli_runner.invoke(details_cmd, args=args, obj=obj)
+ assert result.output == expected_output
+
+
+def test_install_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test install system command."""
+ mock_install_system = MagicMock()
+ monkeypatch.setattr("aiet.cli.system.install_system", mock_install_system)
+
+ args = ["--source", "test"]
+ cli_runner.invoke(install_cmd, args=args)
+ mock_install_system.assert_called_once_with(Path("test"))
+
+
+def test_remove_cmd(cli_runner: CliRunner, monkeypatch: Any) -> None:
+ """Test remove system command."""
+ mock_remove_system = MagicMock()
+ monkeypatch.setattr("aiet.cli.system.remove_system", mock_remove_system)
+ remove_cmd.params[0].type = click.Choice(["test"])
+
+ args = ["--directory_name", "test"]
+ cli_runner.invoke(remove_cmd, args=args)
+ mock_remove_system.assert_called_once_with("test")
diff --git a/tests/aiet/test_cli_tool.py b/tests/aiet/test_cli_tool.py
new file mode 100644
index 0000000..45d45c8
--- /dev/null
+++ b/tests/aiet/test_cli_tool.py
@@ -0,0 +1,333 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=attribute-defined-outside-init,no-member,line-too-long,too-many-arguments,too-many-locals
+"""Module for testing CLI tool subcommand."""
+import json
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from unittest.mock import MagicMock
+
+import click
+import pytest
+from click.testing import CliRunner
+from click.testing import Result
+
+from aiet.backend.tool import get_unique_tool_names
+from aiet.backend.tool import Tool
+from aiet.cli.tool import details_cmd
+from aiet.cli.tool import execute_cmd
+from aiet.cli.tool import list_cmd
+from aiet.cli.tool import tool_cmd
+
+
+def test_tool_cmd() -> None:
+ """Test tool commands."""
+ commands = ["list", "details", "execute"]
+ assert all(command in tool_cmd.commands for command in commands)
+
+
+@pytest.mark.parametrize("format_", ["json", "cli"])
+def test_tool_cmd_context(cli_runner: CliRunner, format_: str) -> None:
+ """Test setting command context parameters."""
+ result = cli_runner.invoke(tool_cmd, ["--format", format_])
+ # command should fail if no subcommand provided
+ assert result.exit_code == 2
+
+ result = cli_runner.invoke(tool_cmd, ["--format", format_, "list"])
+ assert result.exit_code == 0
+
+
+@pytest.mark.parametrize(
+ "format_, expected_output",
+ [
+ (
+ "json",
+ '{"type": "tool", "available": ["tool_1", "tool_2"]}\n',
+ ),
+ ("cli", "Available tools:\n\ntool_1\ntool_2\n"),
+ ],
+)
+def test_list_cmd(
+ cli_runner: CliRunner,
+ monkeypatch: Any,
+ format_: str,
+ expected_output: str,
+) -> None:
+ """Test available tool commands."""
+ # Mock some tools
+ mock_tool_1 = MagicMock(spec=Tool)
+ mock_tool_1.name = "tool_1"
+ mock_tool_2 = MagicMock(spec=Tool)
+ mock_tool_2.name = "tool_2"
+
+ # Monkey patch the call get_available_tools
+ mock_available_tools = MagicMock()
+ mock_available_tools.return_value = [mock_tool_1, mock_tool_2]
+
+ monkeypatch.setattr("aiet.backend.tool.get_available_tools", mock_available_tools)
+
+ obj = {"format": format_}
+ args: Sequence[str] = []
+ result = cli_runner.invoke(list_cmd, obj=obj, args=args)
+ assert result.output == expected_output
+
+
+def get_details_cmd_json_output() -> List[dict]:
+ """Get JSON output for details command."""
+ json_output = [
+ {
+ "type": "tool",
+ "name": "tool_1",
+ "description": "This is tool 1",
+ "supported_systems": ["System 1"],
+ "commands": {
+ "clean": {"command_strings": ["echo 'clean'"], "user_params": []},
+ "build": {"command_strings": ["echo 'build'"], "user_params": []},
+ "run": {"command_strings": ["echo 'run'"], "user_params": []},
+ "post_run": {"command_strings": ["echo 'post_run'"], "user_params": []},
+ },
+ }
+ ]
+
+ return json_output
+
+
+def get_details_cmd_console_output() -> str:
+ """Get console output for details command."""
+ return (
+ 'Tool "tool_1" details'
+ "\nDescription: This is tool 1"
+ "\n\nSupported systems: System 1"
+ "\n\nclean commands:"
+ "\nCommands: [\"echo 'clean'\"]"
+ "\n\nbuild commands:"
+ "\nCommands: [\"echo 'build'\"]"
+ "\n\nrun commands:\nCommands: [\"echo 'run'\"]"
+ "\n\npost_run commands:"
+ "\nCommands: [\"echo 'post_run'\"]"
+ "\n"
+ )
+
+
+@pytest.mark.parametrize(
+ [
+ "tool_name",
+ "format_",
+ "expected_success",
+ "expected_output",
+ ],
+ [
+ ("tool_1", "json", True, get_details_cmd_json_output()),
+ ("tool_1", "cli", True, get_details_cmd_console_output()),
+ ("non-existent tool", "json", False, None),
+ ("non-existent tool", "cli", False, None),
+ ],
+)
+def test_details_cmd(
+ cli_runner: CliRunner,
+ tool_name: str,
+ format_: str,
+ expected_success: bool,
+ expected_output: str,
+) -> None:
+ """Test tool details command."""
+ details_cmd.params[0].type = click.Choice(["tool_1", "tool_2", "vela"])
+ result = cli_runner.invoke(
+ details_cmd, obj={"format": format_}, args=["--name", tool_name]
+ )
+ success = result.exit_code == 0
+ assert success == expected_success, result.output
+ if expected_success:
+ assert result.exception is None
+ output = json.loads(result.output) if format_ == "json" else result.output
+ assert output == expected_output
+
+
+@pytest.mark.parametrize(
+ "system_name",
+ [
+ "",
+ "Corstone-300: Cortex-M55+Ethos-U55",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ "Corstone-310: Cortex-M85+Ethos-U55",
+ ],
+)
+def test_details_cmd_vela(cli_runner: CliRunner, system_name: str) -> None:
+ """Test tool details command for Vela."""
+ details_cmd.params[0].type = click.Choice(get_unique_tool_names())
+ details_cmd.params[1].type = click.Choice([system_name])
+ args = ["--name", "vela"]
+ if system_name:
+ args += ["--system", system_name]
+ result = cli_runner.invoke(details_cmd, obj={"format": "json"}, args=args)
+ success = result.exit_code == 0
+ assert success, result.output
+ result_json = json.loads(result.output)
+ assert result_json
+ if system_name:
+ assert len(result_json) == 1
+ tool = result_json[0]
+ assert len(tool["supported_systems"]) == 1
+ assert system_name == tool["supported_systems"][0]
+ else: # no system specified => list details for all systems
+ assert len(result_json) == 3
+ assert all(len(tool["supported_systems"]) == 1 for tool in result_json)
+
+
+@pytest.fixture(scope="session")
+def input_model_file(non_optimised_input_model_file: Path) -> Path:
+ """Provide the path to a quantized dummy model file in the test_resources_path."""
+ return non_optimised_input_model_file
+
+
+def execute_vela(
+ cli_runner: CliRunner,
+ tool_name: str = "vela",
+ system_name: Optional[str] = None,
+ input_model: Optional[Path] = None,
+ output_model: Optional[Path] = None,
+ mac: Optional[int] = None,
+ format_: str = "cli",
+) -> Result:
+ """Run Vela with different parameters."""
+ execute_cmd.params[0].type = click.Choice(get_unique_tool_names())
+ execute_cmd.params[2].type = click.Choice([system_name or "dummy_system"])
+ args = ["--name", tool_name]
+ if system_name is not None:
+ args += ["--system", system_name]
+ if input_model is not None:
+ args += ["--param", "input={}".format(input_model)]
+ if output_model is not None:
+ args += ["--param", "output={}".format(output_model)]
+ if mac is not None:
+ args += ["--param", "mac={}".format(mac)]
+ result = cli_runner.invoke(
+ execute_cmd,
+ args=args,
+ obj={"format": format_},
+ )
+ return result
+
+
+@pytest.mark.parametrize("format_", ["cli, json"])
+@pytest.mark.parametrize(
+ ["tool_name", "system_name", "mac", "expected_success", "expected_output"],
+ [
+ ("vela", "System 1", 32, False, None), # system not supported
+ ("vela", "NON-EXISTENT SYSTEM", 128, False, None), # system does not exist
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 32, True, None),
+ ("NON-EXISTENT TOOL", "Corstone-300: Cortex-M55+Ethos-U55", 32, False, None),
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 64, True, None),
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 128, True, None),
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U55", 256, True, None),
+ (
+ "vela",
+ "Corstone-300: Cortex-M55+Ethos-U55",
+ 512,
+ False,
+ None,
+ ), # mac not supported
+ (
+ "vela",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ 32,
+ False,
+ None,
+ ), # mac not supported
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U65", 256, True, None),
+ ("vela", "Corstone-300: Cortex-M55+Ethos-U65", 512, True, None),
+ (
+ "vela",
+ None,
+ 512,
+ False,
+ "Error: Please specify the system for tool vela.",
+ ), # no system specified
+ (
+ "NON-EXISTENT TOOL",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ 512,
+ False,
+ None,
+ ), # tool does not exist
+ ("vela", "Corstone-310: Cortex-M85+Ethos-U55", 128, True, None),
+ ],
+)
+def test_vela_run(
+ cli_runner: CliRunner,
+ format_: str,
+ input_model_file: Path, # pylint: disable=redefined-outer-name
+ tool_name: str,
+ system_name: Optional[str],
+ mac: int,
+ expected_success: bool,
+ expected_output: Optional[str],
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ """Test the execution of the Vela command."""
+ monkeypatch.chdir(tmp_path)
+
+ output_file = Path("vela_output.tflite")
+
+ result = execute_vela(
+ cli_runner,
+ tool_name=tool_name,
+ system_name=system_name,
+ input_model=input_model_file,
+ output_model=output_file,
+ mac=mac,
+ format_=format_,
+ )
+
+ success = result.exit_code == 0
+ assert success == expected_success
+ if success:
+ # Check output file
+ output_file = output_file.resolve()
+ assert output_file.is_file()
+ if expected_output:
+ assert result.output.strip() == expected_output
+
+
+@pytest.mark.parametrize("include_input_model", [True, False])
+@pytest.mark.parametrize("include_output_model", [True, False])
+@pytest.mark.parametrize("include_mac", [True, False])
+def test_vela_run_missing_params(
+ cli_runner: CliRunner,
+ input_model_file: Path, # pylint: disable=redefined-outer-name
+ include_input_model: bool,
+ include_output_model: bool,
+ include_mac: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+) -> None:
+ """Test the execution of the Vela command with missing user parameters."""
+ monkeypatch.chdir(tmp_path)
+
+ output_model_file = Path("output_model.tflite")
+ system_name = "Corstone-300: Cortex-M55+Ethos-U65"
+ mac = 256
+ # input_model is a required parameters, but mac and output_model have default values.
+ expected_success = include_input_model
+
+ result = execute_vela(
+ cli_runner,
+ tool_name="vela",
+ system_name=system_name,
+ input_model=input_model_file if include_input_model else None,
+ output_model=output_model_file if include_output_model else None,
+ mac=mac if include_mac else None,
+ )
+
+ success = result.exit_code == 0
+ assert success == expected_success, (
+ f"Success is {success}, but expected {expected_success}. "
+ f"Included params: ["
+ f"input_model={include_input_model}, "
+ f"output_model={include_output_model}, "
+ f"mac={include_mac}]"
+ )
diff --git a/tests/aiet/test_main.py b/tests/aiet/test_main.py
new file mode 100644
index 0000000..f2ebae2
--- /dev/null
+++ b/tests/aiet/test_main.py
@@ -0,0 +1,16 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for testing AIET main.py."""
+from typing import Any
+from unittest.mock import MagicMock
+
+from aiet import main
+
+
+def test_main(monkeypatch: Any) -> None:
+ """Test main entry point function."""
+ with monkeypatch.context() as mock_context:
+ mock = MagicMock()
+ mock_context.setattr(main, "cli", mock)
+ main.main()
+ mock.assert_called_once()
diff --git a/tests/aiet/test_resources/application_config.json b/tests/aiet/test_resources/application_config.json
new file mode 100644
index 0000000..2dfcfec
--- /dev/null
+++ b/tests/aiet/test_resources/application_config.json
@@ -0,0 +1,96 @@
+[
+ {
+ "name": "application_1",
+ "description": "application number one",
+ "supported_systems": [
+ "system_1",
+ "system_2"
+ ],
+ "build_dir": "build_dir_11",
+ "commands": {
+ "clean": [
+ "clean_cmd_11"
+ ],
+ "build": [
+ "build_cmd_11"
+ ],
+ "run": [
+ "run_cmd_11"
+ ],
+ "post_run": [
+ "post_run_cmd_11"
+ ]
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "run_param_11",
+ "values": [],
+ "description": "run param number one"
+ }
+ ],
+ "build": [
+ {
+ "name": "build_param_11",
+ "values": [],
+ "description": "build param number one"
+ },
+ {
+ "name": "build_param_12",
+ "values": [],
+ "description": "build param number two"
+ },
+ {
+ "name": "build_param_13",
+ "values": [
+ "value_1"
+ ],
+ "description": "build param number three with some value"
+ }
+ ]
+ }
+ },
+ {
+ "name": "application_2",
+ "description": "application number two",
+ "supported_systems": [
+ "system_2"
+ ],
+ "build_dir": "build_dir_21",
+ "commands": {
+ "clean": [
+ "clean_cmd_21"
+ ],
+ "build": [
+ "build_cmd_21",
+ "build_cmd_22"
+ ],
+ "run": [
+ "run_cmd_21"
+ ],
+ "post_run": [
+ "post_run_cmd_21"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "build_param_21",
+ "values": [],
+ "description": "build param number one"
+ },
+ {
+ "name": "build_param_22",
+ "values": [],
+ "description": "build param number two"
+ },
+ {
+ "name": "build_param_23",
+ "values": [],
+ "description": "build param number three"
+ }
+ ],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/application_config.json.license b/tests/aiet/test_resources/application_config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/application_config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/applications/application1/aiet-config.json b/tests/aiet/test_resources/applications/application1/aiet-config.json
new file mode 100644
index 0000000..97f0401
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application1/aiet-config.json
@@ -0,0 +1,30 @@
+[
+ {
+ "name": "application_1",
+ "description": "This is application 1",
+ "supported_systems": [
+ {
+ "name": "System 1"
+ }
+ ],
+ "build_dir": "build",
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/applications/application1/aiet-config.json.license b/tests/aiet/test_resources/applications/application1/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application1/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/applications/application2/aiet-config.json b/tests/aiet/test_resources/applications/application2/aiet-config.json
new file mode 100644
index 0000000..e9122d3
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application2/aiet-config.json
@@ -0,0 +1,30 @@
+[
+ {
+ "name": "application_2",
+ "description": "This is application 2",
+ "supported_systems": [
+ {
+ "name": "System 2"
+ }
+ ],
+ "build_dir": "build",
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/applications/application2/aiet-config.json.license b/tests/aiet/test_resources/applications/application2/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application2/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/applications/application3/readme.txt b/tests/aiet/test_resources/applications/application3/readme.txt
new file mode 100644
index 0000000..8c72c05
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application3/readme.txt
@@ -0,0 +1,4 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+This application does not have json configuration file
diff --git a/tests/aiet/test_resources/applications/application4/aiet-config.json b/tests/aiet/test_resources/applications/application4/aiet-config.json
new file mode 100644
index 0000000..34dc780
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application4/aiet-config.json
@@ -0,0 +1,35 @@
+[
+ {
+ "name": "application_4",
+ "description": "This is application 4",
+ "build_dir": "build",
+ "supported_systems": [
+ {
+ "name": "System 4"
+ }
+ ],
+ "commands": {
+ "build": [
+ "cp ../hello_app.txt . # {user_params:0}"
+ ],
+ "run": [
+ "{application.build_dir}/hello_app.txt"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--app",
+ "description": "Sample command param",
+ "values": [
+ "application1",
+ "application2",
+ "application3"
+ ],
+ "default_value": "application1"
+ }
+ ],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/applications/application4/aiet-config.json.license b/tests/aiet/test_resources/applications/application4/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application4/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/applications/application4/hello_app.txt b/tests/aiet/test_resources/applications/application4/hello_app.txt
new file mode 100644
index 0000000..2ec0d1d
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application4/hello_app.txt
@@ -0,0 +1,4 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+Hello from APP!
diff --git a/tests/aiet/test_resources/applications/application5/aiet-config.json b/tests/aiet/test_resources/applications/application5/aiet-config.json
new file mode 100644
index 0000000..5269409
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application5/aiet-config.json
@@ -0,0 +1,160 @@
+[
+ {
+ "name": "application_5",
+ "description": "This is application 5",
+ "build_dir": "default_build_dir",
+ "supported_systems": [
+ {
+ "name": "System 1",
+ "lock": false
+ },
+ {
+ "name": "System 2"
+ }
+ ],
+ "variables": {
+ "var1": "value1",
+ "var2": "value2"
+ },
+ "lock": true,
+ "commands": {
+ "build": [
+ "default build command"
+ ],
+ "run": [
+ "default run command"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ },
+ {
+ "name": "application_5A",
+ "description": "This is application 5A",
+ "supported_systems": [
+ {
+ "name": "System 1",
+ "build_dir": "build_5A",
+ "variables": {
+ "var1": "new value1"
+ }
+ },
+ {
+ "name": "System 2",
+ "variables": {
+ "var2": "new value2"
+ },
+ "lock": true,
+ "commands": {
+ "run": [
+ "run command on system 2"
+ ]
+ }
+ }
+ ],
+ "variables": {
+ "var1": "value1",
+ "var2": "value2"
+ },
+ "build_dir": "build",
+ "commands": {
+ "build": [
+ "default build command"
+ ],
+ "run": [
+ "default run command"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ },
+ {
+ "name": "application_5B",
+ "description": "This is application 5B",
+ "supported_systems": [
+ {
+ "name": "System 1",
+ "build_dir": "build_5B",
+ "variables": {
+ "var1": "value for var1 System1",
+ "var2": "value for var2 System1"
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--param_5B",
+ "description": "Sample command param",
+ "values": [
+ "value1",
+ "value2",
+ "value3"
+ ],
+ "default_value": "value1",
+ "alias": "param1"
+ }
+ ]
+ }
+ },
+ {
+ "name": "System 2",
+ "variables": {
+ "var1": "value for var1 System2",
+ "var2": "value for var2 System2"
+ },
+ "commands": {
+ "build": [
+ "build command on system 2 with {variables:var1} {user_params:param1}"
+ ],
+ "run": [
+ "run command on system 2"
+ ]
+ },
+ "user_params": {
+ "run": []
+ }
+ }
+ ],
+ "build_dir": "build",
+ "commands": {
+ "build": [
+ "default build command with {variables:var1}"
+ ],
+ "run": [
+ "default run command with {variables:var2}"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--param",
+ "description": "Sample command param",
+ "values": [
+ "value1",
+ "value2",
+ "value3"
+ ],
+ "default_value": "value1",
+ "alias": "param1"
+ }
+ ],
+ "run": [],
+ "non_used_command": [
+ {
+ "name": "--not-used",
+ "description": "Not used param anywhere",
+ "values": [
+ "value1",
+ "value2",
+ "value3"
+ ],
+ "default_value": "value1",
+ "alias": "param1"
+ }
+ ]
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/applications/application5/aiet-config.json.license b/tests/aiet/test_resources/applications/application5/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/applications/application5/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/applications/readme.txt b/tests/aiet/test_resources/applications/readme.txt
new file mode 100644
index 0000000..a1f8209
--- /dev/null
+++ b/tests/aiet/test_resources/applications/readme.txt
@@ -0,0 +1,4 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+Dummy file for test purposes
diff --git a/tests/aiet/test_resources/hello_world.json b/tests/aiet/test_resources/hello_world.json
new file mode 100644
index 0000000..8a9a448
--- /dev/null
+++ b/tests/aiet/test_resources/hello_world.json
@@ -0,0 +1,54 @@
+[
+ {
+ "name": "Hello world",
+ "description": "Dummy application that displays 'Hello world!'",
+ "supported_systems": [
+ "Dummy System"
+ ],
+ "build_dir": "build",
+ "deploy_data": [
+ [
+ "src",
+ "/tmp/"
+ ],
+ [
+ "README",
+ "/tmp/README.md"
+ ]
+ ],
+ "commands": {
+ "clean": [],
+ "build": [],
+ "run": [
+ "echo 'Hello world!'",
+ "ls -l /tmp"
+ ],
+ "post_run": []
+ },
+ "user_params": {
+ "run": [
+ {
+ "name": "--choice-param",
+ "values": [
+ "dummy_value_1",
+ "dummy_value_2"
+ ],
+ "default_value": "dummy_value_1",
+ "description": "Choice param"
+ },
+ {
+ "name": "--open-param",
+ "values": [],
+ "default_value": "dummy_value_4",
+ "description": "Open param"
+ },
+ {
+ "name": "--enable-flag",
+ "default_value": "dummy_value_4",
+ "description": "Flag param"
+ }
+ ],
+ "build": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/hello_world.json.license b/tests/aiet/test_resources/hello_world.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/hello_world.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/scripts/test_backend_run b/tests/aiet/test_resources/scripts/test_backend_run
new file mode 100755
index 0000000..548f577
--- /dev/null
+++ b/tests/aiet/test_resources/scripts/test_backend_run
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+echo "Hello from script"
+>&2 echo "Oops!"
+sleep 100
diff --git a/tests/aiet/test_resources/scripts/test_backend_run_script.sh b/tests/aiet/test_resources/scripts/test_backend_run_script.sh
new file mode 100644
index 0000000..548f577
--- /dev/null
+++ b/tests/aiet/test_resources/scripts/test_backend_run_script.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+echo "Hello from script"
+>&2 echo "Oops!"
+sleep 100
diff --git a/tests/aiet/test_resources/systems/system1/aiet-config.json b/tests/aiet/test_resources/systems/system1/aiet-config.json
new file mode 100644
index 0000000..4b5dd19
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system1/aiet-config.json
@@ -0,0 +1,35 @@
+[
+ {
+ "name": "System 1",
+ "description": "This is system 1",
+ "build_dir": "build",
+ "data_transfer": {
+ "protocol": "ssh",
+ "username": "root",
+ "password": "root",
+ "hostname": "localhost",
+ "port": "8021"
+ },
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ],
+ "deploy": [
+ "echo 'deploy'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/systems/system1/aiet-config.json.license b/tests/aiet/test_resources/systems/system1/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system1/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt b/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt
new file mode 100644
index 0000000..487e9d8
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system1/system_artifact/dummy.txt
@@ -0,0 +1,2 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/systems/system2/aiet-config.json b/tests/aiet/test_resources/systems/system2/aiet-config.json
new file mode 100644
index 0000000..a9e0eb3
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system2/aiet-config.json
@@ -0,0 +1,32 @@
+[
+ {
+ "name": "System 2",
+ "description": "This is system 2",
+ "build_dir": "build",
+ "data_transfer": {
+ "protocol": "ssh",
+ "username": "root",
+ "password": "root",
+ "hostname": "localhost",
+ "port": "8021"
+ },
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/systems/system2/aiet-config.json.license b/tests/aiet/test_resources/systems/system2/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system2/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/systems/system3/readme.txt b/tests/aiet/test_resources/systems/system3/readme.txt
new file mode 100644
index 0000000..aba5a9c
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system3/readme.txt
@@ -0,0 +1,4 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+SPDX-License-Identifier: Apache-2.0
+
+This system does not have the json configuration file
diff --git a/tests/aiet/test_resources/systems/system4/aiet-config.json b/tests/aiet/test_resources/systems/system4/aiet-config.json
new file mode 100644
index 0000000..295e00f
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system4/aiet-config.json
@@ -0,0 +1,19 @@
+[
+ {
+ "name": "System 4",
+ "description": "This is system 4",
+ "build_dir": "build",
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "commands": {
+ "run": [
+ "echo {application.name}",
+ "cat {application.commands.run:0}"
+ ]
+ },
+ "user_params": {
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/systems/system4/aiet-config.json.license b/tests/aiet/test_resources/systems/system4/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/systems/system4/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/tools/tool1/aiet-config.json b/tests/aiet/test_resources/tools/tool1/aiet-config.json
new file mode 100644
index 0000000..067ef7e
--- /dev/null
+++ b/tests/aiet/test_resources/tools/tool1/aiet-config.json
@@ -0,0 +1,30 @@
+[
+ {
+ "name": "tool_1",
+ "description": "This is tool 1",
+ "build_dir": "build",
+ "supported_systems": [
+ {
+ "name": "System 1"
+ }
+ ],
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/tools/tool1/aiet-config.json.license b/tests/aiet/test_resources/tools/tool1/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/tools/tool1/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/tools/tool2/aiet-config.json b/tests/aiet/test_resources/tools/tool2/aiet-config.json
new file mode 100644
index 0000000..6eee9a6
--- /dev/null
+++ b/tests/aiet/test_resources/tools/tool2/aiet-config.json
@@ -0,0 +1,26 @@
+[
+ {
+ "name": "tool_2",
+ "description": "This is tool 2 with no supported systems",
+ "build_dir": "build",
+ "supported_systems": [],
+ "commands": {
+ "clean": [
+ "echo 'clean'"
+ ],
+ "build": [
+ "echo 'build'"
+ ],
+ "run": [
+ "echo 'run'"
+ ],
+ "post_run": [
+ "echo 'post_run'"
+ ]
+ },
+ "user_params": {
+ "build": [],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/tools/tool2/aiet-config.json.license b/tests/aiet/test_resources/tools/tool2/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/tools/tool2/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json
new file mode 100644
index 0000000..fe51488
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json
@@ -0,0 +1 @@
+[]
diff --git a/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_empty_config/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json
new file mode 100644
index 0000000..ff1cf1a
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json
@@ -0,0 +1,35 @@
+[
+ {
+ "name": "test_application",
+ "description": "This is test_application",
+ "build_dir": "build",
+ "supported_systems": [
+ {
+ "name": "System 4"
+ }
+ ],
+ "commands": {
+ "build": [
+ "cp ../hello_app.txt ."
+ ],
+ "run": [
+ "{application.build_dir}/hello_app.txt"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--app",
+ "description": "Sample command param",
+ "values": [
+ "application1",
+ "application2",
+ "application3"
+ ],
+ "default_value": "application1"
+ }
+ ],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_valid_config/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json
new file mode 100644
index 0000000..724b31b
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json
@@ -0,0 +1,2 @@
+This is not valid json file
+{
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config1/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json
new file mode 100644
index 0000000..1ebb29c
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json
@@ -0,0 +1,30 @@
+[
+ {
+ "name": "test_application",
+ "description": "This is test_application",
+ "build_dir": "build",
+ "commands": {
+ "build": [
+ "cp ../hello_app.txt ."
+ ],
+ "run": [
+ "{application.build_dir}/hello_app.txt"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--app",
+ "description": "Sample command param",
+ "values": [
+ "application1",
+ "application2",
+ "application3"
+ ],
+ "default_value": "application1"
+ }
+ ],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config2/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json
new file mode 100644
index 0000000..410d12d
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json
@@ -0,0 +1,35 @@
+[
+ {
+ "name": "test_application",
+ "description": "This is test_application",
+ "build_dir": "build",
+ "supported_systems": [
+ {
+ "anme": "System 4"
+ }
+ ],
+ "commands": {
+ "build": [
+ "cp ../hello_app.txt ."
+ ],
+ "run": [
+ "{application.build_dir}/hello_app.txt"
+ ]
+ },
+ "user_params": {
+ "build": [
+ {
+ "name": "--app",
+ "description": "Sample command param",
+ "values": [
+ "application1",
+ "application2",
+ "application3"
+ ],
+ "default_value": "application1"
+ }
+ ],
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/applications/application_with_wrong_config3/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json
new file mode 100644
index 0000000..fe51488
--- /dev/null
+++ b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json
@@ -0,0 +1 @@
+[]
diff --git a/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/systems/system_with_empty_config/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json
new file mode 100644
index 0000000..20142e9
--- /dev/null
+++ b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json
@@ -0,0 +1,16 @@
+[
+ {
+ "name": "Test system",
+ "description": "This is a test system",
+ "build_dir": "build",
+ "data_transfer": {
+ "protocol": "local"
+ },
+ "commands": {
+ "run": []
+ },
+ "user_params": {
+ "run": []
+ }
+ }
+]
diff --git a/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license
new file mode 100644
index 0000000..9b83bfc
--- /dev/null
+++ b/tests/aiet/test_resources/various/systems/system_with_valid_config/aiet-config.json.license
@@ -0,0 +1,3 @@
+SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+
+SPDX-License-Identifier: Apache-2.0
diff --git a/tests/aiet/test_run_vela_script.py b/tests/aiet/test_run_vela_script.py
new file mode 100644
index 0000000..971856e
--- /dev/null
+++ b/tests/aiet/test_run_vela_script.py
@@ -0,0 +1,152 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=redefined-outer-name,no-self-use
+"""Module for testing run_vela.py script."""
+from pathlib import Path
+from typing import Any
+from typing import List
+
+import pytest
+from click.testing import CliRunner
+
+from aiet.cli.common import MiddlewareExitCode
+from aiet.resources.tools.vela.check_model import get_model_from_file
+from aiet.resources.tools.vela.check_model import is_vela_optimised
+from aiet.resources.tools.vela.run_vela import run_vela
+
+
+@pytest.fixture(scope="session")
+def vela_config_path(test_tools_path: Path) -> Path:
+ """Return test systems path in a pytest fixture."""
+ return test_tools_path / "vela" / "vela.ini"
+
+
+@pytest.fixture(
+ params=[
+ ["ethos-u65-256", "Ethos_U65_High_End", "U65_Shared_Sram"],
+ ["ethos-u55-32", "Ethos_U55_High_End_Embedded", "U55_Shared_Sram"],
+ ]
+)
+def ethos_config(request: Any) -> Any:
+ """Fixture to provide different configuration for Ethos-U optimization with Vela."""
+ return request.param
+
+
+# pylint: disable=too-many-arguments
+def generate_args(
+ input_: Path,
+ output: Path,
+ cfg: Path,
+ acc_config: str,
+ system_config: str,
+ memory_mode: str,
+) -> List[str]:
+ """Generate arguments that can be passed to script 'run_vela'."""
+ return [
+ "-i",
+ str(input_),
+ "-o",
+ str(output),
+ "--config",
+ str(cfg),
+ "--accelerator-config",
+ acc_config,
+ "--system-config",
+ system_config,
+ "--memory-mode",
+ memory_mode,
+ "--optimise",
+ "Performance",
+ ]
+
+
+def check_run_vela(
+ cli_runner: CliRunner, args: List, expected_success: bool, output_file: Path
+) -> None:
+ """Run Vela with the given arguments and check the result."""
+ result = cli_runner.invoke(run_vela, args)
+ success = result.exit_code == MiddlewareExitCode.SUCCESS
+ assert success == expected_success
+ if success:
+ model = get_model_from_file(output_file)
+ assert is_vela_optimised(model)
+
+
+def run_vela_script(
+ cli_runner: CliRunner,
+ input_model_file: Path,
+ output_model_file: Path,
+ vela_config: Path,
+ expected_success: bool,
+ acc_config: str,
+ system_config: str,
+ memory_mode: str,
+) -> None:
+ """Run the command 'run_vela' on the command line."""
+ args = generate_args(
+ input_model_file,
+ output_model_file,
+ vela_config,
+ acc_config,
+ system_config,
+ memory_mode,
+ )
+ check_run_vela(cli_runner, args, expected_success, output_model_file)
+
+
+class TestRunVelaCli:
+ """Test the command-line execution of the run_vela command."""
+
+ def test_non_optimised_model(
+ self,
+ cli_runner: CliRunner,
+ non_optimised_input_model_file: Path,
+ tmp_path: Path,
+ vela_config_path: Path,
+ ethos_config: List,
+ ) -> None:
+ """Verify Vela is run correctly on an unoptimised model."""
+ run_vela_script(
+ cli_runner,
+ non_optimised_input_model_file,
+ tmp_path / "test.tflite",
+ vela_config_path,
+ True,
+ *ethos_config,
+ )
+
+ def test_optimised_model(
+ self,
+ cli_runner: CliRunner,
+ optimised_input_model_file: Path,
+ tmp_path: Path,
+ vela_config_path: Path,
+ ethos_config: List,
+ ) -> None:
+ """Verify Vela is run correctly on an already optimised model."""
+ run_vela_script(
+ cli_runner,
+ optimised_input_model_file,
+ tmp_path / "test.tflite",
+ vela_config_path,
+ True,
+ *ethos_config,
+ )
+
+ def test_invalid_model(
+ self,
+ cli_runner: CliRunner,
+ invalid_input_model_file: Path,
+ tmp_path: Path,
+ vela_config_path: Path,
+ ethos_config: List,
+ ) -> None:
+ """Verify an error is raised when the input model is not valid."""
+ run_vela_script(
+ cli_runner,
+ invalid_input_model_file,
+ tmp_path / "test.tflite",
+ vela_config_path,
+ False,
+ *ethos_config,
+ )
diff --git a/tests/aiet/test_utils_fs.py b/tests/aiet/test_utils_fs.py
new file mode 100644
index 0000000..46d276e
--- /dev/null
+++ b/tests/aiet/test_utils_fs.py
@@ -0,0 +1,168 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=no-self-use
+"""Module for testing fs.py."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Union
+from unittest.mock import MagicMock
+
+import pytest
+
+from aiet.utils.fs import get_resources
+from aiet.utils.fs import read_file_as_bytearray
+from aiet.utils.fs import read_file_as_string
+from aiet.utils.fs import recreate_directory
+from aiet.utils.fs import remove_directory
+from aiet.utils.fs import remove_resource
+from aiet.utils.fs import ResourceType
+from aiet.utils.fs import valid_for_filename
+
+
+@pytest.mark.parametrize(
+ "resource_name,expected_path",
+ [
+ ("systems", does_not_raise()),
+ ("applications", does_not_raise()),
+ ("whaaat", pytest.raises(ResourceWarning)),
+ (None, pytest.raises(ResourceWarning)),
+ ],
+)
+def test_get_resources(resource_name: ResourceType, expected_path: Any) -> None:
+ """Test get_resources() with multiple parameters."""
+ with expected_path:
+ resource_path = get_resources(resource_name)
+ assert resource_path.exists()
+
+
+def test_remove_resource_wrong_directory(
+ monkeypatch: Any, test_applications_path: Path
+) -> None:
+ """Test removing resource with wrong directory."""
+ mock_get_resources = MagicMock(return_value=test_applications_path)
+ monkeypatch.setattr("aiet.utils.fs.get_resources", mock_get_resources)
+
+ mock_shutil_rmtree = MagicMock()
+ monkeypatch.setattr("aiet.utils.fs.shutil.rmtree", mock_shutil_rmtree)
+
+ with pytest.raises(Exception, match="Resource .* does not exist"):
+ remove_resource("unknown", "applications")
+ mock_shutil_rmtree.assert_not_called()
+
+ with pytest.raises(Exception, match="Wrong resource .*"):
+ remove_resource("readme.txt", "applications")
+ mock_shutil_rmtree.assert_not_called()
+
+
+def test_remove_resource(monkeypatch: Any, test_applications_path: Path) -> None:
+ """Test removing resource data."""
+ mock_get_resources = MagicMock(return_value=test_applications_path)
+ monkeypatch.setattr("aiet.utils.fs.get_resources", mock_get_resources)
+
+ mock_shutil_rmtree = MagicMock()
+ monkeypatch.setattr("aiet.utils.fs.shutil.rmtree", mock_shutil_rmtree)
+
+ remove_resource("application1", "applications")
+ mock_shutil_rmtree.assert_called_once()
+
+
+def test_remove_directory(tmpdir: Any) -> None:
+ """Test directory removal."""
+ tmpdir_path = Path(tmpdir)
+ tmpfile = tmpdir_path / "temp.txt"
+
+ for item in [None, tmpfile]:
+ with pytest.raises(Exception, match="No directory path provided"):
+ remove_directory(item)
+
+ newdir = tmpdir_path / "newdir"
+ newdir.mkdir()
+
+ assert newdir.is_dir()
+ remove_directory(newdir)
+ assert not newdir.exists()
+
+
+def test_recreate_directory(tmpdir: Any) -> None:
+ """Test directory recreation."""
+ with pytest.raises(Exception, match="No directory path provided"):
+ recreate_directory(None)
+
+ tmpdir_path = Path(tmpdir)
+ tmpfile = tmpdir_path / "temp.txt"
+ tmpfile.touch()
+ with pytest.raises(Exception, match="Path .* does exist and it is not a directory"):
+ recreate_directory(tmpfile)
+
+ newdir = tmpdir_path / "newdir"
+ newdir.mkdir()
+ newfile = newdir / "newfile"
+ newfile.touch()
+ assert list(newdir.iterdir()) == [newfile]
+ recreate_directory(newdir)
+ assert not list(newdir.iterdir())
+
+ newdir2 = tmpdir_path / "newdir2"
+ assert not newdir2.exists()
+ recreate_directory(newdir2)
+ assert newdir2.is_dir()
+
+
+def write_to_file(
+ write_directory: Any, write_mode: str, write_text: Union[str, bytes]
+) -> Path:
+ """Write some text to a temporary test file."""
+ tmpdir_path = Path(write_directory)
+ tmpfile = tmpdir_path / "file_name.txt"
+ with open(tmpfile, write_mode) as file: # pylint: disable=unspecified-encoding
+ file.write(write_text)
+ return tmpfile
+
+
+class TestReadFileAsString:
+ """Test read_file_as_string() function."""
+
+ def test_returns_text_from_valid_file(self, tmpdir: Any) -> None:
+ """Ensure the string written to a file read correctly."""
+ file_path = write_to_file(tmpdir, "w", "hello")
+ assert read_file_as_string(file_path) == "hello"
+
+ def test_output_is_empty_string_when_input_file_non_existent(
+ self, tmpdir: Any
+ ) -> None:
+ """Ensure empty string returned when reading from non-existent file."""
+ file_path = Path(tmpdir / "non-existent.txt")
+ assert read_file_as_string(file_path) == ""
+
+
+class TestReadFileAsByteArray:
+ """Test read_file_as_bytearray() function."""
+
+ def test_returns_bytes_from_valid_file(self, tmpdir: Any) -> None:
+ """Ensure the bytes written to a file read correctly."""
+ file_path = write_to_file(tmpdir, "wb", b"hello bytes")
+ assert read_file_as_bytearray(file_path) == b"hello bytes"
+
+ def test_output_is_empty_bytearray_when_input_file_non_existent(
+ self, tmpdir: Any
+ ) -> None:
+ """Ensure empty bytearray returned when reading from non-existent file."""
+ file_path = Path(tmpdir / "non-existent.txt")
+ assert read_file_as_bytearray(file_path) == bytearray()
+
+
+@pytest.mark.parametrize(
+ "value, replacement, expected_result",
+ [
+ ["", "", ""],
+ ["123", "", "123"],
+ ["123", "_", "123"],
+ ["/some_folder/some_script.sh", "", "some_foldersome_script.sh"],
+ ["/some_folder/some_script.sh", "_", "_some_folder_some_script.sh"],
+ ["!;'some_name$%^!", "_", "___some_name____"],
+ ],
+)
+def test_valid_for_filename(value: str, replacement: str, expected_result: str) -> None:
+ """Test function valid_for_filename."""
+ assert valid_for_filename(value, replacement) == expected_result
diff --git a/tests/aiet/test_utils_helpers.py b/tests/aiet/test_utils_helpers.py
new file mode 100644
index 0000000..bbe03fc
--- /dev/null
+++ b/tests/aiet/test_utils_helpers.py
@@ -0,0 +1,27 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Module for testing helpers.py."""
+import logging
+from typing import Any
+from typing import List
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+from aiet.utils.helpers import set_verbosity
+
+
+@pytest.mark.parametrize(
+ "verbosity,expected_calls",
+ [(0, []), (1, [call(logging.INFO)]), (2, [call(logging.DEBUG)])],
+)
+def test_set_verbosity(
+ verbosity: int, expected_calls: List[Any], monkeypatch: Any
+) -> None:
+ """Test set_verbosity() with different verbsosity levels."""
+ with monkeypatch.context() as mock_context:
+ logging_mock = MagicMock()
+ mock_context.setattr(logging.getLogger(), "setLevel", logging_mock)
+ set_verbosity(None, None, verbosity)
+ logging_mock.assert_has_calls(expected_calls)
diff --git a/tests/aiet/test_utils_proc.py b/tests/aiet/test_utils_proc.py
new file mode 100644
index 0000000..9fb48dd
--- /dev/null
+++ b/tests/aiet/test_utils_proc.py
@@ -0,0 +1,272 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+# pylint: disable=attribute-defined-outside-init,no-self-use,not-callable
+"""Pytests for testing aiet/utils/proc.py."""
+from pathlib import Path
+from typing import Any
+from unittest import mock
+
+import psutil
+import pytest
+from sh import ErrorReturnCode
+
+from aiet.utils.proc import Command
+from aiet.utils.proc import CommandFailedException
+from aiet.utils.proc import CommandNotFound
+from aiet.utils.proc import parse_command
+from aiet.utils.proc import print_command_stdout
+from aiet.utils.proc import run_and_wait
+from aiet.utils.proc import save_process_info
+from aiet.utils.proc import ShellCommand
+from aiet.utils.proc import terminate_command
+from aiet.utils.proc import terminate_external_process
+
+
+class TestShellCommand:
+ """Sample class for collecting tests."""
+
+ def test_shellcommand_default_value(self) -> None:
+ """Test the instantiation of the class ShellCommand with no parameter."""
+ shell_command = ShellCommand()
+ assert shell_command.base_log_path == "/tmp"
+
+ @pytest.mark.parametrize(
+ "base_log_path,expected", [("/test", "/test"), ("/asd", "/asd")]
+ )
+ def test_shellcommand_with_param(self, base_log_path: str, expected: str) -> None:
+ """Test init ShellCommand with different parameters."""
+ shell_command = ShellCommand(base_log_path)
+ assert shell_command.base_log_path == expected
+
+ def test_run_ls(self, monkeypatch: Any) -> None:
+ """Test a simple ls command."""
+ mock_command = mock.MagicMock()
+ monkeypatch.setattr(Command, "bake", mock_command)
+
+ mock_get_stdout_stderr_paths = mock.MagicMock()
+ mock_get_stdout_stderr_paths.return_value = ("/tmp/std.out", "/tmp/std.err")
+ monkeypatch.setattr(
+ ShellCommand, "get_stdout_stderr_paths", mock_get_stdout_stderr_paths
+ )
+
+ shell_command = ShellCommand()
+ shell_command.run("ls", "-l")
+ assert mock_command.mock_calls[0] == mock.call(("-l",))
+ assert mock_command.mock_calls[1] == mock.call()(
+ _bg=True, _err="/tmp/std.err", _out="/tmp/std.out", _tee=True, _bg_exc=False
+ )
+
+ def test_run_command_not_found(self) -> None:
+ """Test whe the command doesn't exist."""
+ shell_command = ShellCommand()
+ with pytest.raises(CommandNotFound):
+ shell_command.run("lsl", "-l")
+
+ def test_get_stdout_stderr_paths_valid_path(self) -> None:
+ """Test the method to get files to store stdout and stderr."""
+ valid_path = "/tmp"
+ shell_command = ShellCommand(valid_path)
+ out, err = shell_command.get_stdout_stderr_paths(valid_path, "cmd")
+ assert out.exists() and out.is_file()
+ assert err.exists() and err.is_file()
+ assert "cmd" in out.name
+ assert "cmd" in err.name
+
+ def test_get_stdout_stderr_paths_not_invalid_path(self) -> None:
+ """Test the method to get output files with an invalid path."""
+ invalid_path = "/invalid/foo/bar"
+ shell_command = ShellCommand(invalid_path)
+ with pytest.raises(FileNotFoundError):
+ shell_command.get_stdout_stderr_paths(invalid_path, "cmd")
+
+
+@mock.patch("builtins.print")
+def test_print_command_stdout_alive(mock_print: Any) -> None:
+ """Test the print command stdout with an alive (running) process."""
+ mock_command = mock.MagicMock()
+ mock_command.is_alive.return_value = True
+ mock_command.next.side_effect = ["test1", "test2", StopIteration]
+
+ print_command_stdout(mock_command)
+
+ mock_command.assert_has_calls(
+ [mock.call.is_alive(), mock.call.next(), mock.call.next()]
+ )
+ mock_print.assert_has_calls(
+ [mock.call("test1", end=""), mock.call("test2", end="")]
+ )
+
+
+@mock.patch("builtins.print")
+def test_print_command_stdout_not_alive(mock_print: Any) -> None:
+ """Test the print command stdout with a not alive (exited) process."""
+ mock_command = mock.MagicMock()
+ mock_command.is_alive.return_value = False
+ mock_command.stdout = "test"
+
+ print_command_stdout(mock_command)
+ mock_command.assert_has_calls([mock.call.is_alive()])
+ mock_print.assert_called_once_with("test")
+
+
+def test_terminate_external_process_no_process(capsys: Any) -> None:
+ """Test that non existed process could be terminated."""
+ mock_command = mock.MagicMock()
+ mock_command.terminate.side_effect = psutil.Error("Error!")
+
+ terminate_external_process(mock_command)
+ captured = capsys.readouterr()
+ assert captured.out == "Unable to terminate process\n"
+
+
+def test_terminate_external_process_case1() -> None:
+ """Test when process terminated immediately."""
+ mock_command = mock.MagicMock()
+ mock_command.is_running.return_value = False
+
+ terminate_external_process(mock_command)
+ mock_command.terminate.assert_called_once()
+ mock_command.is_running.assert_called_once()
+
+
+def test_terminate_external_process_case2() -> None:
+ """Test when process termination takes time."""
+ mock_command = mock.MagicMock()
+ mock_command.is_running.side_effect = [True, True, False]
+
+ terminate_external_process(mock_command)
+ mock_command.terminate.assert_called_once()
+ assert mock_command.is_running.call_count == 3
+
+
+def test_terminate_external_process_case3() -> None:
+ """Test when process termination takes more time."""
+ mock_command = mock.MagicMock()
+ mock_command.is_running.side_effect = [True, True, True]
+
+ terminate_external_process(
+ mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1
+ )
+ assert mock_command.is_running.call_count == 3
+ assert mock_command.terminate.call_count == 2
+
+
+def test_terminate_external_process_case4() -> None:
+ """Test when process termination takes more time."""
+ mock_command = mock.MagicMock()
+ mock_command.is_running.side_effect = [True, True, False]
+
+ terminate_external_process(
+ mock_command, number_of_attempts=2, wait_period=0.1, wait_for_termination=0.1
+ )
+ mock_command.terminate.assert_called_once()
+ assert mock_command.is_running.call_count == 3
+ assert mock_command.terminate.call_count == 1
+
+
+def test_terminate_command_no_process() -> None:
+ """Test command termination when process does not exist."""
+ mock_command = mock.MagicMock()
+ mock_command.process.signal_group.side_effect = ProcessLookupError()
+
+ terminate_command(mock_command)
+ mock_command.process.signal_group.assert_called_once()
+ mock_command.is_alive.assert_not_called()
+
+
+def test_terminate_command() -> None:
+ """Test command termination."""
+ mock_command = mock.MagicMock()
+ mock_command.is_alive.return_value = False
+
+ terminate_command(mock_command)
+ mock_command.process.signal_group.assert_called_once()
+
+
+def test_terminate_command_case1() -> None:
+ """Test command termination when it takes time.."""
+ mock_command = mock.MagicMock()
+ mock_command.is_alive.side_effect = [True, True, False]
+
+ terminate_command(mock_command, wait_period=0.1)
+ mock_command.process.signal_group.assert_called_once()
+ assert mock_command.is_alive.call_count == 3
+
+
+def test_terminate_command_case2() -> None:
+ """Test command termination when it takes much time.."""
+ mock_command = mock.MagicMock()
+ mock_command.is_alive.side_effect = [True, True, True]
+
+ terminate_command(mock_command, number_of_attempts=3, wait_period=0.1)
+ assert mock_command.is_alive.call_count == 3
+ assert mock_command.process.signal_group.call_count == 2
+
+
+class TestRunAndWait:
+ """Test run_and_wait function."""
+
+ @pytest.fixture(autouse=True)
+ def setup_method(self, monkeypatch: Any) -> None:
+ """Init test method."""
+ self.execute_command_mock = mock.MagicMock()
+ monkeypatch.setattr(
+ "aiet.utils.proc.execute_command", self.execute_command_mock
+ )
+
+ self.terminate_command_mock = mock.MagicMock()
+ monkeypatch.setattr(
+ "aiet.utils.proc.terminate_command", self.terminate_command_mock
+ )
+
+ def test_if_execute_command_raises_exception(self) -> None:
+ """Test if execute_command fails."""
+ self.execute_command_mock.side_effect = Exception("Error!")
+ with pytest.raises(Exception, match="Error!"):
+ run_and_wait("command", Path.cwd())
+
+ def test_if_command_finishes_with_error(self) -> None:
+ """Test if command finishes with error."""
+ cmd_mock = mock.MagicMock()
+ self.execute_command_mock.return_value = cmd_mock
+ exit_code_mock = mock.PropertyMock(
+ side_effect=ErrorReturnCode("cmd", bytearray(), bytearray())
+ )
+ type(cmd_mock).exit_code = exit_code_mock
+
+ with pytest.raises(CommandFailedException):
+ run_and_wait("command", Path.cwd())
+
+ @pytest.mark.parametrize("terminate_on_error, call_count", ((False, 0), (True, 1)))
+ def test_if_command_finishes_with_exception(
+ self, terminate_on_error: bool, call_count: int
+ ) -> None:
+ """Test if command finishes with error."""
+ cmd_mock = mock.MagicMock()
+ self.execute_command_mock.return_value = cmd_mock
+ exit_code_mock = mock.PropertyMock(side_effect=Exception("Error!"))
+ type(cmd_mock).exit_code = exit_code_mock
+
+ with pytest.raises(Exception, match="Error!"):
+ run_and_wait("command", Path.cwd(), terminate_on_error=terminate_on_error)
+
+ assert self.terminate_command_mock.call_count == call_count
+
+
+def test_save_process_info_no_process(monkeypatch: Any, tmpdir: Any) -> None:
+ """Test save_process_info function."""
+ mock_process = mock.MagicMock()
+ monkeypatch.setattr("psutil.Process", mock.MagicMock(return_value=mock_process))
+ mock_process.children.side_effect = psutil.NoSuchProcess(555)
+
+ pid_file_path = Path(tmpdir) / "test.pid"
+ save_process_info(555, pid_file_path)
+ assert not pid_file_path.exists()
+
+
+def test_parse_command() -> None:
+ """Test parse_command function."""
+ assert parse_command("1.sh") == ["bash", "1.sh"]
+ assert parse_command("1.sh", shell="sh") == ["sh", "1.sh"]
+ assert parse_command("command") == ["command"]
+ assert parse_command("command 123 --param=1") == ["command", "123", "--param=1"]
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..5c6156c
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,95 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Pytest conf module."""
+import shutil
+from pathlib import Path
+from typing import Generator
+
+import pytest
+import tensorflow as tf
+
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import save_keras_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+from mlia.tools.vela_wrapper import optimize_model
+
+
+def get_test_keras_model() -> tf.keras.Model:
+ """Return test Keras model."""
+ model = tf.keras.Sequential(
+ [
+ tf.keras.Input(shape=(28, 28, 1), batch_size=1, name="input"),
+ tf.keras.layers.Reshape((28, 28, 1)),
+ tf.keras.layers.Conv2D(
+ filters=12, kernel_size=(3, 3), activation="relu", name="conv1"
+ ),
+ tf.keras.layers.Conv2D(
+ filters=12, kernel_size=(3, 3), activation="relu", name="conv2"
+ ),
+ tf.keras.layers.MaxPool2D(2, 2),
+ tf.keras.layers.Flatten(),
+ tf.keras.layers.Dense(10, name="output"),
+ ]
+ )
+
+ model.compile(optimizer="sgd", loss="mean_squared_error")
+ return model
+
+
+@pytest.fixture(scope="session", name="test_models_path")
+def fixture_test_models_path(
+ tmp_path_factory: pytest.TempPathFactory,
+) -> Generator[Path, None, None]:
+ """Provide path to the test models."""
+ tmp_path = tmp_path_factory.mktemp("models")
+
+ keras_model = get_test_keras_model()
+ save_keras_model(keras_model, tmp_path / "test_model.h5")
+
+ tflite_model = convert_to_tflite(keras_model, quantized=True)
+ tflite_model_path = tmp_path / "test_model.tflite"
+ save_tflite_model(tflite_model, tflite_model_path)
+
+ tflite_vela_model = tmp_path / "test_model_vela.tflite"
+ device = EthosUConfiguration("ethos-u55-256")
+ optimize_model(tflite_model_path, device.compiler_options, tflite_vela_model)
+
+ tf.saved_model.save(keras_model, str(tmp_path / "tf_model_test_model"))
+
+ invalid_tflite_model = tmp_path / "invalid.tflite"
+ invalid_tflite_model.touch()
+
+ yield tmp_path
+
+ shutil.rmtree(tmp_path)
+
+
+@pytest.fixture(scope="session", name="test_keras_model")
+def fixture_test_keras_model(test_models_path: Path) -> Path:
+ """Return test Keras model."""
+ return test_models_path / "test_model.h5"
+
+
+@pytest.fixture(scope="session", name="test_tflite_model")
+def fixture_test_tflite_model(test_models_path: Path) -> Path:
+ """Return test TFLite model."""
+ return test_models_path / "test_model.tflite"
+
+
+@pytest.fixture(scope="session", name="test_tflite_vela_model")
+def fixture_test_tflite_vela_model(test_models_path: Path) -> Path:
+ """Return test Vela-optimized TFLite model."""
+ return test_models_path / "test_model_vela.tflite"
+
+
+@pytest.fixture(scope="session", name="test_tf_model")
+def fixture_test_tf_model(test_models_path: Path) -> Path:
+ """Return test TFLite model."""
+ return test_models_path / "tf_model_test_model"
+
+
+@pytest.fixture(scope="session", name="test_tflite_invalid_model")
+def fixture_test_tflite_invalid_model(test_models_path: Path) -> Path:
+ """Return test invalid TFLite model."""
+ return test_models_path / "invalid.tflite"
diff --git a/tests/mlia/__init__.py b/tests/mlia/__init__.py
new file mode 100644
index 0000000..0687f14
--- /dev/null
+++ b/tests/mlia/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""MLIA tests module."""
diff --git a/tests/mlia/conftest.py b/tests/mlia/conftest.py
new file mode 100644
index 0000000..f683fca
--- /dev/null
+++ b/tests/mlia/conftest.py
@@ -0,0 +1,20 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Pytest conf module."""
+from pathlib import Path
+
+import pytest
+
+from mlia.core.context import ExecutionContext
+
+
+@pytest.fixture(scope="session", name="test_resources_path")
+def fixture_test_resources_path() -> Path:
+ """Return test resources path."""
+ return Path(__file__).parent / "test_resources"
+
+
+@pytest.fixture(name="dummy_context")
+def fixture_dummy_context(tmpdir: str) -> ExecutionContext:
+ """Return dummy context fixture."""
+ return ExecutionContext(working_dir=tmpdir)
diff --git a/tests/mlia/test_api.py b/tests/mlia/test_api.py
new file mode 100644
index 0000000..54d4796
--- /dev/null
+++ b/tests/mlia/test_api.py
@@ -0,0 +1,96 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the API functions."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.api import get_advice
+from mlia.core.common import AdviceCategory
+from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
+
+
+def test_get_advice_no_target_provided(test_keras_model: Path) -> None:
+ """Test getting advice when no target provided."""
+ with pytest.raises(Exception, match="Target is not provided"):
+ get_advice(None, test_keras_model, "all") # type: ignore
+
+
+def test_get_advice_wrong_category(test_keras_model: Path) -> None:
+ """Test getting advice when wrong advice category provided."""
+ with pytest.raises(Exception, match="Invalid advice category unknown"):
+ get_advice("ethos-u55-256", test_keras_model, "unknown") # type: ignore
+
+
+@pytest.mark.parametrize(
+ "category, context, expected_category",
+ [
+ [
+ "all",
+ None,
+ AdviceCategory.ALL,
+ ],
+ [
+ "optimization",
+ None,
+ AdviceCategory.OPTIMIZATION,
+ ],
+ [
+ "operators",
+ None,
+ AdviceCategory.OPERATORS,
+ ],
+ [
+ "performance",
+ None,
+ AdviceCategory.PERFORMANCE,
+ ],
+ [
+ "all",
+ ExecutionContext(),
+ AdviceCategory.ALL,
+ ],
+ [
+ "all",
+ ExecutionContext(advice_category=AdviceCategory.PERFORMANCE),
+ AdviceCategory.PERFORMANCE,
+ ],
+ [
+ "all",
+ ExecutionContext(config_parameters={"param": "value"}),
+ AdviceCategory.ALL,
+ ],
+ [
+ "all",
+ ExecutionContext(event_handlers=[MagicMock()]),
+ AdviceCategory.ALL,
+ ],
+ ],
+)
+def test_get_advice(
+ monkeypatch: pytest.MonkeyPatch,
+ category: str,
+ context: ExecutionContext,
+ expected_category: AdviceCategory,
+ test_keras_model: Path,
+) -> None:
+ """Test getting advice with valid parameters."""
+ advisor_mock = MagicMock()
+ monkeypatch.setattr("mlia.api._get_advisor", MagicMock(return_value=advisor_mock))
+
+ get_advice(
+ "ethos-u55-256",
+ test_keras_model,
+ category, # type: ignore
+ context=context,
+ )
+
+ advisor_mock.run.assert_called_once()
+ context = advisor_mock.run.mock_calls[0].args[0]
+ assert isinstance(context, Context)
+ assert context.advice_category == expected_category
+
+ assert context.event_handlers is not None
+ assert context.config_parameters is not None
diff --git a/tests/mlia/test_cli_commands.py b/tests/mlia/test_cli_commands.py
new file mode 100644
index 0000000..bf17339
--- /dev/null
+++ b/tests/mlia/test_cli_commands.py
@@ -0,0 +1,204 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for cli.commands module."""
+from pathlib import Path
+from typing import Any
+from typing import Optional
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.cli.commands import backend
+from mlia.cli.commands import operators
+from mlia.cli.commands import optimization
+from mlia.cli.commands import performance
+from mlia.core.context import ExecutionContext
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.tools.metadata.common import InstallationManager
+
+
+def test_operators_expected_parameters(dummy_context: ExecutionContext) -> None:
+ """Test operators command wrong parameters."""
+ with pytest.raises(Exception, match="Model is not provided"):
+ operators(dummy_context, "ethos-u55-256")
+
+
+def test_performance_unknown_target(
+ dummy_context: ExecutionContext, test_tflite_model: Path
+) -> None:
+ """Test that command should fail if unknown target passed."""
+ with pytest.raises(Exception, match="Unable to find target profile unknown"):
+ performance(
+ dummy_context, model=str(test_tflite_model), target_profile="unknown"
+ )
+
+
+@pytest.mark.parametrize(
+ "target_profile, optimization_type, optimization_target, expected_error",
+ [
+ [
+ "ethos-u55-256",
+ None,
+ "0.5",
+ pytest.raises(Exception, match="Optimization type is not provided"),
+ ],
+ [
+ "ethos-u65-512",
+ "unknown",
+ "16",
+ pytest.raises(Exception, match="Unsupported optimization type: unknown"),
+ ],
+ [
+ "ethos-u55-256",
+ "pruning",
+ None,
+ pytest.raises(Exception, match="Optimization target is not provided"),
+ ],
+ [
+ "ethos-u65-512",
+ "clustering",
+ None,
+ pytest.raises(Exception, match="Optimization target is not provided"),
+ ],
+ [
+ "unknown",
+ "clustering",
+ "16",
+ pytest.raises(Exception, match="Unable to find target profile unknown"),
+ ],
+ ],
+)
+def test_opt_expected_parameters(
+ dummy_context: ExecutionContext,
+ target_profile: str,
+ monkeypatch: pytest.MonkeyPatch,
+ optimization_type: str,
+ optimization_target: str,
+ expected_error: Any,
+ test_keras_model: Path,
+) -> None:
+ """Test that command should fail if no or unknown optimization type provided."""
+ mock_performance_estimation(monkeypatch)
+
+ with expected_error:
+ optimization(
+ ctx=dummy_context,
+ target_profile=target_profile,
+ model=str(test_keras_model),
+ optimization_type=optimization_type,
+ optimization_target=optimization_target,
+ )
+
+
+@pytest.mark.parametrize(
+ "target_profile, optimization_type, optimization_target",
+ [
+ ["ethos-u55-256", "pruning", "0.5"],
+ ["ethos-u65-512", "clustering", "32"],
+ ["ethos-u55-256", "pruning,clustering", "0.5,32"],
+ ],
+)
+def test_opt_valid_optimization_target(
+ target_profile: str,
+ dummy_context: ExecutionContext,
+ optimization_type: str,
+ optimization_target: str,
+ monkeypatch: pytest.MonkeyPatch,
+ test_keras_model: Path,
+) -> None:
+ """Test that command should not fail with valid optimization targets."""
+ mock_performance_estimation(monkeypatch)
+
+ optimization(
+ ctx=dummy_context,
+ target_profile=target_profile,
+ model=str(test_keras_model),
+ optimization_type=optimization_type,
+ optimization_target=optimization_target,
+ )
+
+
+def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Mock performance estimation."""
+ metrics = PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ MemoryUsage(1, 2, 3, 4, 5),
+ )
+ monkeypatch.setattr(
+ "mlia.devices.ethosu.data_collection.EthosUPerformanceEstimator.estimate",
+ MagicMock(return_value=metrics),
+ )
+
+
+@pytest.fixture(name="installation_manager_mock")
+def fixture_mock_installation_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
+ """Mock installation manager."""
+ install_manager_mock = MagicMock(spec=InstallationManager)
+ monkeypatch.setattr(
+ "mlia.cli.commands.get_installation_manager",
+ MagicMock(return_value=install_manager_mock),
+ )
+ return install_manager_mock
+
+
+def test_backend_command_action_status(installation_manager_mock: MagicMock) -> None:
+ """Test backend command "status"."""
+ backend(backend_action="status")
+
+ installation_manager_mock.show_env_details.assert_called_once()
+
+
+@pytest.mark.parametrize(
+ "i_agree_to_the_contained_eula, backend_name, expected_calls",
+ [
+ [False, None, [call(None, True)]],
+ [True, None, [call(None, False)]],
+ [False, "backend_name", [call("backend_name", True)]],
+ [True, "backend_name", [call("backend_name", False)]],
+ ],
+)
+def test_backend_command_action_add_downoad(
+ installation_manager_mock: MagicMock,
+ i_agree_to_the_contained_eula: bool,
+ backend_name: Optional[str],
+ expected_calls: Any,
+) -> None:
+ """Test backend command "install" with download option."""
+ backend(
+ backend_action="install",
+ download=True,
+ name=backend_name,
+ i_agree_to_the_contained_eula=i_agree_to_the_contained_eula,
+ )
+
+ assert installation_manager_mock.download_and_install.mock_calls == expected_calls
+
+
+@pytest.mark.parametrize("backend_name", [None, "backend_name"])
+def test_backend_command_action_install_from_path(
+ installation_manager_mock: MagicMock,
+ tmp_path: Path,
+ backend_name: Optional[str],
+) -> None:
+ """Test backend command "install" with backend path."""
+ backend(backend_action="install", path=tmp_path, name=backend_name)
+
+ installation_manager_mock.install_from(tmp_path, backend_name)
+
+
+def test_backend_command_action_install_only_one_action(
+ installation_manager_mock: MagicMock, # pylint: disable=unused-argument
+ tmp_path: Path,
+) -> None:
+ """Test that only one of action type allowed."""
+ with pytest.raises(
+ Exception,
+ match="Please select only one action: download or "
+ "provide path to the backend installation",
+ ):
+ backend(backend_action="install", download=True, path=tmp_path)
diff --git a/tests/mlia/test_cli_config.py b/tests/mlia/test_cli_config.py
new file mode 100644
index 0000000..6d19eec
--- /dev/null
+++ b/tests/mlia/test_cli_config.py
@@ -0,0 +1,49 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for cli.config module."""
+from typing import List
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.cli.config import get_default_backends
+from mlia.cli.config import is_corstone_backend
+
+
+@pytest.mark.parametrize(
+ "available_backends, expected_default_backends",
+ [
+ [["Vela"], ["Vela"]],
+ [["Corstone-300"], ["Corstone-300"]],
+ [["Corstone-310"], ["Corstone-310"]],
+ [["Corstone-300", "Corstone-310"], ["Corstone-310"]],
+ [["Vela", "Corstone-300", "Corstone-310"], ["Vela", "Corstone-310"]],
+ [
+ ["Vela", "Corstone-300", "Corstone-310", "New backend"],
+ ["Vela", "Corstone-310", "New backend"],
+ ],
+ [
+ ["Vela", "Corstone-300", "New backend"],
+ ["Vela", "Corstone-300", "New backend"],
+ ],
+ ],
+)
+def test_get_default_backends(
+ monkeypatch: pytest.MonkeyPatch,
+ available_backends: List[str],
+ expected_default_backends: List[str],
+) -> None:
+ """Test function get_default backends."""
+ monkeypatch.setattr(
+ "mlia.cli.config.get_available_backends",
+ MagicMock(return_value=available_backends),
+ )
+
+ assert get_default_backends() == expected_default_backends
+
+
+def test_is_corstone_backend() -> None:
+ """Test function is_corstone_backend."""
+ assert is_corstone_backend("Corstone-300") is True
+ assert is_corstone_backend("Corstone-310") is True
+ assert is_corstone_backend("New backend") is False
diff --git a/tests/mlia/test_cli_helpers.py b/tests/mlia/test_cli_helpers.py
new file mode 100644
index 0000000..2c52885
--- /dev/null
+++ b/tests/mlia/test_cli_helpers.py
@@ -0,0 +1,165 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the helper classes."""
+from typing import Any
+from typing import Dict
+from typing import List
+
+import pytest
+
+from mlia.cli.helpers import CLIActionResolver
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+class TestCliActionResolver:
+ """Test cli action resolver."""
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "args, params, expected_result",
+ [
+ [
+ {},
+ {"opt_settings": "some_setting"},
+ [],
+ ],
+ [
+ {},
+ {},
+ [
+ "Note: you will need a Keras model for that.",
+ "For example: mlia optimization --optimization-type "
+ "pruning,clustering --optimization-target 0.5,32 "
+ "/path/to/keras_model",
+ "For more info: mlia optimization --help",
+ ],
+ ],
+ [
+ {"model": "model.h5"},
+ {},
+ [
+ "For example: mlia optimization --optimization-type "
+ "pruning,clustering --optimization-target 0.5,32 model.h5",
+ "For more info: mlia optimization --help",
+ ],
+ ],
+ [
+ {"model": "model.h5"},
+ {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]},
+ [
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ "mlia optimization --optimization-type pruning "
+ "--optimization-target 0.5 model.h5",
+ ],
+ ],
+ [
+ {"model": "model.h5", "target_profile": "target_profile"},
+ {"opt_settings": [OptimizationSettings("pruning", 0.5, None)]},
+ [
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ "mlia optimization --optimization-type pruning "
+ "--optimization-target 0.5 "
+ "--target-profile target_profile model.h5",
+ ],
+ ],
+ ],
+ )
+ def test_apply_optimizations(
+ args: Dict[str, Any],
+ params: Dict[str, Any],
+ expected_result: List[str],
+ ) -> None:
+ """Test action resolving for applying optimizations."""
+ resolver = CLIActionResolver(args)
+ assert resolver.apply_optimizations(**params) == expected_result
+
+ @staticmethod
+ def test_supported_operators_info() -> None:
+ """Test supported operators info."""
+ resolver = CLIActionResolver({})
+ assert resolver.supported_operators_info() == [
+ "For guidance on supported operators, run: mlia operators "
+ "--supported-ops-report",
+ ]
+
+ @staticmethod
+ def test_operator_compatibility_details() -> None:
+ """Test operator compatibility details info."""
+ resolver = CLIActionResolver({})
+ assert resolver.operator_compatibility_details() == [
+ "For more details, run: mlia operators --help"
+ ]
+
+ @staticmethod
+ def test_optimization_details() -> None:
+ """Test optimization details info."""
+ resolver = CLIActionResolver({})
+ assert resolver.optimization_details() == [
+ "For more info, see: mlia optimization --help"
+ ]
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "args, expected_result",
+ [
+ [
+ {},
+ [],
+ ],
+ [
+ {"model": "model.tflite"},
+ [
+ "Check the estimated performance by running the "
+ "following command: ",
+ "mlia performance model.tflite",
+ ],
+ ],
+ [
+ {"model": "model.tflite", "target_profile": "target_profile"},
+ [
+ "Check the estimated performance by running the "
+ "following command: ",
+ "mlia performance --target-profile target_profile model.tflite",
+ ],
+ ],
+ ],
+ )
+ def test_check_performance(
+ args: Dict[str, Any], expected_result: List[str]
+ ) -> None:
+ """Test check performance info."""
+ resolver = CLIActionResolver(args)
+ assert resolver.check_performance() == expected_result
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "args, expected_result",
+ [
+ [
+ {},
+ [],
+ ],
+ [
+ {"model": "model.tflite"},
+ [
+ "Try running the following command to verify that:",
+ "mlia operators model.tflite",
+ ],
+ ],
+ [
+ {"model": "model.tflite", "target_profile": "target_profile"},
+ [
+ "Try running the following command to verify that:",
+ "mlia operators --target-profile target_profile model.tflite",
+ ],
+ ],
+ ],
+ )
+ def test_check_operator_compatibility(
+ args: Dict[str, Any], expected_result: List[str]
+ ) -> None:
+ """Test checking operator compatibility info."""
+ resolver = CLIActionResolver(args)
+ assert resolver.check_operator_compatibility() == expected_result
diff --git a/tests/mlia/test_cli_logging.py b/tests/mlia/test_cli_logging.py
new file mode 100644
index 0000000..7c5f299
--- /dev/null
+++ b/tests/mlia/test_cli_logging.py
@@ -0,0 +1,104 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module cli.logging."""
+import logging
+from pathlib import Path
+from typing import Optional
+
+import pytest
+
+from mlia.cli.logging import setup_logging
+from tests.mlia.utils.logging import clear_loggers
+
+
+def teardown_function() -> None:
+ """Perform action after test completion.
+
+ This function is launched automatically by pytest after each test
+ in this module.
+ """
+ clear_loggers()
+
+
+@pytest.mark.parametrize(
+ "logs_dir, verbose, expected_output, expected_log_file_content",
+ [
+ (
+ None,
+ None,
+ "cli info\n",
+ None,
+ ),
+ (
+ None,
+ True,
+ """mlia.tools.aiet_wrapper - aiet debug
+cli info
+mlia.cli - cli debug
+""",
+ None,
+ ),
+ (
+ "logs",
+ True,
+ """mlia.tools.aiet_wrapper - aiet debug
+cli info
+mlia.cli - cli debug
+""",
+ """mlia.tools.aiet_wrapper - DEBUG - aiet debug
+mlia.cli - DEBUG - cli debug
+""",
+ ),
+ ],
+)
+def test_setup_logging(
+ tmp_path: Path,
+ capfd: pytest.CaptureFixture,
+ logs_dir: str,
+ verbose: bool,
+ expected_output: str,
+ expected_log_file_content: str,
+) -> None:
+ """Test function setup_logging."""
+ logs_dir_path = tmp_path / logs_dir if logs_dir else None
+
+ setup_logging(logs_dir_path, verbose)
+
+ aiet_logger = logging.getLogger("mlia.tools.aiet_wrapper")
+ aiet_logger.debug("aiet debug")
+
+ cli_logger = logging.getLogger("mlia.cli")
+ cli_logger.info("cli info")
+ cli_logger.debug("cli debug")
+
+ stdout, _ = capfd.readouterr()
+ assert stdout == expected_output
+
+ check_log_assertions(logs_dir_path, expected_log_file_content)
+
+
+def check_log_assertions(
+ logs_dir_path: Optional[Path], expected_log_file_content: str
+) -> None:
+ """Test assertions for log file."""
+ if logs_dir_path is not None:
+ assert logs_dir_path.is_dir()
+
+ items = list(logs_dir_path.iterdir())
+ assert len(items) == 1
+
+ log_file_path = items[0]
+ assert log_file_path.is_file()
+
+ log_file_name = log_file_path.name
+ assert log_file_name == "mlia.log"
+
+ with open(log_file_path, encoding="utf-8") as log_file:
+ log_content = log_file.read()
+
+ expected_lines = expected_log_file_content.split("\n")
+ produced_lines = log_content.split("\n")
+
+ assert len(expected_lines) == len(produced_lines)
+ for expected, produced in zip(expected_lines, produced_lines):
+ assert expected in produced
diff --git a/tests/mlia/test_cli_main.py b/tests/mlia/test_cli_main.py
new file mode 100644
index 0000000..a0937d5
--- /dev/null
+++ b/tests/mlia/test_cli_main.py
@@ -0,0 +1,357 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for main module."""
+import argparse
+from functools import wraps
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import List
+from unittest.mock import ANY
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+import mlia
+from mlia.cli.main import CommandInfo
+from mlia.cli.main import main
+from mlia.core.context import ExecutionContext
+from tests.mlia.utils.logging import clear_loggers
+
+
+def teardown_function() -> None:
+ """Perform action after test completion.
+
+ This function is launched automatically by pytest after each test
+ in this module.
+ """
+ clear_loggers()
+
+
+def test_option_version(capfd: pytest.CaptureFixture) -> None:
+ """Test --version."""
+ with pytest.raises(SystemExit) as ex:
+ main(["--version"])
+
+ assert ex.type == SystemExit
+ assert ex.value.code == 0
+
+ stdout, stderr = capfd.readouterr()
+ assert len(stdout.splitlines()) == 1
+ assert stderr == ""
+
+
+@pytest.mark.parametrize(
+ "is_default, expected_command_help",
+ [(True, "Test command [default]"), (False, "Test command")],
+)
+def test_command_info(is_default: bool, expected_command_help: str) -> None:
+ """Test properties of CommandInfo object."""
+
+ def test_command() -> None:
+ """Test command."""
+
+ command_info = CommandInfo(test_command, ["test"], [], is_default)
+ assert command_info.command_name == "test_command"
+ assert command_info.command_name_and_aliases == ["test_command", "test"]
+ assert command_info.command_help == expected_command_help
+
+
+def test_default_command(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
+ """Test adding default command."""
+
+ def mock_command(
+ func_mock: MagicMock, name: str, with_working_dir: bool
+ ) -> Callable[..., None]:
+ """Mock cli command."""
+
+ def sample_cmd_1(*args: Any, **kwargs: Any) -> None:
+ """Sample command."""
+ func_mock(*args, **kwargs)
+
+ def sample_cmd_2(ctx: ExecutionContext, **kwargs: Any) -> None:
+ """Another sample command."""
+ func_mock(ctx=ctx, **kwargs)
+
+ ret_func = sample_cmd_2 if with_working_dir else sample_cmd_1
+ ret_func.__name__ = name
+
+ return ret_func # type: ignore
+
+ default_command = MagicMock()
+ non_default_command = MagicMock()
+
+ def default_command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for default command."""
+ parser.add_argument("--sample")
+ parser.add_argument("--default_arg", default="123")
+
+ def non_default_command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for non default command."""
+ parser.add_argument("--param")
+
+ monkeypatch.setattr(
+ "mlia.cli.main.get_commands",
+ MagicMock(
+ return_value=[
+ CommandInfo(
+ func=mock_command(default_command, "default_command", True),
+ aliases=["command1"],
+ opt_groups=[default_command_params],
+ is_default=True,
+ ),
+ CommandInfo(
+ func=mock_command(
+ non_default_command, "non_default_command", False
+ ),
+ aliases=["command2"],
+ opt_groups=[non_default_command_params],
+ is_default=False,
+ ),
+ ]
+ ),
+ )
+
+ tmp_working_dir = str(tmp_path)
+ main(["--working-dir", tmp_working_dir, "--sample", "1"])
+ main(["command2", "--param", "test"])
+
+ default_command.assert_called_once_with(ctx=ANY, sample="1", default_arg="123")
+ non_default_command.assert_called_once_with(param="test")
+
+
+@pytest.mark.parametrize(
+ "params, expected_call",
+ [
+ [
+ ["operators", "sample_model.tflite"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["ops", "sample_model.tflite"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators", "sample_model.tflite", "--target-profile", "ethos-u55-128"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-128",
+ model="sample_model.tflite",
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model=None,
+ output=None,
+ supported_ops_report=False,
+ ),
+ ],
+ [
+ ["operators", "--supported-ops-report"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model=None,
+ output=None,
+ supported_ops_report=True,
+ ),
+ ],
+ [
+ [
+ "all_tests",
+ "sample_model.h5",
+ "--optimization-type",
+ "pruning",
+ "--optimization-target",
+ "0.5",
+ ],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning",
+ optimization_target="0.5",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["sample_model.h5"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["performance", "sample_model.h5", "--output", "result.json"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ output="result.json",
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["perf", "sample_model.h5", "--target-profile", "ethos-u55-128"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-128",
+ model="sample_model.h5",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["optimization", "sample_model.h5"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["Vela"],
+ ),
+ ],
+ [
+ ["optimization", "sample_model.h5", "--evaluate-on", "some_backend"],
+ call(
+ ctx=ANY,
+ target_profile="ethos-u55-256",
+ model="sample_model.h5",
+ optimization_type="pruning,clustering",
+ optimization_target="0.5,32",
+ output=None,
+ evaluate_on=["some_backend"],
+ ),
+ ],
+ ],
+)
+def test_commands_execution(
+ monkeypatch: pytest.MonkeyPatch, params: List[str], expected_call: Any
+) -> None:
+ """Test calling commands from the main function."""
+ mock = MagicMock()
+
+ def wrap_mock_command(command: Callable) -> Callable:
+ """Wrap the command with the mock."""
+
+ @wraps(command)
+ def mock_command(*args: Any, **kwargs: Any) -> Any:
+ """Mock the command."""
+ mock(*args, **kwargs)
+
+ return mock_command
+
+ monkeypatch.setattr(
+ "mlia.cli.options.get_default_backends", MagicMock(return_value=["Vela"])
+ )
+
+ monkeypatch.setattr(
+ "mlia.cli.options.get_available_backends",
+ MagicMock(return_value=["Vela", "some_backend"]),
+ )
+
+ for command in ["all_tests", "operators", "performance", "optimization"]:
+ monkeypatch.setattr(
+ f"mlia.cli.main.{command}",
+ wrap_mock_command(getattr(mlia.cli.main, command)),
+ )
+
+ main(params)
+
+ mock.assert_called_once_with(*expected_call.args, **expected_call.kwargs)
+
+
+@pytest.mark.parametrize(
+ "verbose, exc_mock, expected_output",
+ [
+ [
+ True,
+ MagicMock(side_effect=Exception("Error")),
+ [
+ "Execution finished with error: Error",
+ f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
+ "for more details",
+ ],
+ ],
+ [
+ False,
+ MagicMock(side_effect=Exception("Error")),
+ [
+ "Execution finished with error: Error",
+ f"Please check the log files in the {Path.cwd()/'mlia_output/logs'} "
+ "for more details, or enable verbose mode",
+ ],
+ ],
+ [
+ False,
+ MagicMock(side_effect=KeyboardInterrupt()),
+ ["Execution has been interrupted"],
+ ],
+ ],
+)
+def test_verbose_output(
+ monkeypatch: pytest.MonkeyPatch,
+ capsys: pytest.CaptureFixture,
+ verbose: bool,
+ exc_mock: MagicMock,
+ expected_output: List[str],
+) -> None:
+ """Test flag --verbose."""
+
+ def command_params(parser: argparse.ArgumentParser) -> None:
+ """Add parameters for non default command."""
+ parser.add_argument("--verbose", action="store_true")
+
+ def command() -> None:
+ """Run test command."""
+ exc_mock()
+
+ monkeypatch.setattr(
+ "mlia.cli.main.get_commands",
+ MagicMock(
+ return_value=[
+ CommandInfo(
+ func=command,
+ aliases=["command"],
+ opt_groups=[command_params],
+ ),
+ ]
+ ),
+ )
+
+ params = ["command"]
+ if verbose:
+ params.append("--verbose")
+
+ exit_code = main(params)
+ assert exit_code == 1
+
+ stdout, _ = capsys.readouterr()
+ for expected_message in expected_output:
+ assert expected_message in stdout
diff --git a/tests/mlia/test_cli_options.py b/tests/mlia/test_cli_options.py
new file mode 100644
index 0000000..a441e58
--- /dev/null
+++ b/tests/mlia/test_cli_options.py
@@ -0,0 +1,186 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module options."""
+import argparse
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+import pytest
+
+from mlia.cli.options import add_output_options
+from mlia.cli.options import get_target_profile_opts
+from mlia.cli.options import parse_optimization_parameters
+
+
+@pytest.mark.parametrize(
+ "optimization_type, optimization_target, expected_error, expected_result",
+ [
+ (
+ "pruning",
+ "0.5",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ )
+ ],
+ ),
+ (
+ "clustering",
+ "32",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.0,
+ layers_to_optimize=None,
+ )
+ ],
+ ),
+ (
+ "pruning,clustering",
+ "0.5,32",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.0,
+ layers_to_optimize=None,
+ ),
+ ],
+ ),
+ (
+ "pruning, clustering",
+ "0.5, 32",
+ does_not_raise(),
+ [
+ dict(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ dict(
+ optimization_type="clustering",
+ optimization_target=32.0,
+ layers_to_optimize=None,
+ ),
+ ],
+ ),
+ (
+ "pruning,clustering",
+ "0.5",
+ pytest.raises(
+ Exception, match="Wrong number of optimization targets and types"
+ ),
+ None,
+ ),
+ (
+ "",
+ "0.5",
+ pytest.raises(Exception, match="Optimization type is not provided"),
+ None,
+ ),
+ (
+ "pruning,clustering",
+ "",
+ pytest.raises(Exception, match="Optimization target is not provided"),
+ None,
+ ),
+ (
+ "pruning,",
+ "0.5,abc",
+ pytest.raises(
+ Exception, match="Non numeric value for the optimization target"
+ ),
+ None,
+ ),
+ ],
+)
+def test_parse_optimization_parameters(
+ optimization_type: str,
+ optimization_target: str,
+ expected_error: Any,
+ expected_result: Any,
+) -> None:
+ """Test function parse_optimization_parameters."""
+ with expected_error:
+ result = parse_optimization_parameters(optimization_type, optimization_target)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "args, expected_opts",
+ [
+ [
+ {},
+ [],
+ ],
+ [
+ {"target_profile": "profile"},
+ ["--target-profile", "profile"],
+ ],
+ [
+ # for the default profile empty list should be returned
+ {"target": "ethos-u55-256"},
+ [],
+ ],
+ ],
+)
+def test_get_target_opts(args: Optional[Dict], expected_opts: List[str]) -> None:
+ """Test getting target options."""
+ assert get_target_profile_opts(args) == expected_opts
+
+
+@pytest.mark.parametrize(
+ "output_parameters, expected_path",
+ [
+ [["--output", "report.json"], "report.json"],
+ [["--output", "REPORT.JSON"], "REPORT.JSON"],
+ [["--output", "some_folder/report.json"], "some_folder/report.json"],
+ [["--output", "report.csv"], "report.csv"],
+ [["--output", "REPORT.CSV"], "REPORT.CSV"],
+ [["--output", "some_folder/report.csv"], "some_folder/report.csv"],
+ ],
+)
+def test_output_options(output_parameters: List[str], expected_path: str) -> None:
+ """Test output options resolving."""
+ parser = argparse.ArgumentParser()
+ add_output_options(parser)
+
+ args = parser.parse_args(output_parameters)
+ assert args.output == expected_path
+
+
+@pytest.mark.parametrize(
+ "output_filename",
+ [
+ "report.txt",
+ "report.TXT",
+ "report",
+ "report.pdf",
+ ],
+)
+def test_output_options_bad_parameters(
+ output_filename: str, capsys: pytest.CaptureFixture
+) -> None:
+ """Test that args parsing should fail if format is not supported."""
+ parser = argparse.ArgumentParser()
+ add_output_options(parser)
+
+ with pytest.raises(SystemExit):
+ parser.parse_args(["--output", output_filename])
+
+ err_output = capsys.readouterr().err
+ suffix = Path(output_filename).suffix[1:]
+ assert f"Unsupported format '{suffix}'" in err_output
diff --git a/tests/mlia/test_core_advice_generation.py b/tests/mlia/test_core_advice_generation.py
new file mode 100644
index 0000000..05db698
--- /dev/null
+++ b/tests/mlia/test_core_advice_generation.py
@@ -0,0 +1,71 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module advice_generation."""
+from typing import List
+
+import pytest
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import advice_category
+from mlia.core.advice_generation import FactBasedAdviceProducer
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.context import Context
+
+
+def test_advice_generation() -> None:
+ """Test advice generation."""
+
+ class SampleProducer(FactBasedAdviceProducer):
+ """Sample producer."""
+
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Process data."""
+ self.add_advice([f"Advice for {data_item}"])
+
+ producer = SampleProducer()
+ producer.produce_advice(123)
+ producer.produce_advice("hello")
+
+ advice = producer.get_advice()
+ assert advice == [Advice(["Advice for 123"]), Advice(["Advice for hello"])]
+
+
+@pytest.mark.parametrize(
+ "category, expected_advice",
+ [
+ [
+ AdviceCategory.OPERATORS,
+ [Advice(["Good advice!"])],
+ ],
+ [
+ AdviceCategory.PERFORMANCE,
+ [],
+ ],
+ ],
+)
+def test_advice_category_decorator(
+ category: AdviceCategory,
+ expected_advice: List[Advice],
+ dummy_context: Context,
+) -> None:
+ """Test for advice_category decorator."""
+
+ class SampleAdviceProducer(FactBasedAdviceProducer):
+ """Sample advice producer."""
+
+ @advice_category(AdviceCategory.OPERATORS)
+ def produce_advice(self, data_item: DataItem) -> None:
+ """Produce the advice."""
+ self.add_advice(["Good advice!"])
+
+ producer = SampleAdviceProducer()
+ dummy_context.update(
+ advice_category=category, event_handlers=[], config_parameters={}
+ )
+ producer.set_context(dummy_context)
+
+ producer.produce_advice("some_data")
+ advice = producer.get_advice()
+
+ assert advice == expected_advice
diff --git a/tests/mlia/test_core_advisor.py b/tests/mlia/test_core_advisor.py
new file mode 100644
index 0000000..375ff62
--- /dev/null
+++ b/tests/mlia/test_core_advisor.py
@@ -0,0 +1,40 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module advisor."""
+from unittest.mock import MagicMock
+
+from mlia.core.advisor import InferenceAdvisor
+from mlia.core.context import Context
+from mlia.core.workflow import WorkflowExecutor
+
+
+def test_inference_advisor_run() -> None:
+ """Test running sample inference advisor."""
+ executor_mock = MagicMock(spec=WorkflowExecutor)
+ context_mock = MagicMock(spec=Context)
+
+ class SampleAdvisor(InferenceAdvisor):
+ """Sample inference advisor."""
+
+ @classmethod
+ def name(cls) -> str:
+ """Return name of the advisor."""
+ return "sample_advisor"
+
+ @classmethod
+ def description(cls) -> str:
+ """Return description of the advisor."""
+ return "Sample advisor"
+
+ @classmethod
+ def info(cls) -> None:
+ """Print advisor info."""
+
+ def configure(self, context: Context) -> WorkflowExecutor:
+ """Configure advisor."""
+ return executor_mock
+
+ advisor = SampleAdvisor()
+ advisor.run(context_mock)
+
+ executor_mock.run.assert_called_once()
diff --git a/tests/mlia/test_core_context.py b/tests/mlia/test_core_context.py
new file mode 100644
index 0000000..10015aa
--- /dev/null
+++ b/tests/mlia/test_core_context.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module context."""
+from pathlib import Path
+
+from mlia.core.common import AdviceCategory
+from mlia.core.context import ExecutionContext
+from mlia.core.events import DefaultEventPublisher
+
+
+def test_execution_context(tmpdir: str) -> None:
+ """Test execution context."""
+ publisher = DefaultEventPublisher()
+ category = AdviceCategory.OPERATORS
+
+ context = ExecutionContext(
+ advice_category=category,
+ config_parameters={"param": "value"},
+ working_dir=tmpdir,
+ event_handlers=[],
+ event_publisher=publisher,
+ verbose=True,
+ logs_dir="logs_directory",
+ models_dir="models_directory",
+ )
+
+ assert context.advice_category == category
+ assert context.config_parameters == {"param": "value"}
+ assert context.event_handlers == []
+ assert context.event_publisher == publisher
+ assert context.logs_path == Path(tmpdir) / "logs_directory"
+ expected_model_path = Path(tmpdir) / "models_directory/sample.model"
+ assert context.get_model_path("sample.model") == expected_model_path
+ assert context.verbose is True
+ assert str(context) == (
+ f"ExecutionContext: "
+ f"working_dir={tmpdir}, "
+ "advice_category=OPERATORS, "
+ "config_parameters={'param': 'value'}, "
+ "verbose=True"
+ )
+
+ context_with_default_params = ExecutionContext(working_dir=tmpdir)
+ assert context_with_default_params.advice_category is None
+ assert context_with_default_params.config_parameters is None
+ assert context_with_default_params.event_handlers is None
+ assert isinstance(
+ context_with_default_params.event_publisher, DefaultEventPublisher
+ )
+ assert context_with_default_params.logs_path == Path(tmpdir) / "logs"
+
+ default_model_path = context_with_default_params.get_model_path("sample.model")
+ expected_default_model_path = Path(tmpdir) / "models/sample.model"
+ assert default_model_path == expected_default_model_path
+
+ expected_str = (
+ f"ExecutionContext: working_dir={tmpdir}, "
+ "advice_category=<not set>, "
+ "config_parameters=None, "
+ "verbose=False"
+ )
+ assert str(context_with_default_params) == expected_str
diff --git a/tests/mlia/test_core_data_analysis.py b/tests/mlia/test_core_data_analysis.py
new file mode 100644
index 0000000..a782159
--- /dev/null
+++ b/tests/mlia/test_core_data_analysis.py
@@ -0,0 +1,31 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module data_analysis."""
+from dataclasses import dataclass
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.core.data_analysis import FactExtractor
+
+
+def test_fact_extractor() -> None:
+ """Test fact extractor."""
+
+ @dataclass
+ class SampleFact(Fact):
+ """Sample fact."""
+
+ msg: str
+
+ class SampleExtractor(FactExtractor):
+ """Sample extractor."""
+
+ def analyze_data(self, data_item: DataItem) -> None:
+ self.add_fact(SampleFact(f"Fact for {data_item}"))
+
+ extractor = SampleExtractor()
+ extractor.analyze_data(42)
+ extractor.analyze_data("some data")
+
+ facts = extractor.get_analyzed_data()
+ assert facts == [SampleFact("Fact for 42"), SampleFact("Fact for some data")]
diff --git a/tests/mlia/test_core_events.py b/tests/mlia/test_core_events.py
new file mode 100644
index 0000000..faaab7c
--- /dev/null
+++ b/tests/mlia/test_core_events.py
@@ -0,0 +1,155 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module events."""
+from dataclasses import dataclass
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.events import action
+from mlia.core.events import ActionFinishedEvent
+from mlia.core.events import ActionStartedEvent
+from mlia.core.events import DebugEventHandler
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import Event
+from mlia.core.events import EventDispatcher
+from mlia.core.events import EventHandler
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.events import stage
+from mlia.core.events import SystemEventsHandler
+
+
+@dataclass
+class SampleEvent(Event):
+ """Sample event."""
+
+ msg: str
+
+
+def test_event_publisher() -> None:
+ """Test event publishing."""
+ publisher = DefaultEventPublisher()
+ handler_mock1 = MagicMock(spec=EventHandler)
+ handler_mock2 = MagicMock(spec=EventHandler)
+
+ publisher.register_event_handlers([handler_mock1, handler_mock2])
+
+ event = SampleEvent("hello, event!")
+ publisher.publish_event(event)
+
+ handler_mock1.handle_event.assert_called_once_with(event)
+ handler_mock2.handle_event.assert_called_once_with(event)
+
+
+def test_stage_context_manager() -> None:
+ """Test stage context manager."""
+ publisher = DefaultEventPublisher()
+
+ handler_mock = MagicMock(spec=EventHandler)
+ publisher.register_event_handler(handler_mock)
+
+ events = (SampleEvent("hello"), SampleEvent("goodbye"))
+ with stage(publisher, events):
+ print("perform actions")
+
+ assert handler_mock.handle_event.call_count == 2
+ calls = [call(event) for event in events]
+ handler_mock.handle_event.assert_has_calls(calls)
+
+
+def test_action_context_manager() -> None:
+ """Test action stage context manager."""
+ publisher = DefaultEventPublisher()
+
+ handler_mock = MagicMock(spec=EventHandler)
+ publisher.register_event_handler(handler_mock)
+
+ with action(publisher, "Sample action"):
+ print("perform actions")
+
+ assert handler_mock.handle_event.call_count == 2
+ calls = handler_mock.handle_event.mock_calls
+
+ action_started = calls[0].args[0]
+ action_finished = calls[1].args[0]
+
+ assert isinstance(action_started, ActionStartedEvent)
+ assert isinstance(action_finished, ActionFinishedEvent)
+
+ assert action_finished.parent_event_id == action_started.event_id
+
+
+def test_debug_event_handler(capsys: pytest.CaptureFixture) -> None:
+ """Test debugging event handler."""
+ publisher = DefaultEventPublisher()
+
+ publisher.register_event_handler(DebugEventHandler())
+ publisher.register_event_handler(DebugEventHandler(with_stacktrace=True))
+
+ messages = ["Sample event 1", "Sample event 2"]
+ for message in messages:
+ publisher.publish_event(SampleEvent(message))
+
+ captured = capsys.readouterr()
+ for message in messages:
+ assert message in captured.out
+
+ assert "traceback.print_stack" in captured.err
+
+
+def test_event_dispatcher(capsys: pytest.CaptureFixture) -> None:
+ """Test event dispatcher."""
+
+ class SampleEventHandler(EventDispatcher):
+ """Sample event handler."""
+
+ def on_sample_event( # pylint: disable=no-self-use
+ self, _event: SampleEvent
+ ) -> None:
+ """Event handler for SampleEvent."""
+ print("Got sample event")
+
+ publisher = DefaultEventPublisher()
+ publisher.register_event_handler(SampleEventHandler())
+ publisher.publish_event(SampleEvent("Sample event"))
+
+ captured = capsys.readouterr()
+ assert captured.out.strip() == "Got sample event"
+
+
+def test_system_events_handler(capsys: pytest.CaptureFixture) -> None:
+ """Test system events handler."""
+
+ class CustomSystemEventHandler(SystemEventsHandler):
+ """Custom system event handler."""
+
+ def on_execution_started(self, event: ExecutionStartedEvent) -> None:
+ """Handle ExecutionStarted event."""
+ print("Execution started")
+
+ def on_execution_finished(self, event: ExecutionFinishedEvent) -> None:
+ """Handle ExecutionFinished event."""
+ print("Execution finished")
+
+ publisher = DefaultEventPublisher()
+ publisher.register_event_handler(CustomSystemEventHandler())
+
+ publisher.publish_event(ExecutionStartedEvent())
+ publisher.publish_event(SampleEvent("Hello world!"))
+ publisher.publish_event(ExecutionFinishedEvent())
+
+ captured = capsys.readouterr()
+ assert captured.out.strip() == "Execution started\nExecution finished"
+
+
+def test_compare_without_id() -> None:
+ """Test event comparison without event_id."""
+ event1 = SampleEvent("message")
+ event2 = SampleEvent("message")
+
+ assert event1 != event2
+ assert event1.compare_without_id(event2)
+
+ assert not event1.compare_without_id("message") # type: ignore
diff --git a/tests/mlia/test_core_helpers.py b/tests/mlia/test_core_helpers.py
new file mode 100644
index 0000000..8577617
--- /dev/null
+++ b/tests/mlia/test_core_helpers.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the helper classes."""
+from mlia.core.helpers import APIActionResolver
+
+
+def test_api_action_resolver() -> None:
+ """Test APIActionResolver class."""
+ helper = APIActionResolver()
+
+ # pylint: disable=use-implicit-booleaness-not-comparison
+ assert helper.apply_optimizations() == []
+ assert helper.supported_operators_info() == []
+ assert helper.check_performance() == []
+ assert helper.check_operator_compatibility() == []
+ assert helper.operator_compatibility_details() == []
+ assert helper.optimization_details() == []
diff --git a/tests/mlia/test_core_mixins.py b/tests/mlia/test_core_mixins.py
new file mode 100644
index 0000000..d66213d
--- /dev/null
+++ b/tests/mlia/test_core_mixins.py
@@ -0,0 +1,99 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module mixins."""
+import pytest
+
+from mlia.core.common import AdviceCategory
+from mlia.core.context import Context
+from mlia.core.context import ExecutionContext
+from mlia.core.mixins import ContextMixin
+from mlia.core.mixins import ParameterResolverMixin
+
+
+def test_context_mixin(dummy_context: Context) -> None:
+ """Test ContextMixin."""
+
+ class SampleClass(ContextMixin):
+ """Sample class."""
+
+ sample_object = SampleClass()
+ sample_object.set_context(dummy_context)
+ assert sample_object.context == dummy_context
+
+
+class TestParameterResolverMixin:
+ """Tests for parameter resolver mixin."""
+
+ @staticmethod
+ def test_parameter_resolver_mixin(dummy_context: ExecutionContext) -> None:
+ """Test ParameterResolverMixin."""
+
+ class SampleClass(ParameterResolverMixin):
+ """Sample class."""
+
+ def __init__(self) -> None:
+ """Init sample object."""
+ self.context = dummy_context
+
+ self.context.update(
+ advice_category=AdviceCategory.OPERATORS,
+ event_handlers=[],
+ config_parameters={"section": {"param": 123}},
+ )
+
+ sample_object = SampleClass()
+ value = sample_object.get_parameter("section", "param")
+ assert value == 123
+
+ with pytest.raises(
+ Exception, match="Parameter param expected to have type <class 'str'>"
+ ):
+ value = sample_object.get_parameter("section", "param", expected_type=str)
+
+ with pytest.raises(Exception, match="Parameter no_param is not set"):
+ value = sample_object.get_parameter("section", "no_param")
+
+ @staticmethod
+ def test_parameter_resolver_mixin_no_config(
+ dummy_context: ExecutionContext,
+ ) -> None:
+ """Test ParameterResolverMixin without config params."""
+
+ class SampleClassNoConfig(ParameterResolverMixin):
+ """Sample context without config params."""
+
+ def __init__(self) -> None:
+ """Init sample object."""
+ self.context = dummy_context
+
+ with pytest.raises(Exception, match="Configuration parameters are not set"):
+ sample_object_no_config = SampleClassNoConfig()
+ sample_object_no_config.get_parameter("section", "param", expected_type=str)
+
+ @staticmethod
+ def test_parameter_resolver_mixin_bad_section(
+ dummy_context: ExecutionContext,
+ ) -> None:
+ """Test ParameterResolverMixin without config params."""
+
+ class SampleClassBadSection(ParameterResolverMixin):
+ """Sample context with bad section in config."""
+
+ def __init__(self) -> None:
+ """Init sample object."""
+ self.context = dummy_context
+ self.context.update(
+ advice_category=AdviceCategory.OPERATORS,
+ event_handlers=[],
+ config_parameters={"section": ["param"]},
+ )
+
+ with pytest.raises(
+ Exception,
+ match="Parameter section section has wrong format, "
+ "expected to be a dictionary",
+ ):
+ sample_object_bad_section = SampleClassBadSection()
+ sample_object_bad_section.get_parameter(
+ "section", "param", expected_type=str
+ )
diff --git a/tests/mlia/test_core_performance.py b/tests/mlia/test_core_performance.py
new file mode 100644
index 0000000..0d28fe8
--- /dev/null
+++ b/tests/mlia/test_core_performance.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module performance."""
+from pathlib import Path
+
+from mlia.core.performance import estimate_performance
+from mlia.core.performance import PerformanceEstimator
+
+
+def test_estimate_performance(tmp_path: Path) -> None:
+ """Test function estimate_performance."""
+ model_path = tmp_path / "original.tflite"
+
+ class SampleEstimator(PerformanceEstimator[Path, int]):
+ """Sample estimator."""
+
+ def estimate(self, model: Path) -> int:
+ """Estimate performance."""
+ if model.name == "original.tflite":
+ return 1
+
+ return 2
+
+ def optimized_model(_original: Path) -> Path:
+ """Return path to the 'optimized' model."""
+ return tmp_path / "optimized.tflite"
+
+ results = estimate_performance(model_path, SampleEstimator(), [optimized_model])
+ assert results == [1, 2]
diff --git a/tests/mlia/test_core_reporting.py b/tests/mlia/test_core_reporting.py
new file mode 100644
index 0000000..2f7ec22
--- /dev/null
+++ b/tests/mlia/test_core_reporting.py
@@ -0,0 +1,413 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for reporting module."""
+from typing import List
+
+import pytest
+
+from mlia.core.reporting import BytesCell
+from mlia.core.reporting import Cell
+from mlia.core.reporting import ClockCell
+from mlia.core.reporting import Column
+from mlia.core.reporting import CyclesCell
+from mlia.core.reporting import Format
+from mlia.core.reporting import NestedReport
+from mlia.core.reporting import ReportItem
+from mlia.core.reporting import SingleRow
+from mlia.core.reporting import Table
+from mlia.utils.console import remove_ascii_codes
+
+
+@pytest.mark.parametrize(
+ "cell, expected_repr",
+ [
+ (BytesCell(None), ""),
+ (BytesCell(0), "0 bytes"),
+ (BytesCell(1), "1 byte"),
+ (BytesCell(100000), "100,000 bytes"),
+ (ClockCell(None), ""),
+ (ClockCell(0), "0 Hz"),
+ (ClockCell(1), "1 Hz"),
+ (ClockCell(100000), "100,000 Hz"),
+ (CyclesCell(None), ""),
+ (CyclesCell(0), "0 cycles"),
+ (CyclesCell(1), "1 cycle"),
+ (CyclesCell(100000), "100,000 cycles"),
+ ],
+)
+def test_predefined_cell_types(cell: Cell, expected_repr: str) -> None:
+ """Test predefined cell types."""
+ assert str(cell) == expected_repr
+
+
+@pytest.mark.parametrize(
+ "with_notes, expected_text_report",
+ [
+ [
+ True,
+ """
+Sample table:
+┌──────────┬──────────┬──────────┐
+│ Header 1 │ Header 2 │ Header 3 │
+╞══════════╪══════════╪══════════╡
+│ 1 │ 2 │ 3 │
+├──────────┼──────────┼──────────┤
+│ 4 │ 5 │ 123,123 │
+└──────────┴──────────┴──────────┘
+Sample notes
+ """.strip(),
+ ],
+ [
+ False,
+ """
+Sample table:
+┌──────────┬──────────┬──────────┐
+│ Header 1 │ Header 2 │ Header 3 │
+╞══════════╪══════════╪══════════╡
+│ 1 │ 2 │ 3 │
+├──────────┼──────────┼──────────┤
+│ 4 │ 5 │ 123,123 │
+└──────────┴──────────┴──────────┘
+ """.strip(),
+ ],
+ ],
+)
+def test_table_representation(with_notes: bool, expected_text_report: str) -> None:
+ """Test table report representation."""
+
+ def sample_table(with_notes: bool) -> Table:
+ columns = [
+ Column("Header 1", alias="header1", only_for=["plain_text"]),
+ Column("Header 2", alias="header2", fmt=Format(wrap_width=5)),
+ Column("Header 3", alias="header3"),
+ ]
+ rows = [(1, 2, 3), (4, 5, Cell(123123, fmt=Format(str_fmt="0,d")))]
+
+ return Table(
+ columns,
+ rows,
+ name="Sample table",
+ alias="sample_table",
+ notes="Sample notes" if with_notes else None,
+ )
+
+ table = sample_table(with_notes)
+ csv_repr = table.to_csv()
+ assert csv_repr == [["Header 2", "Header 3"], [2, 3], [5, 123123]]
+
+ json_repr = table.to_json()
+ assert json_repr == {
+ "sample_table": [
+ {"header2": 2, "header3": 3},
+ {"header2": 5, "header3": 123123},
+ ]
+ }
+
+ text_report = remove_ascii_codes(table.to_plain_text())
+ assert text_report == expected_text_report
+
+
+def test_csv_nested_table_representation() -> None:
+ """Test representation of the nested tables in csv format."""
+
+ def sample_table(num_of_cols: int) -> Table:
+ columns = [
+ Column("Header 1", alias="header1"),
+ Column("Header 2", alias="header2"),
+ ]
+
+ rows = [
+ (
+ 1,
+ Table(
+ columns=[
+ Column(f"Nested column {i+1}") for i in range(num_of_cols)
+ ],
+ rows=[[f"value{i+1}" for i in range(num_of_cols)]],
+ name="Nested table",
+ ),
+ )
+ ]
+
+ return Table(columns, rows, name="Sample table", alias="sample_table")
+
+ assert sample_table(num_of_cols=2).to_csv() == [
+ ["Header 1", "Header 2"],
+ [1, "value1;value2"],
+ ]
+
+ assert sample_table(num_of_cols=1).to_csv() == [
+ ["Header 1", "Header 2"],
+ [1, "value1"],
+ ]
+
+ assert sample_table(num_of_cols=0).to_csv() == [
+ ["Header 1", "Header 2"],
+ [1, ""],
+ ]
+
+
+@pytest.mark.parametrize(
+ "report, expected_plain_text, expected_json_data, expected_csv_data",
+ [
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem("Item", "item", "item_value"),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+""".strip(),
+ {
+ "sample_report": {"item": "item_value"},
+ },
+ [
+ ("item",),
+ ("item_value",),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [ReportItem("Nested item", "nested_item", "nested_item_value")],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item nested_item_value
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": "nested_item_value"},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", "nested_item_value"),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [ReportItem("Nested item", "nested_item", BytesCell(10))],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10 bytes
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": {"unit": "bytes", "value": 10}},
+ },
+ },
+ [
+ ("item", "nested_item_value", "nested_item_unit"),
+ ("item_value", 10, "bytes"),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem(
+ "Nested item",
+ "nested_item",
+ Cell(
+ 10, fmt=Format(str_fmt=lambda x: f"{x} cell value")
+ ),
+ )
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10 cell value
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem(
+ "Nested item",
+ "nested_item",
+ Cell(
+ 10, fmt=Format(str_fmt=lambda x: f"{x} cell value")
+ ),
+ )
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10 cell value
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem("Nested item", "nested_item", Cell(10)),
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ (
+ NestedReport(
+ "Sample report",
+ "sample_report",
+ [
+ ReportItem(
+ "Item",
+ "item",
+ "item_value",
+ [
+ ReportItem(
+ "Nested item", "nested_item", Cell(10, fmt=Format())
+ ),
+ ],
+ ),
+ ],
+ ),
+ """
+Sample report:
+ Item item_value
+ Nested item 10
+""".strip(),
+ {
+ "sample_report": {
+ "item": {"nested_item": 10},
+ },
+ },
+ [
+ ("item", "nested_item"),
+ ("item_value", 10),
+ ],
+ ),
+ ],
+)
+def test_nested_report_representation(
+ report: NestedReport,
+ expected_plain_text: str,
+ expected_json_data: dict,
+ expected_csv_data: List,
+) -> None:
+ """Test representation of the NestedReport."""
+ plain_text = report.to_plain_text()
+ assert plain_text == expected_plain_text
+
+ json_data = report.to_json()
+ assert json_data == expected_json_data
+
+ csv_data = report.to_csv()
+ assert csv_data == expected_csv_data
+
+
+def test_single_row_representation() -> None:
+ """Test representation of the SingleRow."""
+ single_row = SingleRow(
+ columns=[
+ Column("column1", "column1"),
+ ],
+ rows=[("value1", "value2")],
+ name="Single row example",
+ alias="simple_row_example",
+ )
+
+ expected_text = """
+Single row example:
+ column1 value1
+""".strip()
+ assert single_row.to_plain_text() == expected_text
+ assert single_row.to_csv() == [["column1"], ["value1"]]
+ assert single_row.to_json() == {"simple_row_example": [{"column1": "value1"}]}
+
+ with pytest.raises(Exception, match="Table should have only one row"):
+ wrong_single_row = SingleRow(
+ columns=[
+ Column("column1", "column1"),
+ ],
+ rows=[
+ ("value1", "value2"),
+ ("value1", "value2"),
+ ],
+ name="Single row example",
+ alias="simple_row_example",
+ )
+ wrong_single_row.to_plain_text()
diff --git a/tests/mlia/test_core_workflow.py b/tests/mlia/test_core_workflow.py
new file mode 100644
index 0000000..470e572
--- /dev/null
+++ b/tests/mlia/test_core_workflow.py
@@ -0,0 +1,164 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module workflow."""
+from dataclasses import dataclass
+from unittest.mock import call
+from unittest.mock import MagicMock
+
+from mlia.core.advice_generation import Advice
+from mlia.core.advice_generation import AdviceEvent
+from mlia.core.advice_generation import ContextAwareAdviceProducer
+from mlia.core.context import ExecutionContext
+from mlia.core.data_analysis import ContextAwareDataAnalyzer
+from mlia.core.data_collection import ContextAwareDataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.core.events import AdviceStageFinishedEvent
+from mlia.core.events import AdviceStageStartedEvent
+from mlia.core.events import AnalyzedDataEvent
+from mlia.core.events import CollectedDataEvent
+from mlia.core.events import DataAnalysisStageFinishedEvent
+from mlia.core.events import DataAnalysisStageStartedEvent
+from mlia.core.events import DataCollectionStageFinishedEvent
+from mlia.core.events import DataCollectionStageStartedEvent
+from mlia.core.events import DataCollectorSkippedEvent
+from mlia.core.events import DefaultEventPublisher
+from mlia.core.events import Event
+from mlia.core.events import EventHandler
+from mlia.core.events import ExecutionFailedEvent
+from mlia.core.events import ExecutionFinishedEvent
+from mlia.core.events import ExecutionStartedEvent
+from mlia.core.workflow import DefaultWorkflowExecutor
+
+
+@dataclass
+class SampleEvent(Event):
+ """Sample event."""
+
+ msg: str
+
+
+def test_workflow_executor(tmpdir: str) -> None:
+ """Test workflow executor."""
+ handler_mock = MagicMock(spec=EventHandler)
+ data_collector_mock = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock.collect_data.return_value = 42
+
+ data_collector_mock_no_value = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock_no_value.collect_data.return_value = None
+
+ data_collector_mock_skipped = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock_skipped.name.return_value = "skipped_collector"
+ data_collector_mock_skipped.collect_data.side_effect = (
+ FunctionalityNotSupportedError("Error!", "Error!")
+ )
+
+ data_analyzer_mock = MagicMock(spec=ContextAwareDataAnalyzer)
+ data_analyzer_mock.get_analyzed_data.return_value = ["Really good number!"]
+
+ advice_producer_mock1 = MagicMock(spec=ContextAwareAdviceProducer)
+ advice_producer_mock1.get_advice.return_value = Advice(["All good!"])
+
+ advice_producer_mock2 = MagicMock(spec=ContextAwareAdviceProducer)
+ advice_producer_mock2.get_advice.return_value = [Advice(["Good advice!"])]
+
+ context = ExecutionContext(
+ working_dir=tmpdir,
+ event_handlers=[handler_mock],
+ event_publisher=DefaultEventPublisher(),
+ )
+
+ executor = DefaultWorkflowExecutor(
+ context,
+ [
+ data_collector_mock,
+ data_collector_mock_no_value,
+ data_collector_mock_skipped,
+ ],
+ [data_analyzer_mock],
+ [
+ advice_producer_mock1,
+ advice_producer_mock2,
+ ],
+ [SampleEvent("Hello from advisor!")],
+ )
+
+ executor.run()
+
+ data_collector_mock.collect_data.assert_called_once()
+ data_collector_mock_no_value.collect_data.assert_called_once()
+ data_collector_mock_skipped.collect_data.assert_called_once()
+
+ data_analyzer_mock.analyze_data.assert_called_once_with(42)
+
+ advice_producer_mock1.produce_advice.assert_called_once_with("Really good number!")
+ advice_producer_mock1.get_advice.assert_called_once()
+
+ advice_producer_mock2.produce_advice.called_once_with("Really good number!")
+ advice_producer_mock2.get_advice.assert_called_once()
+
+ expected_mock_calls = [
+ call(ExecutionStartedEvent()),
+ call(SampleEvent("Hello from advisor!")),
+ call(DataCollectionStageStartedEvent()),
+ call(CollectedDataEvent(data_item=42)),
+ call(DataCollectorSkippedEvent("skipped_collector", "Error!: Error!")),
+ call(DataCollectionStageFinishedEvent()),
+ call(DataAnalysisStageStartedEvent()),
+ call(AnalyzedDataEvent(data_item="Really good number!")),
+ call(DataAnalysisStageFinishedEvent()),
+ call(AdviceStageStartedEvent()),
+ call(AdviceEvent(advice=Advice(messages=["All good!"]))),
+ call(AdviceEvent(advice=Advice(messages=["Good advice!"]))),
+ call(AdviceStageFinishedEvent()),
+ call(ExecutionFinishedEvent()),
+ ]
+
+ for expected_call, actual_call in zip(
+ expected_mock_calls, handler_mock.handle_event.mock_calls
+ ):
+ expected_event = expected_call.args[0]
+ actual_event = actual_call.args[0]
+
+ assert actual_event.compare_without_id(expected_event)
+
+
+def test_workflow_executor_failed(tmpdir: str) -> None:
+ """Test scenario when one of the components raises exception."""
+ handler_mock = MagicMock(spec=EventHandler)
+
+ context = ExecutionContext(
+ working_dir=tmpdir,
+ event_handlers=[handler_mock],
+ event_publisher=DefaultEventPublisher(),
+ )
+
+ collection_exception = Exception("Collection failed")
+
+ data_collector_mock = MagicMock(spec=ContextAwareDataCollector)
+ data_collector_mock.collect_data.side_effect = collection_exception
+
+ executor = DefaultWorkflowExecutor(context, [data_collector_mock], [], [])
+ executor.run()
+
+ expected_mock_calls = [
+ call(ExecutionStartedEvent()),
+ call(DataCollectionStageStartedEvent()),
+ call(ExecutionFailedEvent(collection_exception)),
+ ]
+
+ for expected_call, actual_call in zip(
+ expected_mock_calls, handler_mock.handle_event.mock_calls
+ ):
+ expected_event = expected_call.args[0]
+ actual_event = actual_call.args[0]
+
+ if isinstance(actual_event, ExecutionFailedEvent):
+ # seems that dataclass comparison doesn't work well
+ # for the exceptions
+ actual_exception = actual_event.err
+ expected_exception = expected_event.err
+
+ assert actual_exception == expected_exception
+ continue
+
+ assert actual_event.compare_without_id(expected_event)
diff --git a/tests/mlia/test_devices_ethosu_advice_generation.py b/tests/mlia/test_devices_ethosu_advice_generation.py
new file mode 100644
index 0000000..98c8a57
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_advice_generation.py
@@ -0,0 +1,483 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Ethos-U advice generation."""
+from typing import List
+from typing import Optional
+
+import pytest
+
+from mlia.cli.helpers import CLIActionResolver
+from mlia.core.advice_generation import Advice
+from mlia.core.common import AdviceCategory
+from mlia.core.common import DataItem
+from mlia.core.context import ExecutionContext
+from mlia.core.helpers import ActionResolver
+from mlia.core.helpers import APIActionResolver
+from mlia.devices.ethosu.advice_generation import EthosUAdviceProducer
+from mlia.devices.ethosu.advice_generation import EthosUStaticAdviceProducer
+from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU
+from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators
+from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators
+from mlia.devices.ethosu.data_analysis import OptimizationDiff
+from mlia.devices.ethosu.data_analysis import OptimizationResults
+from mlia.devices.ethosu.data_analysis import PerfMetricDiff
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+@pytest.mark.parametrize(
+ "input_data, advice_category, action_resolver, expected_advice",
+ [
+ [
+ AllOperatorsSupportedOnNPU(),
+ AdviceCategory.OPERATORS,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You don't have any unsupported operators, your model will "
+ "run completely on NPU."
+ ]
+ )
+ ],
+ ],
+ [
+ AllOperatorsSupportedOnNPU(),
+ AdviceCategory.OPERATORS,
+ CLIActionResolver(
+ {
+ "target_profile": "sample_target",
+ "model": "sample_model.tflite",
+ }
+ ),
+ [
+ Advice(
+ [
+ "You don't have any unsupported operators, your model will "
+ "run completely on NPU.",
+ "Check the estimated performance by running the "
+ "following command: ",
+ "mlia performance --target-profile sample_target "
+ "sample_model.tflite",
+ ]
+ )
+ ],
+ ],
+ [
+ HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]),
+ AdviceCategory.OPERATORS,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You have at least 3 operators that is CPU only: "
+ "OP1,OP2,OP3.",
+ "Using operators that are supported by the NPU will "
+ "improve performance.",
+ ]
+ )
+ ],
+ ],
+ [
+ HasCPUOnlyOperators(cpu_only_ops=["OP1", "OP2", "OP3"]),
+ AdviceCategory.OPERATORS,
+ CLIActionResolver({}),
+ [
+ Advice(
+ [
+ "You have at least 3 operators that is CPU only: "
+ "OP1,OP2,OP3.",
+ "Using operators that are supported by the NPU will "
+ "improve performance.",
+ "For guidance on supported operators, run: mlia operators "
+ "--supported-ops-report",
+ ]
+ )
+ ],
+ ],
+ [
+ HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4),
+ AdviceCategory.OPERATORS,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You have 40% of operators that cannot be placed on the NPU.",
+ "For better performance, please review the reasons reported "
+ "in the table, and adjust the model accordingly "
+ "where possible.",
+ ]
+ )
+ ],
+ ],
+ [
+ HasUnsupportedOnNPUOperators(npu_unsupported_ratio=0.4),
+ AdviceCategory.OPERATORS,
+ CLIActionResolver({}),
+ [
+ Advice(
+ [
+ "You have 40% of operators that cannot be placed on the NPU.",
+ "For better performance, please review the reasons reported "
+ "in the table, and adjust the model accordingly "
+ "where possible.",
+ ]
+ )
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "You can try to push the optimization target higher "
+ "(e.g. pruning: 0.6) "
+ "to check if those results can be further improved.",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ CLIActionResolver({"model": "sample_model.h5"}),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "You can try to push the optimization target higher "
+ "(e.g. pruning: 0.6) "
+ "to check if those results can be further improved.",
+ "For more info: mlia optimization --help",
+ "Optimization command: "
+ "mlia optimization --optimization-type pruning "
+ "--optimization-target 0.6 sample_model.h5",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[
+ OptimizationSettings("pruning", 0.5, None),
+ OptimizationSettings("clustering", 32, None),
+ ],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5, clustering: 32)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "You can try to push the optimization target higher "
+ "(e.g. pruning: 0.6 and/or clustering: 16) "
+ "to check if those results can be further improved.",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[
+ OptimizationSettings("clustering", 2, None),
+ ],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 50),
+ "on_chip_flash": PerfMetricDiff(100, 100),
+ "off_chip_flash": PerfMetricDiff(100, 100),
+ "npu_total_cycles": PerfMetricDiff(10, 5),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (clustering: 2)",
+ "- You have achieved 50.00% performance improvement in "
+ "DRAM used (KB)",
+ "- You have achieved 50.00% performance improvement in "
+ "NPU total cycles",
+ "- SRAM used (KB) have degraded by 50.00%",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 150),
+ "on_chip_flash": PerfMetricDiff(100, 150),
+ "off_chip_flash": PerfMetricDiff(100, 150),
+ "npu_total_cycles": PerfMetricDiff(10, 100),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "With the selected optimization (pruning: 0.5)",
+ "- DRAM used (KB) have degraded by 50.00%",
+ "- SRAM used (KB) have degraded by 50.00%",
+ "- On chip flash used (KB) have degraded by 50.00%",
+ "- Off chip flash used (KB) have degraded by 50.00%",
+ "- NPU total cycles have degraded by 900.00%",
+ "The performance seems to have degraded after "
+ "applying the selected optimizations, "
+ "try exploring different optimization types/targets.",
+ ]
+ ),
+ Advice(
+ [
+ "The applied tooling techniques have an impact "
+ "on accuracy. Additional hyperparameter tuning may be required "
+ "after any optimization."
+ ]
+ ),
+ ],
+ ],
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.5, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 150),
+ "on_chip_flash": PerfMetricDiff(100, 150),
+ "off_chip_flash": PerfMetricDiff(100, 150),
+ "npu_total_cycles": PerfMetricDiff(10, 100),
+ },
+ ),
+ OptimizationDiff(
+ opt_type=[OptimizationSettings("pruning", 0.6, None)],
+ opt_diffs={
+ "sram": PerfMetricDiff(100, 150),
+ "dram": PerfMetricDiff(100, 150),
+ "on_chip_flash": PerfMetricDiff(100, 150),
+ "off_chip_flash": PerfMetricDiff(100, 150),
+ "npu_total_cycles": PerfMetricDiff(10, 100),
+ },
+ ),
+ ]
+ ),
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [], # no advice for more than one optimization result
+ ],
+ ],
+)
+def test_ethosu_advice_producer(
+ tmpdir: str,
+ input_data: DataItem,
+ expected_advice: List[Advice],
+ advice_category: AdviceCategory,
+ action_resolver: ActionResolver,
+) -> None:
+ """Test Ethos-U Advice producer."""
+ producer = EthosUAdviceProducer()
+
+ context = ExecutionContext(
+ advice_category=advice_category,
+ working_dir=tmpdir,
+ action_resolver=action_resolver,
+ )
+
+ producer.set_context(context)
+ producer.produce_advice(input_data)
+
+ assert producer.get_advice() == expected_advice
+
+
+@pytest.mark.parametrize(
+ "advice_category, action_resolver, expected_advice",
+ [
+ [
+ None,
+ None,
+ [],
+ ],
+ [
+ AdviceCategory.OPERATORS,
+ None,
+ [],
+ ],
+ [
+ AdviceCategory.PERFORMANCE,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "You can improve the inference time by using only operators "
+ "that are supported by the NPU.",
+ ]
+ ),
+ Advice(
+ [
+ "Check if you can improve the performance by applying "
+ "tooling techniques to your model."
+ ]
+ ),
+ ],
+ ],
+ [
+ AdviceCategory.PERFORMANCE,
+ CLIActionResolver({"model": "test_model.h5"}),
+ [
+ Advice(
+ [
+ "You can improve the inference time by using only operators "
+ "that are supported by the NPU.",
+ "Try running the following command to verify that:",
+ "mlia operators test_model.h5",
+ ]
+ ),
+ Advice(
+ [
+ "Check if you can improve the performance by applying "
+ "tooling techniques to your model.",
+ "For example: mlia optimization --optimization-type "
+ "pruning,clustering --optimization-target 0.5,32 "
+ "test_model.h5",
+ "For more info: mlia optimization --help",
+ ]
+ ),
+ ],
+ ],
+ [
+ AdviceCategory.OPTIMIZATION,
+ APIActionResolver(),
+ [
+ Advice(
+ [
+ "For better performance, make sure that all the operators "
+ "of your final TFLite model are supported by the NPU.",
+ ]
+ )
+ ],
+ ],
+ [
+ AdviceCategory.OPTIMIZATION,
+ CLIActionResolver({"model": "test_model.h5"}),
+ [
+ Advice(
+ [
+ "For better performance, make sure that all the operators "
+ "of your final TFLite model are supported by the NPU.",
+ "For more details, run: mlia operators --help",
+ ]
+ )
+ ],
+ ],
+ ],
+)
+def test_ethosu_static_advice_producer(
+ tmpdir: str,
+ advice_category: Optional[AdviceCategory],
+ action_resolver: ActionResolver,
+ expected_advice: List[Advice],
+) -> None:
+ """Test static advice generation."""
+ producer = EthosUStaticAdviceProducer()
+
+ context = ExecutionContext(
+ advice_category=advice_category,
+ working_dir=tmpdir,
+ action_resolver=action_resolver,
+ )
+ producer.set_context(context)
+ assert producer.get_advice() == expected_advice
diff --git a/tests/mlia/test_devices_ethosu_advisor.py b/tests/mlia/test_devices_ethosu_advisor.py
new file mode 100644
index 0000000..74d2408
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_advisor.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Ethos-U MLIA module."""
+from mlia.devices.ethosu.advisor import EthosUInferenceAdvisor
+
+
+def test_advisor_metadata() -> None:
+ """Test advisor metadata."""
+ assert EthosUInferenceAdvisor.name() == "ethos_u_inference_advisor"
diff --git a/tests/mlia/test_devices_ethosu_config.py b/tests/mlia/test_devices_ethosu_config.py
new file mode 100644
index 0000000..49c999a
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_config.py
@@ -0,0 +1,124 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for config module."""
+from contextlib import ExitStack as does_not_raise
+from typing import Any
+from typing import Dict
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.config import get_target
+from mlia.tools.vela_wrapper import VelaCompilerOptions
+from mlia.utils.filesystem import get_vela_config
+
+
+def test_compiler_options_default_init() -> None:
+ """Test compiler options default init."""
+ opts = VelaCompilerOptions()
+
+ assert opts.config_files is None
+ assert opts.system_config == "internal-default"
+ assert opts.memory_mode == "internal-default"
+ assert opts.accelerator_config is None
+ assert opts.max_block_dependency == 3
+ assert opts.arena_cache_size is None
+ assert opts.tensor_allocator == "HillClimb"
+ assert opts.cpu_tensor_alignment == 16
+ assert opts.optimization_strategy == "Performance"
+ assert opts.output_dir is None
+
+
+def test_ethosu_target() -> None:
+ """Test Ethos-U target configuration init."""
+ default_config = EthosUConfiguration("ethos-u55-256")
+
+ assert default_config.target == "ethos-u55"
+ assert default_config.mac == 256
+ assert default_config.compiler_options is not None
+
+
+def test_get_target() -> None:
+ """Test function get_target."""
+ with pytest.raises(Exception, match="No target profile given"):
+ get_target(None) # type: ignore
+
+ with pytest.raises(Exception, match="Unable to find target profile unknown"):
+ get_target("unknown")
+
+ u65_device = get_target("ethos-u65-512")
+
+ assert isinstance(u65_device, EthosUConfiguration)
+ assert u65_device.target == "ethos-u65"
+ assert u65_device.mac == 512
+ assert u65_device.compiler_options.accelerator_config == "ethos-u65-512"
+ assert u65_device.compiler_options.memory_mode == "Dedicated_Sram"
+ assert u65_device.compiler_options.config_files == str(get_vela_config())
+
+
+@pytest.mark.parametrize(
+ "profile_data, expected_error",
+ [
+ [
+ {},
+ pytest.raises(
+ Exception,
+ match="Mandatory fields missing from target profile: "
+ r"\['mac', 'memory_mode', 'system_config', 'target'\]",
+ ),
+ ],
+ [
+ {"target": "ethos-u65", "mac": 512},
+ pytest.raises(
+ Exception,
+ match="Mandatory fields missing from target profile: "
+ r"\['memory_mode', 'system_config'\]",
+ ),
+ ],
+ [
+ {
+ "target": "ethos-u65",
+ "mac": 2,
+ "system_config": "Ethos_U65_Embedded",
+ "memory_mode": "Shared_Sram",
+ },
+ pytest.raises(
+ Exception,
+ match=r"Mac value for selected device should be in \[256, 512\]",
+ ),
+ ],
+ [
+ {
+ "target": "ethos-u55",
+ "mac": 1,
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "memory_mode": "Shared_Sram",
+ },
+ pytest.raises(
+ Exception,
+ match="Mac value for selected device should be "
+ r"in \[32, 64, 128, 256\]",
+ ),
+ ],
+ [
+ {
+ "target": "ethos-u65",
+ "mac": 512,
+ "system_config": "Ethos_U65_Embedded",
+ "memory_mode": "Shared_Sram",
+ },
+ does_not_raise(),
+ ],
+ ],
+)
+def test_ethosu_configuration(
+ monkeypatch: pytest.MonkeyPatch, profile_data: Dict[str, Any], expected_error: Any
+) -> None:
+ """Test creating Ethos-U configuration."""
+ monkeypatch.setattr(
+ "mlia.devices.ethosu.config.get_profile", MagicMock(return_value=profile_data)
+ )
+
+ with expected_error:
+ EthosUConfiguration("target")
diff --git a/tests/mlia/test_devices_ethosu_data_analysis.py b/tests/mlia/test_devices_ethosu_data_analysis.py
new file mode 100644
index 0000000..4b1d38b
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_data_analysis.py
@@ -0,0 +1,147 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Ethos-U data analysis module."""
+from typing import List
+
+import pytest
+
+from mlia.core.common import DataItem
+from mlia.core.data_analysis import Fact
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.data_analysis import AllOperatorsSupportedOnNPU
+from mlia.devices.ethosu.data_analysis import EthosUDataAnalyzer
+from mlia.devices.ethosu.data_analysis import HasCPUOnlyOperators
+from mlia.devices.ethosu.data_analysis import HasUnsupportedOnNPUOperators
+from mlia.devices.ethosu.data_analysis import OptimizationDiff
+from mlia.devices.ethosu.data_analysis import OptimizationResults
+from mlia.devices.ethosu.data_analysis import PerfMetricDiff
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.tools.vela_wrapper import NpuSupported
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+
+
+def test_perf_metrics_diff() -> None:
+ """Test PerfMetricsDiff class."""
+ diff_same = PerfMetricDiff(1, 1)
+ assert diff_same.same is True
+ assert diff_same.improved is False
+ assert diff_same.degraded is False
+ assert diff_same.diff == 0
+
+ diff_improved = PerfMetricDiff(10, 5)
+ assert diff_improved.same is False
+ assert diff_improved.improved is True
+ assert diff_improved.degraded is False
+ assert diff_improved.diff == 50.0
+
+ diff_degraded = PerfMetricDiff(5, 10)
+ assert diff_degraded.same is False
+ assert diff_degraded.improved is False
+ assert diff_degraded.degraded is True
+ assert diff_degraded.diff == -100.0
+
+ diff_original_zero = PerfMetricDiff(0, 1)
+ assert diff_original_zero.diff == 0
+
+
+@pytest.mark.parametrize(
+ "input_data, expected_facts",
+ [
+ [
+ Operators(
+ [
+ Operator(
+ "CPU operator",
+ "CPU operator type",
+ NpuSupported(False, [("CPU only operator", "")]),
+ )
+ ]
+ ),
+ [
+ HasCPUOnlyOperators(["CPU operator type"]),
+ HasUnsupportedOnNPUOperators(1.0),
+ ],
+ ],
+ [
+ Operators(
+ [
+ Operator(
+ "NPU operator",
+ "NPU operator type",
+ NpuSupported(True, []),
+ )
+ ]
+ ),
+ [
+ AllOperatorsSupportedOnNPU(),
+ ],
+ ],
+ [
+ OptimizationPerformanceMetrics(
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ # memory metrics are in kilobytes
+ MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore
+ ),
+ [
+ [
+ [
+ OptimizationSettings("pruning", 0.5, None),
+ ],
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ # memory metrics are in kilobytes
+ MemoryUsage(
+ *[i * 1024 for i in range(1, 6)] # type: ignore
+ ),
+ ),
+ ],
+ ],
+ ),
+ [
+ OptimizationResults(
+ [
+ OptimizationDiff(
+ opt_type=[
+ OptimizationSettings("pruning", 0.5, None),
+ ],
+ opt_diffs={
+ "sram": PerfMetricDiff(1.0, 1.0),
+ "dram": PerfMetricDiff(2.0, 2.0),
+ "on_chip_flash": PerfMetricDiff(4.0, 4.0),
+ "off_chip_flash": PerfMetricDiff(5.0, 5.0),
+ "npu_total_cycles": PerfMetricDiff(3, 3),
+ },
+ )
+ ]
+ )
+ ],
+ ],
+ [
+ OptimizationPerformanceMetrics(
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ # memory metrics are in kilobytes
+ MemoryUsage(*[i * 1024 for i in range(1, 6)]), # type: ignore
+ ),
+ [],
+ ),
+ [],
+ ],
+ ],
+)
+def test_ethos_u_data_analyzer(
+ input_data: DataItem, expected_facts: List[Fact]
+) -> None:
+ """Test Ethos-U data analyzer."""
+ analyzer = EthosUDataAnalyzer()
+ analyzer.analyze_data(input_data)
+ assert analyzer.get_analyzed_data() == expected_facts
diff --git a/tests/mlia/test_devices_ethosu_data_collection.py b/tests/mlia/test_devices_ethosu_data_collection.py
new file mode 100644
index 0000000..897cf41
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_data_collection.py
@@ -0,0 +1,151 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the data collection module for Ethos-U."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.context import Context
+from mlia.core.data_collection import DataCollector
+from mlia.core.errors import FunctionalityNotSupportedError
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.data_collection import EthosUOperatorCompatibility
+from mlia.devices.ethosu.data_collection import EthosUOptimizationPerformance
+from mlia.devices.ethosu.data_collection import EthosUPerformance
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import OptimizationPerformanceMetrics
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+from mlia.tools.vela_wrapper import Operators
+
+
+@pytest.mark.parametrize(
+ "collector, expected_name",
+ [
+ (
+ EthosUOperatorCompatibility,
+ "ethos_u_operator_compatibility",
+ ),
+ (
+ EthosUPerformance,
+ "ethos_u_performance",
+ ),
+ (
+ EthosUOptimizationPerformance,
+ "ethos_u_model_optimizations",
+ ),
+ ],
+)
+def test_collectors_metadata(
+ collector: DataCollector,
+ expected_name: str,
+) -> None:
+ """Test collectors metadata."""
+ assert collector.name() == expected_name
+
+
+def test_operator_compatibility_collector(
+ dummy_context: Context, test_tflite_model: Path
+) -> None:
+ """Test operator compatibility data collector."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ collector = EthosUOperatorCompatibility(test_tflite_model, device)
+ collector.set_context(dummy_context)
+
+ result = collector.collect_data()
+ assert isinstance(result, Operators)
+
+
+def test_performance_collector(
+ monkeypatch: pytest.MonkeyPatch, dummy_context: Context, test_tflite_model: Path
+) -> None:
+ """Test performance data collector."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ mock_performance_estimation(monkeypatch, device)
+
+ collector = EthosUPerformance(test_tflite_model, device)
+ collector.set_context(dummy_context)
+
+ result = collector.collect_data()
+ assert isinstance(result, PerformanceMetrics)
+
+
+def test_optimization_performance_collector(
+ monkeypatch: pytest.MonkeyPatch,
+ dummy_context: Context,
+ test_keras_model: Path,
+ test_tflite_model: Path,
+) -> None:
+ """Test optimization performance data collector."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ mock_performance_estimation(monkeypatch, device)
+ collector = EthosUOptimizationPerformance(
+ test_keras_model,
+ device,
+ [
+ [
+ {"optimization_type": "pruning", "optimization_target": 0.5},
+ ]
+ ],
+ )
+ collector.set_context(dummy_context)
+ result = collector.collect_data()
+
+ assert isinstance(result, OptimizationPerformanceMetrics)
+ assert isinstance(result.original_perf_metrics, PerformanceMetrics)
+ assert isinstance(result.optimizations_perf_metrics, list)
+ assert len(result.optimizations_perf_metrics) == 1
+
+ opt, metrics = result.optimizations_perf_metrics[0]
+ assert opt == [OptimizationSettings("pruning", 0.5, None)]
+ assert isinstance(metrics, PerformanceMetrics)
+
+ collector_no_optimizations = EthosUOptimizationPerformance(
+ test_keras_model,
+ device,
+ [],
+ )
+ with pytest.raises(FunctionalityNotSupportedError):
+ collector_no_optimizations.collect_data()
+
+ collector_tflite = EthosUOptimizationPerformance(
+ test_tflite_model,
+ device,
+ [
+ [
+ {"optimization_type": "pruning", "optimization_target": 0.5},
+ ]
+ ],
+ )
+ collector_tflite.set_context(dummy_context)
+ with pytest.raises(FunctionalityNotSupportedError):
+ collector_tflite.collect_data()
+
+ with pytest.raises(
+ Exception, match="Optimization parameters expected to be a list"
+ ):
+ collector_bad_config = EthosUOptimizationPerformance(
+ test_keras_model, device, {"optimization_type": "pruning"} # type: ignore
+ )
+ collector.set_context(dummy_context)
+ collector_bad_config.collect_data()
+
+
+def mock_performance_estimation(
+ monkeypatch: pytest.MonkeyPatch, device: EthosUConfiguration
+) -> None:
+ """Mock performance estimation."""
+ metrics = PerformanceMetrics(
+ device,
+ NPUCycles(1, 2, 3, 4, 5, 6),
+ MemoryUsage(1, 2, 3, 4, 5),
+ )
+ monkeypatch.setattr(
+ "mlia.devices.ethosu.data_collection.EthosUPerformanceEstimator.estimate",
+ MagicMock(return_value=metrics),
+ )
diff --git a/tests/mlia/test_devices_ethosu_performance.py b/tests/mlia/test_devices_ethosu_performance.py
new file mode 100644
index 0000000..e27efa0
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_performance.py
@@ -0,0 +1,28 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Performance estimation tests."""
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.devices.ethosu.performance import MemorySizeType
+from mlia.devices.ethosu.performance import MemoryUsage
+
+
+def test_memory_usage_conversion() -> None:
+ """Test MemoryUsage objects conversion."""
+ memory_usage_in_kb = MemoryUsage(1, 2, 3, 4, 5, MemorySizeType.KILOBYTES)
+ assert memory_usage_in_kb.in_kilobytes() == memory_usage_in_kb
+
+ memory_usage_in_bytes = MemoryUsage(
+ 1 * 1024, 2 * 1024, 3 * 1024, 4 * 1024, 5 * 1024
+ )
+ assert memory_usage_in_bytes.in_kilobytes() == memory_usage_in_kb
+
+
+def mock_performance_estimation(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Mock performance estimation."""
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.estimate_performance",
+ MagicMock(return_value=MagicMock()),
+ )
diff --git a/tests/mlia/test_devices_ethosu_reporters.py b/tests/mlia/test_devices_ethosu_reporters.py
new file mode 100644
index 0000000..2d5905c
--- /dev/null
+++ b/tests/mlia/test_devices_ethosu_reporters.py
@@ -0,0 +1,434 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for reports module."""
+import json
+import sys
+from contextlib import ExitStack as doesnt_raise
+from pathlib import Path
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Literal
+
+import pytest
+
+from mlia.core.reporting import get_reporter
+from mlia.core.reporting import produce_report
+from mlia.core.reporting import Report
+from mlia.core.reporting import Reporter
+from mlia.core.reporting import Table
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.devices.ethosu.performance import MemoryUsage
+from mlia.devices.ethosu.performance import NPUCycles
+from mlia.devices.ethosu.performance import PerformanceMetrics
+from mlia.devices.ethosu.reporters import find_appropriate_formatter
+from mlia.devices.ethosu.reporters import report_device_details
+from mlia.devices.ethosu.reporters import report_operators
+from mlia.devices.ethosu.reporters import report_perf_metrics
+from mlia.tools.vela_wrapper import NpuSupported
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+from mlia.utils.console import remove_ascii_codes
+
+
+@pytest.mark.parametrize(
+ "data, formatters",
+ [
+ (
+ [Operator("test_operator", "test_type", NpuSupported(False, []))],
+ [report_operators],
+ ),
+ (
+ PerformanceMetrics(
+ EthosUConfiguration("ethos-u55-256"),
+ NPUCycles(0, 0, 0, 0, 0, 0),
+ MemoryUsage(0, 0, 0, 0, 0),
+ ),
+ [report_perf_metrics],
+ ),
+ ],
+)
+@pytest.mark.parametrize(
+ "fmt, output, expected_error",
+ [
+ [
+ "unknown_format",
+ sys.stdout,
+ pytest.raises(Exception, match="Unknown format unknown_format"),
+ ],
+ [
+ "plain_text",
+ sys.stdout,
+ doesnt_raise(),
+ ],
+ [
+ "json",
+ sys.stdout,
+ doesnt_raise(),
+ ],
+ [
+ "csv",
+ sys.stdout,
+ doesnt_raise(),
+ ],
+ [
+ "plain_text",
+ "report.txt",
+ doesnt_raise(),
+ ],
+ [
+ "json",
+ "report.json",
+ doesnt_raise(),
+ ],
+ [
+ "csv",
+ "report.csv",
+ doesnt_raise(),
+ ],
+ ],
+)
+def test_report(
+ data: Any,
+ formatters: List[Callable],
+ fmt: Literal["plain_text", "json", "csv"],
+ output: Any,
+ expected_error: Any,
+ tmp_path: Path,
+) -> None:
+ """Test report function."""
+ if is_file := isinstance(output, str):
+ output = tmp_path / output
+
+ for formatter in formatters:
+ with expected_error:
+ produce_report(data, formatter, fmt, output)
+
+ if is_file:
+ assert output.is_file()
+ assert output.stat().st_size > 0
+
+
+@pytest.mark.parametrize(
+ "ops, expected_plain_text, expected_json_dict, expected_csv_list",
+ [
+ (
+ [
+ Operator(
+ "npu_supported",
+ "test_type",
+ NpuSupported(True, []),
+ ),
+ Operator(
+ "cpu_only",
+ "test_type",
+ NpuSupported(
+ False,
+ [
+ (
+ "CPU only operator",
+ "",
+ ),
+ ],
+ ),
+ ),
+ Operator(
+ "npu_unsupported",
+ "test_type",
+ NpuSupported(
+ False,
+ [
+ (
+ "Not supported operator",
+ "Reason why operator is not supported",
+ )
+ ],
+ ),
+ ),
+ ],
+ """
+Operators:
+┌───┬─────────────────┬───────────────┬───────────┬───────────────────────────────┐
+│ # │ Operator name │ Operator type │ Placement │ Notes │
+╞═══╪═════════════════╪═══════════════╪═══════════╪═══════════════════════════════╡
+│ 1 │ npu_supported │ test_type │ NPU │ │
+├───┼─────────────────┼───────────────┼───────────┼───────────────────────────────┤
+│ 2 │ cpu_only │ test_type │ CPU │ * CPU only operator │
+├───┼─────────────────┼───────────────┼───────────┼───────────────────────────────┤
+│ 3 │ npu_unsupported │ test_type │ CPU │ * Not supported operator │
+│ │ │ │ │ │
+│ │ │ │ │ * Reason why operator is not │
+│ │ │ │ │ supported │
+└───┴─────────────────┴───────────────┴───────────┴───────────────────────────────┘
+""".strip(),
+ {
+ "operators": [
+ {
+ "operator_name": "npu_supported",
+ "operator_type": "test_type",
+ "placement": "NPU",
+ "notes": [],
+ },
+ {
+ "operator_name": "cpu_only",
+ "operator_type": "test_type",
+ "placement": "CPU",
+ "notes": [{"note": "CPU only operator"}],
+ },
+ {
+ "operator_name": "npu_unsupported",
+ "operator_type": "test_type",
+ "placement": "CPU",
+ "notes": [
+ {"note": "Not supported operator"},
+ {"note": "Reason why operator is not supported"},
+ ],
+ },
+ ]
+ },
+ [
+ ["Operator name", "Operator type", "Placement", "Notes"],
+ ["npu_supported", "test_type", "NPU", ""],
+ ["cpu_only", "test_type", "CPU", "CPU only operator"],
+ [
+ "npu_unsupported",
+ "test_type",
+ "CPU",
+ "Not supported operator;Reason why operator is not supported",
+ ],
+ ],
+ ),
+ ],
+)
+def test_report_operators(
+ ops: List[Operator],
+ expected_plain_text: str,
+ expected_json_dict: Dict,
+ expected_csv_list: List,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test report_operatos formatter."""
+ # make terminal wide enough to print whole table
+ monkeypatch.setenv("COLUMNS", "100")
+
+ report = report_operators(ops)
+ assert isinstance(report, Table)
+
+ plain_text = remove_ascii_codes(report.to_plain_text())
+ assert plain_text == expected_plain_text
+
+ json_dict = report.to_json()
+ assert json_dict == expected_json_dict
+
+ csv_list = report.to_csv()
+ assert csv_list == expected_csv_list
+
+
+@pytest.mark.parametrize(
+ "device, expected_plain_text, expected_json_dict, expected_csv_list",
+ [
+ [
+ EthosUConfiguration("ethos-u55-256"),
+ """Device information:
+ Target ethos-u55
+ MAC 256
+
+ Memory mode Shared_Sram
+ Const mem area Axi1
+ Arena mem area Axi0
+ Cache mem area Axi0
+ Arena cache size 4,294,967,296 bytes
+
+ System config Ethos_U55_High_End_Embedded
+ Accelerator clock 500,000,000 Hz
+ AXI0 port Sram
+ AXI1 port OffChipFlash
+
+ Memory area settings:
+ Sram:
+ Clock scales 1.0
+ Burst length 32 bytes
+ Read latency 32 cycles
+ Write latency 32 cycles
+
+ Dram:
+ Clock scales 1.0
+ Burst length 1 byte
+ Read latency 0 cycles
+ Write latency 0 cycles
+
+ OnChipFlash:
+ Clock scales 1.0
+ Burst length 1 byte
+ Read latency 0 cycles
+ Write latency 0 cycles
+
+ OffChipFlash:
+ Clock scales 0.125
+ Burst length 128 bytes
+ Read latency 64 cycles
+ Write latency 64 cycles
+
+ Architecture settings:
+ Permanent storage mem area OffChipFlash
+ Feature map storage mem area Sram
+ Fast storage mem area Sram""",
+ {
+ "device": {
+ "target": "ethos-u55",
+ "mac": 256,
+ "memory_mode": {
+ "const_mem_area": "Axi1",
+ "arena_mem_area": "Axi0",
+ "cache_mem_area": "Axi0",
+ "arena_cache_size": {"value": 4294967296, "unit": "bytes"},
+ },
+ "system_config": {
+ "accelerator_clock": {"value": 500000000.0, "unit": "Hz"},
+ "axi0_port": "Sram",
+ "axi1_port": "OffChipFlash",
+ "memory_area": {
+ "Sram": {
+ "clock_scales": 1.0,
+ "burst_length": {"value": 32, "unit": "bytes"},
+ "read_latency": {"value": 32, "unit": "cycles"},
+ "write_latency": {"value": 32, "unit": "cycles"},
+ },
+ "Dram": {
+ "clock_scales": 1.0,
+ "burst_length": {"value": 1, "unit": "byte"},
+ "read_latency": {"value": 0, "unit": "cycles"},
+ "write_latency": {"value": 0, "unit": "cycles"},
+ },
+ "OnChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": {"value": 1, "unit": "byte"},
+ "read_latency": {"value": 0, "unit": "cycles"},
+ "write_latency": {"value": 0, "unit": "cycles"},
+ },
+ "OffChipFlash": {
+ "clock_scales": 0.125,
+ "burst_length": {"value": 128, "unit": "bytes"},
+ "read_latency": {"value": 64, "unit": "cycles"},
+ "write_latency": {"value": 64, "unit": "cycles"},
+ },
+ },
+ },
+ "arch_settings": {
+ "permanent_storage_mem_area": "OffChipFlash",
+ "feature_map_storage_mem_area": "Sram",
+ "fast_storage_mem_area": "Sram",
+ },
+ }
+ },
+ [
+ (
+ "target",
+ "mac",
+ "memory_mode",
+ "const_mem_area",
+ "arena_mem_area",
+ "cache_mem_area",
+ "arena_cache_size_value",
+ "arena_cache_size_unit",
+ "system_config",
+ "accelerator_clock_value",
+ "accelerator_clock_unit",
+ "axi0_port",
+ "axi1_port",
+ "clock_scales",
+ "burst_length_value",
+ "burst_length_unit",
+ "read_latency_value",
+ "read_latency_unit",
+ "write_latency_value",
+ "write_latency_unit",
+ "permanent_storage_mem_area",
+ "feature_map_storage_mem_area",
+ "fast_storage_mem_area",
+ ),
+ (
+ "ethos-u55",
+ 256,
+ "Shared_Sram",
+ "Axi1",
+ "Axi0",
+ "Axi0",
+ 4294967296,
+ "bytes",
+ "Ethos_U55_High_End_Embedded",
+ 500000000.0,
+ "Hz",
+ "Sram",
+ "OffChipFlash",
+ 0.125,
+ 128,
+ "bytes",
+ 64,
+ "cycles",
+ 64,
+ "cycles",
+ "OffChipFlash",
+ "Sram",
+ "Sram",
+ ),
+ ],
+ ],
+ ],
+)
+def test_report_device_details(
+ device: EthosUConfiguration,
+ expected_plain_text: str,
+ expected_json_dict: Dict,
+ expected_csv_list: List,
+) -> None:
+ """Test report_operatos formatter."""
+ report = report_device_details(device)
+ assert isinstance(report, Report)
+
+ plain_text = report.to_plain_text()
+ assert plain_text == expected_plain_text
+
+ json_dict = report.to_json()
+ assert json_dict == expected_json_dict
+
+ csv_list = report.to_csv()
+ assert csv_list == expected_csv_list
+
+
+def test_get_reporter(tmp_path: Path) -> None:
+ """Test reporter functionality."""
+ ops = Operators(
+ [
+ Operator(
+ "npu_supported",
+ "op_type",
+ NpuSupported(True, []),
+ ),
+ ]
+ )
+
+ output = tmp_path / "output.json"
+ with get_reporter("json", output, find_appropriate_formatter) as reporter:
+ assert isinstance(reporter, Reporter)
+
+ with pytest.raises(
+ Exception, match="Unable to find appropriate formatter for some_data"
+ ):
+ reporter.submit("some_data")
+
+ reporter.submit(ops)
+
+ with open(output, encoding="utf-8") as file:
+ json_data = json.load(file)
+
+ assert json_data == {
+ "operators_stats": [
+ {
+ "npu_unsupported_ratio": 0.0,
+ "num_of_npu_supported_operators": 1,
+ "num_of_operators": 1,
+ }
+ ]
+ }
diff --git a/tests/mlia/test_nn_tensorflow_config.py b/tests/mlia/test_nn_tensorflow_config.py
new file mode 100644
index 0000000..1ac9f97
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_config.py
@@ -0,0 +1,72 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for config module."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+
+import pytest
+
+from mlia.nn.tensorflow.config import get_model
+from mlia.nn.tensorflow.config import KerasModel
+from mlia.nn.tensorflow.config import TFLiteModel
+from mlia.nn.tensorflow.config import TfModel
+
+
+def test_convert_keras_to_tflite(tmp_path: Path, test_keras_model: Path) -> None:
+ """Test Keras to TFLite conversion."""
+ keras_model = KerasModel(test_keras_model)
+
+ tflite_model_path = tmp_path / "test.tflite"
+ keras_model.convert_to_tflite(tflite_model_path)
+
+ assert tflite_model_path.is_file()
+ assert tflite_model_path.stat().st_size > 0
+
+
+def test_convert_tf_to_tflite(tmp_path: Path, test_tf_model: Path) -> None:
+ """Test TensorFlow saved model to TFLite conversion."""
+ tf_model = TfModel(test_tf_model)
+
+ tflite_model_path = tmp_path / "test.tflite"
+ tf_model.convert_to_tflite(tflite_model_path)
+
+ assert tflite_model_path.is_file()
+ assert tflite_model_path.stat().st_size > 0
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_type, expected_error",
+ [
+ ("test.tflite", TFLiteModel, does_not_raise()),
+ ("test.h5", KerasModel, does_not_raise()),
+ ("test.hdf5", KerasModel, does_not_raise()),
+ (
+ "test.model",
+ None,
+ pytest.raises(
+ Exception,
+ match="The input model format is not supported"
+ r"\(supported formats: TFLite, Keras, TensorFlow saved model\)!",
+ ),
+ ),
+ ],
+)
+def test_get_model_file(
+ model_path: str, expected_type: type, expected_error: Any
+) -> None:
+ """Test TFLite model type."""
+ with expected_error:
+ model = get_model(model_path)
+ assert isinstance(model, expected_type)
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_type", [("tf_model_test_model", TfModel)]
+)
+def test_get_model_dir(
+ test_models_path: Path, model_path: str, expected_type: type
+) -> None:
+ """Test TFLite model type."""
+ model = get_model(str(test_models_path / model_path))
+ assert isinstance(model, expected_type)
diff --git a/tests/mlia/test_nn_tensorflow_optimizations_clustering.py b/tests/mlia/test_nn_tensorflow_optimizations_clustering.py
new file mode 100644
index 0000000..9bcf918
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_optimizations_clustering.py
@@ -0,0 +1,131 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module optimizations/clustering."""
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.optimizations.clustering import Clusterer
+from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode
+from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import save_tflite_model
+from tests.mlia.utils.common import get_dataset
+from tests.mlia.utils.common import train_model
+
+
+def _prune_model(
+ model: tf.keras.Model, target_sparsity: float, layers_to_prune: Optional[List[str]]
+) -> tf.keras.Model:
+ x_train, y_train = get_dataset()
+ batch_size = 1
+ num_epochs = 1
+
+ pruner = Pruner(
+ model,
+ PruningConfiguration(
+ target_sparsity,
+ layers_to_prune,
+ x_train,
+ y_train,
+ batch_size,
+ num_epochs,
+ ),
+ )
+ pruner.apply_optimization()
+ pruned_model = pruner.get_model()
+
+ return pruned_model
+
+
+def _test_num_unique_weights(
+ metrics: TFLiteMetrics,
+ target_num_clusters: int,
+ layers_to_cluster: Optional[List[str]],
+) -> None:
+ clustered_uniqueness_dict = metrics.num_unique_weights(
+ ReportClusterMode.NUM_CLUSTERS_PER_AXIS
+ )
+ num_clustered_layers = 0
+ num_optimizable_layers = len(clustered_uniqueness_dict)
+ if layers_to_cluster:
+ expected_num_clustered_layers = len(layers_to_cluster)
+ else:
+ expected_num_clustered_layers = num_optimizable_layers
+ for layer_name in clustered_uniqueness_dict:
+ # the +1 is there temporarily because of a bug that's been fixed
+ # but the fix hasn't been merged yet.
+ # Will need to be removed in the future.
+ if clustered_uniqueness_dict[layer_name][0] <= (target_num_clusters + 1):
+ num_clustered_layers = num_clustered_layers + 1
+ # make sure we are having exactly as many clustered layers as we wanted
+ assert num_clustered_layers == expected_num_clustered_layers
+
+
+def _test_sparsity(
+ metrics: TFLiteMetrics,
+ target_sparsity: float,
+ layers_to_cluster: Optional[List[str]],
+) -> None:
+ pruned_sparsity_dict = metrics.sparsity_per_layer()
+ num_sparse_layers = 0
+ num_optimizable_layers = len(pruned_sparsity_dict)
+ error_margin = 0.03
+ if layers_to_cluster:
+ expected_num_sparse_layers = len(layers_to_cluster)
+ else:
+ expected_num_sparse_layers = num_optimizable_layers
+ for layer_name in pruned_sparsity_dict:
+ if abs(pruned_sparsity_dict[layer_name] - target_sparsity) < error_margin:
+ num_sparse_layers = num_sparse_layers + 1
+ # make sure we are having exactly as many sparse layers as we wanted
+ assert num_sparse_layers == expected_num_sparse_layers
+
+
+@pytest.mark.skip(reason="Test fails randomly, further investigation is needed")
+@pytest.mark.parametrize("target_num_clusters", (32, 4))
+@pytest.mark.parametrize("sparsity_aware", (False, True))
+@pytest.mark.parametrize("layers_to_cluster", (["conv1"], ["conv1", "conv2"], None))
+def test_cluster_simple_model_fully(
+ target_num_clusters: int,
+ sparsity_aware: bool,
+ layers_to_cluster: Optional[List[str]],
+ tmp_path: Path,
+ test_keras_model: Path,
+) -> None:
+ """Simple MNIST test to see if clustering works correctly."""
+ target_sparsity = 0.5
+
+ base_model = tf.keras.models.load_model(str(test_keras_model))
+ train_model(base_model)
+
+ if sparsity_aware:
+ base_model = _prune_model(base_model, target_sparsity, layers_to_cluster)
+
+ clusterer = Clusterer(
+ base_model,
+ ClusteringConfiguration(
+ target_num_clusters,
+ layers_to_cluster,
+ ),
+ )
+ clusterer.apply_optimization()
+ clustered_model = clusterer.get_model()
+
+ temp_file = tmp_path / "test_cluster_simple_model_fully_after.tflite"
+ tflite_clustered_model = convert_to_tflite(clustered_model)
+ save_tflite_model(tflite_clustered_model, temp_file)
+ clustered_tflite_metrics = TFLiteMetrics(str(temp_file))
+
+ _test_num_unique_weights(
+ clustered_tflite_metrics, target_num_clusters, layers_to_cluster
+ )
+
+ if sparsity_aware:
+ _test_sparsity(clustered_tflite_metrics, target_sparsity, layers_to_cluster)
diff --git a/tests/mlia/test_nn_tensorflow_optimizations_pruning.py b/tests/mlia/test_nn_tensorflow_optimizations_pruning.py
new file mode 100644
index 0000000..64030a6
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_optimizations_pruning.py
@@ -0,0 +1,117 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module optimizations/pruning."""
+from pathlib import Path
+from typing import List
+from typing import Optional
+
+import pytest
+import tensorflow as tf
+from numpy.core.numeric import isclose
+
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import save_tflite_model
+from tests.mlia.utils.common import get_dataset
+from tests.mlia.utils.common import train_model
+
+
+def _test_sparsity(
+ metrics: TFLiteMetrics,
+ target_sparsity: float,
+ layers_to_prune: Optional[List[str]],
+) -> None:
+ pruned_sparsity_dict = metrics.sparsity_per_layer()
+ num_sparse_layers = 0
+ num_optimizable_layers = len(pruned_sparsity_dict)
+ error_margin = 0.03
+ if layers_to_prune:
+ expected_num_sparse_layers = len(layers_to_prune)
+ else:
+ expected_num_sparse_layers = num_optimizable_layers
+ for layer_name in pruned_sparsity_dict:
+ if abs(pruned_sparsity_dict[layer_name] - target_sparsity) < error_margin:
+ num_sparse_layers = num_sparse_layers + 1
+ # make sure we are having exactly as many sparse layers as we wanted
+ assert num_sparse_layers == expected_num_sparse_layers
+
+
+def _test_check_sparsity(base_tflite_metrics: TFLiteMetrics) -> None:
+ """Assert the sparsity of a model is zero."""
+ base_sparsity_dict = base_tflite_metrics.sparsity_per_layer()
+ for layer_name, sparsity in base_sparsity_dict.items():
+ assert isclose(
+ sparsity, 0, atol=1e-2
+ ), f"Sparsity for layer '{layer_name}' is {sparsity}, but should be zero."
+
+
+def _get_tflite_metrics(
+ path: Path, tflite_fn: str, model: tf.keras.Model
+) -> TFLiteMetrics:
+ """Save model as TFLiteModel and return metrics."""
+ temp_file = path / tflite_fn
+ save_tflite_model(convert_to_tflite(model), temp_file)
+ return TFLiteMetrics(str(temp_file))
+
+
+@pytest.mark.parametrize("target_sparsity", (0.5, 0.9))
+@pytest.mark.parametrize("mock_data", (False, True))
+@pytest.mark.parametrize("layers_to_prune", (["conv1"], ["conv1", "conv2"], None))
+def test_prune_simple_model_fully(
+ target_sparsity: float,
+ mock_data: bool,
+ layers_to_prune: Optional[List[str]],
+ tmp_path: Path,
+ test_keras_model: Path,
+) -> None:
+ """Simple MNIST test to see if pruning works correctly."""
+ x_train, y_train = get_dataset()
+ batch_size = 1
+ num_epochs = 1
+
+ base_model = tf.keras.models.load_model(str(test_keras_model))
+ train_model(base_model)
+
+ base_tflite_metrics = _get_tflite_metrics(
+ path=tmp_path,
+ tflite_fn="test_prune_simple_model_fully_before.tflite",
+ model=base_model,
+ )
+
+ # Make sure sparsity is zero before pruning
+ _test_check_sparsity(base_tflite_metrics)
+
+ if mock_data:
+ pruner = Pruner(
+ base_model,
+ PruningConfiguration(
+ target_sparsity,
+ layers_to_prune,
+ ),
+ )
+
+ else:
+ pruner = Pruner(
+ base_model,
+ PruningConfiguration(
+ target_sparsity,
+ layers_to_prune,
+ x_train,
+ y_train,
+ batch_size,
+ num_epochs,
+ ),
+ )
+
+ pruner.apply_optimization()
+ pruned_model = pruner.get_model()
+
+ pruned_tflite_metrics = _get_tflite_metrics(
+ path=tmp_path,
+ tflite_fn="test_prune_simple_model_fully_after.tflite",
+ model=pruned_model,
+ )
+
+ _test_sparsity(pruned_tflite_metrics, target_sparsity, layers_to_prune)
diff --git a/tests/mlia/test_nn_tensorflow_optimizations_select.py b/tests/mlia/test_nn_tensorflow_optimizations_select.py
new file mode 100644
index 0000000..5cac8ba
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_optimizations_select.py
@@ -0,0 +1,240 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module select."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.optimizations.clustering import Clusterer
+from mlia.nn.tensorflow.optimizations.clustering import ClusteringConfiguration
+from mlia.nn.tensorflow.optimizations.pruning import Pruner
+from mlia.nn.tensorflow.optimizations.pruning import PruningConfiguration
+from mlia.nn.tensorflow.optimizations.select import get_optimizer
+from mlia.nn.tensorflow.optimizations.select import MultiStageOptimizer
+from mlia.nn.tensorflow.optimizations.select import OptimizationSettings
+
+
+@pytest.mark.parametrize(
+ "config, expected_error, expected_type, expected_config",
+ [
+ (
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ does_not_raise(),
+ Pruner,
+ "pruning: 0.5",
+ ),
+ (
+ PruningConfiguration(0.5),
+ does_not_raise(),
+ Pruner,
+ "pruning: 0.5",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ does_not_raise(),
+ Clusterer,
+ "clustering: 32",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization target should be a "
+ "positive integer. "
+ "Optimization target provided: 0.5",
+ ),
+ None,
+ None,
+ ),
+ (
+ ClusteringConfiguration(32),
+ does_not_raise(),
+ Clusterer,
+ "clustering: 32",
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="superoptimization",
+ optimization_target="supertarget", # type: ignore
+ layers_to_optimize="all", # type: ignore
+ ),
+ pytest.raises(
+ Exception,
+ match="Unsupported optimization type: superoptimization",
+ ),
+ None,
+ None,
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization type is not provided",
+ ),
+ None,
+ None,
+ ),
+ (
+ "wrong_config",
+ pytest.raises(
+ Exception,
+ match="Unknown optimization configuration wrong_config",
+ ),
+ None,
+ None,
+ ),
+ (
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=None, # type: ignore
+ layers_to_optimize=None,
+ ),
+ pytest.raises(
+ Exception,
+ match="Optimization target is not provided",
+ ),
+ None,
+ None,
+ ),
+ (
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ ],
+ does_not_raise(),
+ MultiStageOptimizer,
+ "pruning: 0.5 - clustering: 32",
+ ),
+ ],
+)
+def test_get_optimizer(
+ config: Any,
+ expected_error: Any,
+ expected_type: type,
+ expected_config: str,
+ test_keras_model: Path,
+) -> None:
+ """Test function get_optimzer."""
+ model = tf.keras.models.load_model(str(test_keras_model))
+
+ with expected_error:
+ optimizer = get_optimizer(model, config)
+ assert isinstance(optimizer, expected_type)
+ assert optimizer.optimization_config() == expected_config
+
+
+@pytest.mark.parametrize(
+ "params, expected_result",
+ [
+ (
+ [],
+ [],
+ ),
+ (
+ [("pruning", 0.5)],
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ )
+ ],
+ ),
+ (
+ [("pruning", 0.5), ("clustering", 32)],
+ [
+ OptimizationSettings(
+ optimization_type="pruning",
+ optimization_target=0.5,
+ layers_to_optimize=None,
+ ),
+ OptimizationSettings(
+ optimization_type="clustering",
+ optimization_target=32,
+ layers_to_optimize=None,
+ ),
+ ],
+ ),
+ ],
+)
+def test_optimization_settings_create_from(
+ params: List[Tuple[str, float]], expected_result: List[OptimizationSettings]
+) -> None:
+ """Test creating settings from parsed params."""
+ assert OptimizationSettings.create_from(params) == expected_result
+
+
+@pytest.mark.parametrize(
+ "settings, expected_next_target, expected_error",
+ [
+ [
+ OptimizationSettings("clustering", 32, None),
+ OptimizationSettings("clustering", 16, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("clustering", 4, None),
+ OptimizationSettings("clustering", 4, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("clustering", 10, None),
+ OptimizationSettings("clustering", 8, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("pruning", 0.5, None),
+ OptimizationSettings("pruning", 0.6, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("pruning", 0.9, None),
+ OptimizationSettings("pruning", 0.9, None),
+ does_not_raise(),
+ ],
+ [
+ OptimizationSettings("super_optimization", 42, None),
+ None,
+ pytest.raises(
+ Exception, match="Unknown optimization type super_optimization"
+ ),
+ ],
+ ],
+)
+def test_optimization_settings_next_target(
+ settings: OptimizationSettings,
+ expected_next_target: OptimizationSettings,
+ expected_error: Any,
+) -> None:
+ """Test getting next optimization target."""
+ with expected_error:
+ assert settings.next_target() == expected_next_target
diff --git a/tests/mlia/test_nn_tensorflow_tflite_metrics.py b/tests/mlia/test_nn_tensorflow_tflite_metrics.py
new file mode 100644
index 0000000..805f7d1
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_tflite_metrics.py
@@ -0,0 +1,137 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module utils/tflite_metrics."""
+import os
+import tempfile
+from math import isclose
+from pathlib import Path
+from typing import Generator
+from typing import List
+
+import numpy as np
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode
+from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
+
+
+def _dummy_keras_model() -> tf.keras.Model:
+ # Create a dummy model
+ keras_model = tf.keras.Sequential(
+ [
+ tf.keras.Input(shape=(8, 8, 3)),
+ tf.keras.layers.Conv2D(4, 3),
+ tf.keras.layers.DepthwiseConv2D(3),
+ tf.keras.layers.Flatten(),
+ tf.keras.layers.Dense(8),
+ ]
+ )
+ return keras_model
+
+
+def _sparse_binary_keras_model() -> tf.keras.Model:
+ def get_sparse_weights(shape: List[int]) -> np.array:
+ weights = np.zeros(shape)
+ with np.nditer(weights, op_flags=["writeonly"]) as weight_iterator:
+ for idx, value in enumerate(weight_iterator):
+ if idx % 2 == 0:
+ value[...] = 1.0
+ return weights
+
+ keras_model = _dummy_keras_model()
+ # Assign weights to have 0.5 sparsity
+ for layer in keras_model.layers:
+ if not isinstance(layer, tf.keras.layers.Flatten):
+ weight = layer.weights[0]
+ weight.assign(get_sparse_weights(weight.shape))
+ print(layer)
+ print(weight.numpy())
+ return keras_model
+
+
+@pytest.fixture(scope="class", name="tflite_file")
+def fixture_tflite_file() -> Generator:
+ """Generate temporary TFLite file for tests."""
+ converter = tf.lite.TFLiteConverter.from_keras_model(_sparse_binary_keras_model())
+ tflite_model = converter.convert()
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file = os.path.join(tmp_dir, "test.tflite")
+ Path(file).write_bytes(tflite_model)
+ yield file
+
+
+@pytest.fixture(scope="function", name="metrics")
+def fixture_metrics(tflite_file: str) -> TFLiteMetrics:
+ """Generate metrics file for a given TFLite model."""
+ return TFLiteMetrics(tflite_file)
+
+
+class TestTFLiteMetrics:
+ """Tests for module TFLite_metrics."""
+
+ @staticmethod
+ def test_sparsity(metrics: TFLiteMetrics) -> None:
+ """Test sparsity."""
+ # Create new instance with a dummy TFLite file
+ # Check sparsity calculation
+ sparsity_per_layer = metrics.sparsity_per_layer()
+ for name, sparsity in sparsity_per_layer.items():
+ assert isclose(sparsity, 0.5), "Layer '{}' has incorrect sparsity.".format(
+ name
+ )
+ assert isclose(metrics.sparsity_overall(), 0.5)
+
+ @staticmethod
+ def test_clusters(metrics: TFLiteMetrics) -> None:
+ """Test clusters."""
+ # NUM_CLUSTERS_PER_AXIS and NUM_CLUSTERS_MIN_MAX can be handled together
+ for mode in [
+ ReportClusterMode.NUM_CLUSTERS_PER_AXIS,
+ ReportClusterMode.NUM_CLUSTERS_MIN_MAX,
+ ]:
+ num_unique_weights = metrics.num_unique_weights(mode)
+ for name, num_unique_per_axis in num_unique_weights.items():
+ for num_unique in num_unique_per_axis:
+ assert (
+ num_unique == 2
+ ), "Layer '{}' has incorrect number of clusters.".format(name)
+ # NUM_CLUSTERS_HISTOGRAM
+ hists = metrics.num_unique_weights(ReportClusterMode.NUM_CLUSTERS_HISTOGRAM)
+ assert hists
+ for name, hist in hists.items():
+ assert hist
+ for idx, num_axes in enumerate(hist):
+ # The histogram starts with the bin for for num_clusters == 1
+ num_clusters = idx + 1
+ msg = (
+ "Histogram of layer '{}': There are {} axes with {} "
+ "clusters".format(name, num_axes, num_clusters)
+ )
+ if num_clusters == 2:
+ assert num_axes > 0, "{}, but there should be at least one.".format(
+ msg
+ )
+ else:
+ assert num_axes == 0, "{}, but there should be none.".format(msg)
+
+ @staticmethod
+ @pytest.mark.parametrize("report_sparsity", (False, True))
+ @pytest.mark.parametrize("report_cluster_mode", ReportClusterMode)
+ @pytest.mark.parametrize("max_num_clusters", (-1, 8))
+ @pytest.mark.parametrize("verbose", (False, True))
+ def test_summary(
+ tflite_file: str,
+ report_sparsity: bool,
+ report_cluster_mode: ReportClusterMode,
+ max_num_clusters: int,
+ verbose: bool,
+ ) -> None:
+ """Test the summary function."""
+ for metrics in [TFLiteMetrics(tflite_file), TFLiteMetrics(tflite_file, [])]:
+ metrics.summary(
+ report_sparsity=report_sparsity,
+ report_cluster_mode=report_cluster_mode,
+ max_num_clusters=max_num_clusters,
+ verbose=verbose,
+ )
diff --git a/tests/mlia/test_nn_tensorflow_utils.py b/tests/mlia/test_nn_tensorflow_utils.py
new file mode 100644
index 0000000..6d27299
--- /dev/null
+++ b/tests/mlia/test_nn_tensorflow_utils.py
@@ -0,0 +1,81 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module utils/test_utils."""
+from pathlib import Path
+
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.nn.tensorflow.utils import get_tf_tensor_shape
+from mlia.nn.tensorflow.utils import is_keras_model
+from mlia.nn.tensorflow.utils import is_tflite_model
+from mlia.nn.tensorflow.utils import save_keras_model
+from mlia.nn.tensorflow.utils import save_tflite_model
+
+
+def test_convert_to_tflite(test_keras_model: Path) -> None:
+ """Test converting Keras model to TFLite."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+ tflite_model = convert_to_tflite(keras_model)
+
+ assert tflite_model
+
+
+def test_save_keras_model(tmp_path: Path, test_keras_model: Path) -> None:
+ """Test saving Keras model."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+
+ temp_file = tmp_path / "test_model_saving.h5"
+ save_keras_model(keras_model, temp_file)
+ loaded_model = tf.keras.models.load_model(temp_file)
+
+ assert loaded_model.summary() == keras_model.summary()
+
+
+def test_save_tflite_model(tmp_path: Path, test_keras_model: Path) -> None:
+ """Test saving TFLite model."""
+ keras_model = tf.keras.models.load_model(str(test_keras_model))
+
+ tflite_model = convert_to_tflite(keras_model)
+
+ temp_file = tmp_path / "test_model_saving.tflite"
+ save_tflite_model(tflite_model, temp_file)
+
+ interpreter = tf.lite.Interpreter(model_path=str(temp_file))
+ assert interpreter
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_result",
+ [
+ [Path("sample_model.tflite"), True],
+ [Path("strange_model.tflite.tfl"), False],
+ [Path("sample_model.h5"), False],
+ [Path("sample_model"), False],
+ ],
+)
+def test_is_tflite_model(model_path: Path, expected_result: bool) -> None:
+ """Test function is_tflite_model."""
+ result = is_tflite_model(model_path)
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "model_path, expected_result",
+ [
+ [Path("sample_model.h5"), True],
+ [Path("strange_model.h5.keras"), False],
+ [Path("sample_model.tflite"), False],
+ [Path("sample_model"), False],
+ ],
+)
+def test_is_keras_model(model_path: Path, expected_result: bool) -> None:
+ """Test function is_keras_model."""
+ result = is_keras_model(model_path)
+ assert result == expected_result
+
+
+def test_get_tf_tensor_shape(test_tf_model: Path) -> None:
+ """Test get_tf_tensor_shape with test model."""
+ assert get_tf_tensor_shape(str(test_tf_model)) == [1, 28, 28, 1]
diff --git a/tests/mlia/test_resources/vela/sample_vela.ini b/tests/mlia/test_resources/vela/sample_vela.ini
new file mode 100644
index 0000000..c992458
--- /dev/null
+++ b/tests/mlia/test_resources/vela/sample_vela.ini
@@ -0,0 +1,47 @@
+; SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+; SPDX-License-Identifier: Apache-2.0
+; Ethos-U55 High-End Embedded: SRAM (4 GB/s) and Flash (0.5 GB/s)
+[System_Config.Ethos_U55_High_End_Embedded]
+core_clock=500e6
+axi0_port=Sram
+axi1_port=OffChipFlash
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+OffChipFlash_clock_scale=0.125
+OffChipFlash_burst_length=128
+OffChipFlash_read_latency=64
+OffChipFlash_write_latency=64
+
+; Ethos-U65 High-End: SRAM (16 GB/s) and DRAM (3.75 GB/s)
+[System_Config.Ethos_U65_High_End]
+core_clock=1e9
+axi0_port=Sram
+axi1_port=Dram
+Sram_clock_scale=1.0
+Sram_burst_length=32
+Sram_read_latency=32
+Sram_write_latency=32
+Dram_clock_scale=0.234375
+Dram_burst_length=128
+Dram_read_latency=500
+Dram_write_latency=250
+
+; -----------------------------------------------------------------------------
+; Memory Mode
+
+; Shared SRAM: the SRAM is shared between the Ethos-U and the Cortex-M software
+; The non-SRAM memory is assumed to be read-only
+[Memory_Mode.Shared_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi0
+cache_mem_area=Axi0
+
+; The SRAM (384KB) is only for use by the Ethos-U
+; The non-SRAM memory is assumed to be read-writeable
+[Memory_Mode.Dedicated_Sram]
+const_mem_area=Axi1
+arena_mem_area=Axi1
+cache_mem_area=Axi0
+arena_cache_size=393216
diff --git a/tests/mlia/test_tools_aiet_wrapper.py b/tests/mlia/test_tools_aiet_wrapper.py
new file mode 100644
index 0000000..ab55b71
--- /dev/null
+++ b/tests/mlia/test_tools_aiet_wrapper.py
@@ -0,0 +1,760 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module tools/aiet_wrapper."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from unittest.mock import MagicMock
+from unittest.mock import PropertyMock
+
+import pytest
+
+from mlia.tools.aiet_wrapper import AIETRunner
+from mlia.tools.aiet_wrapper import DeviceInfo
+from mlia.tools.aiet_wrapper import estimate_performance
+from mlia.tools.aiet_wrapper import ExecutionParams
+from mlia.tools.aiet_wrapper import GenericInferenceOutputParser
+from mlia.tools.aiet_wrapper import GenericInferenceRunnerEthosU
+from mlia.tools.aiet_wrapper import get_aiet_runner
+from mlia.tools.aiet_wrapper import get_generic_runner
+from mlia.tools.aiet_wrapper import get_system_name
+from mlia.tools.aiet_wrapper import is_supported
+from mlia.tools.aiet_wrapper import ModelInfo
+from mlia.tools.aiet_wrapper import PerformanceMetrics
+from mlia.tools.aiet_wrapper import supported_backends
+from mlia.utils.proc import RunningCommand
+
+
+@pytest.mark.parametrize(
+ "data, is_ready, result, missed_keys",
+ [
+ (
+ [],
+ False,
+ {},
+ [
+ "npu_active_cycles",
+ "npu_axi0_rd_data_beat_received",
+ "npu_axi0_wr_data_beat_written",
+ "npu_axi1_rd_data_beat_received",
+ "npu_idle_cycles",
+ "npu_total_cycles",
+ ],
+ ),
+ (
+ ["sample text"],
+ False,
+ {},
+ [
+ "npu_active_cycles",
+ "npu_axi0_rd_data_beat_received",
+ "npu_axi0_wr_data_beat_written",
+ "npu_axi1_rd_data_beat_received",
+ "npu_idle_cycles",
+ "npu_total_cycles",
+ ],
+ ),
+ (
+ [
+ ["NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 123"],
+ False,
+ {"npu_axi0_rd_data_beat_received": 123},
+ [
+ "npu_active_cycles",
+ "npu_axi0_wr_data_beat_written",
+ "npu_axi1_rd_data_beat_received",
+ "npu_idle_cycles",
+ "npu_total_cycles",
+ ],
+ ]
+ ),
+ (
+ [
+ "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1",
+ "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2",
+ "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3",
+ "NPU ACTIVE cycles: 4",
+ "NPU IDLE cycles: 5",
+ "NPU TOTAL cycles: 6",
+ ],
+ True,
+ {
+ "npu_axi0_rd_data_beat_received": 1,
+ "npu_axi0_wr_data_beat_written": 2,
+ "npu_axi1_rd_data_beat_received": 3,
+ "npu_active_cycles": 4,
+ "npu_idle_cycles": 5,
+ "npu_total_cycles": 6,
+ },
+ [],
+ ),
+ ],
+)
+def test_generic_inference_output_parser(
+ data: List[str], is_ready: bool, result: Dict, missed_keys: List[str]
+) -> None:
+ """Test generic runner output parser."""
+ parser = GenericInferenceOutputParser()
+
+ for line in data:
+ parser.feed(line)
+
+ assert parser.is_ready() == is_ready
+ assert parser.result == result
+ assert parser.missed_keys() == missed_keys
+
+
+class TestAIETRunner:
+ """Tests for AIETRunner class."""
+
+ @staticmethod
+ def _setup_aiet(
+ monkeypatch: pytest.MonkeyPatch,
+ available_systems: Optional[List[str]] = None,
+ available_apps: Optional[List[str]] = None,
+ ) -> None:
+ """Set up AIET metadata."""
+
+ def mock_system(system: str) -> MagicMock:
+ """Mock the System instance."""
+ mock = MagicMock()
+ type(mock).name = PropertyMock(return_value=system)
+ return mock
+
+ def mock_app(app: str) -> MagicMock:
+ """Mock the Application instance."""
+ mock = MagicMock()
+ type(mock).name = PropertyMock(return_value=app)
+ mock.can_run_on.return_value = True
+ return mock
+
+ system_mocks = [mock_system(name) for name in (available_systems or [])]
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.get_available_systems",
+ MagicMock(return_value=system_mocks),
+ )
+
+ apps_mock = [mock_app(name) for name in (available_apps or [])]
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.get_available_applications",
+ MagicMock(return_value=apps_mock),
+ )
+
+ @pytest.mark.parametrize(
+ "available_systems, system, installed",
+ [
+ ([], "system1", False),
+ (["system1", "system2"], "system1", True),
+ ],
+ )
+ def test_is_system_installed(
+ self,
+ available_systems: List,
+ system: str,
+ installed: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method is_system_installed."""
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ self._setup_aiet(monkeypatch, available_systems)
+
+ assert aiet_runner.is_system_installed(system) == installed
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_systems, systems",
+ [
+ ([], []),
+ (["system1"], ["system1"]),
+ ],
+ )
+ def test_installed_systems(
+ self,
+ available_systems: List[str],
+ systems: List[str],
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method installed_systems."""
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ self._setup_aiet(monkeypatch, available_systems)
+ assert aiet_runner.get_installed_systems() == systems
+
+ mock_executor.assert_not_called()
+
+ @staticmethod
+ def test_install_system(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test system installation."""
+ install_system_mock = MagicMock()
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.install_system", install_system_mock
+ )
+
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+ aiet_runner.install_system(Path("test_system_path"))
+
+ install_system_mock.assert_called_once_with(Path("test_system_path"))
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_systems, systems, expected_result",
+ [
+ ([], [], False),
+ (["system1"], [], False),
+ (["system1"], ["system1"], True),
+ (["system1", "system2"], ["system1", "system3"], False),
+ (["system1", "system2"], ["system1", "system2"], True),
+ ],
+ )
+ def test_systems_installed(
+ self,
+ available_systems: List[str],
+ systems: List[str],
+ expected_result: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method systems_installed."""
+ self._setup_aiet(monkeypatch, available_systems)
+
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ assert aiet_runner.systems_installed(systems) is expected_result
+
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_apps, applications, expected_result",
+ [
+ ([], [], False),
+ (["app1"], [], False),
+ (["app1"], ["app1"], True),
+ (["app1", "app2"], ["app1", "app3"], False),
+ (["app1", "app2"], ["app1", "app2"], True),
+ ],
+ )
+ def test_applications_installed(
+ self,
+ available_apps: List[str],
+ applications: List[str],
+ expected_result: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method applications_installed."""
+ self._setup_aiet(monkeypatch, [], available_apps)
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+
+ assert aiet_runner.applications_installed(applications) is expected_result
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_apps, applications",
+ [
+ ([], []),
+ (
+ ["application1", "application2"],
+ ["application1", "application2"],
+ ),
+ ],
+ )
+ def test_get_installed_applications(
+ self,
+ available_apps: List[str],
+ applications: List[str],
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method get_installed_applications."""
+ mock_executor = MagicMock()
+ self._setup_aiet(monkeypatch, [], available_apps)
+
+ aiet_runner = AIETRunner(mock_executor)
+ assert applications == aiet_runner.get_installed_applications()
+
+ mock_executor.assert_not_called()
+
+ @staticmethod
+ def test_install_application(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test application installation."""
+ mock_install_application = MagicMock()
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.install_application", mock_install_application
+ )
+
+ mock_executor = MagicMock()
+
+ aiet_runner = AIETRunner(mock_executor)
+ aiet_runner.install_application(Path("test_application_path"))
+ mock_install_application.assert_called_once_with(Path("test_application_path"))
+
+ mock_executor.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "available_apps, application, installed",
+ [
+ ([], "system1", False),
+ (
+ ["application1", "application2"],
+ "application1",
+ True,
+ ),
+ (
+ [],
+ "application1",
+ False,
+ ),
+ ],
+ )
+ def test_is_application_installed(
+ self,
+ available_apps: List[str],
+ application: str,
+ installed: bool,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Test method is_application_installed."""
+ self._setup_aiet(monkeypatch, [], available_apps)
+
+ mock_executor = MagicMock()
+ aiet_runner = AIETRunner(mock_executor)
+ assert installed == aiet_runner.is_application_installed(application, "system1")
+
+ mock_executor.assert_not_called()
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "execution_params, expected_command",
+ [
+ (
+ ExecutionParams("application1", "system1", [], [], []),
+ ["aiet", "application", "run", "-n", "application1", "-s", "system1"],
+ ),
+ (
+ ExecutionParams(
+ "application1",
+ "system1",
+ ["input_file=123.txt", "size=777"],
+ ["param1=456", "param2=789"],
+ ["source1.txt:dest1.txt", "source2.txt:dest2.txt"],
+ ),
+ [
+ "aiet",
+ "application",
+ "run",
+ "-n",
+ "application1",
+ "-s",
+ "system1",
+ "-p",
+ "input_file=123.txt",
+ "-p",
+ "size=777",
+ "--system-param",
+ "param1=456",
+ "--system-param",
+ "param2=789",
+ "--deploy",
+ "source1.txt:dest1.txt",
+ "--deploy",
+ "source2.txt:dest2.txt",
+ ],
+ ),
+ ],
+ )
+ def test_run_application(
+ execution_params: ExecutionParams, expected_command: List[str]
+ ) -> None:
+ """Test method run_application."""
+ mock_executor = MagicMock()
+ mock_running_command = MagicMock()
+ mock_executor.submit.return_value = mock_running_command
+
+ aiet_runner = AIETRunner(mock_executor)
+ aiet_runner.run_application(execution_params)
+
+ mock_executor.submit.assert_called_once_with(expected_command)
+
+
+@pytest.mark.parametrize(
+ "device, system, application, backend, expected_error",
+ [
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", True),
+ "Corstone-300",
+ does_not_raise(),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U55", False),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"System Corstone-300: Cortex-M55\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM "
+ r"for the system Corstone-300: Cortex-M55\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-310: Cortex-M85+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", True),
+ "Corstone-310",
+ does_not_raise(),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-310: Cortex-M85+Ethos-U55", False),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-310",
+ pytest.raises(
+ Exception,
+ match=r"System Corstone-310: Cortex-M85\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram"),
+ ("Corstone-310: Cortex-M85+Ethos-U55", True),
+ ("Generic Inference Runner: Ethos-U55 SRAM", False),
+ "Corstone-310",
+ pytest.raises(
+ Exception,
+ match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM "
+ r"for the system Corstone-310: Cortex-M85\+Ethos-U55 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U65", True),
+ ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", True),
+ "Corstone-300",
+ does_not_raise(),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U65", False),
+ ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"System Corstone-300: Cortex-M55\+Ethos-U65 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(device_type="ethos-u65", mac=512, memory_mode="Shared_Sram"),
+ ("Corstone-300: Cortex-M55+Ethos-U65", True),
+ ("Generic Inference Runner: Ethos-U65 Dedicated SRAM", False),
+ "Corstone-300",
+ pytest.raises(
+ Exception,
+ match=r"Application Generic Inference Runner: Ethos-U55/65 Shared SRAM "
+ r"for the system Corstone-300: Cortex-M55\+Ethos-U65 is not installed",
+ ),
+ ),
+ (
+ DeviceInfo(
+ device_type="unknown_device", # type: ignore
+ mac=None, # type: ignore
+ memory_mode="Shared_Sram",
+ ),
+ ("some_system", False),
+ ("some_application", False),
+ "some backend",
+ pytest.raises(Exception, match="Unsupported device unknown_device"),
+ ),
+ ],
+)
+def test_estimate_performance(
+ device: DeviceInfo,
+ system: Tuple[str, bool],
+ application: Tuple[str, bool],
+ backend: str,
+ expected_error: Any,
+ test_tflite_model: Path,
+ aiet_runner: MagicMock,
+) -> None:
+ """Test getting performance estimations."""
+ system_name, system_installed = system
+ application_name, application_installed = application
+
+ aiet_runner.is_system_installed.return_value = system_installed
+ aiet_runner.is_application_installed.return_value = application_installed
+
+ mock_process = create_mock_process(
+ [
+ "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1",
+ "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2",
+ "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3",
+ "NPU ACTIVE cycles: 4",
+ "NPU IDLE cycles: 5",
+ "NPU TOTAL cycles: 6",
+ ],
+ [],
+ )
+
+ mock_generic_inference_run = RunningCommand(mock_process)
+ aiet_runner.run_application.return_value = mock_generic_inference_run
+
+ with expected_error:
+ perf_metrics = estimate_performance(
+ ModelInfo(test_tflite_model), device, backend
+ )
+
+ assert isinstance(perf_metrics, PerformanceMetrics)
+ assert perf_metrics == PerformanceMetrics(
+ npu_axi0_rd_data_beat_received=1,
+ npu_axi0_wr_data_beat_written=2,
+ npu_axi1_rd_data_beat_received=3,
+ npu_active_cycles=4,
+ npu_idle_cycles=5,
+ npu_total_cycles=6,
+ )
+
+ assert aiet_runner.is_system_installed.called_once_with(system_name)
+ assert aiet_runner.is_application_installed.called_once_with(
+ application_name, system_name
+ )
+
+
+@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+def test_estimate_performance_insufficient_data(
+ aiet_runner: MagicMock, test_tflite_model: Path, backend: str
+) -> None:
+ """Test that performance could not be estimated when not all data presented."""
+ aiet_runner.is_system_installed.return_value = True
+ aiet_runner.is_application_installed.return_value = True
+
+ no_total_cycles_output = [
+ "NPU AXI0_RD_DATA_BEAT_RECEIVED beats: 1",
+ "NPU AXI0_WR_DATA_BEAT_WRITTEN beats: 2",
+ "NPU AXI1_RD_DATA_BEAT_RECEIVED beats: 3",
+ "NPU ACTIVE cycles: 4",
+ "NPU IDLE cycles: 5",
+ ]
+ mock_process = create_mock_process(
+ no_total_cycles_output,
+ [],
+ )
+
+ mock_generic_inference_run = RunningCommand(mock_process)
+ aiet_runner.run_application.return_value = mock_generic_inference_run
+
+ with pytest.raises(
+ Exception, match="Unable to get performance metrics, insufficient data"
+ ):
+ device = DeviceInfo(device_type="ethos-u55", mac=32, memory_mode="Shared_Sram")
+ estimate_performance(ModelInfo(test_tflite_model), device, backend)
+
+
+@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+def test_estimate_performance_invalid_output(
+ test_tflite_model: Path, aiet_runner: MagicMock, backend: str
+) -> None:
+ """Test estimation could not be done if inference produces unexpected output."""
+ aiet_runner.is_system_installed.return_value = True
+ aiet_runner.is_application_installed.return_value = True
+
+ mock_process = create_mock_process(
+ ["Something", "is", "wrong"], ["What a nice error!"]
+ )
+ aiet_runner.run_application.return_value = RunningCommand(mock_process)
+
+ with pytest.raises(Exception, match="Unable to get performance metrics"):
+ estimate_performance(
+ ModelInfo(test_tflite_model),
+ DeviceInfo(device_type="ethos-u55", mac=256, memory_mode="Shared_Sram"),
+ backend=backend,
+ )
+
+
+def test_get_aiet_runner() -> None:
+ """Test getting aiet runner."""
+ aiet_runner = get_aiet_runner()
+ assert isinstance(aiet_runner, AIETRunner)
+
+
+def create_mock_process(stdout: List[str], stderr: List[str]) -> MagicMock:
+ """Mock underlying process."""
+ mock_process = MagicMock()
+ mock_process.poll.return_value = 0
+ type(mock_process).stdout = PropertyMock(return_value=iter(stdout))
+ type(mock_process).stderr = PropertyMock(return_value=iter(stderr))
+ return mock_process
+
+
+@pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+def test_get_generic_runner(backend: str) -> None:
+ """Test function get_generic_runner()."""
+ device_info = DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram")
+
+ runner = get_generic_runner(device_info=device_info, backend=backend)
+ assert isinstance(runner, GenericInferenceRunnerEthosU)
+
+ with pytest.raises(RuntimeError):
+ get_generic_runner(device_info=device_info, backend="UNKNOWN_BACKEND")
+
+
+@pytest.mark.parametrize(
+ ("backend", "device_type"),
+ (
+ ("Corstone-300", "ethos-u55"),
+ ("Corstone-300", "ethos-u65"),
+ ("Corstone-310", "ethos-u55"),
+ ),
+)
+def test_aiet_backend_support(backend: str, device_type: str) -> None:
+ """Test AIET backend & device support."""
+ assert is_supported(backend)
+ assert is_supported(backend, device_type)
+
+ assert get_system_name(backend, device_type)
+
+ assert backend in supported_backends()
+
+
+class TestGenericInferenceRunnerEthosU:
+ """Test for the class GenericInferenceRunnerEthosU."""
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "device, backend, expected_system, expected_app",
+ [
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"),
+ "Corstone-310",
+ "Corstone-310: Cortex-M85+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Sram"),
+ "Corstone-310",
+ "Corstone-310: Cortex-M85+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55 SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u55", 256, memory_mode="Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U55",
+ "Generic Inference Runner: Ethos-U55 SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ "Generic Inference Runner: Ethos-U55/65 Shared SRAM",
+ ],
+ [
+ DeviceInfo("ethos-u65", 256, memory_mode="Dedicated_Sram"),
+ "Corstone-300",
+ "Corstone-300: Cortex-M55+Ethos-U65",
+ "Generic Inference Runner: Ethos-U65 Dedicated SRAM",
+ ],
+ ],
+ )
+ def test_artifact_resolver(
+ device: DeviceInfo, backend: str, expected_system: str, expected_app: str
+ ) -> None:
+ """Test artifact resolving based on the provided parameters."""
+ generic_runner = get_generic_runner(device, backend)
+ assert isinstance(generic_runner, GenericInferenceRunnerEthosU)
+
+ assert generic_runner.system_name == expected_system
+ assert generic_runner.app_name == expected_app
+
+ @staticmethod
+ def test_artifact_resolver_unsupported_backend() -> None:
+ """Test that it should be not possible to use unsupported backends."""
+ with pytest.raises(
+ RuntimeError, match="Unsupported device ethos-u65 for backend test_backend"
+ ):
+ get_generic_runner(
+ DeviceInfo("ethos-u65", 256, memory_mode="Shared_Sram"), "test_backend"
+ )
+
+ @staticmethod
+ def test_artifact_resolver_unsupported_memory_mode() -> None:
+ """Test that it should be not possible to use unsupported memory modes."""
+ with pytest.raises(
+ RuntimeError, match="Unsupported memory mode test_memory_mode"
+ ):
+ get_generic_runner(
+ DeviceInfo(
+ "ethos-u65",
+ 256,
+ memory_mode="test_memory_mode", # type: ignore
+ ),
+ "Corstone-300",
+ )
+
+ @staticmethod
+ @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+ def test_inference_should_fail_if_system_not_installed(
+ aiet_runner: MagicMock, test_tflite_model: Path, backend: str
+ ) -> None:
+ """Test that inference should fail if system is not installed."""
+ aiet_runner.is_system_installed.return_value = False
+
+ generic_runner = get_generic_runner(
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend
+ )
+ with pytest.raises(
+ Exception,
+ match=r"System Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not installed",
+ ):
+ generic_runner.run(ModelInfo(test_tflite_model), [])
+
+ @staticmethod
+ @pytest.mark.parametrize("backend", ("Corstone-300", "Corstone-310"))
+ def test_inference_should_fail_is_apps_not_installed(
+ aiet_runner: MagicMock, test_tflite_model: Path, backend: str
+ ) -> None:
+ """Test that inference should fail if apps are not installed."""
+ aiet_runner.is_system_installed.return_value = True
+ aiet_runner.is_application_installed.return_value = False
+
+ generic_runner = get_generic_runner(
+ DeviceInfo("ethos-u55", 256, memory_mode="Shared_Sram"), backend
+ )
+ with pytest.raises(
+ Exception,
+ match="Application Generic Inference Runner: Ethos-U55/65 Shared SRAM"
+ r" for the system Corstone-3[01]0: Cortex-M[58]5\+Ethos-U55 is not "
+ r"installed",
+ ):
+ generic_runner.run(ModelInfo(test_tflite_model), [])
+
+
+@pytest.fixture(name="aiet_runner")
+def fixture_aiet_runner(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
+ """Mock AIET runner."""
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ monkeypatch.setattr(
+ "mlia.tools.aiet_wrapper.get_aiet_runner",
+ MagicMock(return_value=aiet_runner_mock),
+ )
+ return aiet_runner_mock
diff --git a/tests/mlia/test_tools_metadata_common.py b/tests/mlia/test_tools_metadata_common.py
new file mode 100644
index 0000000..7663b83
--- /dev/null
+++ b/tests/mlia/test_tools_metadata_common.py
@@ -0,0 +1,196 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for commmon installation related functions."""
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from unittest.mock import call
+from unittest.mock import MagicMock
+from unittest.mock import PropertyMock
+
+import pytest
+
+from mlia.tools.metadata.common import DefaultInstallationManager
+from mlia.tools.metadata.common import DownloadAndInstall
+from mlia.tools.metadata.common import Installation
+from mlia.tools.metadata.common import InstallationType
+from mlia.tools.metadata.common import InstallFromPath
+
+
+def get_installation_mock(
+ name: str,
+ already_installed: bool = False,
+ could_be_installed: bool = False,
+ supported_install_type: Optional[type] = None,
+) -> MagicMock:
+ """Get mock instance for the installation."""
+ mock = MagicMock(spec=Installation)
+
+ def supports(install_type: InstallationType) -> bool:
+ if supported_install_type is None:
+ return False
+
+ return isinstance(install_type, supported_install_type)
+
+ mock.supports.side_effect = supports
+
+ props = {
+ "name": name,
+ "already_installed": already_installed,
+ "could_be_installed": could_be_installed,
+ }
+ for prop, value in props.items():
+ setattr(type(mock), prop, PropertyMock(return_value=value))
+
+ return mock
+
+
+def _already_installed_mock() -> MagicMock:
+ return get_installation_mock(
+ name="already_installed",
+ already_installed=True,
+ )
+
+
+def _ready_for_installation_mock() -> MagicMock:
+ return get_installation_mock(
+ name="ready_for_installation",
+ already_installed=False,
+ could_be_installed=True,
+ )
+
+
+def _could_be_downloaded_and_installed_mock() -> MagicMock:
+ return get_installation_mock(
+ name="could_be_downloaded_and_installed",
+ already_installed=False,
+ could_be_installed=True,
+ supported_install_type=DownloadAndInstall,
+ )
+
+
+def _could_be_installed_from_mock() -> MagicMock:
+ return get_installation_mock(
+ name="could_be_installed_from",
+ already_installed=False,
+ could_be_installed=True,
+ supported_install_type=InstallFromPath,
+ )
+
+
+def get_installation_manager(
+ noninteractive: bool,
+ installations: List[Any],
+ monkeypatch: pytest.MonkeyPatch,
+ yes_response: bool = True,
+) -> DefaultInstallationManager:
+ """Get installation manager instance."""
+ if not noninteractive:
+ monkeypatch.setattr(
+ "mlia.tools.metadata.common.yes", MagicMock(return_value=yes_response)
+ )
+
+ return DefaultInstallationManager(installations, noninteractive=noninteractive)
+
+
+def test_installation_manager_filtering() -> None:
+ """Test default installation manager."""
+ already_installed = _already_installed_mock()
+ ready_for_installation = _ready_for_installation_mock()
+ could_be_downloaded_and_installed = _could_be_downloaded_and_installed_mock()
+
+ manager = DefaultInstallationManager(
+ [
+ already_installed,
+ ready_for_installation,
+ could_be_downloaded_and_installed,
+ ]
+ )
+ assert manager.already_installed() == [already_installed]
+ assert manager.ready_for_installation() == [
+ ready_for_installation,
+ could_be_downloaded_and_installed,
+ ]
+ assert manager.could_be_downloaded_and_installed() == [
+ could_be_downloaded_and_installed
+ ]
+ assert manager.could_be_downloaded_and_installed("some_installation") == []
+
+
+@pytest.mark.parametrize("noninteractive", [True, False])
+@pytest.mark.parametrize(
+ "install_mock, eula_agreement, backend_name, expected_call",
+ [
+ [
+ _could_be_downloaded_and_installed_mock(),
+ True,
+ None,
+ [call(DownloadAndInstall(eula_agreement=True))],
+ ],
+ [
+ _could_be_downloaded_and_installed_mock(),
+ False,
+ None,
+ [call(DownloadAndInstall(eula_agreement=False))],
+ ],
+ [
+ _could_be_downloaded_and_installed_mock(),
+ False,
+ "unknown",
+ [],
+ ],
+ ],
+)
+def test_installation_manager_download_and_install(
+ install_mock: MagicMock,
+ noninteractive: bool,
+ eula_agreement: bool,
+ backend_name: Optional[str],
+ expected_call: Any,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test installation process."""
+ install_mock.reset_mock()
+
+ manager = get_installation_manager(noninteractive, [install_mock], monkeypatch)
+
+ manager.download_and_install(backend_name, eula_agreement=eula_agreement)
+ assert install_mock.install.mock_calls == expected_call
+
+
+@pytest.mark.parametrize("noninteractive", [True, False])
+@pytest.mark.parametrize(
+ "install_mock, backend_name, expected_call",
+ [
+ [
+ _could_be_installed_from_mock(),
+ None,
+ [call(InstallFromPath(Path("some_path")))],
+ ],
+ [
+ _could_be_installed_from_mock(),
+ "unknown",
+ [],
+ ],
+ [
+ _already_installed_mock(),
+ "already_installed",
+ [],
+ ],
+ ],
+)
+def test_installation_manager_install_from(
+ install_mock: MagicMock,
+ noninteractive: bool,
+ backend_name: Optional[str],
+ expected_call: Any,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test installation process."""
+ install_mock.reset_mock()
+
+ manager = get_installation_manager(noninteractive, [install_mock], monkeypatch)
+ manager.install_from(Path("some_path"), backend_name)
+
+ assert install_mock.install.mock_calls == expected_call
diff --git a/tests/mlia/test_tools_metadata_corstone.py b/tests/mlia/test_tools_metadata_corstone.py
new file mode 100644
index 0000000..2ce3610
--- /dev/null
+++ b/tests/mlia/test_tools_metadata_corstone.py
@@ -0,0 +1,419 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Corstone related installation functions.."""
+import tarfile
+from pathlib import Path
+from typing import List
+from typing import Optional
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.tools.aiet_wrapper import AIETRunner
+from mlia.tools.metadata.common import DownloadAndInstall
+from mlia.tools.metadata.common import InstallFromPath
+from mlia.tools.metadata.corstone import AIETBasedInstallation
+from mlia.tools.metadata.corstone import AIETMetadata
+from mlia.tools.metadata.corstone import BackendInfo
+from mlia.tools.metadata.corstone import BackendInstaller
+from mlia.tools.metadata.corstone import CompoundPathChecker
+from mlia.tools.metadata.corstone import Corstone300Installer
+from mlia.tools.metadata.corstone import get_corstone_installations
+from mlia.tools.metadata.corstone import PackagePathChecker
+from mlia.tools.metadata.corstone import PathChecker
+from mlia.tools.metadata.corstone import StaticPathChecker
+
+
+@pytest.fixture(name="test_mlia_resources")
+def fixture_test_mlia_resources(
+ tmp_path: Path, monkeypatch: pytest.MonkeyPatch
+) -> Path:
+ """Redirect MLIA resources resolution to the temp directory."""
+ mlia_resources = tmp_path / "resources"
+ mlia_resources.mkdir()
+
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.get_mlia_resources",
+ MagicMock(return_value=mlia_resources),
+ )
+
+ return mlia_resources
+
+
+def get_aiet_based_installation( # pylint: disable=too-many-arguments
+ aiet_runner_mock: MagicMock = MagicMock(),
+ name: str = "test_name",
+ description: str = "test_description",
+ download_artifact: Optional[MagicMock] = None,
+ path_checker: PathChecker = MagicMock(),
+ apps_resources: Optional[List[str]] = None,
+ system_config: Optional[str] = None,
+ backend_installer: BackendInstaller = MagicMock(),
+ supported_platforms: Optional[List[str]] = None,
+) -> AIETBasedInstallation:
+ """Get AIET based installation."""
+ return AIETBasedInstallation(
+ aiet_runner=aiet_runner_mock,
+ metadata=AIETMetadata(
+ name=name,
+ description=description,
+ system_config=system_config or "",
+ apps_resources=apps_resources or [],
+ fvp_dir_name="sample_dir",
+ download_artifact=download_artifact,
+ supported_platforms=supported_platforms,
+ ),
+ path_checker=path_checker,
+ backend_installer=backend_installer,
+ )
+
+
+@pytest.mark.parametrize(
+ "platform, supported_platforms, expected_result",
+ [
+ ["Linux", ["Linux"], True],
+ ["Linux", [], True],
+ ["Linux", None, True],
+ ["Windows", ["Linux"], False],
+ ],
+)
+def test_could_be_installed_depends_on_platform(
+ platform: str,
+ supported_platforms: Optional[List[str]],
+ expected_result: bool,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test that installation could not be installed on unsupported platform."""
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.platform.system", MagicMock(return_value=platform)
+ )
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.all_paths_valid", MagicMock(return_value=True)
+ )
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+
+ installation = get_aiet_based_installation(
+ aiet_runner_mock,
+ supported_platforms=supported_platforms,
+ )
+ assert installation.could_be_installed == expected_result
+
+
+def test_get_corstone_installations() -> None:
+ """Test function get_corstone_installation."""
+ installs = get_corstone_installations()
+ assert len(installs) == 2
+ assert all(isinstance(install, AIETBasedInstallation) for install in installs)
+
+
+def test_aiet_based_installation_metadata_resolving() -> None:
+ """Test AIET based installation metadata resolving."""
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(aiet_runner_mock)
+
+ assert installation.name == "test_name"
+ assert installation.description == "test_description"
+
+ aiet_runner_mock.all_installed.return_value = False
+ assert installation.already_installed is False
+
+ assert installation.could_be_installed is True
+
+
+def test_aiet_based_installation_supported_install_types(tmp_path: Path) -> None:
+ """Test supported installation types."""
+ installation_no_download_artifact = get_aiet_based_installation()
+ assert installation_no_download_artifact.supports(DownloadAndInstall()) is False
+
+ installation_with_download_artifact = get_aiet_based_installation(
+ download_artifact=MagicMock()
+ )
+ assert installation_with_download_artifact.supports(DownloadAndInstall()) is True
+
+ path_checker_mock = MagicMock(return_value=BackendInfo(tmp_path))
+ installation_can_install_from_dir = get_aiet_based_installation(
+ path_checker=path_checker_mock
+ )
+ assert installation_can_install_from_dir.supports(InstallFromPath(tmp_path)) is True
+
+ any_installation = get_aiet_based_installation()
+ assert any_installation.supports("unknown_install_type") is False # type: ignore
+
+
+def test_aiet_based_installation_install_wrong_type() -> None:
+ """Test that operation should fail if wrong install type provided."""
+ with pytest.raises(Exception, match="Unable to install wrong_install_type"):
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(aiet_runner_mock)
+
+ installation.install("wrong_install_type") # type: ignore
+
+
+def test_aiet_based_installation_install_from_path(
+ tmp_path: Path, test_mlia_resources: Path
+) -> None:
+ """Test installation from the path."""
+ system_config = test_mlia_resources / "example_config.json"
+ system_config.touch()
+
+ sample_app = test_mlia_resources / "sample_app"
+ sample_app.mkdir()
+
+ dist_dir = tmp_path / "dist"
+ dist_dir.mkdir()
+
+ path_checker_mock = MagicMock(return_value=BackendInfo(dist_dir))
+
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(
+ aiet_runner_mock=aiet_runner_mock,
+ path_checker=path_checker_mock,
+ apps_resources=[sample_app.name],
+ system_config="example_config.json",
+ )
+
+ assert installation.supports(InstallFromPath(dist_dir)) is True
+ installation.install(InstallFromPath(dist_dir))
+
+ aiet_runner_mock.install_system.assert_called_once()
+ aiet_runner_mock.install_application.assert_called_once_with(sample_app)
+
+
+@pytest.mark.parametrize("copy_source", [True, False])
+def test_aiet_based_installation_install_from_static_path(
+ tmp_path: Path, test_mlia_resources: Path, copy_source: bool
+) -> None:
+ """Test installation from the predefined path."""
+ system_config = test_mlia_resources / "example_config.json"
+ system_config.touch()
+
+ custom_system_config = test_mlia_resources / "custom_config.json"
+ custom_system_config.touch()
+
+ sample_app = test_mlia_resources / "sample_app"
+ sample_app.mkdir()
+
+ predefined_location = tmp_path / "backend"
+ predefined_location.mkdir()
+
+ predefined_location_file = predefined_location / "file.txt"
+ predefined_location_file.touch()
+
+ predefined_location_dir = predefined_location / "folder"
+ predefined_location_dir.mkdir()
+ nested_file = predefined_location_dir / "nested_file.txt"
+ nested_file.touch()
+
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+
+ def check_install_dir(install_dir: Path) -> None:
+ """Check content of the install dir."""
+ assert install_dir.is_dir()
+ files = list(install_dir.iterdir())
+
+ if copy_source:
+ assert len(files) == 3
+ assert all(install_dir / item in files for item in ["file.txt", "folder"])
+ assert (install_dir / "folder/nested_file.txt").is_file()
+ else:
+ assert len(files) == 1
+
+ assert install_dir / "custom_config.json" in files
+
+ aiet_runner_mock.install_system.side_effect = check_install_dir
+
+ installation = get_aiet_based_installation(
+ aiet_runner_mock=aiet_runner_mock,
+ path_checker=StaticPathChecker(
+ predefined_location,
+ ["file.txt"],
+ copy_source=copy_source,
+ system_config=str(custom_system_config),
+ ),
+ apps_resources=[sample_app.name],
+ system_config="example_config.json",
+ )
+
+ assert installation.supports(InstallFromPath(predefined_location)) is True
+ installation.install(InstallFromPath(predefined_location))
+
+ aiet_runner_mock.install_system.assert_called_once()
+ aiet_runner_mock.install_application.assert_called_once_with(sample_app)
+
+
+def create_sample_fvp_archive(tmp_path: Path) -> Path:
+ """Create sample FVP tar archive."""
+ fvp_archive_dir = tmp_path / "archive"
+ fvp_archive_dir.mkdir()
+
+ sample_file = fvp_archive_dir / "sample.txt"
+ sample_file.write_text("Sample file")
+
+ sample_dir = fvp_archive_dir / "sample_dir"
+ sample_dir.mkdir()
+
+ fvp_archive = tmp_path / "archive.tgz"
+ with tarfile.open(fvp_archive, "w:gz") as fvp_archive_tar:
+ fvp_archive_tar.add(fvp_archive_dir, arcname=fvp_archive_dir.name)
+
+ return fvp_archive
+
+
+def test_aiet_based_installation_download_and_install(
+ test_mlia_resources: Path, tmp_path: Path
+) -> None:
+ """Test downloading and installation process."""
+ fvp_archive = create_sample_fvp_archive(tmp_path)
+
+ system_config = test_mlia_resources / "example_config.json"
+ system_config.touch()
+
+ download_artifact_mock = MagicMock()
+ download_artifact_mock.download_to.return_value = fvp_archive
+
+ path_checker = PackagePathChecker(["archive/sample.txt"], "archive/sample_dir")
+
+ def installer(_eula_agreement: bool, dist_dir: Path) -> Path:
+ """Sample installer."""
+ return dist_dir
+
+ aiet_runner_mock = MagicMock(spec=AIETRunner)
+ installation = get_aiet_based_installation(
+ aiet_runner_mock,
+ download_artifact=download_artifact_mock,
+ backend_installer=installer,
+ path_checker=path_checker,
+ system_config="example_config.json",
+ )
+
+ installation.install(DownloadAndInstall())
+
+ aiet_runner_mock.install_system.assert_called_once()
+
+
+@pytest.mark.parametrize(
+ "dir_content, expected_result",
+ [
+ [
+ ["models/", "file1.txt", "file2.txt"],
+ "models",
+ ],
+ [
+ ["file1.txt", "file2.txt"],
+ None,
+ ],
+ [
+ ["models/", "file2.txt"],
+ None,
+ ],
+ ],
+)
+def test_corstone_path_checker_valid_path(
+ tmp_path: Path, dir_content: List[str], expected_result: Optional[str]
+) -> None:
+ """Test Corstone path checker valid scenario."""
+ path_checker = PackagePathChecker(["file1.txt", "file2.txt"], "models")
+
+ for item in dir_content:
+ if item.endswith("/"):
+ item_dir = tmp_path / item
+ item_dir.mkdir()
+ else:
+ item_file = tmp_path / item
+ item_file.touch()
+
+ result = path_checker(tmp_path)
+ expected = (
+ None if expected_result is None else BackendInfo(tmp_path / expected_result)
+ )
+
+ assert result == expected
+
+
+@pytest.mark.parametrize("system_config", [None, "system_config"])
+@pytest.mark.parametrize("copy_source", [True, False])
+def test_static_path_checker(
+ tmp_path: Path, copy_source: bool, system_config: Optional[str]
+) -> None:
+ """Test static path checker."""
+ static_checker = StaticPathChecker(
+ tmp_path, [], copy_source=copy_source, system_config=system_config
+ )
+ assert static_checker(tmp_path) == BackendInfo(
+ tmp_path, copy_source=copy_source, system_config=system_config
+ )
+
+
+def test_static_path_checker_not_valid_path(tmp_path: Path) -> None:
+ """Test static path checker should return None if path is not valid."""
+ static_checker = StaticPathChecker(tmp_path, ["file.txt"])
+ assert static_checker(tmp_path / "backend") is None
+
+
+def test_static_path_checker_not_valid_structure(tmp_path: Path) -> None:
+ """Test static path checker should return None if files are missing."""
+ static_checker = StaticPathChecker(tmp_path, ["file.txt"])
+ assert static_checker(tmp_path) is None
+
+ missing_file = tmp_path / "file.txt"
+ missing_file.touch()
+
+ assert static_checker(tmp_path) == BackendInfo(tmp_path, copy_source=False)
+
+
+def test_compound_path_checker(tmp_path: Path) -> None:
+ """Test compound path checker."""
+ path_checker_path_valid_path = MagicMock(return_value=BackendInfo(tmp_path))
+ path_checker_path_not_valid_path = MagicMock(return_value=None)
+
+ checker = CompoundPathChecker(
+ path_checker_path_valid_path, path_checker_path_not_valid_path
+ )
+ assert checker(tmp_path) == BackendInfo(tmp_path)
+
+ checker = CompoundPathChecker(path_checker_path_not_valid_path)
+ assert checker(tmp_path) is None
+
+
+@pytest.mark.parametrize(
+ "eula_agreement, expected_command",
+ [
+ [
+ True,
+ [
+ "./FVP_Corstone_SSE-300.sh",
+ "-q",
+ "-d",
+ "corstone-300",
+ ],
+ ],
+ [
+ False,
+ [
+ "./FVP_Corstone_SSE-300.sh",
+ "-q",
+ "-d",
+ "corstone-300",
+ "--nointeractive",
+ "--i-agree-to-the-contained-eula",
+ ],
+ ],
+ ],
+)
+def test_corstone_300_installer(
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+ eula_agreement: bool,
+ expected_command: List[str],
+) -> None:
+ """Test Corstone-300 installer."""
+ command_mock = MagicMock()
+
+ monkeypatch.setattr(
+ "mlia.tools.metadata.corstone.subprocess.check_call", command_mock
+ )
+ installer = Corstone300Installer()
+ result = installer(eula_agreement, tmp_path)
+
+ command_mock.assert_called_once_with(expected_command)
+ assert result == tmp_path / "corstone-300"
diff --git a/tests/mlia/test_tools_vela_wrapper.py b/tests/mlia/test_tools_vela_wrapper.py
new file mode 100644
index 0000000..875d2ff
--- /dev/null
+++ b/tests/mlia/test_tools_vela_wrapper.py
@@ -0,0 +1,285 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for module tools/vela_wrapper."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+from ethosu.vela.compiler_driver import TensorAllocator
+from ethosu.vela.scheduler import OptimizationStrategy
+
+from mlia.devices.ethosu.config import EthosUConfiguration
+from mlia.tools.vela_wrapper import estimate_performance
+from mlia.tools.vela_wrapper import generate_supported_operators_report
+from mlia.tools.vela_wrapper import NpuSupported
+from mlia.tools.vela_wrapper import Operator
+from mlia.tools.vela_wrapper import Operators
+from mlia.tools.vela_wrapper import optimize_model
+from mlia.tools.vela_wrapper import OptimizedModel
+from mlia.tools.vela_wrapper import PerformanceMetrics
+from mlia.tools.vela_wrapper import supported_operators
+from mlia.tools.vela_wrapper import VelaCompiler
+from mlia.tools.vela_wrapper import VelaCompilerOptions
+from mlia.utils.proc import working_directory
+
+
+def test_default_vela_compiler() -> None:
+ """Test default Vela compiler instance."""
+ default_compiler_options = VelaCompilerOptions(accelerator_config="ethos-u55-256")
+ default_compiler = VelaCompiler(default_compiler_options)
+
+ assert default_compiler.config_files is None
+ assert default_compiler.system_config == "internal-default"
+ assert default_compiler.memory_mode == "internal-default"
+ assert default_compiler.accelerator_config == "ethos-u55-256"
+ assert default_compiler.max_block_dependency == 3
+ assert default_compiler.arena_cache_size is None
+ assert default_compiler.tensor_allocator == TensorAllocator.HillClimb
+ assert default_compiler.cpu_tensor_alignment == 16
+ assert default_compiler.optimization_strategy == OptimizationStrategy.Performance
+ assert default_compiler.output_dir is None
+
+ assert default_compiler.get_config() == {
+ "accelerator_config": "ethos-u55-256",
+ "system_config": "internal-default",
+ "core_clock": 500000000.0,
+ "axi0_port": "Sram",
+ "axi1_port": "OffChipFlash",
+ "memory_mode": "internal-default",
+ "const_mem_area": "Axi1",
+ "arena_mem_area": "Axi0",
+ "cache_mem_area": "Axi0",
+ "arena_cache_size": 4294967296,
+ "permanent_storage_mem_area": "OffChipFlash",
+ "feature_map_storage_mem_area": "Sram",
+ "fast_storage_mem_area": "Sram",
+ "memory_area": {
+ "Sram": {
+ "clock_scales": 1.0,
+ "burst_length": 32,
+ "read_latency": 32,
+ "write_latency": 32,
+ },
+ "Dram": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ "OnChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ "OffChipFlash": {
+ "clock_scales": 0.125,
+ "burst_length": 128,
+ "read_latency": 64,
+ "write_latency": 64,
+ },
+ },
+ }
+
+
+def test_vela_compiler_with_parameters(test_resources_path: Path) -> None:
+ """Test creation of Vela compiler instance with non-default params."""
+ vela_ini_path = str(test_resources_path / "vela/sample_vela.ini")
+
+ compiler_options = VelaCompilerOptions(
+ config_files=vela_ini_path,
+ system_config="Ethos_U65_High_End",
+ memory_mode="Shared_Sram",
+ accelerator_config="ethos-u65-256",
+ max_block_dependency=1,
+ arena_cache_size=10,
+ tensor_allocator="Greedy",
+ cpu_tensor_alignment=4,
+ optimization_strategy="Size",
+ output_dir="output",
+ )
+ compiler = VelaCompiler(compiler_options)
+
+ assert compiler.config_files == vela_ini_path
+ assert compiler.system_config == "Ethos_U65_High_End"
+ assert compiler.memory_mode == "Shared_Sram"
+ assert compiler.accelerator_config == "ethos-u65-256"
+ assert compiler.max_block_dependency == 1
+ assert compiler.arena_cache_size == 10
+ assert compiler.tensor_allocator == TensorAllocator.Greedy
+ assert compiler.cpu_tensor_alignment == 4
+ assert compiler.optimization_strategy == OptimizationStrategy.Size
+ assert compiler.output_dir == "output"
+
+ assert compiler.get_config() == {
+ "accelerator_config": "ethos-u65-256",
+ "system_config": "Ethos_U65_High_End",
+ "core_clock": 1000000000.0,
+ "axi0_port": "Sram",
+ "axi1_port": "Dram",
+ "memory_mode": "Shared_Sram",
+ "const_mem_area": "Axi1",
+ "arena_mem_area": "Axi0",
+ "cache_mem_area": "Axi0",
+ "arena_cache_size": 10,
+ "permanent_storage_mem_area": "Dram",
+ "feature_map_storage_mem_area": "Sram",
+ "fast_storage_mem_area": "Sram",
+ "memory_area": {
+ "Sram": {
+ "clock_scales": 1.0,
+ "burst_length": 32,
+ "read_latency": 32,
+ "write_latency": 32,
+ },
+ "Dram": {
+ "clock_scales": 0.234375,
+ "burst_length": 128,
+ "read_latency": 500,
+ "write_latency": 250,
+ },
+ "OnChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ "OffChipFlash": {
+ "clock_scales": 1.0,
+ "burst_length": 1,
+ "read_latency": 0,
+ "write_latency": 0,
+ },
+ },
+ }
+
+
+def test_compile_model(test_tflite_model: Path) -> None:
+ """Test model optimization."""
+ compiler = VelaCompiler(EthosUConfiguration("ethos-u55-256").compiler_options)
+
+ optimized_model = compiler.compile_model(test_tflite_model)
+ assert isinstance(optimized_model, OptimizedModel)
+
+
+def test_optimize_model(tmp_path: Path, test_tflite_model: Path) -> None:
+ """Test model optimization and saving into file."""
+ tmp_file = tmp_path / "temp.tflite"
+
+ device = EthosUConfiguration("ethos-u55-256")
+ optimize_model(test_tflite_model, device.compiler_options, tmp_file.absolute())
+
+ assert tmp_file.is_file()
+ assert tmp_file.stat().st_size > 0
+
+
+@pytest.mark.parametrize(
+ "model, expected_ops",
+ [
+ (
+ "test_model.tflite",
+ Operators(
+ ops=[
+ Operator(
+ name="sequential/conv1/Relu;sequential/conv1/BiasAdd;"
+ "sequential/conv2/Conv2D;sequential/conv1/Conv2D",
+ op_type="CONV_2D",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="sequential/conv2/Relu;sequential/conv2/BiasAdd;"
+ "sequential/conv2/Conv2D",
+ op_type="CONV_2D",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="sequential/max_pooling2d/MaxPool",
+ op_type="MAX_POOL_2D",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="sequential/flatten/Reshape",
+ op_type="RESHAPE",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ Operator(
+ name="Identity",
+ op_type="FULLY_CONNECTED",
+ run_on_npu=NpuSupported(supported=True, reasons=[]),
+ ),
+ ]
+ ),
+ )
+ ],
+)
+def test_operators(test_models_path: Path, model: str, expected_ops: Operators) -> None:
+ """Test operators function."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ operators = supported_operators(test_models_path / model, device.compiler_options)
+ for expected, actual in zip(expected_ops.ops, operators.ops):
+ # do not compare names as they could be different on each model generation
+ assert expected.op_type == actual.op_type
+ assert expected.run_on_npu == actual.run_on_npu
+
+
+def test_estimate_performance(test_tflite_model: Path) -> None:
+ """Test getting performance estimations."""
+ device = EthosUConfiguration("ethos-u55-256")
+ perf_metrics = estimate_performance(test_tflite_model, device.compiler_options)
+
+ assert isinstance(perf_metrics, PerformanceMetrics)
+
+
+def test_estimate_performance_already_optimized(
+ tmp_path: Path, test_tflite_model: Path
+) -> None:
+ """Test that performance estimation should fail for already optimized model."""
+ device = EthosUConfiguration("ethos-u55-256")
+
+ optimized_model_path = tmp_path / "optimized_model.tflite"
+
+ optimize_model(test_tflite_model, device.compiler_options, optimized_model_path)
+
+ with pytest.raises(
+ Exception, match="Unable to estimate performance for the given optimized model"
+ ):
+ estimate_performance(optimized_model_path, device.compiler_options)
+
+
+def test_generate_supported_operators_report(tmp_path: Path) -> None:
+ """Test generating supported operators report."""
+ with working_directory(tmp_path):
+ generate_supported_operators_report()
+
+ md_file = tmp_path / "SUPPORTED_OPS.md"
+ assert md_file.is_file()
+ assert md_file.stat().st_size > 0
+
+
+def test_read_invalid_model(test_tflite_invalid_model: Path) -> None:
+ """Test that reading invalid model should fail with exception."""
+ with pytest.raises(
+ Exception, match=f"Unable to read model {test_tflite_invalid_model}"
+ ):
+ device = EthosUConfiguration("ethos-u55-256")
+ estimate_performance(test_tflite_invalid_model, device.compiler_options)
+
+
+def test_compile_invalid_model(
+ test_tflite_model: Path, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
+) -> None:
+ """Test that if model could not be compiled then correct exception raised."""
+ mock_compiler = MagicMock()
+ mock_compiler.side_effect = Exception("Bad model!")
+
+ monkeypatch.setattr("mlia.tools.vela_wrapper.compiler_driver", mock_compiler)
+
+ model_path = tmp_path / "optimized_model.tflite"
+ with pytest.raises(
+ Exception, match="Model could not be optimized with Vela compiler"
+ ):
+ device = EthosUConfiguration("ethos-u55-256")
+ optimize_model(test_tflite_model, device.compiler_options, model_path)
+
+ assert not model_path.exists()
diff --git a/tests/mlia/test_utils_console.py b/tests/mlia/test_utils_console.py
new file mode 100644
index 0000000..36975f8
--- /dev/null
+++ b/tests/mlia/test_utils_console.py
@@ -0,0 +1,100 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for console utility functions."""
+from typing import Iterable
+from typing import List
+from typing import Optional
+
+import pytest
+
+from mlia.utils.console import apply_style
+from mlia.utils.console import create_section_header
+from mlia.utils.console import produce_table
+from mlia.utils.console import remove_ascii_codes
+
+
+@pytest.mark.parametrize(
+ "rows, headers, table_style, expected_result",
+ [
+ [[], [], "no_borders", ""],
+ [
+ [["1", "2", "3"]],
+ ["Col 1", "Col 2", "Col 3"],
+ "default",
+ """
+┌───────┬───────┬───────┐
+│ Col 1 │ Col 2 │ Col 3 │
+╞═══════╪═══════╪═══════╡
+│ 1 │ 2 │ 3 │
+└───────┴───────┴───────┘
+""".strip(),
+ ],
+ [
+ [["1", "2", "3"]],
+ ["Col 1", "Col 2", "Col 3"],
+ "nested",
+ "Col 1 Col 2 Col 3 \n \n1 2 3",
+ ],
+ [
+ [["1", "2", "3"]],
+ ["Col 1", "Col 2", "Col 3"],
+ "no_borders",
+ " Col 1 Col 2 Col 3 \n 1 2 3",
+ ],
+ ],
+)
+def test_produce_table(
+ rows: Iterable, headers: Optional[List[str]], table_style: str, expected_result: str
+) -> None:
+ """Test produce_table function."""
+ result = produce_table(rows, headers, table_style)
+ assert remove_ascii_codes(result) == expected_result
+
+
+def test_produce_table_unknown_style() -> None:
+ """Test that function should fail if unknown style provided."""
+ with pytest.raises(Exception, match="Unsupported table style unknown_style"):
+ produce_table([["1", "2", "3"]], [], "unknown_style")
+
+
+@pytest.mark.parametrize(
+ "value, expected_result",
+ [
+ ["some text", "some text"],
+ ["\033[32msome text\033[0m", "some text"],
+ ],
+)
+def test_remove_ascii_codes(value: str, expected_result: str) -> None:
+ """Test remove_ascii_codes function."""
+ assert remove_ascii_codes(value) == expected_result
+
+
+def test_apply_style() -> None:
+ """Test function apply_style."""
+ assert apply_style("some text", "green") == "[green]some text"
+
+
+@pytest.mark.parametrize(
+ "section_header, expected_result",
+ [
+ [
+ "Section header",
+ "\n--- Section header -------------------------------"
+ "------------------------------\n",
+ ],
+ [
+ "",
+ f"\n{'-' * 80}\n",
+ ],
+ ],
+)
+def test_create_section_header(section_header: str, expected_result: str) -> None:
+ """Test function test_create_section."""
+ assert create_section_header(section_header) == expected_result
+
+
+def test_create_section_header_too_long_value() -> None:
+ """Test that header could not be created for the too long section names."""
+ section_name = "section name" * 100
+ with pytest.raises(ValueError, match="Section name too long"):
+ create_section_header(section_name)
diff --git a/tests/mlia/test_utils_download.py b/tests/mlia/test_utils_download.py
new file mode 100644
index 0000000..4f8e2dc
--- /dev/null
+++ b/tests/mlia/test_utils_download.py
@@ -0,0 +1,147 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for download functionality."""
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Iterable
+from typing import Optional
+from unittest.mock import MagicMock
+from unittest.mock import PropertyMock
+
+import pytest
+import requests
+
+from mlia.utils.download import download
+from mlia.utils.download import DownloadArtifact
+
+
+def response_mock(
+ content_length: Optional[str], content_chunks: Iterable[bytes]
+) -> MagicMock:
+ """Mock response object."""
+ mock = MagicMock(spec=requests.Response)
+ mock.__enter__.return_value = mock
+
+ type(mock).headers = PropertyMock(return_value={"Content-Length": content_length})
+ mock.iter_content.return_value = content_chunks
+
+ return mock
+
+
+@pytest.mark.parametrize("show_progress", [True, False])
+@pytest.mark.parametrize(
+ "content_length, content_chunks, label",
+ [
+ [
+ "5",
+ [bytes(range(5))],
+ "Downloading artifact",
+ ],
+ [
+ "10",
+ [bytes(range(5)), bytes(range(5))],
+ None,
+ ],
+ [
+ None,
+ [bytes(range(5))],
+ "Downlading no size",
+ ],
+ [
+ "abc",
+ [bytes(range(5))],
+ "Downloading wrong size",
+ ],
+ ],
+)
+def test_download(
+ show_progress: bool,
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+ content_length: Optional[str],
+ content_chunks: Iterable[bytes],
+ label: Optional[str],
+) -> None:
+ """Test function download."""
+ monkeypatch.setattr(
+ "mlia.utils.download.requests.get",
+ MagicMock(return_value=response_mock(content_length, content_chunks)),
+ )
+
+ dest = tmp_path / "sample.bin"
+ download("some_url", dest, show_progress=show_progress, label=label)
+
+ assert dest.is_file()
+ assert dest.read_bytes() == bytes(
+ byte for chunk in content_chunks for byte in chunk
+ )
+
+
+@pytest.mark.parametrize(
+ "content_length, content_chunks, sha256_hash, expected_error",
+ [
+ [
+ "10",
+ [bytes(range(10))],
+ "1f825aa2f0020ef7cf91dfa30da4668d791c5d4824fc8e41354b89ec05795ab3",
+ does_not_raise(),
+ ],
+ [
+ "10",
+ [bytes(range(10))],
+ "bad_hash",
+ pytest.raises(ValueError, match="Digests do not match"),
+ ],
+ ],
+)
+def test_download_artifact_download_to(
+ monkeypatch: pytest.MonkeyPatch,
+ content_length: Optional[str],
+ content_chunks: Iterable[bytes],
+ sha256_hash: str,
+ expected_error: Any,
+ tmp_path: Path,
+) -> None:
+ """Test artifact downloading."""
+ monkeypatch.setattr(
+ "mlia.utils.download.requests.get",
+ MagicMock(return_value=response_mock(content_length, content_chunks)),
+ )
+
+ with expected_error:
+ artifact = DownloadArtifact(
+ "test_artifact",
+ "some_url",
+ "artifact_filename",
+ "1.0",
+ sha256_hash,
+ )
+
+ dest = artifact.download_to(tmp_path)
+ assert isinstance(dest, Path)
+ assert dest.name == "artifact_filename"
+
+
+def test_download_artifact_unable_to_overwrite(
+ monkeypatch: pytest.MonkeyPatch, tmp_path: Path
+) -> None:
+ """Test that download process cannot overwrite file."""
+ monkeypatch.setattr(
+ "mlia.utils.download.requests.get",
+ MagicMock(return_value=response_mock("10", [bytes(range(10))])),
+ )
+
+ artifact = DownloadArtifact(
+ "test_artifact",
+ "some_url",
+ "artifact_filename",
+ "1.0",
+ "sha256_hash",
+ )
+
+ existing_file = tmp_path / "artifact_filename"
+ existing_file.touch()
+
+ with pytest.raises(ValueError, match=f"{existing_file} already exists"):
+ artifact.download_to(tmp_path)
diff --git a/tests/mlia/test_utils_filesystem.py b/tests/mlia/test_utils_filesystem.py
new file mode 100644
index 0000000..4d8d955
--- /dev/null
+++ b/tests/mlia/test_utils_filesystem.py
@@ -0,0 +1,166 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the filesystem module."""
+import contextlib
+import json
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.utils.filesystem import all_files_exist
+from mlia.utils.filesystem import all_paths_valid
+from mlia.utils.filesystem import copy_all
+from mlia.utils.filesystem import get_mlia_resources
+from mlia.utils.filesystem import get_profile
+from mlia.utils.filesystem import get_profiles_data
+from mlia.utils.filesystem import get_profiles_file
+from mlia.utils.filesystem import get_supported_profile_names
+from mlia.utils.filesystem import get_vela_config
+from mlia.utils.filesystem import sha256
+from mlia.utils.filesystem import temp_directory
+from mlia.utils.filesystem import temp_file
+
+
+def test_get_mlia_resources() -> None:
+ """Test resources getter."""
+ assert get_mlia_resources().is_dir()
+
+
+def test_get_vela_config() -> None:
+ """Test Vela config files getter."""
+ assert get_vela_config().is_file()
+ assert get_vela_config().name == "vela.ini"
+
+
+def test_profiles_file() -> None:
+ """Test profiles file getter."""
+ assert get_profiles_file().is_file()
+ assert get_profiles_file().name == "profiles.json"
+
+
+def test_profiles_data() -> None:
+ """Test profiles data getter."""
+ assert list(get_profiles_data().keys()) == [
+ "ethos-u55-256",
+ "ethos-u55-128",
+ "ethos-u65-512",
+ ]
+
+
+def test_profiles_data_wrong_format(
+ monkeypatch: pytest.MonkeyPatch, tmp_path: Path
+) -> None:
+ """Test if profile data has wrong format."""
+ wrong_profile_data = tmp_path / "bad.json"
+ with open(wrong_profile_data, "w", encoding="utf-8") as file:
+ json.dump([], file)
+
+ monkeypatch.setattr(
+ "mlia.utils.filesystem.get_profiles_file",
+ MagicMock(return_value=wrong_profile_data),
+ )
+
+ with pytest.raises(Exception, match="Profiles data format is not valid"):
+ get_profiles_data()
+
+
+def test_get_supported_profile_names() -> None:
+ """Test profile names getter."""
+ assert list(get_supported_profile_names()) == [
+ "ethos-u55-256",
+ "ethos-u55-128",
+ "ethos-u65-512",
+ ]
+
+
+def test_get_profile() -> None:
+ """Test getting profile data."""
+ assert get_profile("ethos-u55-256") == {
+ "target": "ethos-u55",
+ "mac": 256,
+ "system_config": "Ethos_U55_High_End_Embedded",
+ "memory_mode": "Shared_Sram",
+ }
+
+ with pytest.raises(Exception, match="Unable to find target profile unknown"):
+ get_profile("unknown")
+
+
+@pytest.mark.parametrize("raise_exception", [True, False])
+def test_temp_file(raise_exception: bool) -> None:
+ """Test temp_file context manager."""
+ with contextlib.suppress(Exception):
+ with temp_file() as tmp_path:
+ assert tmp_path.is_file()
+
+ if raise_exception:
+ raise Exception("Error!")
+
+ assert not tmp_path.exists()
+
+
+def test_sha256(tmp_path: Path) -> None:
+ """Test getting sha256 hash."""
+ sample = tmp_path / "sample.txt"
+
+ with open(sample, "w", encoding="utf-8") as file:
+ file.write("123")
+
+ expected_hash = "a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3"
+ assert sha256(sample) == expected_hash
+
+
+def test_temp_dir_context_manager() -> None:
+ """Test context manager for temporary directories."""
+ with temp_directory() as tmpdir:
+ assert isinstance(tmpdir, Path)
+ assert tmpdir.is_dir()
+
+ assert not tmpdir.exists()
+
+
+def test_all_files_exist(tmp_path: Path) -> None:
+ """Test function all_files_exist."""
+ sample1 = tmp_path / "sample1.txt"
+ sample1.touch()
+
+ sample2 = tmp_path / "sample2.txt"
+ sample2.touch()
+
+ sample3 = tmp_path / "sample3.txt"
+
+ assert all_files_exist([sample1, sample2]) is True
+ assert all_files_exist([sample1, sample2, sample3]) is False
+
+
+def test_all_paths_valid(tmp_path: Path) -> None:
+ """Test function all_paths_valid."""
+ sample = tmp_path / "sample.txt"
+ sample.touch()
+
+ sample_dir = tmp_path / "sample_dir"
+ sample_dir.mkdir()
+
+ unknown = tmp_path / "unknown.txt"
+
+ assert all_paths_valid([sample, sample_dir]) is True
+ assert all_paths_valid([sample, sample_dir, unknown]) is False
+
+
+def test_copy_all(tmp_path: Path) -> None:
+ """Test function copy_all."""
+ sample = tmp_path / "sample1.txt"
+ sample.touch()
+
+ sample_dir = tmp_path / "sample_dir"
+ sample_dir.mkdir()
+
+ sample_nested_file = sample_dir / "sample_nested.txt"
+ sample_nested_file.touch()
+
+ dest_dir = tmp_path / "dest"
+ copy_all(sample, sample_dir, dest=dest_dir)
+
+ assert (dest_dir / sample.name).is_file()
+ assert (dest_dir / sample_nested_file.name).is_file()
diff --git a/tests/mlia/test_utils_logging.py b/tests/mlia/test_utils_logging.py
new file mode 100644
index 0000000..75ebceb
--- /dev/null
+++ b/tests/mlia/test_utils_logging.py
@@ -0,0 +1,63 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the logging utility functions."""
+import logging
+import sys
+from contextlib import ExitStack as does_not_raise
+from pathlib import Path
+from typing import Any
+from typing import Optional
+
+import pytest
+
+from mlia.cli.logging import create_log_handler
+
+
+@pytest.mark.parametrize(
+ "file_path, stream, log_filter, delay, expected_error, expected_class",
+ [
+ (
+ "test.log",
+ None,
+ None,
+ True,
+ does_not_raise(),
+ logging.FileHandler,
+ ),
+ (
+ None,
+ sys.stdout,
+ None,
+ None,
+ does_not_raise(),
+ logging.StreamHandler,
+ ),
+ (
+ None,
+ None,
+ None,
+ None,
+ pytest.raises(Exception, match="Unable to create logging handler"),
+ None,
+ ),
+ ],
+)
+def test_create_log_handler(
+ file_path: Optional[Path],
+ stream: Optional[Any],
+ log_filter: Optional[logging.Filter],
+ delay: bool,
+ expected_error: Any,
+ expected_class: type,
+) -> None:
+ """Test function test_create_log_handler."""
+ with expected_error:
+ handler = create_log_handler(
+ file_path=file_path,
+ stream=stream,
+ log_level=logging.INFO,
+ log_format="%(name)s - %(message)s",
+ log_filter=log_filter,
+ delay=delay,
+ )
+ assert isinstance(handler, expected_class)
diff --git a/tests/mlia/test_utils_misc.py b/tests/mlia/test_utils_misc.py
new file mode 100644
index 0000000..011d09e
--- /dev/null
+++ b/tests/mlia/test_utils_misc.py
@@ -0,0 +1,25 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for misc util functions."""
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.utils.misc import yes
+
+
+@pytest.mark.parametrize(
+ "response, expected_result",
+ [
+ ["Y", True],
+ ["y", True],
+ ["N", False],
+ ["n", False],
+ ],
+)
+def test_yes(
+ monkeypatch: pytest.MonkeyPatch, expected_result: bool, response: str
+) -> None:
+ """Test yes function."""
+ monkeypatch.setattr("builtins.input", MagicMock(return_value=response))
+ assert yes("some_prompt") == expected_result
diff --git a/tests/mlia/test_utils_proc.py b/tests/mlia/test_utils_proc.py
new file mode 100644
index 0000000..8316ca5
--- /dev/null
+++ b/tests/mlia/test_utils_proc.py
@@ -0,0 +1,149 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module utils/proc."""
+import signal
+import subprocess
+import time
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.utils.proc import CommandExecutor
+from mlia.utils.proc import working_directory
+
+
+class TestCommandExecutor:
+ """Tests for class CommandExecutor."""
+
+ @staticmethod
+ def test_execute() -> None:
+ """Test command execution."""
+ executor = CommandExecutor()
+
+ retcode, stdout, stderr = executor.execute(["echo", "hello world!"])
+ assert retcode == 0
+ assert stdout.decode().strip() == "hello world!"
+ assert stderr.decode() == ""
+
+ @staticmethod
+ def test_submit() -> None:
+ """Test command submittion."""
+ executor = CommandExecutor()
+
+ running_command = executor.submit(["sleep", "10"])
+ assert running_command.is_alive() is True
+ assert running_command.exit_code() is None
+
+ running_command.kill()
+ for _ in range(3):
+ time.sleep(0.5)
+ if not running_command.is_alive():
+ break
+
+ assert running_command.is_alive() is False
+ assert running_command.exit_code() == -9
+
+ with pytest.raises(subprocess.CalledProcessError):
+ executor.execute(["sleep", "-1"])
+
+ @staticmethod
+ @pytest.mark.parametrize("wait", [True, False])
+ def test_stop(wait: bool) -> None:
+ """Test command termination."""
+ executor = CommandExecutor()
+
+ running_command = executor.submit(["sleep", "10"])
+ running_command.stop(wait=wait)
+
+ if wait:
+ assert running_command.is_alive() is False
+
+ @staticmethod
+ def test_unable_to_stop(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test when command could not be stopped."""
+ running_command_mock = MagicMock()
+ running_command_mock.poll.return_value = None
+
+ monkeypatch.setattr(
+ "mlia.utils.proc.subprocess.Popen",
+ MagicMock(return_value=running_command_mock),
+ )
+
+ with pytest.raises(Exception, match="Unable to stop running command"):
+ executor = CommandExecutor()
+ running_command = executor.submit(["sleep", "10"])
+
+ running_command.stop(num_of_attempts=1, interval=0.1)
+
+ running_command_mock.send_signal.assert_called_once_with(signal.SIGINT)
+
+ @staticmethod
+ def test_stop_after_several_attempts(monkeypatch: pytest.MonkeyPatch) -> None:
+ """Test when command could be stopped after several attempts."""
+ running_command_mock = MagicMock()
+ running_command_mock.poll.side_effect = [None, 0]
+
+ monkeypatch.setattr(
+ "mlia.utils.proc.subprocess.Popen",
+ MagicMock(return_value=running_command_mock),
+ )
+
+ executor = CommandExecutor()
+ running_command = executor.submit(["sleep", "10"])
+
+ running_command.stop(num_of_attempts=1, interval=0.1)
+ running_command_mock.send_signal.assert_called_once_with(signal.SIGINT)
+
+ @staticmethod
+ def test_send_signal() -> None:
+ """Test sending signal."""
+ executor = CommandExecutor()
+ running_command = executor.submit(["sleep", "10"])
+ running_command.send_signal(signal.SIGINT)
+
+ # wait a bit for a signal processing
+ time.sleep(1)
+
+ assert running_command.is_alive() is False
+ assert running_command.exit_code() == -2
+
+ @staticmethod
+ @pytest.mark.parametrize(
+ "redirect_output, expected_output", [[True, "hello\n"], [False, ""]]
+ )
+ def test_wait(
+ capsys: pytest.CaptureFixture, redirect_output: bool, expected_output: str
+ ) -> None:
+ """Test wait completion functionality."""
+ executor = CommandExecutor()
+
+ running_command = executor.submit(["echo", "hello"])
+ running_command.wait(redirect_output=redirect_output)
+
+ out, _ = capsys.readouterr()
+ assert out == expected_output
+
+
+@pytest.mark.parametrize(
+ "should_exist, create_dir",
+ [
+ [True, False],
+ [False, True],
+ ],
+)
+def test_working_directory_context_manager(
+ tmp_path: Path, should_exist: bool, create_dir: bool
+) -> None:
+ """Test working_directory context manager."""
+ prev_wd = Path.cwd()
+
+ working_dir = tmp_path / "work_dir"
+ if should_exist:
+ working_dir.mkdir()
+
+ with working_directory(working_dir, create_dir=create_dir) as current_working_dir:
+ assert current_working_dir.is_dir()
+ assert Path.cwd() == current_working_dir
+
+ assert Path.cwd() == prev_wd
diff --git a/tests/mlia/test_utils_types.py b/tests/mlia/test_utils_types.py
new file mode 100644
index 0000000..4909efe
--- /dev/null
+++ b/tests/mlia/test_utils_types.py
@@ -0,0 +1,77 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the types related utility functions."""
+from typing import Any
+from typing import Iterable
+from typing import Optional
+
+import pytest
+
+from mlia.utils.types import is_list_of
+from mlia.utils.types import is_number
+from mlia.utils.types import only_one_selected
+from mlia.utils.types import parse_int
+
+
+@pytest.mark.parametrize(
+ "value, expected_result",
+ [
+ ["", False],
+ ["abc", False],
+ ["123", True],
+ ["123.1", True],
+ ["-123", True],
+ ["-123.1", True],
+ ["0", True],
+ ["1.e10", True],
+ ],
+)
+def test_is_number(value: str, expected_result: bool) -> None:
+ """Test function is_number."""
+ assert is_number(value) == expected_result
+
+
+@pytest.mark.parametrize(
+ "data, cls, elem_num, expected_result",
+ [
+ [(1, 2), int, 2, True],
+ [[1, 2], int, 2, True],
+ [[1, 2], int, 3, False],
+ [["1", "2", "3"], str, None, True],
+ [["1", "2", "3"], int, None, False],
+ ],
+)
+def test_is_list(
+ data: Any, cls: type, elem_num: Optional[int], expected_result: bool
+) -> None:
+ """Test function is_list."""
+ assert is_list_of(data, cls, elem_num) == expected_result
+
+
+@pytest.mark.parametrize(
+ "options, expected_result",
+ [
+ [[True], True],
+ [[False], False],
+ [[True, True, False, False], False],
+ ],
+)
+def test_only_one_selected(options: Iterable[bool], expected_result: bool) -> None:
+ """Test function only_one_selected."""
+ assert only_one_selected(*options) == expected_result
+
+
+@pytest.mark.parametrize(
+ "value, default, expected_int",
+ [
+ ["1", None, 1],
+ ["abc", 123, 123],
+ [None, None, None],
+ [None, 11, 11],
+ ],
+)
+def test_parse_int(
+ value: Any, default: Optional[int], expected_int: Optional[int]
+) -> None:
+ """Test function parse_int."""
+ assert parse_int(value, default) == expected_int
diff --git a/tests/mlia/utils/__init__.py b/tests/mlia/utils/__init__.py
new file mode 100644
index 0000000..27166ef
--- /dev/null
+++ b/tests/mlia/utils/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test utils module."""
diff --git a/tests/mlia/utils/common.py b/tests/mlia/utils/common.py
new file mode 100644
index 0000000..4313cde
--- /dev/null
+++ b/tests/mlia/utils/common.py
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Common test utils module."""
+from typing import Tuple
+
+import numpy as np
+import tensorflow as tf
+
+
+def get_dataset() -> Tuple[np.array, np.array]:
+ """Return sample dataset."""
+ mnist = tf.keras.datasets.mnist
+ (x_train, y_train), _ = mnist.load_data()
+ x_train = x_train / 255.0
+
+ # Use subset of 60000 examples to keep unit test speed fast.
+ x_train = x_train[0:1]
+ y_train = y_train[0:1]
+
+ return x_train, y_train
+
+
+def train_model(model: tf.keras.Model) -> None:
+ """Train model using sample dataset."""
+ num_epochs = 1
+
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
+
+ x_train, y_train = get_dataset()
+
+ model.fit(x_train, y_train, epochs=num_epochs)
diff --git a/tests/mlia/utils/logging.py b/tests/mlia/utils/logging.py
new file mode 100644
index 0000000..d223fb2
--- /dev/null
+++ b/tests/mlia/utils/logging.py
@@ -0,0 +1,13 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Utils for logging."""
+import logging
+
+
+def clear_loggers() -> None:
+ """Close the log handlers."""
+ for _, logger in logging.Logger.manager.loggerDict.items():
+ if not isinstance(logger, logging.PlaceHolder):
+ for handler in logger.handlers:
+ handler.close()
+ logger.removeHandler(handler)